mirror of
https://github.com/langgenius/dify.git
synced 2026-03-13 19:17:43 +08:00
add end stream output test
This commit is contained in:
@ -29,6 +29,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
|
||||
# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
@ -82,14 +83,21 @@ class GraphEngine:
|
||||
yield GraphRunStartedEvent()
|
||||
|
||||
try:
|
||||
# run graph
|
||||
generator = self._run(start_node_id=self.graph.root_node_id)
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
answer_stream_processor = AnswerStreamProcessor(
|
||||
stream_processor = AnswerStreamProcessor(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
generator = answer_stream_processor.process(generator)
|
||||
else:
|
||||
stream_processor = EndStreamProcessor(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
# run graph
|
||||
generator = stream_processor.process(
|
||||
self._run(start_node_id=self.graph.root_node_id)
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
yield item
|
||||
@ -151,6 +159,11 @@ class GraphEngine:
|
||||
)
|
||||
raise e
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (self.graph.node_id_config_mapping[next_node_id]
|
||||
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
||||
break
|
||||
|
||||
previous_route_node_state = route_node_state
|
||||
|
||||
# get next node ids
|
||||
@ -160,11 +173,6 @@ class GraphEngine:
|
||||
|
||||
if len(edge_mappings) == 1:
|
||||
next_node_id = edge_mappings[0].target_node_id
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (self.graph.node_id_config_mapping[next_node_id]
|
||||
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
||||
break
|
||||
else:
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
|
||||
@ -66,6 +66,7 @@ class AnswerStreamProcessor:
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
@ -179,14 +180,13 @@ class AnswerStreamProcessor:
|
||||
return []
|
||||
|
||||
stream_out_answer_node_ids = []
|
||||
for answer_node_id, position in self.route_position.items():
|
||||
for answer_node_id, route_position in self.route_position.items():
|
||||
if answer_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
|
||||
# all depends on answer node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
||||
route_position = self.route_position[answer_node_id]
|
||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
||||
continue
|
||||
|
||||
|
||||
@ -31,49 +31,6 @@ class EndNode(BaseNode):
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
|
||||
"""
|
||||
Extract generate nodes
|
||||
:param graph: graph
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(EndNodeData, node_data)
|
||||
|
||||
return cls.extract_generate_nodes_from_node_data(graph, node_data)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
|
||||
"""
|
||||
Extract generate nodes from node data
|
||||
:param graph: graph
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes', [])
|
||||
node_mapping = {node.get('id'): node for node in nodes}
|
||||
|
||||
variable_selectors = node_data.outputs
|
||||
|
||||
generate_nodes = []
|
||||
for variable_selector in variable_selectors:
|
||||
if not variable_selector.value_selector:
|
||||
continue
|
||||
|
||||
node_id = variable_selector.value_selector[0]
|
||||
if node_id != 'sys' and node_id in node_mapping:
|
||||
node = node_mapping[node_id]
|
||||
node_type = node.get('data', {}).get('type')
|
||||
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
|
||||
generate_nodes.append(node_id)
|
||||
|
||||
# remove duplicates
|
||||
generate_nodes = list(set(generate_nodes))
|
||||
|
||||
return generate_nodes
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
|
||||
@ -61,7 +61,9 @@ class EndStreamGeneratorRouter:
|
||||
value_selectors.append(variable_selector.value_selector)
|
||||
|
||||
# remove duplicates
|
||||
value_selectors = list(set(value_selectors))
|
||||
value_selector_tuples = [tuple(item) for item in value_selectors]
|
||||
unique_value_selector_tuples = list(set(value_selector_tuples))
|
||||
value_selectors = [list(item) for item in unique_value_selector_tuples]
|
||||
|
||||
return value_selectors
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@ -9,7 +8,6 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -20,10 +18,7 @@ class EndStreamProcessor:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.stream_param = graph.end_stream_param
|
||||
self.end_streamed_variable_selectors: dict[str, list[str]] = {
|
||||
end_node_id: [] for end_node_id in graph.end_stream_param.end_stream_variable_selector_mapping
|
||||
}
|
||||
|
||||
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
|
||||
@ -33,43 +28,37 @@ class EndStreamProcessor:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
]
|
||||
else:
|
||||
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
|
||||
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_answer_node_ids
|
||||
] = stream_out_end_node_ids
|
||||
|
||||
for _ in stream_out_answer_node_ids:
|
||||
for _ in stream_out_end_node_ids:
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.end_streamed_variable_selectors = {}
|
||||
self.end_streamed_variable_selectors: dict[str, list[str]] = {
|
||||
end_node_id: [] for end_node_id in self.graph.end_stream_param.end_stream_variable_selector_mapping
|
||||
}
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
@ -113,59 +102,7 @@ class EndStreamProcessor:
|
||||
|
||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
:return:
|
||||
"""
|
||||
for answer_node_id, position in self.route_position.items():
|
||||
# all depends on answer node id not in rest node ids
|
||||
if (event.route_node_state.node_id != answer_node_id
|
||||
and (answer_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[answer_node_id]
|
||||
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
yield NodeRunStreamChunkEvent(
|
||||
chunk_content=route_chunk.text,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
break
|
||||
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = value.markdown
|
||||
|
||||
if text:
|
||||
yield NodeRunStreamChunkEvent(
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
)
|
||||
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
@ -178,30 +115,17 @@ class EndStreamProcessor:
|
||||
if not stream_output_value_selector:
|
||||
return []
|
||||
|
||||
stream_out_answer_node_ids = []
|
||||
for answer_node_id, position in self.route_position.items():
|
||||
if answer_node_id not in self.rest_node_ids:
|
||||
stream_out_end_node_ids = []
|
||||
for end_node_id, variable_selectors in self.end_streamed_variable_selectors.items():
|
||||
if end_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
|
||||
# all depends on answer node id not in rest node ids
|
||||
# all depends on end node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
||||
route_position = self.route_position[answer_node_id]
|
||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
||||
for dep_id in self.stream_param.end_dependencies[end_node_id]):
|
||||
if stream_output_value_selector not in variable_selectors:
|
||||
continue
|
||||
|
||||
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
|
||||
stream_out_end_node_ids.append(end_node_id)
|
||||
|
||||
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
|
||||
continue
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
continue
|
||||
|
||||
stream_out_answer_node_ids.append(answer_node_id)
|
||||
|
||||
return stream_out_answer_node_ids
|
||||
return stream_out_end_node_ids
|
||||
|
||||
Reference in New Issue
Block a user