feat(workflow): integrate parallel into workflow apps

This commit is contained in:
takatost
2024-08-16 21:33:09 +08:00
parent 1973f5003b
commit 352c45c8a2
13 changed files with 233 additions and 94 deletions

View File

@ -99,6 +99,10 @@ class Graph(BaseModel):
if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = []
# is target node id in source node id edge mapping
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
continue
target_edge_ids.add(target_node_id)
# parse run condition

View File

@ -244,6 +244,7 @@ class GraphEngine:
if len(edge_mappings) == 1:
edge = edge_mappings[0]
if edge.run_condition:
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@ -296,14 +297,20 @@ class GraphEngine:
# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
# Create a list to store the threads
threads = []
# new thread
for edge in edge_mappings:
threading.Thread(target=self._run_parallel_node, kwargs={
thread = threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'q': q
}).start()
})
threads.append(thread)
thread.start()
succeeded_count = 0
while True:
@ -315,8 +322,8 @@ class GraphEngine:
yield event
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(edge_mappings):
break
if succeeded_count == len(threads):
q.put(None)
continue
elif isinstance(event, ParallelBranchRunFailedEvent):
@ -324,6 +331,10 @@ class GraphEngine:
except queue.Empty:
continue
# Join all threads
for thread in threads:
thread.join()
# get final node id
final_node_id = parallel.end_to_node_id
if not final_node_id:
@ -331,8 +342,8 @@ class GraphEngine:
next_node_id = final_node_id
# if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
# break
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
break
def _run_parallel_node(self,
flask_app: Flask,
@ -449,6 +460,14 @@ class GraphEngine:
variable_value=variable_value
)
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,