add parallel branch output

This commit is contained in:
takatost
2024-07-25 19:39:06 +08:00
parent f4eb7cd037
commit 4097f7c069
4 changed files with 45 additions and 16 deletions

View File

@ -40,7 +40,10 @@ class GraphRunFailedEvent(BaseGraphEvent):
class BaseNodeEvent(GraphEngineEvent):
route_node_state: RouteNodeState = Field(..., description="route node state")
parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")

View File

@ -115,6 +115,10 @@ class GraphEngine:
raise e
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
parallel_start_node_id = None
if in_parallel_id:
parallel_start_node_id = start_node_id
next_node_id = start_node_id
previous_route_node_state: Optional[RouteNodeState] = None
while True:
@ -139,7 +143,8 @@ class GraphEngine:
yield from self._run_node(
route_node_state=route_node_state,
previous_node_id=previous_route_node_state.node_id if previous_route_node_state else None,
parallel_id=in_parallel_id
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
@ -155,7 +160,8 @@ class GraphEngine:
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=in_parallel_id
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
raise e
@ -287,14 +293,16 @@ class GraphEngine:
def _run_node(self,
route_node_state: RouteNodeState,
previous_node_id: Optional[str] = None,
parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
"""
# trigger node run start event
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
# get node config
@ -305,7 +313,8 @@ class GraphEngine:
route_node_state.failed_reason = f'Node {node_id} config not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
return
@ -317,7 +326,8 @@ class GraphEngine:
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
return
@ -344,8 +354,9 @@ class GraphEngine:
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id,
route_node_state=route_node_state
parallel_start_node_id=parallel_start_node_id
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
@ -365,24 +376,27 @@ class GraphEngine:
)
yield NodeRunSucceededEvent(
route_node_state=route_node_state,
parallel_id=parallel_id,
route_node_state=route_node_state
parallel_start_node_id=parallel_start_node_id
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
route_node_state=route_node_state,
parallel_id=parallel_id,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
retriever_resources=item.retriever_resources,
context=item.context
parallel_start_node_id=parallel_start_node_id,
)
except GenerateTaskStoppedException:
# trigger node run failed event
@ -390,7 +404,8 @@ class GraphEngine:
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
)
return
except Exception as e: