finished answer stream output

This commit is contained in:
takatost
2024-07-20 00:49:46 +08:00
parent 7ad77e9e77
commit dad1a967ee
15 changed files with 989 additions and 522 deletions

View File

@ -50,7 +50,8 @@ class NodeRunStartedEvent(BaseNodeEvent):
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
class NodeRunRetrieverResourceEvent(BaseNodeEvent):

View File

@ -5,21 +5,24 @@ from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
class GraphEdge(BaseModel):
source_node_id: str = Field(..., description="source node id")
target_node_id: str = Field(..., description="target node id")
run_condition: Optional[RunCondition] = Field(None, description="run condition")
run_condition: Optional[RunCondition] = None
"""run condition"""
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = Field(None, description="parent parallel id")
end_to_node_id: Optional[str] = Field(None, description="end to node id")
parent_parallel_id: Optional[str] = None
"""parent parallel id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
class Graph(BaseModel):
@ -33,6 +36,10 @@ class Graph(BaseModel):
default_factory=dict,
description="graph edge mapping (source node id: edges)"
)
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="reverse graph edge mapping (target node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict,
description="graph parallel mapping (parallel id: parallel)"
@ -41,8 +48,8 @@ class Graph(BaseModel):
default_factory=dict,
description="graph node parallel mapping (node id: parallel id)"
)
answer_stream_generate_routes: dict[str, AnswerStreamGenerateRoute] = Field(
default_factory=dict,
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
...,
description="answer stream generate routes"
)
@ -66,6 +73,7 @@ class Graph(BaseModel):
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
for edge_config in edge_configs:
source_node_id = edge_config.get('source')
@ -79,6 +87,9 @@ class Graph(BaseModel):
if not target_node_id:
continue
if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = []
target_edge_ids.add(target_node_id)
# parse run condition
@ -91,11 +102,12 @@ class Graph(BaseModel):
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=edge_config.get('target'),
target_node_id=target_node_id,
run_condition=run_condition
)
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# node configs
node_configs = graph_config.get('nodes')
@ -149,9 +161,9 @@ class Graph(BaseModel):
)
# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamOutputManager.init_stream_generate_routes(
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping
reverse_edge_mapping=reverse_edge_mapping
)
# init graph
@ -160,6 +172,7 @@ class Graph(BaseModel):
node_ids=node_ids,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes

View File

@ -8,7 +8,11 @@ class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
start_at: float = Field(..., description="start time")
total_tokens: int = Field(0, description="total tokens")
node_run_steps: int = Field(0, description="node run steps")
total_tokens: int = 0
"""total tokens"""
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state")
node_run_steps: int = 0
"""node run steps"""
node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state"""

View File

@ -28,6 +28,9 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes
from extensions.ext_database import db
@ -81,6 +84,10 @@ class GraphEngine:
try:
# run graph
generator = self._run(start_node_id=self.graph.root_node_id)
if self.init_params.workflow_type == WorkflowType.CHAT:
answer_stream_processor = AnswerStreamProcessor(self.graph)
generator = answer_stream_processor.process(generator)
for item in generator:
yield item
if isinstance(item, NodeRunFailedEvent):
@ -314,8 +321,6 @@ class GraphEngine:
db.session.close()
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
self.graph_runtime_state.node_run_steps += 1
try:
@ -335,7 +340,7 @@ class GraphEngine:
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
# append node output variables to variable pool
@ -397,7 +402,7 @@ class GraphEngine:
self.graph_runtime_state.variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value
value=variable_value # type: ignore[arg-type]
)
# if variable_value is a dict, then recursively append variables