mirror of
https://github.com/langgenius/dify.git
synced 2026-05-28 04:43:33 +08:00
feat(workflow): integrate parallel into workflow apps
This commit is contained in:
@ -99,6 +99,10 @@ class Graph(BaseModel):
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
# is target node id in source node id edge mapping
|
||||
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
|
||||
continue
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
|
||||
@ -244,6 +244,7 @@ class GraphEngine:
|
||||
|
||||
if len(edge_mappings) == 1:
|
||||
edge = edge_mappings[0]
|
||||
|
||||
if edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
@ -296,14 +297,20 @@ class GraphEngine:
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
|
||||
# Create a list to store the threads
|
||||
threads = []
|
||||
|
||||
# new thread
|
||||
for edge in edge_mappings:
|
||||
threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
thread = threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'q': q
|
||||
}).start()
|
||||
})
|
||||
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
@ -315,8 +322,8 @@ class GraphEngine:
|
||||
yield event
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(edge_mappings):
|
||||
break
|
||||
if succeeded_count == len(threads):
|
||||
q.put(None)
|
||||
|
||||
continue
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
@ -324,6 +331,10 @@ class GraphEngine:
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# Join all threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
if not final_node_id:
|
||||
@ -331,8 +342,8 @@ class GraphEngine:
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
# if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
|
||||
# break
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||
break
|
||||
|
||||
def _run_parallel_node(self,
|
||||
flask_app: Flask,
|
||||
@ -449,6 +460,14 @@ class GraphEngine:
|
||||
variable_value=variable_value
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
|
||||
Reference in New Issue
Block a user