Merge remote-tracking branch 'upstream/main' into feat/human-input-merge-again

This commit is contained in:
QuantumGhost
2026-01-28 16:21:37 +08:00
4167 changed files with 345823 additions and 171263 deletions

View File

@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@ -113,6 +113,8 @@ class RedisChannel:
return AbortCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
if command_type == CommandType.UPDATE_VARIABLES:
return UpdateVariablesCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)

View File

@ -5,11 +5,12 @@ This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler, PauseCommandHandler
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
"UpdateVariablesCommandHandler",
]

View File

@ -4,9 +4,10 @@ from typing import final
from typing_extensions import override
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.runtime import VariablePool
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)
@final
class UpdateVariablesCommandHandler(CommandHandler):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, UpdateVariablesCommand)
for update in command.updates:
try:
variable = update.value
self._variable_pool.add(variable.selector, variable)
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
except ValueError as exc:
logger.warning(
"Skipping invalid variable selector %s for workflow %s: %s",
getattr(update.value, "selector", None),
execution.workflow_id,
exc,
)

View File

@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from enum import StrEnum
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import Variable
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = "abort"
PAUSE = "pause"
ABORT = auto()
PAUSE = auto()
UPDATE_VARIABLES = auto()
class GraphEngineCommand(BaseModel):
@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str = Field(default="unknown reason", description="reason for pause")
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
"""Command to update a group of variables in the variable pool."""
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")

View File

@ -5,14 +5,15 @@ This engine uses a modular architecture with separated packages following
Domain-Driven Design principles for improved maintainability and testability.
"""
import contextvars
from __future__ import annotations
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
from flask import Flask, current_app
from core.workflow.context import capture_current_context
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
@ -31,8 +32,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
from .entities.commands import AbortCommand, PauseCommand
from .command_processing import (
AbortCommandHandler,
CommandProcessor,
PauseCommandHandler,
UpdateVariablesCommandHandler,
)
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@ -71,10 +77,13 @@ class GraphEngine:
scale_down_idle_time: float | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
@ -141,22 +150,16 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
flask_app: Flask | None = None
try:
app = current_app._get_current_object() # type: ignore
if isinstance(app, Flask):
flask_app = app
except RuntimeError:
pass
# Capture context variables for worker threads
context_vars = contextvars.copy_context()
# Capture execution context for worker threads
execution_context = capture_current_context()
# Create worker pool for parallel node execution
self._worker_pool = WorkerPool(
@ -164,12 +167,12 @@ class GraphEngine:
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
flask_app=flask_app,
context_vars=context_vars,
execution_context=execution_context,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
scale_down_idle_time=self._scale_down_idle_time,
stop_event=self._stop_event,
)
# === Orchestration ===
@ -200,6 +203,7 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
stop_event=self._stop_event,
)
# === Validation ===
@ -213,9 +217,16 @@ class GraphEngine:
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
def _bind_layer_context(
self,
layer: GraphEngineLayer,
) -> None:
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
"""Add a layer for extending functionality."""
self._layers.append(layer)
self._bind_layer_context(layer)
return self
def run(self) -> Generator[GraphEngineEvent, None, None]:
@ -304,14 +315,7 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
# Create a read-only wrapper for the runtime state
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
for layer in self._layers:
try:
layer.initialize(read_only_state, self._command_channel)
except Exception:
logger.exception("Failed to initialize layer %s", layer.__class__.__name__)
try:
layer.on_graph_start()
except Exception:
@ -319,6 +323,7 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
@ -352,13 +357,12 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
# Notify layers
logger = logging.getLogger(__name__)
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)

View File

@ -60,6 +60,7 @@ class SkipPropagator:
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node

View File

@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
Abstract base class for layers.
- `initialize()` - Receive runtime context
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
@ -34,6 +34,9 @@ engine.layer(debug_layer)
engine.run()
```
`engine.layer()` binds the read-only runtime state before execution, so
`graph_runtime_state` is always available inside layer hooks.
## Custom Layers
```python

View File

@ -8,11 +8,9 @@ with middleware-like components that can observe events and interact with execut
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
from .observability import ObservabilityLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"GraphEngineLayer",
"ObservabilityLayer",
]

View File

@ -8,11 +8,19 @@ intercept and respond to GraphEngine events.
from abc import ABC, abstractmethod
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
class GraphEngineLayerNotInitializedError(Exception):
"""Raised when a layer's runtime state is accessed before initialization."""
def __init__(self, layer_name: str | None = None) -> None:
name = layer_name or "GraphEngineLayer"
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
@property
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
if self._graph_runtime_state is None:
raise GraphEngineLayerNotInitializedError(type(self).__name__)
return self._graph_runtime_state
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine before execution starts to inject the read-only runtime state
and command channel. This allows layers to observe engine context and send
commands, but prevents direct state modification.
Called by GraphEngine to inject the read-only runtime state and command channel.
This is invoked when the layer is registered with a `GraphEngine` instance.
Implementations should be idempotent.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state
self._graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod
@ -85,7 +98,7 @@ class GraphEngineLayer(ABC):
"""
pass
def on_node_run_start(self, node: Node) -> None: # noqa: B027
def on_node_run_start(self, node: Node) -> None:
"""
Called immediately before a node begins execution.
@ -96,9 +109,11 @@ class GraphEngineLayer(ABC):
Args:
node: The node instance about to be executed
"""
pass
return
def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
"""
Called after a node finishes execution.
@ -108,5 +123,6 @@ class GraphEngineLayer(ABC):
Args:
node: The node instance that just finished execution
error: Exception instance if the node failed, otherwise None
result_event: The final result event from node execution (succeeded/failed/paused), if any
"""
pass
return

View File

@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
if self.graph_runtime_state:
# Log initial state
self.logger.info("Initial State:")
# Log initial state
self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
if self.graph_runtime_state and self.include_outputs:
if self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
if self.include_outputs and self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)

View File

@ -1,61 +0,0 @@
"""
Node-level OpenTelemetry parser interfaces and defaults.
"""
import json
from typing import Protocol
from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.tool.entities import ToolNodeData
class NodeOTelParser(Protocol):
"""Parser interface for node-specific OpenTelemetry enrichment."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ...
class DefaultNodeOTelParser:
"""Fallback parser used when no node-specific parser is registered."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
span.set_attribute("node.id", node.id)
if node.execution_id:
span.set_attribute("node.execution_id", node.execution_id)
if hasattr(node, "node_type") and node.node_type:
span.set_attribute("node.type", node.node_type.value)
if error:
span.record_exception(error)
span.set_status(Status(StatusCode.ERROR, str(error)))
else:
span.set_status(Status(StatusCode.OK))
class ToolNodeOTelParser:
"""Parser for tool nodes that captures tool-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
self._delegate.parse(node=node, span=span, error=error)
tool_data = getattr(node, "_node_data", None)
if not isinstance(tool_data, ToolNodeData):
return
span.set_attribute("tool.provider.id", tool_data.provider_id)
span.set_attribute("tool.provider.type", tool_data.provider_type.value)
span.set_attribute("tool.provider.name", tool_data.provider_name)
span.set_attribute("tool.name", tool_data.tool_name)
span.set_attribute("tool.label", tool_data.tool_label)
if tool_data.plugin_unique_identifier:
span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier)
if tool_data.credential_id:
span.set_attribute("tool.credential.id", tool_data.credential_id)
if tool_data.tool_configurations:
span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False))

View File

@ -1,169 +0,0 @@
"""
Observability layer for GraphEngine.
This layer creates OpenTelemetry spans for node execution, enabling distributed
tracing of workflow execution. It establishes OTel context during node execution
so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically
associates with the node span.
"""
import logging
from dataclasses import dataclass
from typing import cast, final
from opentelemetry import context as context_api
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
from typing_extensions import override
from configs import dify_config
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.node_parsers import (
DefaultNodeOTelParser,
NodeOTelParser,
ToolNodeOTelParser,
)
from core.workflow.nodes.base.node import Node
from extensions.otel.runtime import is_instrument_flag_enabled
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class _NodeSpanContext:
span: "Span"
token: object
@final
class ObservabilityLayer(GraphEngineLayer):
"""
Layer that creates OpenTelemetry spans for node execution.
This layer:
- Creates a span when a node starts execution
- Establishes OTel context so automatic instrumentation associates with the span
- Sets complete attributes and status when node execution ends
"""
def __init__(self) -> None:
super().__init__()
self._node_contexts: dict[str, _NodeSpanContext] = {}
self._parsers: dict[NodeType, NodeOTelParser] = {}
self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser())
self._is_disabled: bool = False
self._tracer: Tracer | None = None
self._build_parser_registry()
self._init_tracer()
def _init_tracer(self) -> None:
"""Initialize OpenTelemetry tracer in constructor."""
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
self._is_disabled = True
return
try:
self._tracer = get_tracer(__name__)
except Exception as e:
logger.warning("Failed to get OpenTelemetry tracer: %s", e)
self._is_disabled = True
def _build_parser_registry(self) -> None:
"""Initialize parser registry for node types."""
self._parsers = {
NodeType.TOOL: ToolNodeOTelParser(),
}
def _get_parser(self, node: Node) -> NodeOTelParser:
node_type = getattr(node, "node_type", None)
if isinstance(node_type, NodeType):
return self._parsers.get(node_type, self._default_parser)
return self._default_parser
@override
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self._node_contexts.clear()
@override
def on_node_run_start(self, node: Node) -> None:
"""
Called when a node starts execution.
Creates a span and establishes OTel context for automatic instrumentation.
"""
if self._is_disabled:
return
try:
if not self._tracer:
return
execution_id = node.execution_id
if not execution_id:
return
parent_context = context_api.get_current()
span = self._tracer.start_span(
f"{node.title}",
kind=SpanKind.INTERNAL,
context=parent_context,
)
new_context = set_span_in_context(span)
token = context_api.attach(new_context)
self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token)
except Exception as e:
logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_node_run_end(self, node: Node, error: Exception | None) -> None:
"""
Called when a node finishes execution.
Sets complete attributes, records exceptions, and ends the span.
"""
if self._is_disabled:
return
try:
execution_id = node.execution_id
if not execution_id:
return
node_context = self._node_contexts.get(execution_id)
if not node_context:
return
span = node_context.span
parser = self._get_parser(node)
try:
parser.parse(node=node, span=span, error=error)
span.end()
finally:
token = node_context.token
if token is not None:
try:
context_api.detach(token)
except Exception:
logger.warning("Failed to detach OpenTelemetry token: %s", token)
self._node_contexts.pop(execution_id, None)
except Exception as e:
logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_event(self, event) -> None:
"""Not used in this layer."""
pass
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._node_contexts:
logger.warning(
"ObservabilityLayer: %d node spans were not properly ended",
len(self._node_contexts),
)
self._node_contexts.clear()

View File

@ -1,409 +0,0 @@
"""Workflow persistence layer for GraphEngine.
This layer mirrors the former ``WorkflowCycleManager`` responsibilities by
listening to ``GraphEngineEvent`` instances directly and persisting workflow
and node execution state via the injected repositories.
The design keeps domain persistence concerns inside the engine thread, while
allowing presentation layers to remain read-only observers of repository
state.
"""
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import (
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
WorkflowType,
)
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
@dataclass(slots=True)
class PersistenceWorkflowInfo:
"""Static workflow metadata required for persistence."""
workflow_id: str
workflow_type: WorkflowType
version: str
graph_data: Mapping[str, Any]
@dataclass(slots=True)
class _NodeRuntimeSnapshot:
"""Lightweight cache to keep node metadata across event phases."""
node_id: str
title: str
predecessor_node_id: str | None
iteration_id: str | None
loop_id: str | None
created_at: datetime
class WorkflowPersistenceLayer(GraphEngineLayer):
"""GraphEngine layer that persists workflow and node execution state."""
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_info: PersistenceWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
trace_manager: TraceQueueManager | None = None,
) -> None:
super().__init__()
self._application_generate_entity = application_generate_entity
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._trace_manager = trace_manager
self._workflow_execution: WorkflowExecution | None = None
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
self._node_snapshots: dict[str, _NodeRuntimeSnapshot] = {}
self._node_sequence: int = 0
# ------------------------------------------------------------------
# GraphEngineLayer lifecycle
# ------------------------------------------------------------------
def on_graph_start(self) -> None:
self._workflow_execution = None
self._node_execution_cache.clear()
self._node_snapshots.clear()
self._node_sequence = 0
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self._handle_graph_run_started()
return
if isinstance(event, GraphRunSucceededEvent):
self._handle_graph_run_succeeded(event)
return
if isinstance(event, GraphRunPartialSucceededEvent):
self._handle_graph_run_partial_succeeded(event)
return
if isinstance(event, GraphRunFailedEvent):
self._handle_graph_run_failed(event)
return
if isinstance(event, GraphRunAbortedEvent):
self._handle_graph_run_aborted(event)
return
if isinstance(event, GraphRunPausedEvent):
self._handle_graph_run_paused(event)
return
if isinstance(event, NodeRunStartedEvent):
self._handle_node_started(event)
return
if isinstance(event, NodeRunRetryEvent):
self._handle_node_retry(event)
return
if isinstance(event, NodeRunSucceededEvent):
self._handle_node_succeeded(event)
return
if isinstance(event, NodeRunFailedEvent):
self._handle_node_failed(event)
return
if isinstance(event, NodeRunExceptionEvent):
self._handle_node_exception(event)
return
if isinstance(event, NodeRunPauseRequestedEvent):
self._handle_node_pause_requested(event)
def on_graph_end(self, error: Exception | None) -> None:
return
# ------------------------------------------------------------------
# Graph-level handlers
# ------------------------------------------------------------------
def _handle_graph_run_started(self) -> None:
execution_id = self._get_execution_id()
workflow_execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
workflow_type=self._workflow_info.workflow_type,
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=self._prepare_workflow_inputs(),
started_at=naive_utc_now(),
)
self._workflow_execution_repository.save(workflow_execution)
self._workflow_execution = workflow_execution
def _handle_graph_run_succeeded(self, event: GraphRunSucceededEvent) -> None:
execution = self._get_workflow_execution()
execution.outputs = event.outputs
execution.status = WorkflowExecutionStatus.SUCCEEDED
self._populate_completion_statistics(execution)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
execution = self._get_workflow_execution()
execution.outputs = event.outputs
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.exceptions_count = event.exceptions_count
self._populate_completion_statistics(execution)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.FAILED
execution.error_message = event.error
execution.exceptions_count = event.exceptions_count
self._populate_completion_statistics(execution)
self._fail_running_node_executions(error_message=event.error)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.STOPPED
execution.error_message = event.reason or "Workflow execution aborted"
self._populate_completion_statistics(execution)
self._fail_running_node_executions(error_message=execution.error_message or "")
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.PAUSED
execution.outputs = event.outputs
self._populate_completion_statistics(execution, update_finished=False)
self._workflow_execution_repository.save(execution)
# ------------------------------------------------------------------
# Node-level handlers
# ------------------------------------------------------------------
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
execution = self._get_workflow_execution()
metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = WorkflowNodeExecution(
id=event.id,
node_execution_id=event.id,
workflow_id=execution.workflow_id,
workflow_execution_id=execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=self._next_node_sequence(),
node_id=event.node_id,
node_type=event.node_type,
title=event.node_title,
status=WorkflowNodeExecutionStatus.RUNNING,
metadata=metadata,
created_at=event.start_at,
)
self._node_execution_cache[event.id] = domain_execution
self._workflow_node_execution_repository.save(domain_execution)
snapshot = _NodeRuntimeSnapshot(
node_id=event.node_id,
title=event.node_title,
predecessor_node_id=event.predecessor_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=event.start_at,
)
self._node_snapshots[event.id] = snapshot
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
domain_execution = self._get_node_execution(event.id)
domain_execution.status = WorkflowNodeExecutionStatus.RETRY
domain_execution.error = event.error
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.EXCEPTION,
error=event.error,
)
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.PAUSED,
error="",
update_outputs=False,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_execution_id(self) -> str:
workflow_execution_id = self._system_variables().get(SystemVariableKey.WORKFLOW_EXECUTION_ID)
if not workflow_execution_id:
raise ValueError("workflow_execution_id must be provided in system variables for pause/resume flows")
return str(workflow_execution_id)
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
inputs = {**self._application_generate_entity.inputs}
for field_name, value in self._system_variables().items():
if field_name == SystemVariableKey.CONVERSATION_ID.value:
# Conversation IDs are tied to the current session; omit them so persisted
# workflow inputs stay reusable without binding future runs to this conversation.
continue
inputs[f"sys.{field_name}"] = value
handled = WorkflowEntry.handle_special_values(inputs)
return handled or {}
def _get_workflow_execution(self) -> WorkflowExecution:
if self._workflow_execution is None:
raise ValueError("workflow execution not initialized")
return self._workflow_execution
def _get_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
if node_execution_id not in self._node_execution_cache:
raise ValueError(f"Node execution not found for id={node_execution_id}")
return self._node_execution_cache[node_execution_id]
def _next_node_sequence(self) -> int:
self._node_sequence += 1
return self._node_sequence
def _populate_completion_statistics(self, execution: WorkflowExecution, *, update_finished: bool = True) -> None:
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
if runtime_state is None:
return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
execution.exceptions_count = runtime_state.exceptions_count
def _update_node_execution(
self,
domain_execution: WorkflowNodeExecution,
node_result: NodeRunResult,
status: WorkflowNodeExecutionStatus,
*,
error: str | None = None,
update_outputs: bool = True,
) -> None:
finished_at = naive_utc_now()
snapshot = self._node_snapshots.get(domain_execution.id)
start_at = snapshot.created_at if snapshot else domain_execution.created_at
domain_execution.status = status
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
if error:
domain_execution.error = error
if update_outputs:
domain_execution.update_from_mapping(
inputs=node_result.inputs,
process_data=node_result.process_data,
outputs=node_result.outputs,
metadata=node_result.metadata,
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
def _fail_running_node_executions(self, *, error_message: str) -> None:
now = naive_utc_now()
for execution in self._node_execution_cache.values():
if execution.status == WorkflowNodeExecutionStatus.RUNNING:
execution.status = WorkflowNodeExecutionStatus.FAILED
execution.error = error_message
execution.finished_at = now
execution.elapsed_time = max((now - execution.created_at).total_seconds(), 0.0)
self._workflow_node_execution_repository.save(execution)
def _enqueue_trace_task(self, execution: WorkflowExecution) -> None:
if not self._trace_manager:
return
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
external_trace_id = None
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
trace_task = TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=self._trace_manager.user_id,
external_trace_id=external_trace_id,
)
self._trace_manager.add_trace_task(trace_task)
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
if runtime_state is None:
return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)

View File

@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
import logging
from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@ -23,7 +29,6 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop and pause operations.
"""
@staticmethod
@ -45,6 +50,16 @@ class GraphEngineManager:
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
@staticmethod
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
GraphEngineManager._send_command(task_id, update_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""

View File

@ -44,6 +44,7 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@ -61,7 +62,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._stop_event = stop_event
self._start_time: float | None = None
def start(self) -> None:
@ -69,16 +70,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
self._thread.join(timeout=2.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""

View File

@ -2,6 +2,8 @@
Factory for creating ReadyQueue instances from serialized state.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
@ -11,7 +13,7 @@ if TYPE_CHECKING:
from .protocol import ReadyQueue
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
"""
Create a ReadyQueue instance from a serialized state.

View File

@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
from __future__ import annotations
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
@ -27,7 +29,7 @@ class ResponseSession:
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: Node) -> "ResponseSession":
def from_node(cls, node: Node) -> ResponseSession:
"""
Create a ResponseSession from an AnswerNode or EndNode.

View File

@ -5,26 +5,26 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
from typing import final
from uuid import uuid4
from typing import TYPE_CHECKING, final
from flask import Flask
from typing_extensions import override
from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
from .ready_queue import ReadyQueue
if TYPE_CHECKING:
pass
@final
class Worker(threading.Thread):
@ -42,9 +42,9 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
execution_context: IExecutionContext | None = None,
) -> None:
"""
Initialize worker thread.
@ -55,23 +55,24 @@ class Worker(threading.Thread):
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
execution_context: Optional execution context for context preservation
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self._last_task_time = time.time()
self._execution_context = execution_context
self._stop_event = stop_event
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
"""Worker is controlled via shared stop_event from GraphEngine.
This method is a no-op retained for backward compatibility.
"""
pass
@property
def is_idle(self) -> bool:
@ -111,7 +112,7 @@ class Worker(threading.Thread):
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
@ -130,33 +131,36 @@ class Worker(threading.Thread):
node.ensure_execution_id()
error: Exception | None = None
result_event: GraphNodeEventBase | None = None
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node with preserved context if execution context is provided
if self._execution_context is not None:
with self._execution_context:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
self._invoke_node_run_end_hooks(node, error, result_event)
else:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
self._invoke_node_run_end_hooks(node, error, result_event)
def _invoke_node_run_start_hooks(self, node: Node) -> None:
"""Invoke on_node_run_start hooks for all layers."""
@ -167,11 +171,13 @@ class Worker(threading.Thread):
# Silently ignore layer errors to prevent disrupting node execution
continue
def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None:
def _invoke_node_run_end_hooks(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
"""Invoke on_node_run_end hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_end(node, error)
layer.on_node_run_end(node, error, result_event)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue

View File

@ -8,9 +8,10 @@ DynamicScaler, and WorkerFactory into a single class.
import logging
import queue
import threading
from typing import TYPE_CHECKING, final
from typing import final
from configs import dify_config
from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
@ -20,11 +21,6 @@ from ..worker import Worker
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from contextvars import Context
from flask import Flask
@final
class WorkerPool:
@ -41,8 +37,8 @@ class WorkerPool:
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
stop_event: threading.Event,
execution_context: IExecutionContext | None = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
@ -56,8 +52,7 @@ class WorkerPool:
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
execution_context: Optional execution context for context preservation
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
@ -66,8 +61,7 @@ class WorkerPool:
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
self._execution_context = execution_context
self._layers = layers
# Scaling parameters with defaults
@ -81,6 +75,7 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@ -135,7 +130,7 @@ class WorkerPool:
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=10.0)
worker.join(timeout=2.0)
self._workers.clear()
@ -150,8 +145,8 @@ class WorkerPool:
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,
execution_context=self._execution_context,
stop_event=self._stop_event,
)
worker.start()