mirror of
https://github.com/langgenius/dify.git
synced 2026-03-01 05:48:40 +08:00
add end stream output test
This commit is contained in:
@ -29,6 +29,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
|
||||
# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
@ -82,14 +83,21 @@ class GraphEngine:
|
||||
yield GraphRunStartedEvent()
|
||||
|
||||
try:
|
||||
# run graph
|
||||
generator = self._run(start_node_id=self.graph.root_node_id)
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
answer_stream_processor = AnswerStreamProcessor(
|
||||
stream_processor = AnswerStreamProcessor(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
generator = answer_stream_processor.process(generator)
|
||||
else:
|
||||
stream_processor = EndStreamProcessor(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
# run graph
|
||||
generator = stream_processor.process(
|
||||
self._run(start_node_id=self.graph.root_node_id)
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
yield item
|
||||
@ -151,6 +159,11 @@ class GraphEngine:
|
||||
)
|
||||
raise e
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (self.graph.node_id_config_mapping[next_node_id]
|
||||
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
||||
break
|
||||
|
||||
previous_route_node_state = route_node_state
|
||||
|
||||
# get next node ids
|
||||
@ -160,11 +173,6 @@ class GraphEngine:
|
||||
|
||||
if len(edge_mappings) == 1:
|
||||
next_node_id = edge_mappings[0].target_node_id
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (self.graph.node_id_config_mapping[next_node_id]
|
||||
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
||||
break
|
||||
else:
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
|
||||
Reference in New Issue
Block a user