Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

# Conflicts:
#	api/commands.py
#	api/core/app/apps/common/workflow_response_converter.py
#	api/core/llm_generator/llm_generator.py
#	api/core/plugin/entities/plugin.py
#	api/core/plugin/impl/tool.py
#	api/core/rag/index_processor/index_processor_base.py
#	api/core/workflow/entities/workflow_execution.py
#	api/core/workflow/entities/workflow_node_execution.py
#	api/core/workflow/enums.py
#	api/core/workflow/graph_engine/entities/graph.py
#	api/core/workflow/graph_engine/graph_engine.py
#	api/core/workflow/nodes/enums.py
#	api/services/dataset_service.py
This commit is contained in:
jyong
2025-08-27 16:05:59 +08:00
385 changed files with 23289 additions and 11938 deletions

View File

@ -1,3 +1,3 @@
from .enums import NodeType
from core.workflow.enums import NodeType
__all__ = ["NodeType"]

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from packaging.version import Version
from pydantic import ValidationError
@ -9,16 +9,12 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.agent.strategy.plugin import PluginAgentStrategy
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
@ -29,17 +25,19 @@ from core.tools.entities.tool_entities import (
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import AgentLogEvent, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
@ -57,13 +55,17 @@ from .exc import (
ToolFileNotFoundError,
)
if TYPE_CHECKING:
from core.agent.strategy.plugin import PluginAgentStrategy
from core.plugin.entities.request import InvokeCredentials
class AgentNode(BaseNode):
class AgentNode(Node):
"""
Agent Node
"""
_node_type = NodeType.AGENT
node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
@ -92,6 +94,8 @@ class AgentNode(BaseNode):
return "1"
def _run(self) -> Generator:
from core.plugin.impl.exc import PluginDaemonClientSideError
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
@ -99,12 +103,12 @@ class AgentNode(BaseNode):
agent_strategy_name=self._node_data.agent_strategy_name,
)
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
)
),
)
return
@ -139,8 +143,8 @@ class AgentNode(BaseNode):
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
@ -158,16 +162,16 @@ class AgentNode(BaseNode):
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_type=self.type_,
node_id=self.node_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
@ -181,7 +185,7 @@ class AgentNode(BaseNode):
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: PluginAgentStrategy,
strategy: "PluginAgentStrategy",
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@ -339,10 +343,11 @@ class AgentNode(BaseNode):
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
) -> "InvokeCredentials":
"""
Generate credentials based on the given agent parameters.
"""
from core.plugin.entities.request import InvokeCredentials
credentials = InvokeCredentials()
@ -388,6 +393,8 @@ class AgentNode(BaseNode):
Get agent strategy icon
:return:
"""
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id)
try:
@ -451,7 +458,9 @@ class AgentNode(BaseNode):
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
def _filter_mcp_type_tool(
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy
@ -479,6 +488,8 @@ class AgentNode(BaseNode):
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
@ -492,7 +503,7 @@ class AgentNode(BaseNode):
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
@ -554,7 +565,11 @@ class AgentNode(BaseNode):
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
@ -571,7 +586,11 @@ class AgentNode(BaseNode):
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
@ -588,8 +607,10 @@ class AgentNode(BaseNode):
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
@ -640,7 +661,7 @@ class AgentNode(BaseNode):
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
@ -653,7 +674,7 @@ class AgentNode(BaseNode):
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
if log.message_id == agent_log.message_id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
@ -674,7 +695,7 @@ class AgentNode(BaseNode):
for log in agent_logs:
json_output.append(
{
"id": log.id,
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
@ -690,8 +711,24 @@ class AgentNode(BaseNode):
else:
json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,

View File

@ -1,4 +0,0 @@
from .answer_node import AnswerNode
from .entities import AnswerStreamGenerateRoute
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]

View File

@ -1,24 +1,19 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base import BaseNode
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode):
_node_type = NodeType.ANSWER
class AnswerNode(Node):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.RESPONSE
_node_data: AnswerNodeData
@ -48,35 +43,29 @@ class AnswerNode(BaseNode):
return "1"
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
answer = ""
files = []
for part in generate_routes:
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
variable = self.graph_runtime_state.variable_pool.get(value_selector)
if variable:
if isinstance(variable, FileSegment):
files.append(variable.value)
elif isinstance(variable, ArrayFileSegment):
files.extend(variable.value)
answer += variable.markdown
else:
part = cast(TextGenerateRouteChunk, part)
answer += part.text
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
files = self._extract_files_from_segments(segments.value)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)},
)
def _extract_files_from_segments(self, segments: Sequence[Segment]):
"""Extract all files from segments containing FileSegment or ArrayFileSegment instances.
FileSegment contains a single file, while ArrayFileSegment contains multiple files.
This method flattens all files into a single list.
"""
files = []
for segment in segments:
if isinstance(segment, FileSegment):
# Single file - wrap in list for consistency
files.append(segment.value)
elif isinstance(segment, ArrayFileSegment):
# Multiple files - extend the list
files.extend(segment.value)
return files
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -96,3 +85,12 @@ class AnswerNode(BaseNode):
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return variable_mapping
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this Answer node
"""
return Template.from_answer_template(self._node_data.answer)

View File

@ -1,174 +0,0 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamGeneratorRouter:
@classmethod
def init(
cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
for answer_node_id, node_config in node_id_config_mapping.items():
if node_config.get("data", {}).get("type") != NodeType.ANSWER.value:
continue
# get generate route for stream output
generate_route = cls._extract_generate_route_selectors(node_config)
answer_generate_route[answer_node_id] = generate_route
# fetch answer dependencies
answer_node_ids = list(answer_generate_route.keys())
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping,
)
return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω")
generate_routes: list[GenerateRouteChunk] = []
for part in template.split("Ω"):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "")
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector))
else:
generate_routes.append(TextGenerateRouteChunk(text=part))
return generate_routes
@classmethod
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = AnswerNodeData(**config.get("data", {}))
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace("{{", "").replace("}}", "")
return part.startswith("{{") and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(
cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
answer_dependencies: dict[str, list[str]] = {}
for answer_node_id in answer_node_ids:
if answer_dependencies.get(answer_node_id) is None:
answer_dependencies[answer_node_id] = []
cls._recursive_fetch_answer_dependencies(
current_node_id=answer_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(
cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (
source_node_type
in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.LOOP,
NodeType.VARIABLE_ASSIGNER,
}
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)

View File

@ -1,202 +0,0 @@
import logging
from collections.abc import Generator
from typing import cast
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunExceptionEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
for answer_node_id in self.generate_routes.answer_generate_route:
self.route_position[answer_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id or event.in_loop_id:
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
stream_out_answer_node_ids
)
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
else:
yield event
def reset(self) -> None:
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for answer_node_id in self.route_position:
# all depends on answer node id not in rest node ids
if event.route_node_state.node_id != answer_node_id and (
answer_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
)
):
continue
route_position = self.route_position[answer_node_id]
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
for route_chunk in route_chunks:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
break
value = self.variable_pool.get(value_selector)
if value is None:
break
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=list(value_selector),
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[answer_node_id] += 1
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_answer_node_ids = []
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies
edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids):
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
continue
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids

View File

@ -1,109 +0,0 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
logger = logging.getLogger(__name__)
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids: list[str] = []
unreachable_first_node_ids: list[str] = []
if finished_node_id not in self.graph.edge_mapping:
logger.warning("node %s has no edge mapping", finished_node_id)
return
for edge in self.graph.edge_mapping[finished_node_id]:
if (
edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify
):
# remove unreachable nodes
# FIXME: because of the code branch can combine directly, so for answer node
# we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564
# ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
# if "answer" in ids:
# continue
# else:
# reachable_node_ids.extend(ids)
# The branch_identify parameter is added to ensure that
# only nodes in the correct logical branch are included.
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
reachable_node_ids.extend(ids)
else:
# if the condition edge in parallel, and the target node is not in parallel, we should not remove it
# Issues: #13626
if (
finished_node_id in self.graph.node_parallel_mapping
and edge.target_node_id not in self.graph.node_parallel_mapping
):
continue
unreachable_first_node_ids.append(edge.target_node_id)
unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids))
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
if node_id not in self.rest_node_ids:
self.rest_node_ids.append(node_id)
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
# Only follow edges that match the branch_identify or have no run_condition
if edge.run_condition and edge.run_condition.branch_identify:
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
if node_id in reachable_node_ids:
return
self.rest_node_ids.remove(node_id)
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)

View File

@ -1,11 +1,9 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .node import BaseNode
__all__ = [
"BaseIterationNodeData",
"BaseIterationState",
"BaseLoopNodeData",
"BaseLoopState",
"BaseNode",
"BaseNodeData",
]

View File

@ -1,12 +1,37 @@
import json
from abc import ABC
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Optional, Union
from pydantic import BaseModel, model_validator
from core.workflow.nodes.base.exc import DefaultValueTypeError
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.enums import ErrorStrategy
from .exc import DefaultValueTypeError
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: Sequence[str]
class DefaultValueType(StrEnum):
@ -19,9 +44,6 @@ class DefaultValueType(StrEnum):
ARRAY_FILES = "array[file]"
NumberType = Union[int, float]
class DefaultValue(BaseModel):
value: Any
type: DefaultValueType
@ -61,7 +83,7 @@ class DefaultValue(BaseModel):
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": NumberType,
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
@ -70,7 +92,7 @@ class DefaultValue(BaseModel):
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": NumberType,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
@ -107,18 +129,6 @@ class DefaultValue(BaseModel):
return self
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None

View File

@ -1,81 +1,168 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from uuid import uuid4
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import (
AgentLogEvent,
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
IterationSucceededEvent,
LoopFailedEvent,
LoopNextEvent,
LoopStartedEvent,
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
)
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
from .entities import BaseNodeData, RetryConfig
if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeType
from core.workflow.node_events import NodeRunResult
logger = logging.getLogger(__name__)
class BaseNode:
_node_type: ClassVar[NodeType]
class Node:
node_type: ClassVar["NodeType"]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
self.user_from = UserFrom(graph_init_params.user_from)
self.invoke_from = InvokeFrom(graph_init_params.invoke_from)
self.workflow_call_depth = graph_init_params.call_depth
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id
self.state: NodeState = NodeState.UNKNOWN # node execution state
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
self.node_id = node_id
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
@abstractmethod
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
def _run(self) -> "NodeRunResult | Generator[GraphNodeEventBase, None, None]":
"""
Run node
:return:
"""
raise NotImplementedError
def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
def run(self) -> "Generator[GraphNodeEventBase, None, None]":
# Generate a single node execution ID to use for all events
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
self._start_at = naive_utc_now()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.title,
in_iteration_id=None,
start_at=self._start_at,
)
# === FIXME(-LAN-): Needs to refactor.
from core.workflow.nodes.tool.tool_node import ToolNode
if isinstance(self, ToolNode):
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from typing import cast
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.agent.entities import AgentNodeData
if isinstance(self, AgentNode):
start_event.agent_strategy = AgentNodeStrategyInit(
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
icon=self.agent_strategy_icon,
)
# ===
yield start_event
try:
result = self._run()
# Handle NodeRunResult
if isinstance(result, NodeRunResult):
yield self._convert_node_run_result_to_graph_node_event(result)
return
# Handle event stream
for event in result:
if isinstance(event, NodeEventBase):
event = self._convert_node_event_to_graph_node_event(event)
if not event.in_iteration_id and not event.in_loop_id:
event.id = self._node_execution_id
yield event
except Exception as e:
logger.exception("Node %s failed to run", self.node_id)
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type="WorkflowNodeError",
)
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
else:
yield from result
yield NodeRunFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=result,
error=str(e),
)
@classmethod
def extract_variable_selector_to_variable_mapping(
@ -140,14 +227,22 @@ class BaseNode:
) -> Mapping[str, Sequence[str]]:
return {}
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this node blocks the output of specific variables.
This method is used to determine if a node must complete execution before
the specified variables can be used in streaming output.
:param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str'))
:return: True if this node blocks output of any of the specified variables, False otherwise
"""
return False
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {}
@property
def type_(self) -> NodeType:
return self._node_type
@classmethod
@abstractmethod
def version(cls) -> str:
@ -158,10 +253,6 @@ class BaseNode:
# in `api/core/workflow/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@property
def continue_on_error(self) -> bool:
return False
@property
def retry(self) -> bool:
return False
@ -170,7 +261,7 @@ class BaseNode:
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> Optional["ErrorStrategy"]:
"""Get the error strategy for this node."""
...
@ -201,7 +292,7 @@ class BaseNode:
# Public interface properties that delegate to abstract methods
@property
def error_strategy(self) -> Optional[ErrorStrategy]:
def error_strategy(self) -> Optional["ErrorStrategy"]:
"""Get the error strategy for this node."""
return self._get_error_strategy()
@ -224,3 +315,198 @@ class BaseNode:
def default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
return self._get_default_value_dict()
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=result,
error=result.error,
)
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=result,
)
raise Exception(f"result status {result.status} not supported")
def _convert_node_event_to_graph_node_event(self, event: NodeEventBase) -> GraphNodeEventBase:
handler_maps: dict[type[NodeEventBase], Callable[[Any], GraphNodeEventBase]] = {
StreamChunkEvent: self._handle_stream_chunk_event,
StreamCompletedEvent: self._handle_stream_completed_event,
AgentLogEvent: self._handle_agent_log_event,
LoopStartedEvent: self._handle_loop_started_event,
LoopNextEvent: self._handle_loop_next_event,
LoopSucceededEvent: self._handle_loop_succeeded_event,
LoopFailedEvent: self._handle_loop_failed_event,
IterationStartedEvent: self._handle_iteration_started_event,
IterationNextEvent: self._handle_iteration_next_event,
IterationSucceededEvent: self._handle_iteration_succeeded_event,
IterationFailedEvent: self._handle_iteration_failed_event,
RunRetrieverResourceEvent: self._handle_run_retriever_resource_event,
}
handler = handler_maps.get(type(event))
if not handler:
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
return handler(event)
def _handle_stream_chunk_event(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
)
def _handle_stream_completed_event(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=event.node_run_result,
)
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=event.node_run_result,
error=event.node_run_result.error,
)
raise NotImplementedError(f"Node {self._node_id} does not support status {event.node_run_result.status}")
def _handle_agent_log_event(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
message_id=event.message_id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
)
def _handle_loop_started_event(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
metadata=event.metadata,
predecessor_node_id=event.predecessor_node_id,
)
def _handle_loop_next_event(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
return NodeRunLoopNextEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
index=event.index,
pre_loop_output=event.pre_loop_output,
)
def _handle_loop_succeeded_event(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
return NodeRunLoopSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
)
def _handle_loop_failed_event(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
return NodeRunLoopFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error,
)
def _handle_iteration_started_event(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
return NodeRunIterationStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
metadata=event.metadata,
predecessor_node_id=event.predecessor_node_id,
)
def _handle_iteration_next_event(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
return NodeRunIterationNextEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
index=event.index,
pre_iteration_output=event.pre_iteration_output,
)
def _handle_iteration_succeeded_event(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
return NodeRunIterationSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
)
def _handle_iteration_failed_event(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
return NodeRunIterationFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error,
)
def _handle_run_retriever_resource_event(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
return NodeRunRetrieverResourceEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=event.retriever_resources,
context=event.context,
node_version=self.version(),
)

View File

@ -0,0 +1,148 @@
"""Template structures for Response nodes (Answer and End).
This module provides a unified template structure for both Answer and End nodes,
similar to SegmentGroup but focused on template representation without values.
"""
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Union
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
@dataclass(frozen=True)
class TemplateSegment(ABC):
"""Base class for template segments."""
@abstractmethod
def __str__(self) -> str:
"""String representation of the segment."""
pass
@dataclass(frozen=True)
class TextSegment(TemplateSegment):
"""A text segment in a template."""
text: str
def __str__(self) -> str:
return self.text
@dataclass(frozen=True)
class VariableSegment(TemplateSegment):
"""A variable reference segment in a template."""
selector: Sequence[str]
variable_name: str | None = None # Optional variable name for End nodes
def __str__(self) -> str:
return "{{#" + ".".join(self.selector) + "#}}"
# Type alias for segments
TemplateSegmentUnion = Union[TextSegment, VariableSegment]
@dataclass(frozen=True)
class Template:
"""Unified template structure for Response nodes.
Similar to SegmentGroup, but represents the template structure
without variable values - only marking variable selectors.
"""
segments: list[TemplateSegmentUnion]
@classmethod
def from_answer_template(cls, template_str: str) -> "Template":
"""Create a Template from an Answer node template string.
Example:
"Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])]
Args:
template_str: The answer template string
Returns:
Template instance
"""
parser = VariableTemplateParser(template_str)
segments: list[TemplateSegmentUnion] = []
# Extract variable selectors to find all variables
variable_selectors = parser.extract_variable_selectors()
var_map = {var.variable: var.value_selector for var in variable_selectors}
# Parse template to get ordered segments
# We need to split the template by variable placeholders while preserving order
import re
# Create a regex pattern that matches variable placeholders
pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}"
# Split template while keeping the delimiters (variable placeholders)
parts = re.split(pattern, template_str)
for i, part in enumerate(parts):
if not part:
continue
# Check if this part is a variable reference (odd indices after split)
if i % 2 == 1: # Odd indices are variable keys
# Remove the # symbols from the variable key
var_key = part
if var_key in var_map:
segments.append(VariableSegment(selector=list(var_map[var_key])))
else:
# This shouldn't happen with valid templates
segments.append(TextSegment(text="{{" + part + "}}"))
else:
# Even indices are text segments
segments.append(TextSegment(text=part))
return cls(segments=segments)
@classmethod
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
"""Create a Template from an End node outputs configuration.
End nodes are treated as templates of concatenated variables with newlines.
Example:
[{"variable": "text", "value_selector": ["node1", "text"]},
{"variable": "result", "value_selector": ["node2", "result"]}]
->
[VariableSegment(["node1", "text"]),
TextSegment("\n"),
VariableSegment(["node2", "result"])]
Args:
outputs_config: List of output configurations with variable and value_selector
Returns:
Template instance
"""
segments: list[TemplateSegmentUnion] = []
for i, output in enumerate(outputs_config):
if i > 0:
# Add newline separator between variables
segments.append(TextSegment(text="\n"))
value_selector = output.get("value_selector", [])
variable_name = output.get("variable", "")
if value_selector:
segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name))
if len(segments) > 0 and isinstance(segments[-1], TextSegment):
segments = segments[:-1]
return cls(segments=segments)
def __str__(self) -> str:
"""String representation of the template."""
return "".join(str(segment) for segment in self.segments)

View File

@ -0,0 +1,130 @@
import re
from collections.abc import Mapping, Sequence
from typing import Any
from .entities import VariableSelector
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]:
parts = SELECTOR_PATTERN.split(template)
selectors = []
for part in filter(lambda x: x, parts):
if "." in part and part[0] == "#" and part[-1] == "#":
selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split(".")))
return selectors
class VariableTemplateParser:
"""
!NOTE: Consider to use the new `segments` module instead of this class.
A class for parsing and manipulating template variables in a string.
Rules:
1. Template variables must be enclosed in `{{}}`.
2. The template variable Key can only be: #node_id.var1.var2#.
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
Example usage:
template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}."
parser = VariableTemplateParser(template)
# Extract template variable keys
variable_keys = parser.extract()
print(variable_keys)
# Output: ['#node_id.query.name#', '#node_id.query.age#']
# Extract variable selectors
variable_selectors = parser.extract_variable_selectors()
print(variable_selectors)
# Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']),
# VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])]
# Format the template string
inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}}
formatted_string = parser.format(inputs)
print(formatted_string)
# Output: "Hello, John! Your age is 25."
"""
def __init__(self, template: str):
self.template = template
self.variable_keys = self.extract()
def extract(self) -> list:
"""
Extracts all the template variable keys from the template string.
Returns:
A list of template variable keys.
"""
# Regular expression to match the template rules
matches = re.findall(REGEX, self.template)
first_group_matches = [match[0] for match in matches]
return list(set(first_group_matches))
def extract_variable_selectors(self) -> list[VariableSelector]:
"""
Extracts the variable selectors from the template variable keys.
Returns:
A list of VariableSelector objects representing the variable selectors.
"""
variable_selectors = []
for variable_key in self.variable_keys:
remove_hash = variable_key.replace("#", "")
split_result = remove_hash.split(".")
if len(split_result) < 2:
continue
variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result))
return variable_selectors
def format(self, inputs: Mapping[str, Any]) -> str:
"""
Formats the template string by replacing the template variables with their corresponding values.
Args:
inputs: A dictionary containing the values for the template variables.
Returns:
The formatted string with template variables replaced by their values.
"""
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
if value is None:
value = ""
# convert the value to string
if isinstance(value, list | dict | bool | int | float):
value = str(value)
# remove template variables if required
return VariableTemplateParser.remove_template_variables(value)
prompt = re.sub(REGEX, replacer, self.template)
return re.sub(r"<\|.*?\|>", "", prompt)
@classmethod
def remove_template_variables(cls, text: str):
"""
Removes the template variables from the given text.
Args:
text: The text from which to remove the template variables.
Returns:
The text with template variables removed.
"""
return re.sub(REGEX, r"{\1}", text)

View File

@ -9,12 +9,11 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .exc import (
CodeNodeError,
@ -23,8 +22,8 @@ from .exc import (
)
class CodeNode(BaseNode):
_node_type = NodeType.CODE
class CodeNode(Node):
node_type = NodeType.CODE
_node_data: CodeNodeData
@ -403,6 +402,7 @@ class CodeNode(BaseNode):
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = CodeNodeData.model_validate(node_data)
@ -411,10 +411,6 @@ class CodeNode(BaseNode):
for variable_selector in typed_node_data.variables
}
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -4,8 +4,8 @@ from pydantic import AfterValidator, BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
[

View File

@ -25,11 +25,10 @@ from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
@ -37,13 +36,13 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode):
class DocumentExtractorNode(Node):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
"""
_node_type = NodeType.DOCUMENT_EXTRACTOR
node_type = NodeType.DOCUMENT_EXTRACTOR
_node_data: DocumentExtractorNodeData

View File

@ -1,4 +0,0 @@
from .end_node import EndNode
from .entities import EndStreamParam
__all__ = ["EndNode", "EndStreamParam"]

View File

@ -1,16 +1,17 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.enums import ErrorStrategy, NodeType
class EndNode(BaseNode):
_node_type = NodeType.END
class EndNode(Node):
node_type = NodeType.END
execution_type = NodeExecutionType.RESPONSE
_node_data: EndNodeData
@ -41,8 +42,10 @@ class EndNode(BaseNode):
def _run(self) -> NodeRunResult:
"""
Run node
:return:
Run node - collect all outputs at once.
This method runs after streaming is complete (if streaming was enabled).
It collects all output variables and returns them.
"""
output_variables = self._node_data.outputs
@ -57,3 +60,15 @@ class EndNode(BaseNode):
inputs=outputs,
outputs=outputs,
)
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this End node
"""
outputs_config = [
{"variable": output.variable, "value_selector": output.value_selector} for output in self._node_data.outputs
]
return Template.from_end_outputs(outputs_config)

View File

@ -1,152 +0,0 @@
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
from core.workflow.nodes.enums import NodeType
class EndStreamGeneratorRouter:
@classmethod
def init(
cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_parallel_mapping: dict[str, str],
) -> EndStreamParam:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selector of end nodes
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
for end_node_id, node_config in node_id_config_mapping.items():
if node_config.get("data", {}).get("type") != NodeType.END.value:
continue
# skip end node in parallel
if end_node_id in node_parallel_mapping:
continue
# get generate route for stream output
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors
# fetch end dependencies
end_node_ids = list(end_stream_variable_selectors_mapping.keys())
end_dependencies = cls._fetch_ends_dependencies(
end_node_ids=end_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping,
)
return EndStreamParam(
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
end_dependencies=end_dependencies,
)
@classmethod
def extract_stream_variable_selector_from_node_data(
cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData
) -> list[list[str]]:
"""
Extract stream variable selector from node data
:param node_id_config_mapping: node id config mapping
:param node_data: node data object
:return:
"""
variable_selectors = node_data.outputs
value_selectors = []
for variable_selector in variable_selectors:
if not variable_selector.value_selector:
continue
node_id = variable_selector.value_selector[0]
if node_id != "sys" and node_id in node_id_config_mapping:
node = node_id_config_mapping[node_id]
node_type = node.get("data", {}).get("type")
if (
variable_selector.value_selector not in value_selectors
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == "text"
):
value_selectors.append(list(variable_selector.value_selector))
return value_selectors
@classmethod
def _extract_stream_variable_selector(
cls, node_id_config_mapping: dict[str, dict], config: dict
) -> list[list[str]]:
"""
Extract stream variable selector from node config
:param node_id_config_mapping: node id config mapping
:param config: node config
:return:
"""
node_data = EndNodeData(**config.get("data", {}))
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
@classmethod
def _fetch_ends_dependencies(
cls,
end_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch end dependencies
:param end_node_ids: end node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
end_dependencies: dict[str, list[str]] = {}
for end_node_id in end_node_ids:
if end_dependencies.get(end_node_id) is None:
end_dependencies[end_node_id] = []
cls._recursive_fetch_end_dependencies(
current_node_id=end_node_id,
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies,
)
return end_dependencies
@classmethod
def _recursive_fetch_end_dependencies(
cls,
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
end_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch end dependencies
:param current_node_id: current node id
:param end_node_id: end node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param end_dependencies: end dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
}:
end_dependencies[end_node_id].append(source_node_id)
else:
cls._recursive_fetch_end_dependencies(
current_node_id=source_node_id,
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies,
)

View File

@ -1,188 +0,0 @@
import logging
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.end_stream_param = graph.end_stream_param
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
self.has_output = False
self.output_node_ids: set[str] = set()
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id or event.in_loop_id:
if self.has_output and event.node_id not in self.output_node_ids:
event.chunk_content = "\n" + event.chunk_content
self.output_node_ids.add(event.node_id)
self.has_output = True
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
stream_out_end_node_ids
)
if stream_out_end_node_ids:
if self.has_output and event.node_id not in self.output_node_ids:
event.chunk_content = "\n" + event.chunk_content
self.output_node_ids.add(event.node_id)
self.has_output = True
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[end_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
else:
yield event
def reset(self) -> None:
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for end_node_id, position in self.route_position.items():
# all depends on end node id not in rest node ids
if event.route_node_state.node_id != end_node_id and (
end_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]
)
):
continue
route_position = self.route_position[end_node_id]
position = 0
value_selectors = []
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
if position >= route_position:
value_selectors.append(current_value_selectors)
position += 1
for value_selector in value_selectors:
if not value_selector:
continue
value = self.variable_pool.get(value_selector)
if value is None:
break
text = value.markdown
if text:
current_node_id = value_selector[0]
if self.has_output and current_node_id not in self.output_node_ids:
text = "\n" + text
self.output_node_ids.add(current_node_id)
self.has_output = True
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[end_node_id] += 1
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_end_node_ids = []
for end_node_id, route_position in self.route_position.items():
if end_node_id not in self.rest_node_ids:
continue
# all depends on end node id not in rest node ids
if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
continue
position = 0
value_selector = None
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
if position == route_position:
value_selector = current_value_selectors
break
position += 1
if not value_selector:
continue
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_end_node_ids.append(end_node_id)
return stream_out_end_node_ids

View File

@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
class EndNodeData(BaseNodeData):

View File

@ -1,39 +0,0 @@
from enum import StrEnum
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
DATASOURCE = "datasource"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
LOOP_START = "loop-start"
LOOP_END = "loop-end"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
class ErrorStrategy(StrEnum):
FAIL_BRANCH = "fail-branch"
DEFAULT_VALUE = "default-value"
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"

View File

@ -1,17 +0,0 @@
from .event import (
ModelInvokeCompletedEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
RunStreamChunkEvent,
)
from .types import NodeEvent
__all__ = [
"ModelInvokeCompletedEvent",
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"RunStreamChunkEvent",
]

View File

@ -1,40 +0,0 @@
from collections.abc import Sequence
from datetime import datetime
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import NodeRunResult
class RunCompletedEvent(BaseModel):
run_result: NodeRunResult = Field(..., description="run result")
class RunStreamChunkEvent(BaseModel):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
class RunRetrieverResourceEvent(BaseModel):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class ModelInvokeCompletedEvent(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
finish_reason: str | None = None
class RunRetryEvent(BaseModel):
"""Node Run Retry event"""
error: str = Field(..., description="error")
retry_index: int = Field(..., description="Retry attempt number")
start_at: datetime = Field(..., description="Retry start time")

View File

@ -1,3 +0,0 @@
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
NodeEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent | ModelInvokeCompletedEvent

View File

@ -15,7 +15,7 @@ from core.file import file_manager
from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import VariablePool
from .entities import (
HttpRequestNodeAuthorization,
@ -329,22 +329,16 @@ class Executor:
"""
do http request depending on api bundle
"""
if self.method not in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
@ -362,11 +356,11 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
response: httpx.Response = _METHOD_MAP[method_lc](**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response # type: ignore
return response
def invoke(self) -> Response:
# assemble headers

View File

@ -7,14 +7,12 @@ from configs import dify_config
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser
from factories import file_factory
from .entities import (
@ -33,8 +31,8 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
class HttpRequestNode(BaseNode):
_node_type = NodeType.HTTP_REQUEST
class HttpRequestNode(Node):
node_type = NodeType.HTTP_REQUEST
_node_data: HttpRequestNodeData
@ -101,7 +99,7 @@ class HttpRequestNode(BaseNode):
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and (self.continue_on_error or self.retry):
if not response.response.is_success and (self.error_strategy or self.retry):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
@ -129,7 +127,7 @@ class HttpRequestNode(BaseNode):
},
)
except HttpRequestNodeError as e:
logger.warning("http request node %s failed to run: %s", self.node_id, e)
logger.warning("http request node %s failed to run: %s", self._node_id, e)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
@ -244,10 +242,6 @@ class HttpRequestNode(BaseNode):
return ArrayFileSegment(value=files)
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -3,19 +3,19 @@ from typing import Any, Literal, Optional
from typing_extensions import deprecated
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.entities import VariablePool
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(BaseNode):
_node_type = NodeType.IF_ELSE
class IfElseNode(Node):
node_type = NodeType.IF_ELSE
execution_type = NodeExecutionType.BRANCH
_node_data: IfElseNodeData
@ -49,13 +49,13 @@ class IfElseNode(BaseNode):
Run node
:return:
"""
node_inputs: dict[str, list] = {"conditions": []}
node_inputs: dict[str, Sequence[Mapping[str, Any]]] = {"conditions": []}
process_data: dict[str, list] = {"condition_results": []}
input_conditions = []
input_conditions: Sequence[Mapping[str, Any]] = []
final_result = False
selected_case_id = None
selected_case_id = "false"
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used

View File

@ -1,48 +1,35 @@
import contextvars
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
from datetime import datetime
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from flask import Flask, current_app
from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import (
NodeRunResult,
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
InNodeEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
GraphRunSucceededEvent,
)
from core.workflow.node_events import (
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
IterationSucceededEvent,
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@ -54,17 +41,18 @@ from .exc import (
)
if TYPE_CHECKING:
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
class IterationNode(BaseNode):
class IterationNode(Node):
"""
Iteration Node.
"""
_node_type = NodeType.ITERATION
node_type = NodeType.ITERATION
execution_type = NodeExecutionType.CONTAINER
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
@ -103,10 +91,7 @@ class IterationNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
def _run(self) -> Generator:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable:
@ -121,8 +106,8 @@ class IterationNode(BaseNode):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
@ -138,190 +123,76 @@ class IterationNode(BaseNode):
inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
root_node_id = self._node_data.start_node_id
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool
variable_pool.add([self.node_id, "index"], 0)
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,
)
start_at = naive_utc_now()
yield IterationRunStartedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"iterator_length": len(iterator_list_value)},
predecessor_node_id=self.previous_node_id,
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=0,
pre_iteration_output=None,
duration=None,
)
started_at = naive_utc_now()
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
outputs: list[Any] = []
yield IterationStartedEvent(
start_at=started_at,
inputs=inputs,
metadata={"iteration_length": len(iterator_list_value)},
)
try:
if self._node_data.is_parallel:
futures: list[Future] = []
q: Queue = Queue()
thread_pool = GraphEngineThreadPool(
max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
for index, item in enumerate(iterator_list_value):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
yield IterationNextEvent(index=index)
graph_engine = self._create_graph_engine(index, item)
# Run the iteration
yield from self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs,
graph_engine=graph_engine,
)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
flask_app=current_app._get_current_object(), # type: ignore
q=q,
context=contextvars.copy_context(),
iterator_list_value=iterator_list_value,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
index=index,
item=item,
iter_run_map=iter_run_map,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
if isinstance(event, IterationRunNextEvent):
succeeded_count += 1
if succeeded_count == len(futures):
q.put(None)
yield event
if isinstance(event, RunCompletedEvent):
q.put(None)
for f in futures:
if not f.done():
f.cancel()
yield event
if isinstance(event, IterationRunFailedEvent):
q.put(None)
yield event
except Empty:
continue
# wait all threads
wait(futures)
else:
for _ in range(len(iterator_list_value)):
yield from self._run_single_iter(
iterator_list_value=iterator_list_value,
variable_pool=variable_pool,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
)
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]
output_segment = build_segment(outputs)
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
yield RunCompletedEvent(
run_result=NodeRunResult(
# Yield final success event
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": output_segment},
outputs={"output": outputs},
metadata={
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
},
)
)
except IterationNodeError as e:
# iteration run failed
logger.warning("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
yield IterationFailedEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
finally:
# remove iteration variable (item, index) from variable pool after iteration run completed
variable_pool.remove([self.node_id, "index"])
variable_pool.remove([self.node_id, "item"])
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -339,12 +210,45 @@ class IterationNode(BaseNode):
}
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Create minimal GraphRuntimeState for static analysis
from core.workflow.entities import VariablePool
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
iteration_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
continue
@ -382,297 +286,120 @@ class IterationNode(BaseNode):
return variable_mapping
def _handle_event_metadata(
def _append_iteration_info_to_event(
self,
*,
event: BaseNodeEvent | InNodeEvent,
event: GraphNodeEventBase,
iter_run_index: int,
parallel_mode_run_id: str | None,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
ensures iteration context (ID, index/parallel_run_id) is added to metadata,
"""
if not isinstance(event, BaseNodeEvent):
return event
if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
):
event.in_iteration_id = self._node_id
iter_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self.node_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id,
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
}
if parallel_mode_run_id:
# for parallel, the specific branch ID is more important than the sequential index
iter_metadata[WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
if event.route_node_state.node_run_result:
current_metadata = event.route_node_state.node_run_result.metadata or {}
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
return event
current_metadata = event.node_run_result.metadata
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **iter_metadata}
def _run_single_iter(
self,
*,
iterator_list_value: Sequence[str],
variable_pool: VariablePool,
inputs: Mapping[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
iter_run_map: dict[str, float],
parallel_mode_run_id: Optional[str] = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration
"""
iter_start_at = naive_utc_now()
) -> Generator[Union[GraphNodeEventBase, StreamCompletedEvent], None, None]:
rst = graph_engine.run()
# get current iteration index
index_variable = variable_pool.get([self._node_id, "index"])
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found")
current_index = index_variable.value
for event in rst:
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START:
continue
try:
rst = graph_engine.run()
# get current iteration index
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
current_index = index_variable.value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
continue
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
else:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
if isinstance(event, GraphNodeEventBase):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event
elif isinstance(event, GraphRunSucceededEvent):
result = variable_pool.get(self._node_data.output_selector)
if result is None:
outputs.append(None)
else:
outputs.append(result.to_object())
return
elif isinstance(event, GraphRunFailedEvent):
match self._node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
raise IterationNodeError(event.error)
case ErrorHandleMode.CONTINUE_ON_ERROR:
outputs.append(None)
return
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
return
elif isinstance(event, InNodeEvent):
# event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent):
if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
outputs[current_index] = None
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
outputs[current_index] = None
def _create_graph_engine(self, index: int, item: Any):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
# clean nodes resources
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
user_id=self.user_id,
user_from=self.user_from.value,
invoke_from=self.invoke_from.value,
call_depth=self.workflow_call_depth,
)
# Create a deep copy of the variable pool for each iteration
variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True)
# iteration run failed
if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
else:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
# append iteration variable (item, index) to variable pool
variable_pool_copy.add([self._node_id, "index"], index)
variable_pool_copy.add([self._node_id, "item"], item)
# stop the iterator
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
yield metadata_event
# Create a new GraphRuntimeState for this iteration
graph_runtime_state_copy = GraphRuntimeState(
variable_pool=variable_pool_copy,
start_at=self.graph_runtime_state.start_at,
total_tokens=0,
node_run_steps=0,
)
current_output_segment = variable_pool.get(self._node_data.output_selector)
if current_output_segment is None:
raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value
outputs[current_index] = current_iteration_output
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# move to next iteration
variable_pool.add([self.node_id, "index"], next_index)
# Initialize the iteration graph with the new node factory
iteration_graph = Graph.init(
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (naive_utc_now() - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=current_iteration_output or None,
duration=duration,
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
except IterationNodeError as e:
logger.warning("Iteration run failed:%s", str(e))
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=self.graph_config,
graph_runtime_state=graph_runtime_state_copy,
max_execution_steps=10000, # Use default or config value
max_execution_time=600, # Use default or config value
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)
def _run_single_iter_parallel(
self,
*,
flask_app: Flask,
context: contextvars.Context,
q: Queue,
iterator_list_value: Sequence[str],
inputs: Mapping[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
index: int,
item: Any,
iter_run_map: dict[str, float],
):
"""
run single iteration in parallel mode
"""
with preserve_flask_contexts(flask_app, context_vars=context):
parallel_mode_run_id = uuid.uuid4().hex
graph_engine_copy = graph_engine.create_copy()
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
variable_pool_copy.add([self.node_id, "index"], index)
variable_pool_copy.add([self.node_id, "item"], item)
for event in self._run_single_iter(
iterator_list_value=iterator_list_value,
variable_pool=variable_pool_copy,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine_copy,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
parallel_mode_run_id=parallel_mode_run_id,
):
q.put(event)
graph_engine.graph_runtime_state.total_tokens += graph_engine_copy.graph_runtime_state.total_tokens
return graph_engine

View File

@ -1,20 +1,19 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import IterationStartNodeData
class IterationStartNode(BaseNode):
class IterationStartNode(Node):
"""
Iteration Start Node.
"""
_node_type = NodeType.ITERATION_START
node_type = NodeType.ITERATION_START
_node_data: IterationStartNodeData

View File

@ -32,14 +32,11 @@ from core.variables import (
StringSegment,
)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
)
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
@ -70,7 +67,7 @@ from .exc import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphRuntimeState
logger = logging.getLogger(__name__)
@ -83,8 +80,8 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(BaseNode):
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
class KnowledgeRetrievalNode(Node):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
@ -99,10 +96,7 @@ class KnowledgeRetrievalNode(BaseNode):
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
@ -110,10 +104,7 @@ class KnowledgeRetrievalNode(BaseNode):
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
@ -197,7 +188,7 @@ class KnowledgeRetrievalNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
process_data={},
outputs=outputs, # type: ignore
)
@ -429,7 +420,7 @@ class KnowledgeRetrievalNode(BaseNode):
Document.enabled == True,
Document.archived == False,
)
filters = [] # type: ignore
filters: list[Any] = []
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None
@ -443,7 +434,7 @@ class KnowledgeRetrievalNode(BaseNode):
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters, # type: ignore
filters,
)
conditions.append(
Condition(
@ -552,7 +543,8 @@ class KnowledgeRetrievalNode(BaseNode):
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
node_id=self._node_id,
node_type=self.node_type,
)
for event in generator:
@ -573,15 +565,15 @@ class KnowledgeRetrievalNode(BaseNode):
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
except Exception:
return []
return automatic_metadata_filters
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
):
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list[Any]
) -> list[Any]:
if value is None and condition not in ("empty", "not empty"):
return
return filters
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
@ -666,6 +658,7 @@ class KnowledgeRetrievalNode(BaseNode):
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)

View File

@ -4,11 +4,10 @@ from typing import Any, Optional, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
@ -36,8 +35,8 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
return wrapper
class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
class ListOperatorNode(Node):
node_type = NodeType.LIST_OPERATOR
_node_data: ListOperatorNodeData

View File

@ -5,8 +5,8 @@ from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
class ModelConfig(BaseModel):

View File

@ -8,7 +8,7 @@ from core.file import File, FileTransferMethod, FileType
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db
from extensions.ext_database import db as global_db
class LLMFileSaver(tp.Protocol):

View File

@ -13,16 +13,16 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError

View File

@ -3,7 +3,7 @@ import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@ -50,22 +50,25 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
NodeEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunStreamChunkEvent,
from core.workflow.entities import GraphInitParams, VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.node_events import (
ModelInvokeCompletedEvent,
NodeEventBase,
NodeRunResult,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from . import llm_utils
from .entities import (
@ -88,14 +91,13 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.entities import GraphRuntimeState
logger = logging.getLogger(__name__)
class LLMNode(BaseNode):
_node_type = NodeType.LLM
class LLMNode(Node):
node_type = NodeType.LLM
_node_data: LLMNodeData
@ -110,10 +112,7 @@ class LLMNode(BaseNode):
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
@ -121,10 +120,7 @@ class LLMNode(BaseNode):
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
@ -161,9 +157,9 @@ class LLMNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None
def _run(self) -> Generator:
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
@ -182,8 +178,6 @@ class LLMNode(BaseNode):
# merge inputs
inputs.update(jinja_inputs)
node_inputs = {}
# fetch files
files = (
llm_utils.fetch_files(
@ -255,13 +249,14 @@ class LLMNode(BaseNode):
structured_output=self._node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
node_id=self._node_id,
node_type=self.node_type,
)
structured_output: LLMStructuredOutput | None = None
for event in generator:
if isinstance(event, RunStreamChunkEvent):
if isinstance(event, StreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
@ -290,8 +285,15 @@ class LLMNode(BaseNode):
if self._file_outputs is not None:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
yield RunCompletedEvent(
run_result=NodeRunResult(
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
@ -305,8 +307,8 @@ class LLMNode(BaseNode):
)
)
except ValueError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
@ -316,8 +318,8 @@ class LLMNode(BaseNode):
)
except Exception as e:
logger.exception("error while executing llm node")
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
@ -338,7 +340,8 @@ class LLMNode(BaseNode):
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
node_type: NodeType,
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
)
@ -374,6 +377,7 @@ class LLMNode(BaseNode):
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
node_type=node_type,
)
@staticmethod
@ -383,7 +387,8 @@ class LLMNode(BaseNode):
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
node_type: NodeType,
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
event = LLMNode.handle_blocking_result(
@ -414,7 +419,11 @@ class LLMNode(BaseNode):
file_outputs=file_outputs,
):
full_text_buffer.write(text_part)
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=text_part,
is_final=False,
)
# Update the whole metadata
if not model and result.model:
@ -811,6 +820,8 @@ class LLMNode(BaseNode):
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data)
@ -1070,10 +1081,6 @@ class LLMNode(BaseNode):
logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents)
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -1,7 +1,6 @@
from collections.abc import Mapping
from typing import Annotated, Any, Literal, Optional
from pydantic import AfterValidator, BaseModel, Field
from pydantic import AfterValidator, BaseModel, Field, field_validator
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
@ -35,19 +34,22 @@ class LoopVariableData(BaseModel):
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None
value: Any = None
class LoopNodeData(BaseLoopNodeData):
"""
Loop Node Data.
"""
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
outputs: Optional[Mapping[str, Any]] = None
outputs: dict[str, Any] = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@classmethod
def validate_outputs(cls, v):
if v is None:
return {}
return v
class LoopStartNodeData(BaseNodeData):

View File

@ -1,20 +1,19 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(BaseNode):
class LoopEndNode(Node):
"""
Loop End Node.
"""
_node_type = NodeType.LOOP_END
node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData

View File

@ -1,58 +1,53 @@
import json
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config
from core.variables import (
IntegerSegment,
Segment,
SegmentType,
from core.variables import Segment, SegmentType
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
InNodeEvent,
LoopRunFailedEvent,
LoopRunNextEvent,
LoopRunStartedEvent,
LoopRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.node_events import (
LoopFailedEvent,
LoopNextEvent,
LoopStartedEvent,
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
class LoopNode(BaseNode):
class LoopNode(Node):
"""
Loop Node.
"""
_node_type = NodeType.LOOP
node_type = NodeType.LOOP
_node_data: LoopNodeData
execution_type = NodeExecutionType.CONTAINER
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LoopNodeData.model_validate(data)
@ -79,7 +74,7 @@ class LoopNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def _run(self) -> Generator:
"""Run the node."""
# Get inputs
loop_count = self._node_data.loop_count
@ -89,144 +84,126 @@ class LoopNode(BaseNode):
inputs = {"loop_count": loop_count}
if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
# Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
root_node_id = self._node_data.start_node_id
# Initialize variable pool
variable_pool = self.graph_runtime_state.variable_pool
variable_pool.add([self.node_id, "index"], 0)
# Initialize loop variables
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
if self._node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value),
}
for loop_variable in self._node_data.loop_variables:
value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value),
}
if loop_variable.value_type not in value_processor:
raise ValueError(
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
)
processed_segment = value_processor[loop_variable.value_type]()
processed_segment = value_processor[loop_variable.value_type](loop_variable)
if not processed_segment:
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
variable_selector = [self.node_id, loop_variable.label]
variable_pool.add(variable_selector, processed_segment.value)
variable_selector = [self._node_id, loop_variable.label]
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
self.graph_runtime_state.variable_pool.add(variable_selector, variable)
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,
)
start_at = naive_utc_now()
condition_processor = ConditionProcessor()
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
# Start Loop event
yield LoopRunStartedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
yield LoopStartedEvent(
start_at=start_at,
inputs=inputs,
metadata={"loop_length": loop_count},
predecessor_node_id=self.previous_node_id,
)
# yield LoopRunNextEvent(
# loop_id=self.id,
# loop_node_id=self.node_id,
# loop_node_type=self.node_type,
# loop_node_data=self.node_data,
# index=0,
# pre_loop_output=None,
# )
loop_duration_map = {}
single_loop_variable_map = {} # single loop variable output
try:
check_break_result = False
for i in range(loop_count):
loop_start_time = naive_utc_now()
# run single loop
loop_result = yield from self._run_single_loop(
graph_engine=graph_engine,
loop_graph=loop_graph,
variable_pool=variable_pool,
loop_variable_selectors=loop_variable_selectors,
break_conditions=break_conditions,
logical_operator=logical_operator,
condition_processor=condition_processor,
current_index=i,
start_at=start_at,
inputs=inputs,
reach_break_condition = False
if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
loop_end_time = naive_utc_now()
if reach_break_condition:
loop_count = 0
cost_tokens = 0
for i in range(loop_count):
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
loop_start_time = naive_utc_now()
reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i)
# Track loop duration
loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds()
# Accumulate outputs from the sub-graph's response nodes
for key, value in graph_engine.graph_runtime_state.outputs.items():
if key == "answer":
# Concatenate answer outputs with newline
existing_answer = self.graph_runtime_state.outputs.get("answer", "")
if existing_answer:
self.graph_runtime_state.outputs["answer"] = f"{existing_answer}{value}"
else:
self.graph_runtime_state.outputs["answer"] = value
else:
# For other outputs, just update
self.graph_runtime_state.outputs[key] = value
# Update the total tokens from this iteration
cost_tokens += graph_engine.graph_runtime_state.total_tokens
# Collect loop variable values after iteration
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
item = variable_pool.get(selector)
if item:
single_loop_variable[key] = item.value
else:
single_loop_variable[key] = None
segment = self.graph_runtime_state.variable_pool.get(selector)
single_loop_variable[key] = segment.value if segment else None
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds()
single_loop_variable_map[str(i)] = single_loop_variable
check_break_result = loop_result.get("check_break_result", False)
if check_break_result:
if reach_break_node:
break
if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if reach_break_condition:
break
yield LoopNextEvent(
index=i + 1,
pre_loop_output=self._node_data.outputs,
)
self.graph_runtime_state.total_tokens += cost_tokens
# Loop completed successfully
yield LoopRunSucceededEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
yield LoopSucceededEvent(
start_at=start_at,
inputs=inputs,
outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "loop_break" if check_break_result else "loop_completed",
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: cost_tokens,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
@ -236,18 +213,12 @@ class LoopNode(BaseNode):
)
except Exception as e:
# Loop failed
logger.exception("Loop run failed")
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
yield LoopFailedEvent(
start_at=start_at,
inputs=inputs,
steps=loop_count,
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
"completed_reason": "error",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
@ -255,207 +226,60 @@ class LoopNode(BaseNode):
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
)
finally:
# Clean up
variable_pool.remove([self.node_id, "index"])
def _run_single_loop(
self,
*,
graph_engine: "GraphEngine",
loop_graph: Graph,
variable_pool: "VariablePool",
loop_variable_selectors: dict,
break_conditions: list,
logical_operator: Literal["and", "or"],
condition_processor: ConditionProcessor,
current_index: int,
start_at: datetime,
inputs: dict,
) -> Generator[NodeEvent | InNodeEvent, None, dict]:
"""Run a single loop iteration.
Returns:
dict: {'check_break_result': bool}
"""
# Run workflow
rst = graph_engine.run()
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"loop {self.node_id} current index not found")
current_index = current_index_variable.value
) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]:
reach_break_node = False
for event in graph_engine.run():
if isinstance(event, GraphNodeEventBase):
self._append_loop_info_to_event(event=event, loop_run_index=current_index)
check_break_result = False
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
event.in_loop_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.LOOP_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.LOOP_START:
continue
if isinstance(event, GraphNodeEventBase):
yield event
if isinstance(event, NodeRunSucceededEvent) and event.node_type == NodeType.LOOP_END:
reach_break_node = True
if isinstance(event, GraphRunFailedEvent):
raise Exception(event.error)
if (
isinstance(event, NodeRunSucceededEvent)
and event.node_type == NodeType.LOOP_END
and not isinstance(event, NodeRunStreamChunkEvent)
):
# Check if variables in break conditions exist and process conditions
# Allow loop internal variables to be used in break conditions
available_conditions = []
for condition in break_conditions:
variable = self.graph_runtime_state.variable_pool.get(condition.variable_selector)
if variable:
available_conditions.append(condition)
for loop_var in self._node_data.loop_variables or []:
key, sel = loop_var.label, [self._node_id, loop_var.label]
segment = self.graph_runtime_state.variable_pool.get(sel)
self._node_data.outputs[key] = segment.value if segment else None
self._node_data.outputs["loop_round"] = current_index + 1
# Process conditions if at least one variable is available
if available_conditions:
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=available_conditions,
operator=logical_operator,
)
if check_break_result:
break
else:
check_break_result = True
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
break
return reach_break_node
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# Loop run failed
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
graph_engine.graph_runtime_state.total_tokens
),
"completed_reason": "error",
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: (
graph_engine.graph_runtime_state.total_tokens
)
},
)
)
return {"check_break_result": True}
elif isinstance(event, NodeRunFailedEvent):
# Loop run failed
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
"completed_reason": "error",
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
},
)
)
return {"check_break_result": True}
else:
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
# Remove all nodes outputs from variable pool
for node_id in loop_graph.node_ids:
variable_pool.remove([node_id])
_outputs: dict[str, Segment | int | None] = {}
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
_loop_variable_segment = variable_pool.get(loop_variable_selector)
if _loop_variable_segment:
_outputs[loop_variable_key] = _loop_variable_segment
else:
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
self._node_data.outputs = _outputs
if check_break_result:
return {"check_break_result": True}
# Move to next loop
next_index = current_index + 1
variable_pool.add([self.node_id, "index"], next_index)
yield LoopRunNextEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.type_,
loop_node_data=self._node_data,
index=next_index,
pre_loop_output=self._node_data.outputs,
)
return {"check_break_result": False}
def _handle_event_metadata(
def _append_loop_info_to_event(
self,
*,
event: BaseNodeEvent | InNodeEvent,
iter_run_index: int,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
"""
if not isinstance(event, BaseNodeEvent):
return event
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata:
metadata = {
**metadata,
WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id,
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index,
}
event.route_node_state.node_run_result.metadata = metadata
return event
event: GraphNodeEventBase,
loop_run_index: int,
):
event.in_loop_id = self._node_id
loop_metadata = {
WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id,
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index,
}
current_metadata = event.node_run_result.metadata
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -471,12 +295,43 @@ class LoopNode(BaseNode):
variable_mapping = {}
# init graph
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create minimal GraphInitParams for static analysis
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Create minimal GraphRuntimeState for static analysis
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
loop_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not loop_graph:
raise ValueError("loop graph not found")
for sub_node_id, sub_node_config in loop_graph.node_id_config_mapping.items():
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
continue
@ -524,7 +379,12 @@ class LoopNode(BaseNode):
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
if var_type in [
# TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
# 2. Consider moving this method to LoopVariableData class for better encapsulation
if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN:
value = original_value
elif var_type in [
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
@ -534,8 +394,6 @@ class LoopNode(BaseNode):
else:
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
value = []
elif var_type == SegmentType.ARRAY_BOOLEAN:
value = original_value
else:
raise AssertionError("this statement should be unreachable.")
try:
@ -549,3 +407,56 @@ class LoopNode(BaseNode):
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
user_id=self.user_id,
user_from=self.user_from.value,
invoke_from=self.invoke_from.value,
call_depth=self.workflow_call_depth,
)
# Create a new GraphRuntimeState for this iteration
graph_runtime_state_copy = GraphRuntimeState(
variable_pool=self.graph_runtime_state.variable_pool,
start_at=start_at.timestamp(),
)
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# Initialize the loop graph with the new node factory
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
graph_runtime_state=graph_runtime_state_copy,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)
return graph_engine

View File

@ -1,20 +1,19 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(BaseNode):
class LoopStartNode(Node):
"""
Loop Start Node.
"""
_node_type = NodeType.LOOP_START
node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData

View File

@ -0,0 +1,81 @@
from typing import TYPE_CHECKING, Any
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
class DifyNodeFactory(NodeFactory):
"""
Default implementation of NodeFactory that uses the traditional node mapping.
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
and instantiating the appropriate node class.
"""
def __init__(
self,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
def create_node(
self,
node_config: dict[str, Any],
) -> Node:
"""
Create a Node instance from node configuration data using the traditional mapping.
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
"""
# Get node_id from config
node_id = node_config.get("id")
if not node_id:
raise ValueError("Node config missing id")
# Get node type from config
node_data = node_config.get("data", {})
node_type_str = node_data.get("type")
if not node_type_str:
raise ValueError(f"Node {node_id} missing type information")
try:
node_type = NodeType(node_type_str)
except ValueError:
raise ValueError(f"Unknown node type: {node_type_str}")
# Get node class
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
node_class = node_mapping.get(LATEST_VERSION)
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
node_instance = node_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
# Initialize node with provided data
node_data = node_config.get("data", {})
node_instance.init_node_data(node_data)
# If node has fail branch, change execution type to branch
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
node_instance.execution_type = NodeExecutionType.BRANCH
return node_instance

View File

@ -1,13 +1,13 @@
from collections.abc import Mapping
from core.workflow.enums import NodeType
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end import EndNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
@ -32,7 +32,7 @@ LATEST_VERSION = "latest"
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,

View File

@ -27,14 +27,13 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
@ -85,12 +84,12 @@ def extract_json(text):
return None
class ParameterExtractorNode(BaseNode):
class ParameterExtractorNode(Node):
"""
Parameter Extractor Node.
"""
_node_type = NodeType.PARAMETER_EXTRACTOR
node_type = NodeType.PARAMETER_EXTRACTOR
_node_data: ParameterExtractorNodeData

View File

@ -10,21 +10,20 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import (
LLMNode,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
llm_utils,
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
@ -41,11 +40,12 @@ from .template_prompts import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphRuntimeState
class QuestionClassifierNode(BaseNode):
_node_type = NodeType.QUESTION_CLASSIFIER
class QuestionClassifierNode(Node):
node_type = NodeType.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_node_data: QuestionClassifierNodeData
@ -57,10 +57,7 @@ class QuestionClassifierNode(BaseNode):
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
@ -68,10 +65,7 @@ class QuestionClassifierNode(BaseNode):
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
@ -187,7 +181,8 @@ class QuestionClassifierNode(BaseNode):
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
node_id=self._node_id,
node_type=self.node_type,
)
for event in generator:
@ -259,6 +254,7 @@ class QuestionClassifierNode(BaseNode):
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
@ -278,9 +274,10 @@ class QuestionClassifierNode(BaseNode):
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:param filters: filter by node config parameters (not used in this implementation).
:return:
"""
# filters parameter is not used in this node type
return {"type": "question-classifier", "config": {"instructions": ""}}
def _calculate_rest_token(

View File

@ -2,16 +2,15 @@ from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.start.entities import StartNodeData
class StartNode(BaseNode):
_node_type = NodeType.START
class StartNode(Node):
node_type = NodeType.START
_node_data: StartNodeData

View File

@ -1,5 +1,5 @@
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
class TemplateTransformNodeData(BaseNodeData):

View File

@ -3,18 +3,17 @@ from collections.abc import Mapping, Sequence
from typing import Any, Optional
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
class TemplateTransformNode(BaseNode):
_node_type = NodeType.TEMPLATE_TRANSFORM
class TemplateTransformNode(Node):
node_type = NodeType.TEMPLATE_TRANSFORM
_node_data: TemplateTransformNodeData
@ -57,7 +56,7 @@ class TemplateTransformNode(BaseNode):
def _run(self) -> NodeRunResult:
# Get variables
variables = {}
variables: dict[str, Any] = {}
for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)

View File

@ -1,28 +1,28 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
@ -35,13 +35,16 @@ from .exc import (
ToolParameterError,
)
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
class ToolNode(BaseNode):
class ToolNode(Node):
"""
Tool Node
"""
_node_type = NodeType.TOOL
node_type = NodeType.TOOL
_node_data: ToolNodeData
@ -56,6 +59,7 @@ class ToolNode(BaseNode):
"""
Run the tool node
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
node_data = cast(ToolNodeData, self._node_data)
@ -78,11 +82,11 @@ class ToolNode(BaseNode):
if node_data.version != "1" or node_data.tool_node_version != "1":
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
@ -115,13 +119,12 @@ class ToolNode(BaseNode):
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
)
except ToolNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
@ -139,11 +142,11 @@ class ToolNode(BaseNode):
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self.node_id,
node_id=self._node_id,
)
except ToolInvokeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
@ -152,8 +155,8 @@ class ToolNode(BaseNode):
)
)
except PluginInvokeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
@ -165,8 +168,8 @@ class ToolNode(BaseNode):
)
)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
@ -179,7 +182,7 @@ class ToolNode(BaseNode):
self,
*,
tool_parameters: Sequence[ToolParameter],
variable_pool: VariablePool,
variable_pool: "VariablePool",
node_data: ToolNodeData,
for_log: bool = False,
) -> dict[str, Any]:
@ -220,7 +223,7 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
@ -238,6 +241,8 @@ class ToolNode(BaseNode):
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
@ -310,7 +315,11 @@ class ToolNode(BaseNode):
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
# JSON message handling for tool node
@ -320,7 +329,11 @@ class ToolNode(BaseNode):
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
@ -332,8 +345,10 @@ class ToolNode(BaseNode):
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
@ -393,8 +408,24 @@ class ToolNode(BaseNode):
else:
json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[self._node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
@ -457,10 +488,6 @@ class ToolNode(BaseNode):
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -2,16 +2,15 @@ from collections.abc import Mapping
from typing import Any, Optional
from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
class VariableAggregatorNode(BaseNode):
_node_type = NodeType.VARIABLE_AGGREGATOR
class VariableAggregatorNode(Node):
node_type = NodeType.VARIABLE_AGGREGATOR
_node_data: VariableAssignerNodeData

View File

@ -1,29 +1,19 @@
from sqlalchemy import Engine, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables.variables import Variable
from models.engine import db
from models.workflow import ConversationVariable
from extensions.ext_database import db
from models import ConversationVariable
from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
_engine: Engine | None
def __init__(self, engine: Engine | None = None) -> None:
self._engine = engine
def _get_engine(self) -> Engine:
if self._engine:
return self._engine
return db.engine
def update(self, conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(self._get_engine()) as session:
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")

View File

@ -5,11 +5,11 @@ from core.variables import SegmentType, Variable
from core.variables.segments import BooleanSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
@ -18,14 +18,14 @@ from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
class VariableAssignerNode(Node):
node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
_node_data: VariableAssignerData
@ -56,20 +56,14 @@ class VariableAssignerNode(BaseNode):
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
self._conv_var_updater_factory = conv_var_updater_factory

View File

@ -1,7 +1,7 @@
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
@ -23,4 +23,4 @@ class VariableOperationItem(BaseModel):
class VariableAssignerNodeData(BaseNodeData):
version: str = "2"
items: Sequence[VariableOperationItem]
items: Sequence[VariableOperationItem] = Field(default_factory=list)

View File

@ -7,11 +7,10 @@ from core.variables import SegmentType, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
@ -53,8 +52,8 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
mapping[key] = selector
class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
class VariableAssignerNode(Node):
node_type = NodeType.VARIABLE_ASSIGNER
_node_data: VariableAssignerNodeData
@ -79,6 +78,23 @@ class VariableAssignerNode(BaseNode):
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
Returns True if this node updates any of the requested conversation variables.
"""
# Check each item in this Variable Assigner node
for item in self._node_data.items:
# Convert the item's variable_selector to tuple for comparison
item_selector_tuple = tuple(item.variable_selector)
# Check if this item updates any of the requested variables
if item_selector_tuple in variable_selectors:
return True
return False
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()