add chatflow app event convert

This commit is contained in:
takatost
2024-07-31 02:21:35 +08:00
parent 0818b7b078
commit 917aacbf7f
19 changed files with 1566 additions and 239 deletions

View File

@ -9,18 +9,17 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
class EndStreamProcessor:
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
super().__init__(graph, variable_pool)
self.stream_param = graph.end_stream_param
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.rest_node_ids = graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
@ -56,64 +55,10 @@ class EndStreamProcessor:
yield event
def reset(self) -> None:
self.end_streamed_variable_selectors = {}
self.end_streamed_variable_selectors: dict[str, list[str]] = {
end_node_id: [] for end_node_id in self.graph.end_stream_param.end_stream_variable_selector_mapping
}
self.end_streamed_variable_selectors = self.graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support