mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 06:28:05 +08:00
finished answer stream output
This commit is contained in:
@ -50,7 +50,8 @@ class NodeRunStartedEvent(BaseNodeEvent):
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: list[str] = Field(..., description="from variable selector")
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
|
||||
@ -5,21 +5,24 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: Optional[RunCondition] = Field(None, description="run condition")
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""run condition"""
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: Optional[str] = Field(None, description="parent parallel id")
|
||||
end_to_node_id: Optional[str] = Field(None, description="end to node id")
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id"""
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
@ -33,6 +36,10 @@ class Graph(BaseModel):
|
||||
default_factory=dict,
|
||||
description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="reverse graph edge mapping (target node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict,
|
||||
description="graph parallel mapping (parallel id: parallel)"
|
||||
@ -41,8 +48,8 @@ class Graph(BaseModel):
|
||||
default_factory=dict,
|
||||
description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: dict[str, AnswerStreamGenerateRoute] = Field(
|
||||
default_factory=dict,
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
|
||||
...,
|
||||
description="answer stream generate routes"
|
||||
)
|
||||
|
||||
@ -66,6 +73,7 @@ class Graph(BaseModel):
|
||||
|
||||
# reorganize edges mapping
|
||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get('source')
|
||||
@ -79,6 +87,9 @@ class Graph(BaseModel):
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
@ -91,11 +102,12 @@ class Graph(BaseModel):
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=edge_config.get('target'),
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# node configs
|
||||
node_configs = graph_config.get('nodes')
|
||||
@ -149,9 +161,9 @@ class Graph(BaseModel):
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
answer_stream_generate_routes = AnswerStreamOutputManager.init_stream_generate_routes(
|
||||
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping
|
||||
reverse_edge_mapping=reverse_edge_mapping
|
||||
)
|
||||
|
||||
# init graph
|
||||
@ -160,6 +172,7 @@ class Graph(BaseModel):
|
||||
node_ids=node_ids,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
answer_stream_generate_routes=answer_stream_generate_routes
|
||||
|
||||
@ -8,7 +8,11 @@ class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
total_tokens: int = Field(0, description="total tokens")
|
||||
node_run_steps: int = Field(0, description="node run steps")
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
|
||||
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state")
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
||||
"""node run state"""
|
||||
|
||||
@ -28,6 +28,9 @@ from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
|
||||
# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from extensions.ext_database import db
|
||||
@ -81,6 +84,10 @@ class GraphEngine:
|
||||
try:
|
||||
# run graph
|
||||
generator = self._run(start_node_id=self.graph.root_node_id)
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
answer_stream_processor = AnswerStreamProcessor(self.graph)
|
||||
generator = answer_stream_processor.process(generator)
|
||||
|
||||
for item in generator:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
@ -314,8 +321,6 @@ class GraphEngine:
|
||||
|
||||
db.session.close()
|
||||
|
||||
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
|
||||
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
|
||||
try:
|
||||
@ -335,7 +340,7 @@ class GraphEngine:
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# append node output variables to variable pool
|
||||
@ -397,7 +402,7 @@ class GraphEngine:
|
||||
self.graph_runtime_state.variable_pool.append_variable(
|
||||
node_id=node_id,
|
||||
variable_key_list=variable_key_list,
|
||||
value=variable_value
|
||||
value=variable_value # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
|
||||
Reference in New Issue
Block a user