use callback to filter workflow stream output

This commit is contained in:
takatost
2024-03-07 09:55:29 +08:00
parent 6372183471
commit 79f0e894e9
7 changed files with 138 additions and 57 deletions

View File

@ -1,9 +1,9 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from models.workflow import WorkflowNodeExecution, WorkflowRun
class BaseWorkflowCallback:
class BaseWorkflowCallback(ABC):
@abstractmethod
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
"""
@ -33,7 +33,7 @@ class BaseWorkflowCallback:
raise NotImplementedError
@abstractmethod
def on_text_chunk(self, text: str) -> None:
def on_node_text_chunk(self, node_id: str, text: str) -> None:
"""
Publish text chunk
"""

View File

@ -16,7 +16,6 @@ class BaseNode:
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
stream_output_supported: bool = False
callbacks: list[BaseWorkflowCallback]
def __init__(self, config: dict,
@ -71,10 +70,12 @@ class BaseNode:
:param text: chunk text
:return:
"""
if self.stream_output_supported:
if self.callbacks:
for callback in self.callbacks:
callback.on_text_chunk(text)
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text
)
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:

View File

@ -32,7 +32,6 @@ from models.workflow import (
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
WorkflowType,
)
node_classes = {
@ -171,9 +170,6 @@ class WorkflowEngineManager:
)
)
# fetch predecessor node ids before end node (include: llm, direct answer)
streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph)
try:
predecessor_node = None
while True:
@ -187,10 +183,6 @@ class WorkflowEngineManager:
if not next_node:
break
# check if node is streamable
if next_node.node_id in streamable_node_ids:
next_node.stream_output_supported = True
# max steps 30 reached
if len(workflow_run_state.workflow_node_executions) > 30:
raise ValueError('Max steps 30 reached.')
@ -233,34 +225,6 @@ class WorkflowEngineManager:
callbacks=callbacks
)
def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param workflow: Workflow instance
:param graph: workflow graph
:return:
"""
workflow_type = WorkflowType.value_of(workflow.type)
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('type') == NodeType.END.value:
if workflow_type == WorkflowType.WORKFLOW:
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
end_node_ids.append(node_config.get('id'))
else:
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],