add chatflow app event convert

This commit is contained in:
takatost
2024-07-31 02:21:35 +08:00
parent 0818b7b078
commit 917aacbf7f
19 changed files with 1566 additions and 239 deletions

View File

@ -11,21 +11,20 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
class AnswerStreamProcessor:
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
@ -74,58 +73,6 @@ class AnswerStreamProcessor:
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
@ -138,8 +85,8 @@ class AnswerStreamProcessor:
# all depends on answer node id not in rest node ids
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue
route_position = self.route_position[answer_node_id]
@ -149,6 +96,9 @@ class AnswerStreamProcessor:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
@ -171,6 +121,9 @@ class AnswerStreamProcessor:
if text:
yield NodeRunStreamChunkEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,

View File

@ -0,0 +1,71 @@
from abc import ABC, abstractmethod
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)

View File

@ -4,6 +4,7 @@ from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@ -42,14 +43,14 @@ class BaseNode(ABC):
@abstractmethod
def _run(self) \
-> NodeRunResult | Generator[RunEvent, None, None]:
-> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:return:
"""
raise NotImplementedError
def run(self) -> Generator[RunEvent, None, None]:
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node entry
:return:

View File

@ -9,18 +9,17 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
class EndStreamProcessor:
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
super().__init__(graph, variable_pool)
self.stream_param = graph.end_stream_param
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.rest_node_ids = graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
@ -56,64 +55,10 @@ class EndStreamProcessor:
yield event
def reset(self) -> None:
self.end_streamed_variable_selectors = {}
self.end_streamed_variable_selectors: dict[str, list[str]] = {
end_node_id: [] for end_node_id in self.graph.end_stream_param.end_stream_variable_selector_mapping
}
self.end_streamed_variable_selectors = self.graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support

View File

@ -1,13 +1,16 @@
import logging
from collections.abc import Generator
from typing import Any, cast
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
GraphRunFailedEvent,
InNodeEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
@ -17,7 +20,7 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus
@ -32,7 +35,7 @@ class IterationNode(BaseNode):
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
def _run(self) -> BaseIterationState:
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
@ -42,6 +45,10 @@ class IterationNode(BaseNode):
if not isinstance(iterator_list_value, list):
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
inputs = {
"iterator_selector": iterator_list_value
}
root_node_id = self.node_data.start_node_id
graph_config = self.graph_config
@ -117,21 +124,42 @@ class IterationNode(BaseNode):
)
yield IterationRunStartedEvent(
iteration_id=self.node_id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
metadata={
"iterator_length": 1
},
predecessor_node_id=self.previous_node_id
)
yield IterationRunNextEvent(
iteration_id=self.node_id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=0,
output=None
pre_iteration_output=None
)
outputs: list[Any] = []
try:
# run workflow
rst = graph_engine.run()
outputs: list[Any] = []
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if isinstance(event, NodeRunSucceededEvent):
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
event.route_node_state.node_run_result.metadata = metadata
yield event
# handle iteration run result
@ -158,22 +186,35 @@ class IterationNode(BaseNode):
)
yield IterationRunNextEvent(
iteration_id=self.node_id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None
pre_iteration_output=jsonable_encoder(
current_iteration_output) if current_iteration_output else None
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.node_id,
reason=event.reason,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.reason,
error=event.error,
)
)
break
@ -181,7 +222,17 @@ class IterationNode(BaseNode):
yield event
yield IterationRunSucceededEvent(
iteration_id=self.node_id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
}
)
yield RunCompletedEvent(
@ -196,8 +247,18 @@ class IterationNode(BaseNode):
# iteration run failed
logger.exception("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.node_id,
reason=str(e),
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
error=str(e),
)
yield RunCompletedEvent(

View File

@ -27,6 +27,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
@ -42,11 +43,19 @@ from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self) -> Generator[RunEvent, None, None]:
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:return:
@ -167,7 +176,7 @@ class LLMNode(BaseNode):
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None) \
-> Generator[RunEvent, None, None]:
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
@ -201,7 +210,7 @@ class LLMNode(BaseNode):
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunEvent, None, None]:
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
@ -762,11 +771,3 @@ class LLMNode(BaseNode):
}
}
}
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage