fix iteration

This commit is contained in:
takatost
2024-07-26 02:43:40 +08:00
parent ae351bd40e
commit a31feacf28
7 changed files with 283 additions and 27 deletions

View File

@ -4,6 +4,7 @@ from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
@ -26,7 +27,12 @@ class EndStreamProcessor:
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStreamChunkEvent):
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
@ -87,6 +93,9 @@ class EndStreamProcessor:
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
@ -95,6 +104,9 @@ class EndStreamProcessor:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids: