mirror of
https://github.com/langgenius/dify.git
synced 2026-04-22 03:37:44 +08:00
use callback to filter workflow stream output
This commit is contained in:
@ -1,9 +1,9 @@
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||
|
||||
|
||||
class BaseWorkflowCallback:
|
||||
class BaseWorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
@ -33,7 +33,7 @@ class BaseWorkflowCallback:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_text_chunk(self, text: str) -> None:
|
||||
def on_node_text_chunk(self, node_id: str, text: str) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
|
||||
@ -16,7 +16,6 @@ class BaseNode:
|
||||
node_data: BaseNodeData
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
|
||||
stream_output_supported: bool = False
|
||||
callbacks: list[BaseWorkflowCallback]
|
||||
|
||||
def __init__(self, config: dict,
|
||||
@ -71,10 +70,12 @@ class BaseNode:
|
||||
:param text: chunk text
|
||||
:return:
|
||||
"""
|
||||
if self.stream_output_supported:
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_text_chunk(text)
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_node_text_chunk(
|
||||
node_id=self.node_id,
|
||||
text=text
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
|
||||
@ -32,7 +32,6 @@ from models.workflow import (
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
node_classes = {
|
||||
@ -171,9 +170,6 @@ class WorkflowEngineManager:
|
||||
)
|
||||
)
|
||||
|
||||
# fetch predecessor node ids before end node (include: llm, direct answer)
|
||||
streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph)
|
||||
|
||||
try:
|
||||
predecessor_node = None
|
||||
while True:
|
||||
@ -187,10 +183,6 @@ class WorkflowEngineManager:
|
||||
if not next_node:
|
||||
break
|
||||
|
||||
# check if node is streamable
|
||||
if next_node.node_id in streamable_node_ids:
|
||||
next_node.stream_output_supported = True
|
||||
|
||||
# max steps 30 reached
|
||||
if len(workflow_run_state.workflow_node_executions) > 30:
|
||||
raise ValueError('Max steps 30 reached.')
|
||||
@ -233,34 +225,6 @@ class WorkflowEngineManager:
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]:
|
||||
"""
|
||||
Fetch streamable node ids
|
||||
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
||||
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
||||
|
||||
:param workflow: Workflow instance
|
||||
:param graph: workflow graph
|
||||
:return:
|
||||
"""
|
||||
workflow_type = WorkflowType.value_of(workflow.type)
|
||||
|
||||
streamable_node_ids = []
|
||||
end_node_ids = []
|
||||
for node_config in graph.get('nodes'):
|
||||
if node_config.get('type') == NodeType.END.value:
|
||||
if workflow_type == WorkflowType.WORKFLOW:
|
||||
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
|
||||
end_node_ids.append(node_config.get('id'))
|
||||
else:
|
||||
end_node_ids.append(node_config.get('id'))
|
||||
|
||||
for edge_config in graph.get('edges'):
|
||||
if edge_config.get('target') in end_node_ids:
|
||||
streamable_node_ids.append(edge_config.get('source'))
|
||||
|
||||
return streamable_node_ids
|
||||
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
|
||||
Reference in New Issue
Block a user