add end stream output test

This commit is contained in:
takatost
2024-07-25 04:03:53 +08:00
parent 833584ba76
commit f4eb7cd037
6 changed files with 300 additions and 154 deletions

View File

@ -29,6 +29,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
@ -82,14 +83,21 @@ class GraphEngine:
yield GraphRunStartedEvent()
try:
# run graph
generator = self._run(start_node_id=self.graph.root_node_id)
if self.init_params.workflow_type == WorkflowType.CHAT:
answer_stream_processor = AnswerStreamProcessor(
stream_processor = AnswerStreamProcessor(
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
generator = answer_stream_processor.process(generator)
else:
stream_processor = EndStreamProcessor(
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
# run graph
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id)
)
for item in generator:
yield item
@ -151,6 +159,11 @@ class GraphEngine:
)
raise e
# It may not be necessary, but it is necessary. :)
if (self.graph.node_id_config_mapping[next_node_id]
.get("data", {}).get("type", "").lower() == NodeType.END.value):
break
previous_route_node_state = route_node_state
# get next node ids
@ -160,11 +173,6 @@ class GraphEngine:
if len(edge_mappings) == 1:
next_node_id = edge_mappings[0].target_node_id
# It may not be necessary, but it is necessary. :)
if (self.graph.node_id_config_mapping[next_node_id]
.get("data", {}).get("type", "").lower() == NodeType.END.value):
break
else:
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results