mirror of
https://github.com/tencentmusic/cube-studio.git
synced 2025-01-24 14:04:01 +08:00
添加pipeline自动排版
This commit is contained in:
parent
02a120f2c2
commit
054c0cd24c
@ -402,6 +402,7 @@ class Pipeline(Model,ImportMixin,AuditMixinNullable,MyappModelBase):
|
|||||||
|
|
||||||
# 生成前端锁需要的扩展字段
|
# 生成前端锁需要的扩展字段
|
||||||
def fix_expand(self,dbsession=db.session):
|
def fix_expand(self,dbsession=db.session):
|
||||||
|
# 补充expand 的基本节点信息(节点和关系)
|
||||||
tasks_src = self.get_tasks(dbsession)
|
tasks_src = self.get_tasks(dbsession)
|
||||||
tasks = {}
|
tasks = {}
|
||||||
for task in tasks_src:
|
for task in tasks_src:
|
||||||
@ -414,9 +415,12 @@ class Pipeline(Model,ImportMixin,AuditMixinNullable,MyappModelBase):
|
|||||||
|
|
||||||
# 已经不存在的task要删掉
|
# 已经不存在的task要删掉
|
||||||
for item in expand_copy:
|
for item in expand_copy:
|
||||||
|
# 节点类型
|
||||||
if "data" in item:
|
if "data" in item:
|
||||||
if item['id'] not in tasks:
|
if item['id'] not in tasks:
|
||||||
expand_tasks.remove(item)
|
expand_tasks.remove(item)
|
||||||
|
|
||||||
|
# 上下游关系类型
|
||||||
else:
|
else:
|
||||||
# if item['source'] not in tasks or item['target'] not in tasks:
|
# if item['source'] not in tasks or item['target'] not in tasks:
|
||||||
expand_tasks.remove(item) # 删除所有的上下游关系,后面全部重新
|
expand_tasks.remove(item) # 删除所有的上下游关系,后面全部重新
|
||||||
@ -438,10 +442,10 @@ class Pipeline(Model,ImportMixin,AuditMixinNullable,MyappModelBase):
|
|||||||
"y": random.randint(100,1000)
|
"y": random.randint(100,1000)
|
||||||
},
|
},
|
||||||
"data": {
|
"data": {
|
||||||
"taskId": task_id,
|
# "taskId": task_id,
|
||||||
"taskName": tasks[task_id].name,
|
# "taskName": tasks[task_id].name,
|
||||||
"name": tasks[task_id].name,
|
"name": tasks[task_id].name,
|
||||||
"describe": tasks[task_id].label
|
"label": tasks[task_id].label
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -457,7 +461,7 @@ class Pipeline(Model,ImportMixin,AuditMixinNullable,MyappModelBase):
|
|||||||
expand_tasks.append(
|
expand_tasks.append(
|
||||||
{
|
{
|
||||||
"source": str(upstream_task_id),
|
"source": str(upstream_task_id),
|
||||||
# "sourceHandle": None,
|
"arrowHeadType": 'arrow',
|
||||||
"target": str(task_id),
|
"target": str(task_id),
|
||||||
# "targetHandle": None,
|
# "targetHandle": None,
|
||||||
"id": self.name + "__edge-%snull-%snull" % (upstream_task_id, task_id)
|
"id": self.name + "__edge-%snull-%snull" % (upstream_task_id, task_id)
|
||||||
|
@ -1963,56 +1963,33 @@ def sort_expand_index(items,dbsession):
|
|||||||
# pass
|
# pass
|
||||||
return back
|
return back
|
||||||
|
|
||||||
|
|
||||||
# 生成前端锁需要的扩展字段
|
# 生成前端锁需要的扩展字段
|
||||||
def fix_task_position(pipeline,tasks):
|
# @pysnooper.snoop()
|
||||||
expand_tasks = []
|
def fix_task_position(pipeline,tasks,expand_tasks):
|
||||||
|
|
||||||
for task_name in tasks:
|
for task_name in tasks:
|
||||||
task = tasks[task_name]
|
task = tasks[task_name]
|
||||||
expand_task = {
|
|
||||||
"id": str(task['id']),
|
|
||||||
"type": "dataSet",
|
|
||||||
"position": {
|
|
||||||
"x": 0,
|
|
||||||
"y": 0
|
|
||||||
},
|
|
||||||
"data": {
|
|
||||||
"taskId": task['id'],
|
|
||||||
"taskName": task['name'],
|
|
||||||
"name": task['name'],
|
|
||||||
"describe": task['label']
|
|
||||||
}
|
|
||||||
}
|
|
||||||
expand_tasks.append(expand_task)
|
|
||||||
# print(pipeline['dag_json'])
|
# print(pipeline['dag_json'])
|
||||||
dag_json = json.loads(pipeline['dag_json'])
|
dag_json = json.loads(pipeline['dag_json'])
|
||||||
dag_json_sorted = sorted(dag_json.items(), key=lambda item: item[0])
|
dag_json_sorted = sorted(dag_json.items(), key=lambda item: item[0])
|
||||||
dag_json = {}
|
dag_json = {}
|
||||||
for item in dag_json_sorted:
|
for item in dag_json_sorted:
|
||||||
dag_json[item[0]] = item[1]
|
dag_json[item[0]] = item[1]
|
||||||
# print(dag_json)
|
|
||||||
for task_name in dag_json:
|
|
||||||
upstreams = dag_json[task_name].get("upstream", [])
|
|
||||||
if upstreams:
|
|
||||||
for upstream_name in upstreams:
|
|
||||||
upstream_task_id = tasks[upstream_name]['id']
|
|
||||||
task_id = tasks[task_name]['id']
|
|
||||||
if upstream_task_id and task_id:
|
|
||||||
expand_tasks.append(
|
|
||||||
{
|
|
||||||
"source": str(upstream_task_id),
|
|
||||||
# "sourceHandle": None,
|
|
||||||
"target": str(task_id),
|
|
||||||
# "targetHandle": None,
|
|
||||||
"id": pipeline['name'] + "__edge-%snull-%snull" % (upstream_task_id, task_id)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# 设置节点的位置
|
||||||
def set_position(task_id, x, y):
|
def set_position(task_id, x, y):
|
||||||
for task in expand_tasks:
|
for task in expand_tasks:
|
||||||
if str(task_id) == task['id']:
|
if str(task_id) == task['id']:
|
||||||
task['position']['x'] = x
|
task['position']['x'] = x
|
||||||
task['position']['y'] = y
|
task['position']['y'] = y
|
||||||
|
def read_position(task_id):
|
||||||
|
for task in expand_tasks:
|
||||||
|
if str(task_id) == task['id']:
|
||||||
|
return task['position']['x'],task['position']['y']
|
||||||
|
|
||||||
|
# 检查指定位置是否存在指定节点
|
||||||
def has_exist_node(x, y, task_id):
|
def has_exist_node(x, y, task_id):
|
||||||
for task in expand_tasks:
|
for task in expand_tasks:
|
||||||
if 'position' in task:
|
if 'position' in task:
|
||||||
@ -2028,25 +2005,75 @@ def fix_task_position(pipeline,tasks):
|
|||||||
if task_name in dag_json[task_name1].get("upstream", []):
|
if task_name in dag_json[task_name1].get("upstream", []):
|
||||||
dag_json[task_name]['downstream'].append(task_name1)
|
dag_json[task_name]['downstream'].append(task_name1)
|
||||||
|
|
||||||
# 计算每个节点的最大深度
|
# 获取节点下游节点总数目
|
||||||
|
def get_down_node_num(task_name):
|
||||||
|
down_nodes = dag_json[task_name].get('downstream',[])
|
||||||
|
if down_nodes:
|
||||||
|
return len(down_nodes)+sum([get_down_node_num(node) for node in down_nodes])
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
# 计算每个根节点的下游叶子总数
|
||||||
has_change = True
|
has_change = True
|
||||||
|
root_num=0
|
||||||
|
root_nodes = []
|
||||||
|
for task_name in dag_json:
|
||||||
|
task = dag_json[task_name]
|
||||||
|
# 为根节点记录第几颗树和deep
|
||||||
|
if not task.get("upstream",[]):
|
||||||
|
root_num+=1
|
||||||
|
task['deep'] = 1
|
||||||
|
root_nodes.append(task_name)
|
||||||
|
dag_json[task_name]['total_down_num']=get_down_node_num(task_name)
|
||||||
|
|
||||||
|
|
||||||
|
root_nodes = sorted(root_nodes, key=lambda task_name: dag_json[task_name]['total_down_num'],reverse=True) # 按子孙数量排序
|
||||||
|
print(root_nodes)
|
||||||
|
for i in range(len(root_nodes)):
|
||||||
|
dag_json[root_nodes[i]]['index']=i
|
||||||
|
|
||||||
|
|
||||||
|
# 更新叶子深度和树index,下游节点总数目
|
||||||
|
max_deep=1
|
||||||
while (has_change):
|
while (has_change):
|
||||||
has_change = False
|
has_change = False
|
||||||
for task_name in dag_json:
|
for task_name in dag_json:
|
||||||
task = dag_json[task_name]
|
task = dag_json[task_name]
|
||||||
if 'deep' not in task:
|
downstream_tasks = dag_json[task_name]['downstream']
|
||||||
|
|
||||||
|
# 配置全部下游节点总数
|
||||||
|
if 'total_down_num' not in dag_json[task_name]:
|
||||||
has_change = True
|
has_change = True
|
||||||
task['deep'] = 1
|
dag_json[task_name]['total_down_num'] = get_down_node_num(task_name)
|
||||||
else:
|
|
||||||
downstream_tasks = dag_json[task_name]['downstream']
|
for downstream_task_name in downstream_tasks:
|
||||||
for downstream_task_name in downstream_tasks:
|
# 新出现的叶子节点,直接deep+1
|
||||||
if 'deep' not in dag_json[downstream_task_name]:
|
if 'deep' not in dag_json[downstream_task_name]:
|
||||||
|
has_change = True
|
||||||
|
if 'deep' in task:
|
||||||
|
dag_json[downstream_task_name]['deep'] = 1 + task['deep']
|
||||||
|
if max_deep<(1 + task['deep']):
|
||||||
|
max_deep = 1 + task['deep']
|
||||||
|
else:
|
||||||
|
# 旧叶子,可能节点被多个不同deep的上游引导,使用deep最大的做为引导
|
||||||
|
if dag_json[downstream_task_name]['deep'] < task['deep'] + 1:
|
||||||
has_change = True
|
has_change = True
|
||||||
dag_json[downstream_task_name]['deep'] = 1 + task['deep']
|
dag_json[downstream_task_name]['deep'] = 1 + task['deep']
|
||||||
else:
|
if max_deep<(1 + task['deep']):
|
||||||
if dag_json[downstream_task_name]['deep'] < task['deep'] + 1:
|
max_deep = 1 + task['deep']
|
||||||
has_change = True
|
|
||||||
dag_json[downstream_task_name]['deep'] = 1 + task['deep']
|
|
||||||
|
# 叶子节点直接采用根节点的信息。有可能是多个根长出来的,选择index最小的根
|
||||||
|
if 'index' not in dag_json[downstream_task_name]:
|
||||||
|
has_change = True
|
||||||
|
if 'index' in task:
|
||||||
|
dag_json[downstream_task_name]['index']=task['index']
|
||||||
|
else:
|
||||||
|
if task['index']>dag_json[downstream_task_name]['index']:
|
||||||
|
has_change = True
|
||||||
|
dag_json[downstream_task_name]['index'] = task['index']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# print(dag_json)
|
# print(dag_json)
|
||||||
@ -2054,85 +2081,49 @@ def fix_task_position(pipeline,tasks):
|
|||||||
start_x = 50
|
start_x = 50
|
||||||
start_y = 50
|
start_y = 50
|
||||||
|
|
||||||
# 先把根的位置弄好
|
# 先把根的位置弄好,子节点多的排在左侧前方。
|
||||||
root_nodes = {}
|
|
||||||
for task_name in dag_json:
|
# @pysnooper.snoop()
|
||||||
|
def set_downstream_position(task_name):
|
||||||
task_id = str(tasks[task_name]['id'])
|
task_id = str(tasks[task_name]['id'])
|
||||||
if dag_json[task_name]['deep'] == 1:
|
downstream_tasks = [x for x in dag_json[task_name]['downstream'] if dag_json[x]['index']==dag_json[task_name]['index']] # 获取相同树的下游节点
|
||||||
set_position(task_id, start_x, start_y)
|
downstream_tasks = sorted(downstream_tasks, key=lambda temp: dag_json[temp]['total_down_num'],reverse=True) # 按子孙数目排序
|
||||||
start_x += 400
|
for i in range(len(downstream_tasks)):
|
||||||
root_nodes[task_name] = dag_json[task_name]
|
downstream_task = downstream_tasks[i]
|
||||||
|
y = dag_json[downstream_task]['deep']*150-100
|
||||||
|
# 获取前面的树有多少同一层叶子
|
||||||
|
front_task_num=0
|
||||||
|
for temp in dag_json:
|
||||||
|
# print(dag_json[temp]['index'],dag_json[task_name]['index'], dag_json[temp]['deep'],dag_json[task_name]['deep'])
|
||||||
|
if dag_json[temp]['index']<dag_json[downstream_task]['index'] and dag_json[temp]['deep']==dag_json[downstream_task]['deep']:
|
||||||
|
front_task_num+=1
|
||||||
|
front_task_num+=i
|
||||||
|
# y至少要操作他的上游节点的最小值。下游节点有多上上游节点时,靠左排布
|
||||||
|
up = min([read_position(tasks[task_name]['id'])[0] for task_name in dag_json[downstream_task]['upstream']]) # 获取这个下游节点的全部上游节点的x值
|
||||||
|
x = max(up,400*front_task_num+50)
|
||||||
|
# x = 400*front_task_num+50
|
||||||
|
set_position(str(tasks[downstream_task]['id']),x,y)
|
||||||
|
|
||||||
deep = 2
|
# 布局下一层
|
||||||
|
for temp in downstream_tasks:
|
||||||
# 广度遍历配置节点位置
|
set_downstream_position(temp)
|
||||||
while (root_nodes):
|
|
||||||
# print(root_nodes.keys())
|
|
||||||
for task_name in root_nodes:
|
|
||||||
task_id = str(tasks[task_name]['id'])
|
|
||||||
task_x = 0
|
|
||||||
task_y = 0
|
|
||||||
for task_item in expand_tasks:
|
|
||||||
if task_item['id'] == task_id:
|
|
||||||
task_x = task_item['position']['x']
|
|
||||||
task_y = task_item['position']['y']
|
|
||||||
# 只有当当前task位置定了,才能确定下面的节点的位置
|
|
||||||
if task_x >= 1 and task_y >= 1:
|
|
||||||
downstream_tasks_names = dag_json[task_name].get("downstream", [])
|
|
||||||
min_downstream_task_x = []
|
|
||||||
if downstream_tasks_names:
|
|
||||||
for downstream_task_name in downstream_tasks_names:
|
|
||||||
downstream_task_id = str(tasks[downstream_task_name]['id'])
|
|
||||||
|
|
||||||
downstream_task_x = 0
|
|
||||||
downstream_task_y = 0
|
|
||||||
for task_item in expand_tasks:
|
|
||||||
if task_item['id'] == downstream_task_id:
|
|
||||||
downstream_task_x = task_item['position']['x']
|
|
||||||
downstream_task_y = task_item['position']['y']
|
|
||||||
# new_x = downstream_task_x
|
|
||||||
# new_y = downstream_task_y
|
|
||||||
|
|
||||||
if downstream_task_y == 0:
|
|
||||||
new_x = task_x
|
|
||||||
new_y = task_y + 100
|
|
||||||
while has_exist_node(new_x, new_y, downstream_task_id):
|
|
||||||
print('%s %s exist node' % (new_x, new_y))
|
|
||||||
new_x += 300
|
|
||||||
|
|
||||||
print(downstream_task_name, new_x, new_y)
|
|
||||||
set_position(downstream_task_id, new_x, new_y)
|
|
||||||
min_downstream_task_x.append(new_x)
|
|
||||||
else:
|
|
||||||
# 子节点 由父节点产生
|
|
||||||
|
|
||||||
new_x = min(downstream_task_x, task_x)
|
|
||||||
new_y = task_y + 100
|
|
||||||
while has_exist_node(new_x, new_y, downstream_task_id):
|
|
||||||
print('%s %s exist node' % (new_x, new_y))
|
|
||||||
new_x += 300
|
|
||||||
|
|
||||||
print(downstream_task_name, new_x, new_y)
|
|
||||||
set_position(downstream_task_id, new_x, new_y)
|
|
||||||
# 在右下方的子节点,又会影响父节点的位置。其他位置的子节点不影响父节点的位置
|
|
||||||
if new_y == (task_y + 100) and new_x >= task_x:
|
|
||||||
min_downstream_task_x.append(new_x)
|
|
||||||
|
|
||||||
|
|
||||||
if min_downstream_task_x:
|
# print(dag_json)
|
||||||
new_x = min(min_downstream_task_x)
|
# 一棵树一棵树的构建。优先布局下游叶子节点数量大的
|
||||||
print('父节点%s %s %s' % (task_name, new_x, task_y))
|
for task_name in root_nodes:
|
||||||
set_position(task_id, new_x, task_y)
|
task_id = str(tasks[task_name]['id'])
|
||||||
|
set_position(task_id, start_x, start_y)
|
||||||
|
start_x += 400
|
||||||
|
set_downstream_position(task_name)
|
||||||
|
|
||||||
|
|
||||||
root_nodes = {}
|
|
||||||
for task_name in dag_json:
|
|
||||||
if dag_json[task_name]['deep'] == deep:
|
|
||||||
root_nodes[task_name] = dag_json[task_name]
|
|
||||||
deep += 1
|
|
||||||
|
|
||||||
|
|
||||||
return expand_tasks
|
return expand_tasks
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# import yaml
|
# import yaml
|
||||||
# # @pysnooper.snoop(watch_explode=())
|
# # @pysnooper.snoop(watch_explode=())
|
||||||
# def merge_tf_experiment_template(worker_num,node_selector,volume_mount,image,workingDir,image_pull_policy,memory,cpu,command,parameters):
|
# def merge_tf_experiment_template(worker_num,node_selector,volume_mount,image,workingDir,image_pull_policy,memory,cpu,command,parameters):
|
||||||
|
@ -862,7 +862,7 @@ class Pipeline_ModelView_Base():
|
|||||||
|
|
||||||
def pre_update_get(self,item):
|
def pre_update_get(self,item):
|
||||||
item.dag_json = item.fix_dag_json()
|
item.dag_json = item.fix_dag_json()
|
||||||
# item.expand = json.dumps(item.fix_expand(),indent=4,ensure_ascii=False)
|
item.expand = json.dumps(item.fix_expand(),indent=4,ensure_ascii=False)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# 删除前先把下面的task删除了
|
# 删除前先把下面的task删除了
|
||||||
@ -1098,21 +1098,22 @@ class Pipeline_ModelView_Base():
|
|||||||
def web(self,pipeline_id):
|
def web(self,pipeline_id):
|
||||||
pipeline = db.session.query(Pipeline).filter_by(id=pipeline_id).first()
|
pipeline = db.session.query(Pipeline).filter_by(id=pipeline_id).first()
|
||||||
|
|
||||||
pipeline.dag_json = pipeline.fix_dag_json()
|
pipeline.dag_json = pipeline.fix_dag_json() # 修正 dag_json
|
||||||
# pipeline.expand = json.dumps(pipeline.fix_expand(), indent=4, ensure_ascii=False)
|
pipeline.expand = json.dumps(pipeline.fix_expand(), indent=4, ensure_ascii=False) # 修正 前端expand字段缺失
|
||||||
pipeline.expand = json.dumps(pipeline.fix_position(), indent=4, ensure_ascii=False)
|
pipeline.expand = json.dumps(pipeline.fix_position(), indent=4, ensure_ascii=False) # 修正 节点中心位置到视图中间
|
||||||
|
|
||||||
# db_tasks = pipeline.get_tasks(db.session)
|
# 自动排版
|
||||||
# if db_tasks:
|
db_tasks = pipeline.get_tasks(db.session)
|
||||||
# try:
|
if db_tasks:
|
||||||
# tasks={}
|
try:
|
||||||
# for task in db_tasks:
|
tasks={}
|
||||||
# tasks[task.name]=task.to_json()
|
for task in db_tasks:
|
||||||
# expand = core.fix_task_position(pipeline.to_json(),tasks)
|
tasks[task.name]=task.to_json()
|
||||||
# pipeline.expand=json.dumps(expand,indent=4,ensure_ascii=False)
|
expand = core.fix_task_position(pipeline.to_json(),tasks,json.loads(pipeline.expand))
|
||||||
# db.session.commit()
|
pipeline.expand=json.dumps(expand,indent=4,ensure_ascii=False)
|
||||||
# except Exception as e:
|
db.session.commit()
|
||||||
# print(e)
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
print(pipeline_id)
|
print(pipeline_id)
|
||||||
|
@ -363,7 +363,6 @@ class Task_ModelView_Base():
|
|||||||
# item.pipeline.pipeline_argo_id = pipeline_argo_id
|
# item.pipeline.pipeline_argo_id = pipeline_argo_id
|
||||||
# if version_id:
|
# if version_id:
|
||||||
# item.pipeline.version_id = version_id
|
# item.pipeline.version_id = version_id
|
||||||
# # db.session.update(item)
|
|
||||||
# db.session.commit()
|
# db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user