add parallel branch output

This commit is contained in:
takatost
2024-07-25 19:39:06 +08:00
parent f4eb7cd037
commit 4097f7c069
4 changed files with 45 additions and 16 deletions

View File

@ -31,9 +31,16 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
start_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
parallel_id = graph.node_parallel_mapping.get(next_node_id)
parallel_start_node_id = None
if parallel_id:
parallel = graph.parallel_mapping.get(parallel_id)
parallel_start_node_id = parallel.start_from_node_id if parallel else None
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=graph.node_parallel_mapping.get(next_node_id),
parallel_start_node_id=parallel_start_node_id
)
if 'llm' in next_node_id:
@ -43,14 +50,16 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
chunk_content=str(i),
route_node_state=route_node_state,
from_variable_selector=[next_node_id, "text"],
parallel_id=graph.node_parallel_mapping.get(next_node_id),
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
route_node_state.status = RouteNodeState.Status.SUCCESS
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield NodeRunSucceededEvent(
route_node_state=route_node_state,
parallel_id=graph.node_parallel_mapping.get(next_node_id),
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)