Merge branch 'deploy/dev' into feat/knowledge-summary-index

This commit is contained in:
FFXN
2026-01-12 23:44:10 +08:00
committed by GitHub
5709 changed files with 406424 additions and 240364 deletions

View File

@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
engine.layer(ExecutionLimitsLayer(max_nodes=100))
```
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
can assume `graph_runtime_state` is available.
### Event-Driven Architecture
All node executions emit events for monitoring and integration:

View File

@ -1,11 +1,16 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"ToolCall",
"ToolCallResult",
"ToolResult",
"ToolResultStatus",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -0,0 +1,39 @@
from enum import StrEnum
from pydantic import BaseModel, Field
from core.file import File
class ToolResultStatus(StrEnum):
SUCCESS = "success"
ERROR = "error"
class ToolCall(BaseModel):
id: str | None = Field(default=None, description="Unique identifier for this tool call")
name: str | None = Field(default=None, description="Name of the tool being called")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
icon: str | dict | None = Field(default=None, description="Icon of the tool")
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
class ToolResult(BaseModel):
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
name: str | None = Field(default=None, description="Name of the tool")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[str] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
icon: str | dict | None = Field(default=None, description="Icon of the tool")
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
class ToolCallResult(BaseModel):
id: str | None = Field(default=None, description="Identifier for the tool call")
name: str | None = Field(default=None, description="Name of the tool")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[File] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")

View File

@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain
implementation details like tenant_id, app_id, etc.
"""
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime
from typing import Any
@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel):
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
) -> "WorkflowExecution":
) -> WorkflowExecution:
return WorkflowExecution(
id_=id_,
workflow_id=workflow_id,

View File

@ -247,6 +247,9 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
LLM_CONTENT_SEQUENCE = "llm_content_sequence"
LLM_TRACE = "llm_trace"
COMPLETED_REASON = "completed_reason" # completed reason for loop node
class WorkflowNodeExecutionStatus(StrEnum):

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
@ -175,7 +177,7 @@ class Graph:
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, object]],
node_factory: "NodeFactory",
node_factory: NodeFactory,
) -> dict[str, Node]:
"""
Create node instances from configurations using the node factory.
@ -197,7 +199,7 @@ class Graph:
return nodes
@classmethod
def new(cls) -> "GraphBuilder":
def new(cls) -> GraphBuilder:
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@ -284,9 +286,9 @@ class Graph:
cls,
*,
graph_config: Mapping[str, object],
node_factory: "NodeFactory",
node_factory: NodeFactory,
root_node_id: str | None = None,
) -> "Graph":
) -> Graph:
"""
Initialize graph
@ -383,7 +385,7 @@ class GraphBuilder:
self._edges: list[Edge] = []
self._edge_counter = 0
def add_root(self, node: Node) -> "GraphBuilder":
def add_root(self, node: Node) -> GraphBuilder:
"""Register the root node. Must be called exactly once."""
if self._nodes:
@ -398,7 +400,7 @@ class GraphBuilder:
*,
from_node_id: str | None = None,
source_handle: str = "source",
) -> "GraphBuilder":
) -> GraphBuilder:
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
@ -419,7 +421,7 @@ class GraphBuilder:
return self
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:

View File

@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@ -113,6 +113,8 @@ class RedisChannel:
return AbortCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
if command_type == CommandType.UPDATE_VARIABLES:
return UpdateVariablesCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)

View File

@ -5,11 +5,12 @@ This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler, PauseCommandHandler
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
"UpdateVariablesCommandHandler",
]

View File

@ -4,9 +4,10 @@ from typing import final
from typing_extensions import override
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.runtime import VariablePool
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)
@final
class UpdateVariablesCommandHandler(CommandHandler):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, UpdateVariablesCommand)
for update in command.updates:
try:
variable = update.value
self._variable_pool.add(variable.selector, variable)
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
except ValueError as exc:
logger.warning(
"Skipping invalid variable selector %s for workflow %s: %s",
getattr(update.value, "selector", None),
execution.workflow_id,
exc,
)

View File

@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from enum import StrEnum
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = "abort"
PAUSE = "pause"
ABORT = auto()
PAUSE = auto()
UPDATE_VARIABLES = auto()
class GraphEngineCommand(BaseModel):
@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str = Field(default="unknown reason", description="reason for pause")
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
"""Command to update a group of variables in the variable pool."""
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")

View File

@ -5,9 +5,12 @@ This engine uses a modular architecture with separated packages following
Domain-Driven Design principles for improved maintainability and testability.
"""
from __future__ import annotations
import contextvars
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
@ -30,8 +33,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
from .entities.commands import AbortCommand, PauseCommand
from .command_processing import (
AbortCommandHandler,
CommandProcessor,
PauseCommandHandler,
UpdateVariablesCommandHandler,
)
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@ -70,10 +78,13 @@ class GraphEngine:
scale_down_idle_time: float | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
@ -140,6 +151,13 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
flask_app: Flask | None = None
@ -158,12 +176,14 @@ class GraphEngine:
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
flask_app=flask_app,
context_vars=context_vars,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
scale_down_idle_time=self._scale_down_idle_time,
stop_event=self._stop_event,
)
# === Orchestration ===
@ -194,12 +214,9 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
stop_event=self._stop_event,
)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
# === Validation ===
# Ensure all nodes share the same GraphRuntimeState instance
self._validate_graph_state_consistency()
@ -211,9 +228,16 @@ class GraphEngine:
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
def _bind_layer_context(
self,
layer: GraphEngineLayer,
) -> None:
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
"""Add a layer for extending functionality."""
self._layers.append(layer)
self._bind_layer_context(layer)
return self
def run(self) -> Generator[GraphEngineEvent, None, None]:
@ -300,14 +324,7 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
# Create a read-only wrapper for the runtime state
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
for layer in self._layers:
try:
layer.initialize(read_only_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
try:
layer.on_graph_start()
except Exception as e:
@ -315,6 +332,7 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
@ -342,13 +360,12 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
# Notify layers
logger = logging.getLogger(__name__)
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)

View File

@ -60,6 +60,7 @@ class SkipPropagator:
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node

View File

@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
Abstract base class for layers.
- `initialize()` - Receive runtime context
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
@ -34,6 +34,9 @@ engine.layer(debug_layer)
engine.run()
```
`engine.layer()` binds the read-only runtime state before execution, so
`graph_runtime_state` is always available inside layer hooks.
## Custom Layers
```python

View File

@ -8,9 +8,11 @@ with middleware-like components that can observe events and interact with execut
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
from .observability import ObservabilityLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"GraphEngineLayer",
"ObservabilityLayer",
]

View File

@ -9,9 +9,18 @@ from abc import ABC, abstractmethod
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
class GraphEngineLayerNotInitializedError(Exception):
"""Raised when a layer's runtime state is accessed before initialization."""
def __init__(self, layer_name: str | None = None) -> None:
name = layer_name or "GraphEngineLayer"
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
@ -27,22 +36,27 @@ class GraphEngineLayer(ABC):
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
@property
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
if self._graph_runtime_state is None:
raise GraphEngineLayerNotInitializedError(type(self).__name__)
return self._graph_runtime_state
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine before execution starts to inject the read-only runtime state
and command channel. This allows layers to observe engine context and send
commands, but prevents direct state modification.
Called by GraphEngine to inject the read-only runtime state and command channel.
This is invoked when the layer is registered with a `GraphEngine` instance.
Implementations should be idempotent.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state
self._graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod
@ -83,3 +97,29 @@ class GraphEngineLayer(ABC):
error: The exception that caused execution to fail, or None if successful
"""
pass
def on_node_run_start(self, node: Node) -> None: # noqa: B027
"""
Called immediately before a node begins execution.
Layers can override to inject behavior (e.g., start spans) prior to node execution.
The node's execution ID is available via `node._node_execution_id` and will be
consistent with all events emitted by this node execution.
Args:
node: The node instance about to be executed
"""
pass
def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027
"""
Called after a node finishes execution.
The node's execution ID is available via `node._node_execution_id` and matches
the `id` field in all events emitted by this node execution.
Args:
node: The node instance that just finished execution
error: Exception instance if the node failed, otherwise None
"""
pass

View File

@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
if self.graph_runtime_state:
# Log initial state
self.logger.info("Initial State:")
# Log initial state
self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
if self.graph_runtime_state and self.include_outputs:
if self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
if self.include_outputs and self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)

View File

@ -0,0 +1,61 @@
"""
Node-level OpenTelemetry parser interfaces and defaults.
"""
import json
from typing import Protocol
from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.tool.entities import ToolNodeData
class NodeOTelParser(Protocol):
"""Parser interface for node-specific OpenTelemetry enrichment."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ...
class DefaultNodeOTelParser:
"""Fallback parser used when no node-specific parser is registered."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
span.set_attribute("node.id", node.id)
if node.execution_id:
span.set_attribute("node.execution_id", node.execution_id)
if hasattr(node, "node_type") and node.node_type:
span.set_attribute("node.type", node.node_type.value)
if error:
span.record_exception(error)
span.set_status(Status(StatusCode.ERROR, str(error)))
else:
span.set_status(Status(StatusCode.OK))
class ToolNodeOTelParser:
"""Parser for tool nodes that captures tool-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
self._delegate.parse(node=node, span=span, error=error)
tool_data = getattr(node, "_node_data", None)
if not isinstance(tool_data, ToolNodeData):
return
span.set_attribute("tool.provider.id", tool_data.provider_id)
span.set_attribute("tool.provider.type", tool_data.provider_type.value)
span.set_attribute("tool.provider.name", tool_data.provider_name)
span.set_attribute("tool.name", tool_data.tool_name)
span.set_attribute("tool.label", tool_data.tool_label)
if tool_data.plugin_unique_identifier:
span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier)
if tool_data.credential_id:
span.set_attribute("tool.credential.id", tool_data.credential_id)
if tool_data.tool_configurations:
span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False))

View File

@ -0,0 +1,169 @@
"""
Observability layer for GraphEngine.
This layer creates OpenTelemetry spans for node execution, enabling distributed
tracing of workflow execution. It establishes OTel context during node execution
so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically
associates with the node span.
"""
import logging
from dataclasses import dataclass
from typing import cast, final
from opentelemetry import context as context_api
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
from typing_extensions import override
from configs import dify_config
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.node_parsers import (
DefaultNodeOTelParser,
NodeOTelParser,
ToolNodeOTelParser,
)
from core.workflow.nodes.base.node import Node
from extensions.otel.runtime import is_instrument_flag_enabled
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class _NodeSpanContext:
span: "Span"
token: object
@final
class ObservabilityLayer(GraphEngineLayer):
"""
Layer that creates OpenTelemetry spans for node execution.
This layer:
- Creates a span when a node starts execution
- Establishes OTel context so automatic instrumentation associates with the span
- Sets complete attributes and status when node execution ends
"""
def __init__(self) -> None:
super().__init__()
self._node_contexts: dict[str, _NodeSpanContext] = {}
self._parsers: dict[NodeType, NodeOTelParser] = {}
self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser())
self._is_disabled: bool = False
self._tracer: Tracer | None = None
self._build_parser_registry()
self._init_tracer()
def _init_tracer(self) -> None:
"""Initialize OpenTelemetry tracer in constructor."""
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
self._is_disabled = True
return
try:
self._tracer = get_tracer(__name__)
except Exception as e:
logger.warning("Failed to get OpenTelemetry tracer: %s", e)
self._is_disabled = True
def _build_parser_registry(self) -> None:
"""Initialize parser registry for node types."""
self._parsers = {
NodeType.TOOL: ToolNodeOTelParser(),
}
def _get_parser(self, node: Node) -> NodeOTelParser:
node_type = getattr(node, "node_type", None)
if isinstance(node_type, NodeType):
return self._parsers.get(node_type, self._default_parser)
return self._default_parser
@override
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self._node_contexts.clear()
@override
def on_node_run_start(self, node: Node) -> None:
"""
Called when a node starts execution.
Creates a span and establishes OTel context for automatic instrumentation.
"""
if self._is_disabled:
return
try:
if not self._tracer:
return
execution_id = node.execution_id
if not execution_id:
return
parent_context = context_api.get_current()
span = self._tracer.start_span(
f"{node.title}",
kind=SpanKind.INTERNAL,
context=parent_context,
)
new_context = set_span_in_context(span)
token = context_api.attach(new_context)
self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token)
except Exception as e:
logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_node_run_end(self, node: Node, error: Exception | None) -> None:
"""
Called when a node finishes execution.
Sets complete attributes, records exceptions, and ends the span.
"""
if self._is_disabled:
return
try:
execution_id = node.execution_id
if not execution_id:
return
node_context = self._node_contexts.get(execution_id)
if not node_context:
return
span = node_context.span
parser = self._get_parser(node)
try:
parser.parse(node=node, span=span, error=error)
span.end()
finally:
token = node_context.token
if token is not None:
try:
context_api.detach(token)
except Exception:
logger.warning("Failed to detach OpenTelemetry token: %s", token)
self._node_contexts.pop(execution_id, None)
except Exception as e:
logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_event(self, event) -> None:
"""Not used in this layer."""
pass
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._node_contexts:
logger.warning(
"ObservabilityLayer: %d node spans were not properly ended",
len(self._node_contexts),
)
self._node_contexts.clear()

View File

@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
if runtime_state is None:
return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
if runtime_state is None:
return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)

View File

@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
import logging
from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@ -23,7 +29,6 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop and pause operations.
"""
@staticmethod
@ -45,6 +50,16 @@ class GraphEngineManager:
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
@staticmethod
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
GraphEngineManager._send_command(task_id, update_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""

View File

@ -44,6 +44,7 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@ -61,7 +62,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._stop_event = stop_event
self._start_time: float | None = None
def start(self) -> None:
@ -69,16 +70,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
self._thread.join(timeout=2.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""

View File

@ -2,6 +2,8 @@
Factory for creating ReadyQueue instances from serialized state.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
@ -11,7 +13,7 @@ if TYPE_CHECKING:
from .protocol import ReadyQueue
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
"""
Create a ReadyQueue instance from a serialized state.

View File

@ -16,7 +16,13 @@ from pydantic import BaseModel, Field
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.graph_events import (
ChunkType,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
@ -321,11 +327,24 @@ class ResponseStreamCoordinator:
selector: Sequence[str],
chunk: str,
is_final: bool = False,
chunk_type: ChunkType = ChunkType.TEXT,
tool_call: ToolCall | None = None,
tool_result: ToolResult | None = None,
) -> NodeRunStreamChunkEvent:
"""Create a stream chunk event with consistent structure.
For selectors with special prefixes (sys, env, conversation), we use the
active response node's information since these are not actual node IDs.
Args:
node_id: The node ID to attribute the event to
execution_id: The execution ID for this node
selector: The variable selector
chunk: The chunk content
is_final: Whether this is the final chunk
chunk_type: The semantic type of the chunk being streamed
tool_call: Structured data for tool_call chunks
tool_result: Structured data for tool_result chunks
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self._graph.nodes and self._active_session:
@ -338,6 +357,9 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
# Standard case: selector refers to an actual node
@ -349,6 +371,9 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
@ -356,6 +381,8 @@ class ResponseStreamCoordinator:
Handles both regular node selectors and special system selectors (sys, env, conversation).
For special selectors, we attribute the output to the active response node.
For object-type variables, automatically streams all child fields that have stream events.
"""
events: list[NodeRunStreamChunkEvent] = []
source_selector_prefix = segment.selector[0] if segment.selector else ""
@ -364,60 +391,81 @@ class ResponseStreamCoordinator:
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
if self._active_session and source_selector_prefix not in self._graph.nodes:
# Special selector - use active response node
output_node_id = self._active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
active_session = self._active_session
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
execution_id = self._get_or_create_execution_id(output_node_id)
# Stream all available chunks
while self._has_unread_stream(segment.selector):
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self._active_session and source_selector_prefix not in self._graph.nodes:
response_node = self._graph.nodes[self._active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=event.selector, # Keep original selector
chunk=event.chunk,
is_final=event.is_final,
)
events.append(updated_event)
else:
# Regular node selector - use event as is
events.append(event)
# Check if there's a direct stream for this selector
has_direct_stream = (
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
)
# Check if this is the last chunk by looking ahead
stream_closed = self._is_stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete
if stream_closed:
is_complete = True
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
elif value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
if stream_targets:
all_complete = True
for target_selector in stream_targets:
while self._has_unread_stream(target_selector):
if event := self._pop_stream_chunk(target_selector):
events.append(
self._rewrite_stream_event(
event=event,
output_node_id=output_node_id,
execution_id=execution_id,
special_selector=bool(special_selector),
)
)
if not self._is_stream_closed(target_selector):
all_complete = False
is_complete = all_complete
# Fallback: check if scalar value exists in variable pool
if not is_complete and not has_direct_stream:
if value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session
and self._active_session.index == len(self._active_session.template.segments) - 1
)
)
is_complete = True
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
)
)
is_complete = True
return events, is_complete
def _rewrite_stream_event(
self,
event: NodeRunStreamChunkEvent,
output_node_id: str,
execution_id: str,
special_selector: bool,
) -> NodeRunStreamChunkEvent:
"""Rewrite event to attribute to active response node when selector is special."""
if not special_selector:
return event
return self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=event.chunk_type,
tool_call=event.tool_call,
tool_result=event.tool_result,
)
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self._active_session is not None
@ -513,6 +561,36 @@ class ResponseStreamCoordinator:
# ============= Internal Stream Management Methods =============
def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]:
"""Find all child stream selectors that are descendants of the parent selector.
For example, if parent_selector is ['llm', 'generation'], this will find:
- ['llm', 'generation', 'content']
- ['llm', 'generation', 'tool_calls']
- ['llm', 'generation', 'tool_results']
- ['llm', 'generation', 'thought']
Args:
parent_selector: The parent selector to search for children
Returns:
List of child selector tuples found in stream buffers or closed streams
"""
parent_key = tuple(parent_selector)
parent_len = len(parent_key)
child_streams: set[tuple[str, ...]] = set()
# Search in both active buffers and closed streams
all_selectors = set(self._stream_buffers.keys()) | self._closed_streams
for selector_key in all_selectors:
# Check if this selector is a direct child of the parent
# Direct child means: len(child) == len(parent) + 1 and child starts with parent
if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key:
child_streams.add(selector_key)
return sorted(child_streams)
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
"""
Append a stream chunk to the internal buffer.

View File

@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
from __future__ import annotations
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
@ -27,7 +29,7 @@ class ResponseSession:
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: Node) -> "ResponseSession":
def from_node(cls, node: Node) -> ResponseSession:
"""
Create a ResponseSession from an AnswerNode or EndNode.

View File

@ -9,6 +9,7 @@ import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
from typing import final
from uuid import uuid4
@ -17,6 +18,7 @@ from flask import Flask
from typing_extensions import override
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
@ -39,6 +41,8 @@ class Worker(threading.Thread):
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
@ -50,6 +54,7 @@ class Worker(threading.Thread):
ready_queue: Ready queue containing node IDs ready for execution
event_queue: Queue for pushing execution events
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
@ -61,12 +66,16 @@ class Worker(threading.Thread):
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self._last_task_time = time.time()
self._stop_event = stop_event
self._layers = layers if layers is not None else []
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
"""Worker is controlled via shared stop_event from GraphEngine.
This method is a no-op retained for backward compatibility.
"""
pass
@property
def is_idle(self) -> bool:
@ -122,20 +131,51 @@ class Worker(threading.Thread):
Args:
node: The node instance to execute
"""
# Execute the node with preserved context if Flask app is provided
node.ensure_execution_id()
error: Exception | None = None
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
self._event_queue.put(event)
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
else:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
else:
# Execute without context preservation
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
def _invoke_node_run_start_hooks(self, node: Node) -> None:
"""Invoke on_node_run_start hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_start(node)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue
def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None:
"""Invoke on_node_run_end hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_end(node, error)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue

View File

@ -14,6 +14,7 @@ from configs import dify_config
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..layers.base import GraphEngineLayer
from ..ready_queue import ReadyQueue
from ..worker import Worker
@ -39,6 +40,8 @@ class WorkerPool:
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
@ -53,6 +56,7 @@ class WorkerPool:
ready_queue: Ready queue for nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
@ -65,6 +69,7 @@ class WorkerPool:
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
self._layers = layers
# Scaling parameters with defaults
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
@ -77,6 +82,7 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@ -131,7 +137,7 @@ class WorkerPool:
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=10.0)
worker.join(timeout=2.0)
self._workers.clear()
@ -144,9 +150,11 @@ class WorkerPool:
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,
stop_event=self._stop_event,
)
worker.start()

View File

@ -36,6 +36,7 @@ from .loop import (
# Node events
from .node import (
ChunkType,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
@ -44,10 +45,13 @@ from .node import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
__all__ = [
"BaseGraphEvent",
"ChunkType",
"GraphEngineEvent",
"GraphNodeEventBase",
"GraphRunAbortedEvent",
@ -73,4 +77,6 @@ __all__ = [
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
"ToolCall",
"ToolResult",
]

View File

@ -1,10 +1,11 @@
from collections.abc import Sequence
from datetime import datetime
from enum import StrEnum
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@ -21,13 +22,39 @@ class NodeRunStartedEvent(GraphNodeEventBase):
provider_id: str = ""
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class NodeRunStreamChunkEvent(GraphNodeEventBase):
# Spec-compliant fields
"""Stream chunk event for workflow node execution."""
# Base fields
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
# Tool call fields (when chunk_type == TOOL_CALL)
tool_call: ToolCall | None = Field(
default=None,
description="structured payload for tool_call chunks",
)
# Tool result fields (when chunk_type == TOOL_RESULT)
tool_result: ToolResult | None = Field(
default=None,
description="structured payload for tool_result chunks",
)
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):

View File

@ -13,16 +13,21 @@ from .loop import (
LoopSucceededEvent,
)
from .node import (
ChunkType,
ModelInvokeCompletedEvent,
PauseRequestedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
StreamChunkEvent,
StreamCompletedEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
__all__ = [
"AgentLogEvent",
"ChunkType",
"IterationFailedEvent",
"IterationNextEvent",
"IterationStartedEvent",
@ -39,4 +44,7 @@ __all__ = [
"RunRetryEvent",
"StreamChunkEvent",
"StreamCompletedEvent",
"ThoughtChunkEvent",
"ToolCallChunkEvent",
"ToolResultChunkEvent",
]

View File

@ -1,11 +1,13 @@
from collections.abc import Sequence
from datetime import datetime
from enum import StrEnum
from pydantic import Field
from core.file import File
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.node_events import NodeRunResult
@ -32,13 +34,60 @@ class RunRetryEvent(NodeEventBase):
start_at: datetime = Field(..., description="Retry start time")
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class StreamChunkEvent(NodeEventBase):
# Spec-compliant fields
"""Base stream chunk event - normal text streaming output."""
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
class ToolCallChunkEvent(StreamChunkEvent):
"""Tool call streaming event - tool call arguments streaming output."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
class ToolResultChunkEvent(StreamChunkEvent):
"""Tool result event - tool execution result."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
class ThoughtStartChunkEvent(StreamChunkEvent):
"""Agent thought start streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_START, frozen=True)
class ThoughtEndChunkEvent(StreamChunkEvent):
"""Agent thought end streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_END, frozen=True)
class ThoughtChunkEvent(StreamChunkEvent):
"""Agent thought streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True)
class StreamCompletedEvent(NodeEventBase):

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]):
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: "PluginAgentStrategy",
strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@ -328,7 +330,7 @@ class AgentNode(Node[AgentNodeData]):
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> "InvokeCredentials":
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
@ -442,9 +444,7 @@ class AgentNode(Node[AgentNodeData]):
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy

View File

@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
class AgentMaxIterationError(AgentNodeError):
"""Exception raised when the agent exceeds the maximum iteration limit."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
@ -111,7 +113,7 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import importlib
import logging
import operator
@ -46,6 +48,9 @@ from core.workflow.node_events import (
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
from core.workflow.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
@ -59,7 +64,7 @@ logger = logging.getLogger(__name__)
class Node(Generic[NodeDataT]):
node_type: ClassVar["NodeType"]
node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
@ -198,14 +203,14 @@ class Node(Generic[NodeDataT]):
return None
# Global registry populated via __init_subclass__
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
) -> None:
self._graph_init_params = graph_init_params
self.id = id
@ -241,9 +246,18 @@ class Node(Generic[NodeDataT]):
return
@property
def graph_init_params(self) -> "GraphInitParams":
def graph_init_params(self) -> GraphInitParams:
return self._graph_init_params
@property
def execution_id(self) -> str:
return self._node_execution_id
def ensure_execution_id(self) -> str:
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
return self._node_execution_id
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@ -255,15 +269,17 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def _should_stop(self) -> bool:
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def run(self) -> Generator[GraphNodeEventBase, None, None]:
# Generate a single node execution ID to use for all events
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=self._node_execution_id,
id=execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.title,
@ -321,10 +337,25 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self._node_execution_id
event.id = self.execution_id
yield event
else:
yield event
if self._should_stop():
error_message = "Execution cancelled"
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
),
error=error_message,
)
return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(
@ -333,7 +364,7 @@ class Node(Generic[NodeDataT]):
error_type="WorkflowNodeError",
)
yield NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -431,7 +462,7 @@ class Node(Generic[NodeDataT]):
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
Import all modules under core.workflow.nodes so subclasses register themselves on import.
@ -512,7 +543,7 @@ class Node(Generic[NodeDataT]):
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
@ -521,7 +552,7 @@ class Node(Generic[NodeDataT]):
)
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
@ -536,6 +567,24 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
return NodeRunStreamChunkEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType(event.chunk_type.value),
tool_call=event.tool_call,
tool_result=event.tool_result,
)
@_dispatch.register
def _(self, event: ToolCallChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
@ -543,6 +592,44 @@ class Node(Generic[NodeDataT]):
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.TOOL_CALL,
tool_call=event.tool_call,
)
@_dispatch.register
def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.entities import ToolResult, ToolResultStatus
from core.workflow.graph_events import ChunkType
tool_result = event.tool_result or ToolResult()
status: ToolResultStatus = tool_result.status or ToolResultStatus.SUCCESS
tool_result = tool_result.model_copy(
update={"status": status, "files": tool_result.files or []},
)
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.TOOL_RESULT,
tool_result=tool_result,
)
@_dispatch.register
def _(self, event: ThoughtChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.THOUGHT,
)
@_dispatch.register
@ -550,7 +637,7 @@ class Node(Generic[NodeDataT]):
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -558,7 +645,7 @@ class Node(Generic[NodeDataT]):
)
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -573,7 +660,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
return NodeRunPauseRequestedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
@ -583,7 +670,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
message_id=event.message_id,
@ -599,7 +686,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -612,7 +699,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
return NodeRunLoopNextEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -623,7 +710,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
return NodeRunLoopSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -637,7 +724,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
return NodeRunLoopFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -652,7 +739,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
return NodeRunIterationStartedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -665,7 +752,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
return NodeRunIterationNextEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -676,7 +763,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
return NodeRunIterationSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -690,7 +777,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
return NodeRunIterationFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -705,7 +792,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
return NodeRunRetrieverResourceEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=event.retriever_resources,

View File

@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
similar to SegmentGroup but focused on template representation without values.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
@ -58,7 +60,7 @@ class Template:
segments: list[TemplateSegmentUnion]
@classmethod
def from_answer_template(cls, template_str: str) -> "Template":
def from_answer_template(cls, template_str: str) -> Template:
"""Create a Template from an Answer node template string.
Example:
@ -107,7 +109,7 @@ class Template:
return cls(segments=segments)
@classmethod
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
"""Create a Template from an End node outputs configuration.
End nodes are treated as templates of concatenated variables with newlines.

View File

@ -1,8 +1,7 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, cast
from typing import TYPE_CHECKING, Any, ClassVar, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.code.limits import CodeNodeLimits
from .exc import (
CodeNodeError,
@ -20,9 +20,41 @@ from .exc import (
OutputValidationError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
Python3CodeProvider,
JavascriptCodeProvider,
)
_limits: CodeNodeLimits
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
)
self._limits = code_limits
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]):
if filters:
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
code_provider: type[CodeNodeProvider] = next(
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
)
return code_provider.get_default_config()
@classmethod
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
return cls._DEFAULT_CODE_PROVIDERS
@classmethod
def version(cls) -> str:
return "1"
@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]):
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
_ = self._select_code_provider(code_language)
result = self._code_executor.execute_workflow_code_template(
language=code_language,
code=code,
inputs=variables,
@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
for provider in self._code_providers:
if provider.is_accept_language(code_language):
return provider
raise CodeNodeError(f"Unsupported code language: {code_language}")
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
if len(value) > self._limits.max_string_length:
raise OutputValidationError(
f"The length of output variable `{variable}` must be"
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
f" less than {self._limits.max_string_length} characters"
)
return value.replace("\x00", "")
@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
if value > self._limits.max_number or value < self._limits.min_number:
raise OutputValidationError(
f"Output variable `{variable}` is out of range,"
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
f" it must be between {self._limits.min_number} and {self._limits.max_number}."
)
if isinstance(value, float):
decimal_value = Decimal(str(value)).normalize()
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
if precision > dify_config.CODE_MAX_PRECISION:
if precision > self._limits.max_precision:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
f" it must be less than {self._limits.max_precision} digits."
)
return value
@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
if depth > self._limits.max_depth:
raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
transformed_result: dict[str, Any] = {}
if output_schema is None:
@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]):
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
)
else:
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
if len(value) > self._limits.max_number_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
f" less than {self._limits.max_number_array_length} elements."
)
for i, inner_value in enumerate(value):
@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
if len(result[output_name]) > self._limits.max_string_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
f" less than {self._limits.max_string_array_length} elements."
)
transformed_result[output_name] = [
@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
if len(result[output_name]) > self._limits.max_object_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
f" less than {self._limits.max_object_array_length} elements."
)
for i, value in enumerate(result[output_name]):

View File

@ -0,0 +1,13 @@
from dataclasses import dataclass
@dataclass(frozen=True)
class CodeNodeLimits:
max_string_length: int
max_number: int | float
min_number: int | float
max_precision: int
max_depth: int
max_number_array_length: int
max_string_array_length: int
max_object_array_length: int

View File

@ -86,6 +86,11 @@ class Executor:
node_data.authorization.config.api_key = variable_pool.convert_template(
node_data.authorization.config.api_key
).text
# Validate that API key is not empty after template conversion
if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip():
raise AuthorizationConfigError(
"API key is required for authorization but was empty. Please provide a valid API key."
)
self.url = node_data.url
self.method = node_data.method

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -460,7 +460,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
self._process_metadata_filter_func(
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
@ -504,7 +504,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@ -603,87 +603,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return [], usage
return automatic_metadata_filters, usage
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
) -> list[Any]:
if value is None and condition not in ("empty", "not empty"):
return filters
json_field = Document.doc_metadata[metadata_name].as_string()
match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
case "start with":
filters.append(json_field.like(f"{value}%"))
case "end with":
filters.append(json_field.like(f"%{value}"))
case "in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(False))
else:
filters.append(json_field.in_(value_list))
case "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(True))
else:
filters.append(json_field.notin_(value_list))
case "is" | "=":
if isinstance(value, str):
filters.append(json_field == value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
case "is not" | "":
if isinstance(value, str):
filters.append(json_field != value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
case "after" | ">":
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
case "" | "<=":
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=":
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
case _:
pass
return filters
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -3,6 +3,7 @@ from .entities import (
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
ToolMetadata,
VisionConfig,
)
from .node import LLMNode
@ -13,5 +14,6 @@ __all__ = [
"LLMNodeCompletionModelPromptTemplate",
"LLMNodeData",
"ModelConfig",
"ToolMetadata",
"VisionConfig",
]

View File

@ -1,10 +1,17 @@
import re
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.agent.entities import AgentLog, AgentResult
from core.file import File
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.model_runtime.entities.llm_entities import LLMUsage
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.entities import ToolCall, ToolCallResult
from core.workflow.node_events import AgentLogEvent
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
@ -58,6 +65,268 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: str | None = None
class ToolMetadata(BaseModel):
"""
Tool metadata for LLM node with tool support.
Defines the essential fields needed for tool configuration,
particularly the 'type' field to identify tool provider type.
"""
# Core fields
enabled: bool = True
type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow")
provider_name: str = Field(..., description="Tool provider name/identifier")
tool_name: str = Field(..., description="Tool name")
# Optional fields
plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools")
credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication")
# Configuration fields
parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters")
settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration")
extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description")
class ModelTraceSegment(BaseModel):
"""Model invocation trace segment with token usage and output."""
text: str | None = Field(None, description="Model output text content")
reasoning: str | None = Field(None, description="Reasoning/thought content from model")
tool_calls: list[ToolCall] = Field(default_factory=list, description="Tool calls made by the model")
class ToolTraceSegment(BaseModel):
"""Tool invocation trace segment with call details and result."""
id: str | None = Field(default=None, description="Unique identifier for this tool call")
name: str | None = Field(default=None, description="Name of the tool being called")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
output: str | None = Field(default=None, description="Tool call result")
class LLMTraceSegment(BaseModel):
"""
Streaming trace segment for LLM tool-enabled runs.
Represents alternating model and tool invocations in sequence:
model -> tool -> model -> tool -> ...
Each segment records its execution duration.
"""
type: Literal["model", "tool"]
duration: float = Field(..., description="Execution duration in seconds")
usage: LLMUsage | None = Field(default=None, description="Token usage statistics for this model call")
output: ModelTraceSegment | ToolTraceSegment = Field(..., description="Output of the segment")
# Common metadata for both model and tool segments
provider: str | None = Field(default=None, description="Model or tool provider identifier")
name: str | None = Field(default=None, description="Name of the model or tool")
icon: str | None = Field(default=None, description="Icon for the provider")
icon_dark: str | None = Field(default=None, description="Dark theme icon for the provider")
error: str | None = Field(default=None, description="Error message if segment failed")
status: Literal["success", "error"] | None = Field(default=None, description="Tool execution status")
class LLMGenerationData(BaseModel):
"""Generation data from LLM invocation with tools.
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
- reasoning_contents: [thought1, thought2, ...] - one element per turn
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
"""
text: str = Field(..., description="Accumulated text content from all turns")
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
usage: LLMUsage = Field(..., description="LLM usage statistics")
finish_reason: str | None = Field(None, description="Finish reason from LLM")
files: list[File] = Field(default_factory=list, description="Generated files")
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
class ThinkTagStreamParser:
"""Lightweight state machine to split streaming chunks by <think> tags."""
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
_START_PREFIX = "<think"
_END_PREFIX = "</think"
def __init__(self):
self._buffer = ""
self._in_think = False
@staticmethod
def _suffix_prefix_len(text: str, prefix: str) -> int:
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
max_len = min(len(text), len(prefix) - 1)
for i in range(max_len, 0, -1):
if text[-i:].lower() == prefix[:i].lower():
return i
return 0
def process(self, chunk: str) -> list[tuple[str, str]]:
"""
Split incoming chunk into ('thought' | 'text', content) tuples.
Content excludes the <think> tags themselves and handles split tags across chunks.
"""
parts: list[tuple[str, str]] = []
self._buffer += chunk
while self._buffer:
if self._in_think:
end_match = self._END_PATTERN.search(self._buffer)
if end_match:
thought_text = self._buffer[: end_match.start()]
if thought_text:
parts.append(("thought", thought_text))
parts.append(("thought_end", ""))
self._buffer = self._buffer[end_match.end() :]
self._in_think = False
continue
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
emit = self._buffer[: len(self._buffer) - hold_len]
if emit:
parts.append(("thought", emit))
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
break
start_match = self._START_PATTERN.search(self._buffer)
if start_match:
prefix = self._buffer[: start_match.start()]
if prefix:
parts.append(("text", prefix))
self._buffer = self._buffer[start_match.end() :]
parts.append(("thought_start", ""))
self._in_think = True
continue
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
emit = self._buffer[: len(self._buffer) - hold_len]
if emit:
parts.append(("text", emit))
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
break
cleaned_parts: list[tuple[str, str]] = []
for kind, content in parts:
# Extra safeguard: strip any stray tags that slipped through.
content = self._START_PATTERN.sub("", content)
content = self._END_PATTERN.sub("", content)
if content or kind in {"thought_start", "thought_end"}:
cleaned_parts.append((kind, content))
return cleaned_parts
def flush(self) -> list[tuple[str, str]]:
"""Flush remaining buffer when the stream ends."""
if not self._buffer:
return []
kind = "thought" if self._in_think else "text"
content = self._buffer
# Drop dangling partial tags instead of emitting them
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
content = ""
self._buffer = ""
if not content and not self._in_think:
return []
# Strip any complete tags that might still be present.
content = self._START_PATTERN.sub("", content)
content = self._END_PATTERN.sub("", content)
result: list[tuple[str, str]] = []
if content:
result.append((kind, content))
if self._in_think:
result.append(("thought_end", ""))
self._in_think = False
return result
class StreamBuffers(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser)
pending_thought: list[str] = Field(default_factory=list)
pending_content: list[str] = Field(default_factory=list)
pending_tool_calls: list[ToolCall] = Field(default_factory=list)
current_turn_reasoning: list[str] = Field(default_factory=list)
reasoning_per_turn: list[str] = Field(default_factory=list)
class TraceState(BaseModel):
trace_segments: list[LLMTraceSegment] = Field(default_factory=list)
tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict)
tool_call_index_map: dict[str, int] = Field(default_factory=dict)
model_segment_start_time: float | None = Field(default=None, description="Start time for current model segment")
pending_usage: LLMUsage | None = Field(default=None, description="Pending usage for current model segment")
class AggregatedResult(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
text: str = ""
files: list[File] = Field(default_factory=list)
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
finish_reason: str | None = None
class AgentContext(BaseModel):
agent_logs: list[AgentLogEvent] = Field(default_factory=list)
agent_result: AgentResult | None = None
class ToolOutputState(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
stream: StreamBuffers = Field(default_factory=StreamBuffers)
trace: TraceState = Field(default_factory=TraceState)
aggregate: AggregatedResult = Field(default_factory=AggregatedResult)
agent: AgentContext = Field(default_factory=AgentContext)
class ToolLogPayload(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
tool_name: str = ""
tool_call_id: str = ""
tool_args: dict[str, Any] = Field(default_factory=dict)
tool_output: Any = None
tool_error: Any = None
files: list[Any] = Field(default_factory=list)
meta: dict[str, Any] = Field(default_factory=dict)
@classmethod
def from_log(cls, log: AgentLog) -> "ToolLogPayload":
data = log.data or {}
return cls(
tool_name=data.get("tool_name", ""),
tool_call_id=data.get("tool_call_id", ""),
tool_args=data.get("tool_args") or {},
tool_output=data.get("output"),
tool_error=data.get("error"),
files=data.get("files") or [],
meta=data.get("meta") or {},
)
@classmethod
def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload":
return cls(
tool_name=data.get("tool_name", ""),
tool_call_id=data.get("tool_call_id", ""),
tool_args=data.get("tool_args") or {},
tool_output=data.get("output"),
tool_error=data.get("error"),
files=data.get("files") or [],
meta=data.get("meta") or {},
)
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
@ -86,6 +355,10 @@ class LLMNodeData(BaseNodeData):
),
)
# Tool support
tools: Sequence[ToolMetadata] = Field(default_factory=list)
max_iterations: int | None = Field(default=None, description="Maximum number of iterations for the LLM node")
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@ -136,21 +136,37 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
session.execute(stmt)
session.commit()
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,4 @@
from enum import StrEnum
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
@ -96,3 +97,8 @@ class LoopState(BaseLoopState):
Get current output.
"""
return self.current_output
class LoopCompletedReason(StrEnum):
LOOP_BREAK = "loop_break"
LOOP_COMPLETED = "loop_completed"

View File

@ -29,7 +29,7 @@ from core.workflow.node_events import (
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now
@ -96,6 +96,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
# Start Loop event
yield LoopStartedEvent(
@ -118,6 +119,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_count = 0
for i in range(loop_count):
# Clear stale variables from previous loop iterations to avoid streaming old values
self._clear_loop_subgraph_variables(loop_node_ids)
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
loop_start_time = naive_utc_now()
@ -177,7 +180,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: (
LoopCompletedReason.LOOP_BREAK
if reach_break_condition
else LoopCompletedReason.LOOP_COMPLETED.value
),
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
@ -274,6 +281,17 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None:
"""
Remove variables produced by loop sub-graph nodes from previous iterations.
Keeping stale variables causes a freshly created response coordinator in the
next iteration to fall back to outdated values when no stream chunks exist.
"""
variable_pool = self.graph_runtime_state.variable_pool
for node_id in loop_node_ids:
variable_pool.remove([node_id])
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -1,10 +1,21 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@ -27,9 +38,29 @@ class DifyNodeFactory(NodeFactory):
self,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
)
self._code_limits = code_limits or CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@ -72,6 +103,25 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
if node_type == NodeType.CODE:
return CodeNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
code_providers=self._code_providers,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
)
return node_class(
id=node_id,
config=node_config,

View File

@ -281,7 +281,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
# handle invoke result
text = invoke_result.message.content or ""
text = invoke_result.message.get_text_content()
if not isinstance(text, str):
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")

View File

@ -42,9 +42,17 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
if not isinstance(value, dict):
raise ValueError(f"{key} must be a JSON object")
# If no value provided, skip further processing for this key
if not value:
continue
if not isinstance(value, dict):
raise ValueError(f"JSON object for '{key}' must be an object")
# Overwrite with normalized dict to ensure downstream consistency
node_inputs[key] = value
# If schema exists, then validate against it
schema = variable.json_schema
if not schema:
continue
@ -53,4 +61,3 @@ class StartNode(Node[StartNodeData]):
Draft7Validator(schema).validate(value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = value

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
class TemplateRenderError(ValueError):
"""Raised when rendering a Jinja2 template fails."""
class Jinja2TemplateRenderer(Protocol):
"""Render Jinja2 templates for template transform nodes."""
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
"""Render a Jinja2 template with provided variables."""
raise NotImplementedError
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""
_code_executor: type[CodeExecutor]
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
self._code_executor = code_executor or CodeExecutor
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = self._code_executor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=variables
)
except CodeExecutionError as exc:
raise TemplateRenderError(str(exc)) from exc
rendered = result.get("result")
if not isinstance(rendered, str):
raise TemplateRenderError("Template render result must be a string.")
return rendered

View File

@ -1,18 +1,44 @@
from collections.abc import Mapping, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
TemplateRenderError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
rendered = self._template_renderer.render_template(self.node_data.template, variables)
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
)
@classmethod

View File

@ -1,28 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables.variables import Variable
from extensions.ext_database import db
from models import ConversationVariable
from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
def update(self, conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def flush(self):
pass
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()

View File

@ -1,9 +1,8 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.runtime import GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
def __init__(
self,
@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
):
super().__init__(
id=id,
@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._conv_var_updater_factory = conv_var_updater_factory
@classmethod
def version(cls) -> str:
@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableOperatorNodeError("conversation_id not found")
conv_var_updater = self._conv_var_updater_factory()
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater.flush()
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={

View File

@ -1,24 +1,20 @@
import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, cast
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from . import helpers
from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidDataError,
InvalidInputValueError,
@ -26,6 +22,10 @@ from .exc import (
VariableNotFoundError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
return False
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()
@classmethod
def version(cls) -> str:
return "2"
@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
conv_var_updater = self._conv_var_updater_factory()
# Update variables
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
if self.invoke_from != InvokeFrom.DEBUGGER:
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
conv_var_updater.update(
conversation_id=cast(str, conversation_id),
variable=variable,
)
conv_var_updater.flush()
updated_variables = [
common_helpers.variable_to_processed_data(selector, seg)
for selector in updated_variable_selectors

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import abc
from collections.abc import Mapping
from typing import Any, Protocol
@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol):
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> "DraftVariableSaver":
) -> DraftVariableSaver:
pass

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import importlib
import json
import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
@ -168,6 +169,7 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self.stop_event: threading.Event = threading.Event()
if graph is not None:
self.attach_graph(graph)

View File

@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
def get(self, node_id: str, variable_key: str) -> Segment | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""Get a variable value (read-only)."""
...

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Any
@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper:
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""Return a copy of a variable value if present."""
value = self._variable_pool.get([node_id, variable_key])
value = self._variable_pool.get(selector)
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
@ -267,6 +269,6 @@ class VariablePool(BaseModel):
self.add(selector, value)
@classmethod
def empty(cls) -> "VariablePool":
def empty(cls) -> VariablePool:
"""Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty())

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from types import MappingProxyType
from typing import Any
@ -70,7 +72,7 @@ class SystemVariable(BaseModel):
return data
@classmethod
def empty(cls) -> "SystemVariable":
def empty(cls) -> SystemVariable:
return cls()
def to_dict(self) -> dict[SystemVariableKey, Any]:
@ -114,7 +116,7 @@ class SystemVariable(BaseModel):
d[SystemVariableKey.TIMESTAMP] = self.timestamp
return d
def as_view(self) -> "SystemVariableReadOnlyView":
def as_view(self) -> SystemVariableReadOnlyView:
return SystemVariableReadOnlyView(self)

View File

@ -14,7 +14,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
@ -23,6 +23,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from extensions.otel.runtime import is_instrument_flag_enabled
from factories import file_factory
from models.enums import UserFrom
from models.workflow import Workflow
@ -98,6 +99,10 @@ class WorkflowEntry:
)
self.graph_engine.layer(limits_layer)
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
self.graph_engine.layer(ObservabilityLayer())
def run(self) -> Generator[GraphEngineEvent, None, None]:
graph_engine = self.graph_engine