feat(graph_engine): Support pausing workflow graph executions (#26585)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-10-19 21:33:41 +08:00
committed by GitHub
parent 9a5f214623
commit 578247ffbc
112 changed files with 3766 additions and 2415 deletions

View File

@ -1,18 +1,11 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .graph_runtime_state import GraphRuntimeState
from .run_condition import RunCondition
from .variable_pool import VariablePool, VariableValue
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"RunCondition",
"VariablePool",
"VariableValue",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -1,160 +0,0 @@
from copy import deepcopy
from pydantic import BaseModel, PrivateAttr
from core.model_runtime.entities.llm_entities import LLMUsage
from .variable_pool import VariablePool
class GraphRuntimeState(BaseModel):
# Private attributes to prevent direct modification
_variable_pool: VariablePool = PrivateAttr()
_start_at: float = PrivateAttr()
_total_tokens: int = PrivateAttr(default=0)
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
_node_run_steps: int = PrivateAttr(default=0)
_ready_queue_json: str = PrivateAttr()
_graph_execution_json: str = PrivateAttr()
_response_coordinator_json: str = PrivateAttr()
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue_json: str = "",
graph_execution_json: str = "",
response_coordinator_json: str = "",
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
super().__init__(**kwargs)
# Initialize private attributes with validation
self._variable_pool = variable_pool
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
if llm_usage is None:
llm_usage = LLMUsage.empty_usage()
self._llm_usage = llm_usage
if outputs is None:
outputs = {}
self._outputs = deepcopy(outputs)
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
self._ready_queue_json = ready_queue_json
self._graph_execution_json = graph_execution_json
self._response_coordinator_json = response_coordinator_json
@property
def variable_pool(self) -> VariablePool:
"""Get the variable pool."""
return self._variable_pool
@property
def start_at(self) -> float:
"""Get the start time."""
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
"""Set the start time."""
self._start_at = value
@property
def total_tokens(self) -> int:
"""Get the total tokens count."""
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int):
"""Set the total tokens count."""
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
"""Get the LLM usage info."""
# Return a copy to prevent external modification
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage):
"""Set the LLM usage info."""
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, object]:
"""Get a copy of the outputs dictionary."""
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, object]) -> None:
"""Set the outputs dictionary."""
self._outputs = deepcopy(value)
def set_output(self, key: str, value: object) -> None:
"""Set a single output value."""
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: object = None) -> object:
"""Get a single output value."""
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, object]) -> None:
"""Update multiple output values."""
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count."""
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
"""Set the node run steps count."""
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
"""Increment the node run steps by 1."""
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
"""Add tokens to the total count."""
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
@property
def ready_queue_json(self) -> str:
"""Get a copy of the ready queue state."""
return self._ready_queue_json
@property
def graph_execution_json(self) -> str:
"""Get a copy of the serialized graph execution state."""
return self._graph_execution_json
@property
def response_coordinator_json(self) -> str:
"""Get a copy of the serialized response coordinator state."""
return self._response_coordinator_json

View File

@ -1,21 +0,0 @@
import hashlib
from typing import Literal
from pydantic import BaseModel
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: str | None = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: list[Condition] | None = None
"""conditions to run the node, required when type is condition"""
@property
def hash(self) -> str:
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()

View File

@ -58,6 +58,7 @@ class NodeType(StrEnum):
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
HUMAN_INPUT = "human-input"
class NodeExecutionType(StrEnum):
@ -96,6 +97,7 @@ class WorkflowExecutionStatus(StrEnum):
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCEEDED = "partial-succeeded"
PAUSED = "paused"
class WorkflowNodeExecutionMetadataKey(StrEnum):

View File

@ -1,16 +1,11 @@
from .edge import Edge
from .graph import Graph, NodeFactory
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .graph import Graph, GraphBuilder, NodeFactory
from .graph_template import GraphTemplate
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
__all__ = [
"Edge",
"Graph",
"GraphBuilder",
"GraphTemplate",
"NodeFactory",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
]

View File

@ -195,6 +195,12 @@ class Graph:
return nodes
@classmethod
def new(cls) -> "GraphBuilder":
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@classmethod
def _mark_inactive_root_branches(
cls,
@ -344,3 +350,96 @@ class Graph:
"""
edge_ids = self.in_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
@final
class GraphBuilder:
"""Fluent helper for constructing simple graphs, primarily for tests."""
def __init__(self, *, graph_cls: type[Graph]):
self._graph_cls = graph_cls
self._nodes: list[Node] = []
self._nodes_by_id: dict[str, Node] = {}
self._edges: list[Edge] = []
self._edge_counter = 0
def add_root(self, node: Node) -> "GraphBuilder":
"""Register the root node. Must be called exactly once."""
if self._nodes:
raise ValueError("Root node has already been added")
self._register_node(node)
self._nodes.append(node)
return self
def add_node(
self,
node: Node,
*,
from_node_id: str | None = None,
source_handle: str = "source",
) -> "GraphBuilder":
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
raise ValueError("Root node must be added before adding other nodes")
predecessor_id = from_node_id or self._nodes[-1].id
if predecessor_id not in self._nodes_by_id:
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
predecessor = self._nodes_by_id[predecessor_id]
self._register_node(node)
self._nodes.append(node)
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
self._edges.append(edge)
return self
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:
raise ValueError(f"Tail node '{tail}' not found")
if head not in self._nodes_by_id:
raise ValueError(f"Head node '{head}' not found")
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
self._edges.append(edge)
return self
def build(self) -> Graph:
"""Materialize the graph instance from the accumulated nodes and edges."""
if not self._nodes:
raise ValueError("Cannot build an empty graph")
nodes = {node.id: node for node in self._nodes}
edges = {edge.id: edge for edge in self._edges}
in_edges: dict[str, list[str]] = defaultdict(list)
out_edges: dict[str, list[str]] = defaultdict(list)
for edge in self._edges:
out_edges[edge.tail].append(edge.id)
in_edges[edge.head].append(edge.id)
return self._graph_cls(
nodes=nodes,
edges=edges,
in_edges=dict(in_edges),
out_edges=dict(out_edges),
root_node=self._nodes[0],
)
def _register_node(self, node: Node) -> None:
if not node.id:
raise ValueError("Node must have a non-empty id")
if node.id in self._nodes_by_id:
raise ValueError(f"Duplicate node id detected: {node.id}")
self._nodes_by_id[node.id] = node

View File

@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@ -111,9 +111,11 @@ class RedisChannel:
if command_type == CommandType.ABORT:
return AbortCommand.model_validate(data)
else:
# For other command types, use base class
return GraphEngineCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)
except (ValueError, TypeError):
return None

View File

@ -5,10 +5,11 @@ This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler
from .command_handlers import AbortCommandHandler, PauseCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
]

View File

@ -1,14 +1,10 @@
"""
Command handler implementations.
"""
import logging
from typing import final
from typing_extensions import override
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@ -16,17 +12,17 @@ logger = logging.getLogger(__name__)
@final
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
"""
Handle an abort command.
Args:
command: The abort command
execution: Graph execution to abort
"""
assert isinstance(command, AbortCommand)
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
execution.abort(command.reason or "User requested abort")
@final
class PauseCommandHandler(CommandHandler):
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, PauseCommand)
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
execution.pause(command.reason)

View File

@ -40,6 +40,8 @@ class GraphExecutionState(BaseModel):
started: bool = Field(default=False)
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
paused: bool = Field(default=False)
pause_reason: str | None = Field(default=None)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
@ -103,6 +105,8 @@ class GraphExecution:
started: bool = False
completed: bool = False
aborted: bool = False
paused: bool = False
pause_reason: str | None = None
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
@ -126,6 +130,17 @@ class GraphExecution:
self.aborted = True
self.error = RuntimeError(f"Aborted: {reason}")
def pause(self, reason: str | None = None) -> None:
"""Pause the graph execution without marking it complete."""
if self.completed:
raise RuntimeError("Cannot pause execution that has completed")
if self.aborted:
raise RuntimeError("Cannot pause execution that has been aborted")
if self.paused:
return
self.paused = True
self.pause_reason = reason
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
self.error = error
@ -140,7 +155,12 @@ class GraphExecution:
@property
def is_running(self) -> bool:
"""Check if the execution is currently running."""
return self.started and not self.completed and not self.aborted
return self.started and not self.completed and not self.aborted and not self.paused
@property
def is_paused(self) -> bool:
"""Check if the execution is currently paused."""
return self.paused
@property
def has_error(self) -> bool:
@ -173,6 +193,8 @@ class GraphExecution:
started=self.started,
completed=self.completed,
aborted=self.aborted,
paused=self.paused,
pause_reason=self.pause_reason,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
@ -197,6 +219,8 @@ class GraphExecution:
self.started = state.started
self.completed = state.completed
self.aborted = state.aborted
self.paused = state.paused
self.pause_reason = state.pause_reason
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {

View File

@ -16,7 +16,6 @@ class CommandType(StrEnum):
ABORT = "abort"
PAUSE = "pause"
RESUME = "resume"
class GraphEngineCommand(BaseModel):
@ -31,3 +30,10 @@ class AbortCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
reason: str | None = Field(default=None, description="Optional reason for abort")
class PauseCommand(GraphEngineCommand):
"""Command to pause a running workflow execution."""
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str | None = Field(default=None, description="Optional reason for pause")

View File

@ -8,8 +8,7 @@ from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
@ -24,11 +23,13 @@ from core.workflow.graph_events import (
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.runtime import GraphRuntimeState
from ..domain.graph_execution import GraphExecution
from ..response_coordinator import ResponseStreamCoordinator
@ -203,6 +204,18 @@ class EventHandler:
# Collect the event
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunPauseRequestedEvent) -> None:
"""Handle pause requests emitted by nodes."""
pause_reason = event.reason or "Awaiting human input"
self._graph_execution.pause(pause_reason)
self._state_manager.finish_execution(event.node_id)
if event.node_id in self._graph.nodes:
self._graph.nodes[event.node_id].state = NodeState.UNKNOWN
self._graph_runtime_state.register_paused_node(event.node_id)
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunFailedEvent) -> None:
"""

View File

@ -97,6 +97,10 @@ class EventManager:
"""
self._layers = layers
def notify_layers(self, event: GraphEngineEvent) -> None:
"""Notify registered layers about an event without buffering it."""
self._notify_layers(event)
def collect(self, event: GraphEngineEvent) -> None:
"""
Thread-safe method to collect an event.

View File

@ -9,28 +9,29 @@ import contextvars
import logging
import queue
from collections.abc import Generator
from typing import final
from typing import TYPE_CHECKING, cast, final
from flask import Flask, current_app
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.graph_events import (
GraphEngineEvent,
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
from .command_processing import AbortCommandHandler, CommandProcessor
from .domain import GraphExecution
from .entities.commands import AbortCommand
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
from .entities.commands import AbortCommand, PauseCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@ -38,10 +39,13 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
from .ready_queue import ReadyQueue, ReadyQueueState, create_ready_queue_from_state
from .response_coordinator import ResponseStreamCoordinator
from .ready_queue import ReadyQueue
from .worker_management import WorkerPool
if TYPE_CHECKING:
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
logger = logging.getLogger(__name__)
@ -67,17 +71,16 @@ class GraphEngine:
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# Graph execution tracks the overall execution state
self._graph_execution = GraphExecution(workflow_id=workflow_id)
if graph_runtime_state.graph_execution_json != "":
self._graph_execution.loads(graph_runtime_state.graph_execution_json)
# === Core Dependencies ===
# Graph structure and configuration
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
# Graph execution tracks the overall execution state
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
self._graph_execution.workflow_id = workflow_id
# === Worker Management Parameters ===
# Parameters for dynamic worker pool scaling
self._min_workers = min_workers
@ -86,13 +89,7 @@ class GraphEngine:
self._scale_down_idle_time = scale_down_idle_time
# === Execution Queues ===
# Create ready queue from saved state or initialize new one
self._ready_queue: ReadyQueue
if self._graph_runtime_state.ready_queue_json == "":
self._ready_queue = InMemoryReadyQueue()
else:
ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json)
self._ready_queue = create_ready_queue_from_state(ready_queue_state)
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
@ -103,11 +100,7 @@ class GraphEngine:
# === Response Coordination ===
# Coordinates response streaming from response nodes
self._response_coordinator = ResponseStreamCoordinator(
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
)
if graph_runtime_state.response_coordinator_json != "":
self._response_coordinator.loads(graph_runtime_state.response_coordinator_json)
self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator)
# === Event Management ===
# Event manager handles both collection and emission of events
@ -133,19 +126,6 @@ class GraphEngine:
skip_propagator=self._skip_propagator,
)
# === Event Handler Registry ===
# Central registry for handling all node execution events
self._event_handler_registry = EventHandler(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_manager,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)
# === Command Processing ===
# Processes external commands (e.g., abort requests)
self._command_processor = CommandProcessor(
@ -153,12 +133,12 @@ class GraphEngine:
graph_execution=self._graph_execution,
)
# Register abort command handler
# Register command handlers
abort_handler = AbortCommandHandler()
self._command_processor.register_handler(
AbortCommand,
abort_handler,
)
self._command_processor.register_handler(AbortCommand, abort_handler)
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
@ -191,12 +171,23 @@ class GraphEngine:
self._execution_coordinator = ExecutionCoordinator(
graph_execution=self._graph_execution,
state_manager=self._state_manager,
event_handler=self._event_handler_registry,
event_collector=self._event_manager,
command_processor=self._command_processor,
worker_pool=self._worker_pool,
)
# === Event Handler Registry ===
# Central registry for handling all node execution events
self._event_handler_registry = EventHandler(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_manager,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)
# Dispatches events and manages execution flow
self._dispatcher = Dispatcher(
event_queue=self._event_queue,
@ -237,26 +228,41 @@ class GraphEngine:
# Initialize layers
self._initialize_layers()
# Start execution
self._graph_execution.start()
is_resume = self._graph_execution.started
if not is_resume:
self._graph_execution.start()
else:
self._graph_execution.paused = False
self._graph_execution.pause_reason = None
start_event = GraphRunStartedEvent()
self._event_manager.notify_layers(start_event)
yield start_event
# Start subsystems
self._start_execution()
self._start_execution(resume=is_resume)
# Yield events as they occur
yield from self._event_manager.emit_events()
# Handle completion
if self._graph_execution.aborted:
if self._graph_execution.is_paused:
paused_event = GraphRunPausedEvent(
reason=self._graph_execution.pause_reason,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(paused_event)
yield paused_event
elif self._graph_execution.aborted:
abort_reason = "Workflow execution aborted by user command"
if self._graph_execution.error:
abort_reason = str(self._graph_execution.error)
yield GraphRunAbortedEvent(
aborted_event = GraphRunAbortedEvent(
reason=abort_reason,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(aborted_event)
yield aborted_event
elif self._graph_execution.has_error:
if self._graph_execution.error:
raise self._graph_execution.error
@ -264,20 +270,26 @@ class GraphEngine:
outputs = self._graph_runtime_state.outputs
exceptions_count = self._graph_execution.exceptions_count
if exceptions_count > 0:
yield GraphRunPartialSucceededEvent(
partial_event = GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
self._event_manager.notify_layers(partial_event)
yield partial_event
else:
yield GraphRunSucceededEvent(
succeeded_event = GraphRunSucceededEvent(
outputs=outputs,
)
self._event_manager.notify_layers(succeeded_event)
yield succeeded_event
except Exception as e:
yield GraphRunFailedEvent(
failed_event = GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
self._event_manager.notify_layers(failed_event)
yield failed_event
raise
finally:
@ -299,8 +311,12 @@ class GraphEngine:
except Exception as e:
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
def _start_execution(self) -> None:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
paused_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()
@ -309,10 +325,15 @@ class GraphEngine:
if node.execution_type == NodeExecutionType.RESPONSE:
self._response_coordinator.register(node.id)
# Enqueue root node
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
if not resume:
# Enqueue root node
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
else:
for node_id in paused_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Start dispatcher
self._dispatcher.start()

View File

@ -7,9 +7,9 @@ intercept and respond to GraphEngine events.
from abc import ABC, abstractmethod
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
from core.workflow.runtime import ReadOnlyGraphRuntimeState
class GraphEngineLayer(ABC):

View File

@ -0,0 +1,410 @@
"""Workflow persistence layer for GraphEngine.
This layer mirrors the former ``WorkflowCycleManager`` responsibilities by
listening to ``GraphEngineEvent`` instances directly and persisting workflow
and node execution state via the injected repositories.
The design keeps domain persistence concerns inside the engine thread, while
allowing presentation layers to remain read-only observers of repository
state.
"""
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import (
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
WorkflowType,
)
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
@dataclass(slots=True)
class PersistenceWorkflowInfo:
"""Static workflow metadata required for persistence."""
workflow_id: str
workflow_type: WorkflowType
version: str
graph_data: Mapping[str, Any]
@dataclass(slots=True)
class _NodeRuntimeSnapshot:
"""Lightweight cache to keep node metadata across event phases."""
node_id: str
title: str
predecessor_node_id: str | None
iteration_id: str | None
loop_id: str | None
created_at: datetime
class WorkflowPersistenceLayer(GraphEngineLayer):
"""GraphEngine layer that persists workflow and node execution state."""
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_info: PersistenceWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
trace_manager: TraceQueueManager | None = None,
) -> None:
super().__init__()
self._application_generate_entity = application_generate_entity
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._trace_manager = trace_manager
self._workflow_execution: WorkflowExecution | None = None
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
self._node_snapshots: dict[str, _NodeRuntimeSnapshot] = {}
self._node_sequence: int = 0
# ------------------------------------------------------------------
# GraphEngineLayer lifecycle
# ------------------------------------------------------------------
def on_graph_start(self) -> None:
self._workflow_execution = None
self._node_execution_cache.clear()
self._node_snapshots.clear()
self._node_sequence = 0
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self._handle_graph_run_started()
return
if isinstance(event, GraphRunSucceededEvent):
self._handle_graph_run_succeeded(event)
return
if isinstance(event, GraphRunPartialSucceededEvent):
self._handle_graph_run_partial_succeeded(event)
return
if isinstance(event, GraphRunFailedEvent):
self._handle_graph_run_failed(event)
return
if isinstance(event, GraphRunAbortedEvent):
self._handle_graph_run_aborted(event)
return
if isinstance(event, GraphRunPausedEvent):
self._handle_graph_run_paused(event)
return
if isinstance(event, NodeRunStartedEvent):
self._handle_node_started(event)
return
if isinstance(event, NodeRunRetryEvent):
self._handle_node_retry(event)
return
if isinstance(event, NodeRunSucceededEvent):
self._handle_node_succeeded(event)
return
if isinstance(event, NodeRunFailedEvent):
self._handle_node_failed(event)
return
if isinstance(event, NodeRunExceptionEvent):
self._handle_node_exception(event)
return
if isinstance(event, NodeRunPauseRequestedEvent):
self._handle_node_pause_requested(event)
def on_graph_end(self, error: Exception | None) -> None:
return
# ------------------------------------------------------------------
# Graph-level handlers
# ------------------------------------------------------------------
def _handle_graph_run_started(self) -> None:
execution_id = self._get_execution_id()
workflow_execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
workflow_type=self._workflow_info.workflow_type,
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=self._prepare_workflow_inputs(),
started_at=naive_utc_now(),
)
self._workflow_execution_repository.save(workflow_execution)
self._workflow_execution = workflow_execution
def _handle_graph_run_succeeded(self, event: GraphRunSucceededEvent) -> None:
execution = self._get_workflow_execution()
execution.outputs = event.outputs
execution.status = WorkflowExecutionStatus.SUCCEEDED
self._populate_completion_statistics(execution)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
execution = self._get_workflow_execution()
execution.outputs = event.outputs
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.exceptions_count = event.exceptions_count
self._populate_completion_statistics(execution)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.FAILED
execution.error_message = event.error
execution.exceptions_count = event.exceptions_count
self._populate_completion_statistics(execution)
self._fail_running_node_executions(error_message=event.error)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.STOPPED
execution.error_message = event.reason or "Workflow execution aborted"
self._populate_completion_statistics(execution)
self._fail_running_node_executions(error_message=execution.error_message or "")
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.PAUSED
execution.error_message = event.reason or "Workflow execution paused"
execution.outputs = event.outputs
self._populate_completion_statistics(execution, update_finished=False)
self._workflow_execution_repository.save(execution)
# ------------------------------------------------------------------
# Node-level handlers
# ------------------------------------------------------------------
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
execution = self._get_workflow_execution()
metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = WorkflowNodeExecution(
id=event.id,
node_execution_id=event.id,
workflow_id=execution.workflow_id,
workflow_execution_id=execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=self._next_node_sequence(),
node_id=event.node_id,
node_type=event.node_type,
title=event.node_title,
status=WorkflowNodeExecutionStatus.RUNNING,
metadata=metadata,
created_at=event.start_at,
)
self._node_execution_cache[event.id] = domain_execution
self._workflow_node_execution_repository.save(domain_execution)
snapshot = _NodeRuntimeSnapshot(
node_id=event.node_id,
title=event.node_title,
predecessor_node_id=event.predecessor_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=event.start_at,
)
self._node_snapshots[event.id] = snapshot
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
domain_execution = self._get_node_execution(event.id)
domain_execution.status = WorkflowNodeExecutionStatus.RETRY
domain_execution.error = event.error
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.EXCEPTION,
error=event.error,
)
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.PAUSED,
error=event.reason,
update_outputs=False,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_execution_id(self) -> str:
workflow_execution_id = self._system_variables().get(SystemVariableKey.WORKFLOW_EXECUTION_ID)
if not workflow_execution_id:
raise ValueError("workflow_execution_id must be provided in system variables for pause/resume flows")
return str(workflow_execution_id)
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
inputs = {**self._application_generate_entity.inputs}
for field_name, value in self._system_variables().items():
if field_name == SystemVariableKey.CONVERSATION_ID.value:
# Conversation IDs are tied to the current session; omit them so persisted
# workflow inputs stay reusable without binding future runs to this conversation.
continue
inputs[f"sys.{field_name}"] = value
handled = WorkflowEntry.handle_special_values(inputs)
return handled or {}
def _get_workflow_execution(self) -> WorkflowExecution:
if self._workflow_execution is None:
raise ValueError("workflow execution not initialized")
return self._workflow_execution
def _get_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
if node_execution_id not in self._node_execution_cache:
raise ValueError(f"Node execution not found for id={node_execution_id}")
return self._node_execution_cache[node_execution_id]
def _next_node_sequence(self) -> int:
self._node_sequence += 1
return self._node_sequence
def _populate_completion_statistics(self, execution: WorkflowExecution, *, update_finished: bool = True) -> None:
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
if runtime_state is None:
return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
execution.exceptions_count = runtime_state.exceptions_count
def _update_node_execution(
self,
domain_execution: WorkflowNodeExecution,
node_result: NodeRunResult,
status: WorkflowNodeExecutionStatus,
*,
error: str | None = None,
update_outputs: bool = True,
) -> None:
finished_at = naive_utc_now()
snapshot = self._node_snapshots.get(domain_execution.id)
start_at = snapshot.created_at if snapshot else domain_execution.created_at
domain_execution.status = status
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
if error:
domain_execution.error = error
if update_outputs:
domain_execution.update_from_mapping(
inputs=node_result.inputs,
process_data=node_result.process_data,
outputs=node_result.outputs,
metadata=node_result.metadata,
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
def _fail_running_node_executions(self, *, error_message: str) -> None:
now = naive_utc_now()
for execution in self._node_execution_cache.values():
if execution.status == WorkflowNodeExecutionStatus.RUNNING:
execution.status = WorkflowNodeExecutionStatus.FAILED
execution.error = error_message
execution.finished_at = now
execution.elapsed_time = max((now - execution.created_at).total_seconds(), 0.0)
self._workflow_node_execution_repository.save(execution)
def _enqueue_trace_task(self, execution: WorkflowExecution) -> None:
if not self._trace_manager:
return
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
external_trace_id = None
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
trace_task = TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=self._trace_manager.user_id,
external_trace_id=external_trace_id,
)
self._trace_manager.add_trace_task(trace_task)
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
if runtime_state is None:
return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)

View File

@ -9,7 +9,7 @@ Supports stop, pause, and resume operations.
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from extensions.ext_redis import redis_client
@ -20,7 +20,7 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop, pause, and resume operations.
Supports stop and pause operations.
"""
@staticmethod
@ -32,19 +32,29 @@ class GraphEngineManager:
task_id: The task ID of the workflow to stop
reason: Optional reason for stopping (defaults to "User requested stop")
"""
abort_command = AbortCommand(reason=reason or "User requested stop")
GraphEngineManager._send_command(task_id, abort_command)
@staticmethod
def send_pause_command(task_id: str, reason: str | None = None) -> None:
"""Send a pause command to a running workflow."""
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""
if not task_id:
return
# Create Redis channel for this task
channel_key = f"workflow:{task_id}:commands"
channel = RedisChannel(redis_client, channel_key)
# Create and send abort command
abort_command = AbortCommand(reason=reason or "User requested stop")
try:
channel.send_command(abort_command)
channel.send_command(command)
except Exception:
# Silently fail if Redis is unavailable
# The legacy stop flag mechanism will still work
# The legacy control mechanisms will still work
pass

View File

@ -33,6 +33,12 @@ class Dispatcher:
with timeout and completion detection.
"""
_COMMAND_TRIGGER_EVENTS = (
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunExceptionEvent,
)
def __init__(
self,
event_queue: queue.Queue[GraphNodeEventBase],
@ -77,33 +83,41 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
_COMMAND_TRIGGER_EVENTS = (
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunExceptionEvent,
)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
while not self._stop_event.is_set():
# Check for scaling
self._execution_coordinator.check_scaling()
commands_checked = False
should_check_commands = False
should_break = False
# Process events
try:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
if self._should_check_commands(event):
self._execution_coordinator.check_commands()
self._event_queue.task_done()
except queue.Empty:
# Process commands even when no new events arrive so abort requests are not missed
if self._execution_coordinator.is_execution_complete():
should_check_commands = True
should_break = True
else:
# Check for scaling
self._execution_coordinator.check_scaling()
# Process events
try:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
should_check_commands = self._should_check_commands(event)
self._event_queue.task_done()
except queue.Empty:
# Process commands even when no new events arrive so abort requests are not missed
should_check_commands = True
time.sleep(0.1)
if should_check_commands and not commands_checked:
self._execution_coordinator.check_commands()
# Check if execution is complete
if self._execution_coordinator.is_execution_complete():
break
commands_checked = True
if should_break:
if not commands_checked:
self._execution_coordinator.check_commands()
break
except Exception as e:
logger.exception("Dispatcher error")

View File

@ -2,17 +2,13 @@
Execution coordinator for managing overall workflow execution.
"""
from typing import TYPE_CHECKING, final
from typing import final
from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..event_management import EventManager
from ..graph_state_manager import GraphStateManager
from ..worker_management import WorkerPool
if TYPE_CHECKING:
from ..event_management import EventHandler
@final
class ExecutionCoordinator:
@ -27,8 +23,6 @@ class ExecutionCoordinator:
self,
graph_execution: GraphExecution,
state_manager: GraphStateManager,
event_handler: "EventHandler",
event_collector: EventManager,
command_processor: CommandProcessor,
worker_pool: WorkerPool,
) -> None:
@ -38,15 +32,11 @@ class ExecutionCoordinator:
Args:
graph_execution: Graph execution aggregate
state_manager: Unified state manager
event_handler: Event handler registry for processing events
event_collector: Event manager for collecting events
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self._graph_execution = graph_execution
self._state_manager = state_manager
self._event_handler = event_handler
self._event_collector = event_collector
self._command_processor = command_processor
self._worker_pool = worker_pool
@ -65,15 +55,24 @@ class ExecutionCoordinator:
Returns:
True if execution is complete
"""
# Check if aborted or failed
# Treat paused, aborted, or failed executions as terminal states
if self._graph_execution.is_paused:
return True
if self._graph_execution.aborted or self._graph_execution.has_error:
return True
# Complete if no work remains
return self._state_manager.is_execution_complete()
@property
def is_paused(self) -> bool:
"""Expose whether the underlying graph execution is paused."""
return self._graph_execution.is_paused
def mark_complete(self) -> None:
"""Mark execution as complete."""
if self._graph_execution.is_paused:
return
if not self._graph_execution.completed:
self._graph_execution.complete()
@ -85,3 +84,21 @@ class ExecutionCoordinator:
error: The error that caused failure
"""
self._graph_execution.fail(error)
def handle_pause_if_needed(self) -> None:
"""If the execution has been paused, stop workers immediately."""
if not self._graph_execution.is_paused:
return
self._worker_pool.stop()
self._state_manager.clear_executing()
def handle_abort_if_needed(self) -> None:
"""If the execution has been aborted, stop workers immediately."""
if not self._graph_execution.aborted:
return
self._worker_pool.stop()
self._state_manager.clear_executing()

View File

@ -14,11 +14,11 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
from .path import Path
from .session import ResponseSession

View File

@ -13,6 +13,7 @@ from .graph import (
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
@ -37,6 +38,7 @@ from .loop import (
from .node import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
@ -51,6 +53,7 @@ __all__ = [
"GraphRunAbortedEvent",
"GraphRunFailedEvent",
"GraphRunPartialSucceededEvent",
"GraphRunPausedEvent",
"GraphRunStartedEvent",
"GraphRunSucceededEvent",
"NodeRunAgentLogEvent",
@ -64,6 +67,7 @@ __all__ = [
"NodeRunLoopNextEvent",
"NodeRunLoopStartedEvent",
"NodeRunLoopSucceededEvent",
"NodeRunPauseRequestedEvent",
"NodeRunRetrieverResourceEvent",
"NodeRunRetryEvent",
"NodeRunStartedEvent",

View File

@ -8,7 +8,12 @@ class GraphRunStartedEvent(BaseGraphEvent):
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: dict[str, object] = Field(default_factory=dict)
"""Event emitted when a run completes successfully with final outputs."""
outputs: dict[str, object] = Field(
default_factory=dict,
description="Final workflow outputs keyed by output selector.",
)
class GraphRunFailedEvent(BaseGraphEvent):
@ -17,12 +22,30 @@ class GraphRunFailedEvent(BaseGraphEvent):
class GraphRunPartialSucceededEvent(BaseGraphEvent):
"""Event emitted when a run finishes with partial success and failures."""
exceptions_count: int = Field(..., description="exception count")
outputs: dict[str, object] = Field(default_factory=dict)
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs that were materialised before failures occurred.",
)
class GraphRunAbortedEvent(BaseGraphEvent):
"""Event emitted when a graph run is aborted by user command."""
reason: str | None = Field(default=None, description="reason for abort")
outputs: dict[str, object] = Field(default_factory=dict, description="partial outputs if any")
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs produced before the abort was requested.",
)
class GraphRunPausedEvent(BaseGraphEvent):
"""Event emitted when a graph run is paused by user command."""
reason: str | None = Field(default=None, description="reason for pause")
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs available to the client while the run is paused.",
)

View File

@ -51,3 +51,7 @@ class NodeRunExceptionEvent(GraphNodeEventBase):
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: str | None = Field(default=None, description="Optional pause reason")

View File

@ -14,6 +14,7 @@ from .loop import (
)
from .node import (
ModelInvokeCompletedEvent,
PauseRequestedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
StreamChunkEvent,
@ -33,6 +34,7 @@ __all__ = [
"ModelInvokeCompletedEvent",
"NodeEventBase",
"NodeRunResult",
"PauseRequestedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"StreamChunkEvent",

View File

@ -40,3 +40,7 @@ class StreamChunkEvent(NodeEventBase):
class StreamCompletedEvent(NodeEventBase):
node_run_result: NodeRunResult = Field(..., description="run result")
class PauseRequestedEvent(NodeEventBase):
reason: str | None = Field(default=None, description="Optional pause reason")

View File

@ -25,7 +25,6 @@ from core.tools.entities.tool_entities import (
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeType,
@ -44,6 +43,7 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy

View File

@ -6,7 +6,7 @@ from typing import Any, ClassVar
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_events import (
GraphNodeEventBase,
@ -20,6 +20,7 @@ from core.workflow.graph_events import (
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
@ -37,10 +38,12 @@ from core.workflow.node_events import (
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
PauseRequestedEvent,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
@ -385,6 +388,16 @@ class Node:
f"Node {self._node_id} does not support status {event.node_run_result.status}"
)
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
return NodeRunPauseRequestedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
reason=event.reason,
)
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(

View File

@ -19,7 +19,6 @@ from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@ -27,6 +26,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile

View File

@ -15,7 +15,7 @@ from core.file import file_manager
from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
from .entities import (
HttpRequestNodeAuthorization,

View File

@ -0,0 +1,3 @@
from .human_input_node import HumanInputNode
__all__ = ["HumanInputNode"]

View File

@ -0,0 +1,10 @@
from pydantic import Field
from core.workflow.nodes.base import BaseNodeData
class HumanInputNodeData(BaseNodeData):
"""Configuration schema for the HumanInput node."""
required_variables: list[str] = Field(default_factory=list)
pause_reason: str | None = Field(default=None)

View File

@ -0,0 +1,132 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import HumanInputNodeData
class HumanInputNode(Node):
node_type = NodeType.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
"edge_source_handle",
"edgeSourceHandle",
"source_handle",
"selected_branch",
"selectedBranch",
"branch",
"branch_id",
"branchId",
"handle",
)
_node_data: HumanInputNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = HumanInputNodeData(**data)
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def _run(self): # type: ignore[override]
if self._is_completion_ready():
branch_handle = self._resolve_branch_selection()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
edge_source_handle=branch_handle or "source",
)
return self._pause_generator()
def _pause_generator(self):
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
if not self._node_data.required_variables:
return False
variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self._node_data.required_variables:
parts = selector_str.split(".")
if len(parts) != 2:
return False
segment = variable_pool.get(parts)
if segment is None:
return False
return True
def _resolve_branch_selection(self) -> str | None:
"""Determine the branch handle selected by human input if available."""
variable_pool = self.graph_runtime_state.variable_pool
for key in self._BRANCH_SELECTION_KEYS:
handle = self._extract_branch_handle(variable_pool.get((self.id, key)))
if handle:
return handle
default_values = self._node_data.default_value_dict
for key in self._BRANCH_SELECTION_KEYS:
handle = self._normalize_branch_value(default_values.get(key))
if handle:
return handle
return None
@staticmethod
def _extract_branch_handle(segment: Any) -> str | None:
if segment is None:
return None
candidate = getattr(segment, "to_object", None)
raw_value = candidate() if callable(candidate) else getattr(segment, "value", None)
if raw_value is None:
return None
return HumanInputNode._normalize_branch_value(raw_value)
@staticmethod
def _normalize_branch_value(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
if isinstance(value, Mapping):
for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
return None

View File

@ -3,12 +3,12 @@ from typing import Any, Literal
from typing_extensions import deprecated
from core.workflow.entities import VariablePool
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.runtime import VariablePool
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor

View File

@ -12,7 +12,6 @@ from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
@ -38,6 +37,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
@ -557,11 +557,12 @@ class IterationNode(Node):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(

View File

@ -9,13 +9,13 @@ from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment

View File

@ -67,7 +67,7 @@ from .exc import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)

View File

@ -15,9 +15,9 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.entities import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation

View File

@ -52,7 +52,7 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeType,
@ -71,6 +71,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from . import llm_utils
from .entities import (
@ -93,7 +94,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)

View File

@ -406,11 +406,12 @@ class LoopNode(Node):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(

View File

@ -10,7 +10,8 @@ from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
@final

View File

@ -9,6 +9,7 @@ from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
@ -134,6 +135,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
"2": AgentNode,
"1": AgentNode,
},
NodeType.HUMAN_INPUT: {
LATEST_VERSION: HumanInputNode,
"1": HumanInputNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,

View File

@ -27,13 +27,13 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.runtime import VariablePool
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData

View File

@ -41,7 +41,7 @@ from .template_prompts import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
class QuestionClassifierNode(Node):

View File

@ -36,7 +36,7 @@ from .exc import (
)
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
class ToolNode(Node):

View File

@ -18,7 +18,7 @@ from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]

View File

@ -0,0 +1,14 @@
from .graph_runtime_state import GraphRuntimeState
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
from .variable_pool import VariablePool, VariableValue
__all__ = [
"GraphRuntimeState",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
"VariablePool",
"VariableValue",
]

View File

@ -0,0 +1,393 @@
from __future__ import annotations
import importlib
import json
from collections.abc import Mapping, Sequence
from collections.abc import Mapping as TypingMapping
from copy import deepcopy
from typing import Any, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime.variable_pool import VariablePool
class ReadyQueueProtocol(Protocol):
"""Structural interface required from ready queue implementations."""
def put(self, item: str) -> None:
"""Enqueue the identifier of a node that is ready to run."""
...
def get(self, timeout: float | None = None) -> str:
"""Return the next node identifier, blocking until available or timeout expires."""
...
def task_done(self) -> None:
"""Signal that the most recently dequeued node has completed processing."""
...
def empty(self) -> bool:
"""Return True when the queue contains no pending nodes."""
...
def qsize(self) -> int:
"""Approximate the number of pending nodes awaiting execution."""
...
def dumps(self) -> str:
"""Serialize the queue contents for persistence."""
...
def loads(self, data: str) -> None:
"""Restore the queue contents from a serialized payload."""
...
class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate."""
workflow_id: str
started: bool
completed: bool
aborted: bool
error: Exception | None
exceptions_count: int
def start(self) -> None:
"""Transition execution into the running state."""
...
def complete(self) -> None:
"""Mark execution as successfully completed."""
...
def abort(self, reason: str) -> None:
"""Abort execution in response to an external stop request."""
...
def fail(self, error: Exception) -> None:
"""Record an unrecoverable error and end execution."""
...
def dumps(self) -> str:
"""Serialize execution state into a JSON payload."""
...
def loads(self, data: str) -> None:
"""Restore execution state from a previously serialized payload."""
...
class ResponseStreamCoordinatorProtocol(Protocol):
"""Structural interface for response stream coordinator."""
def register(self, response_node_id: str) -> None:
"""Register a response node so its outputs can be streamed."""
...
def loads(self, data: str) -> None:
"""Restore coordinator state from a serialized payload."""
...
def dumps(self) -> str:
"""Serialize coordinator state for persistence."""
...
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: TypingMapping[str, object]
edges: TypingMapping[str, object]
root_node: object
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components."""
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue: ReadyQueueProtocol | None = None,
graph_execution: GraphExecutionProtocol | None = None,
response_coordinator: ResponseStreamCoordinatorProtocol | None = None,
graph: GraphProtocol | None = None,
) -> None:
self._variable_pool = variable_pool
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy()
self._outputs = deepcopy(outputs) if outputs is not None else {}
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
self._graph: GraphProtocol | None = None
self._ready_queue = ready_queue
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
if graph is not None:
self.attach_graph(graph)
# ------------------------------------------------------------------
# Context binding helpers
# ------------------------------------------------------------------
def attach_graph(self, graph: GraphProtocol) -> None:
"""Attach the materialized graph to the runtime state."""
if self._graph is not None and self._graph is not graph:
raise ValueError("GraphRuntimeState already attached to a different graph instance")
self._graph = graph
if self._response_coordinator is None:
self._response_coordinator = self._build_response_coordinator(graph)
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
self._response_coordinator.loads(self._pending_response_coordinator_dump)
self._pending_response_coordinator_dump = None
def configure(self, *, graph: GraphProtocol | None = None) -> None:
"""Ensure core collaborators are initialized with the provided context."""
if graph is not None:
self.attach_graph(graph)
# Ensure collaborators are instantiated
_ = self.ready_queue
_ = self.graph_execution
if self._graph is not None:
_ = self.response_coordinator
# ------------------------------------------------------------------
# Primary collaborators
# ------------------------------------------------------------------
@property
def variable_pool(self) -> VariablePool:
return self._variable_pool
@property
def ready_queue(self) -> ReadyQueueProtocol:
if self._ready_queue is None:
self._ready_queue = self._build_ready_queue()
return self._ready_queue
@property
def graph_execution(self) -> GraphExecutionProtocol:
if self._graph_execution is None:
self._graph_execution = self._build_graph_execution()
return self._graph_execution
@property
def response_coordinator(self) -> ResponseStreamCoordinatorProtocol:
if self._response_coordinator is None:
if self._graph is None:
raise ValueError("Graph must be attached before accessing response coordinator")
self._response_coordinator = self._build_response_coordinator(self._graph)
return self._response_coordinator
# ------------------------------------------------------------------
# Scalar state
# ------------------------------------------------------------------
@property
def start_at(self) -> float:
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
self._start_at = value
@property
def total_tokens(self) -> int:
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int) -> None:
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage) -> None:
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, Any]:
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, Any]) -> None:
self._outputs = deepcopy(value)
def set_output(self, key: str, value: object) -> None:
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: object = None) -> object:
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, object]) -> None:
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
# ------------------------------------------------------------------
# Serialization
# ------------------------------------------------------------------
def dumps(self) -> str:
"""Serialize runtime state into a JSON string."""
snapshot: dict[str, Any] = {
"version": "1.0",
"start_at": self._start_at,
"total_tokens": self._total_tokens,
"node_run_steps": self._node_run_steps,
"llm_usage": self._llm_usage.model_dump(mode="json"),
"outputs": self.outputs,
"variable_pool": self.variable_pool.model_dump(mode="json"),
"ready_queue": self.ready_queue.dumps(),
"graph_execution": self.graph_execution.dumps(),
"paused_nodes": list(self._paused_nodes),
}
if self._response_coordinator is not None and self._graph is not None:
snapshot["response_coordinator"] = self._response_coordinator.dumps()
return json.dumps(snapshot, default=pydantic_encoder)
def loads(self, data: str | Mapping[str, Any]) -> None:
"""Restore runtime state from a serialized snapshot."""
payload: dict[str, Any]
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
version = payload.get("version")
if version != "1.0":
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
self._start_at = float(payload.get("start_at", 0.0))
total_tokens = int(payload.get("total_tokens", 0))
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
node_run_steps = int(payload.get("node_run_steps", 0))
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
llm_usage_payload = payload.get("llm_usage", {})
self._llm_usage = LLMUsage.model_validate(llm_usage_payload)
self._outputs = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
if variable_pool_payload is not None:
self._variable_pool = VariablePool.model_validate(variable_pool_payload)
ready_queue_payload = payload.get("ready_queue")
if ready_queue_payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(ready_queue_payload)
else:
self._ready_queue = None
graph_execution_payload = payload.get("graph_execution")
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if graph_execution_payload is not None:
try:
execution_payload = json.loads(graph_execution_payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(graph_execution_payload)
response_payload = payload.get("response_coordinator")
if response_payload is not None:
if self._graph is not None:
self.response_coordinator.loads(response_payload)
else:
self._pending_response_coordinator_dump = response_payload
else:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
paused_nodes_payload = payload.get("paused_nodes", [])
self._paused_nodes = set(map(str, paused_nodes_payload))
def register_paused_node(self, node_id: str) -> None:
"""Record a node that should resume when execution is continued."""
self._paused_nodes.add(node_id)
def consume_paused_nodes(self) -> list[str]:
"""Retrieve and clear the list of paused nodes awaiting resume."""
nodes = list(self._paused_nodes)
self._paused_nodes.clear()
return nodes
# ------------------------------------------------------------------
# Builders
# ------------------------------------------------------------------
def _build_ready_queue(self) -> ReadyQueueProtocol:
# Import lazily to avoid breaching architecture boundaries enforced by import-linter.
module = importlib.import_module("core.workflow.graph_engine.ready_queue")
in_memory_cls = module.InMemoryReadyQueue
return in_memory_cls()
def _build_graph_execution(self) -> GraphExecutionProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
module = importlib.import_module("core.workflow.graph_engine.domain.graph_execution")
graph_execution_cls = module.GraphExecution
workflow_id = self._pending_graph_execution_workflow_id or ""
self._pending_graph_execution_workflow_id = None
return graph_execution_cls(workflow_id=workflow_id)
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
coordinator_cls = module.ResponseStreamCoordinator
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)

View File

@ -16,6 +16,10 @@ class ReadOnlyVariablePool(Protocol):
"""Get all variables for a node (read-only)."""
...
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
"""Get all variables stored under a given node prefix (read-only)."""
...
class ReadOnlyGraphRuntimeState(Protocol):
"""
@ -56,6 +60,20 @@ class ReadOnlyGraphRuntimeState(Protocol):
"""Get the node run steps count (read-only)."""
...
@property
def ready_queue_size(self) -> int:
"""Get the number of nodes currently in the ready queue."""
...
@property
def exceptions_count(self) -> int:
"""Get the number of node execution exceptions recorded."""
...
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
...
def dumps(self) -> str:
"""Serialize the runtime state into a JSON snapshot (read-only)."""
...

View File

@ -1,77 +1,82 @@
from __future__ import annotations
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from .graph_runtime_state import GraphRuntimeState
from .variable_pool import VariablePool
class ReadOnlyVariablePoolWrapper:
"""Wrapper that provides read-only access to VariablePool."""
"""Provide defensive, read-only access to ``VariablePool``."""
def __init__(self, variable_pool: VariablePool):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (returns a defensive copy)."""
"""Return a copy of a variable value if present."""
value = self._variable_pool.get([node_id, variable_key])
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (returns defensive copies)."""
"""Return a copy of all variables for the specified node."""
variables: dict[str, object] = {}
if node_id in self._variable_pool.variable_dictionary:
for key, var in self._variable_pool.variable_dictionary[node_id].items():
# Variables have a value property that contains the actual data
variables[key] = deepcopy(var.value)
for key, variable in self._variable_pool.variable_dictionary[node_id].items():
variables[key] = deepcopy(variable.value)
return variables
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
"""Return a copy of all variables stored under the given prefix."""
return self._variable_pool.get_by_prefix(prefix)
class ReadOnlyGraphRuntimeStateWrapper:
"""
Wrapper that provides read-only access to GraphRuntimeState.
"""Expose a defensive, read-only view of ``GraphRuntimeState``."""
This wrapper ensures that layers can observe the state without
modifying it. All returned values are defensive copies.
"""
def __init__(self, state: GraphRuntimeState):
def __init__(self, state: GraphRuntimeState) -> None:
self._state = state
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
@property
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
"""Get read-only access to the variable pool."""
return self._variable_pool_wrapper
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
return self._state.start_at
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
return self._state.total_tokens
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
# Return a copy to prevent modification
return self._state.llm_usage.model_copy()
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
return deepcopy(self._state.outputs)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
return self._state.node_run_steps
@property
def ready_queue_size(self) -> int:
return self._state.ready_queue.qsize()
@property
def exceptions_count(self) -> int:
return self._state.graph_execution.exceptions_count
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
return self._state.get_output(key, default)
def dumps(self) -> str:
"""Serialize the underlying runtime state for external persistence."""
return self._state.dumps()

View File

@ -1,6 +1,7 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
@ -235,6 +236,20 @@ class VariablePool(BaseModel):
return segment
return None
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
"""Return a copy of all variables stored under the given node prefix."""
nodes = self.variable_dictionary.get(prefix)
if not nodes:
return {}
result: dict[str, object] = {}
for key, variable in nodes.items():
value = variable.value
result[key] = deepcopy(value)
return result
def _add_system_variables(self, system_variable: SystemVariable):
sys_var_mapping = system_variable.to_dict()
for key, value in sys_var_mapping.items():

View File

@ -5,7 +5,7 @@ from typing import Literal, NamedTuple
from core.file import FileAttribute, file_manager
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
from .entities import Condition, SubCondition, SupportedComparisonOperator

View File

@ -4,7 +4,7 @@ from typing import Any, Protocol
from core.variables import Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.runtime import VariablePool
class VariableLoader(Protocol):

View File

@ -1,459 +0,0 @@
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.entities import (
WorkflowExecution,
WorkflowNodeExecution,
)
from core.workflow.enums import (
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
WorkflowType,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
@dataclass
class CycleManagerWorkflowInfo:
workflow_id: str
workflow_type: WorkflowType
version: str
graph_data: Mapping[str, Any]
class WorkflowCycleManager:
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: SystemVariable,
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
# Initialize caches for workflow execution cycle
# These caches avoid redundant repository calls during a single workflow execution
self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
def handle_workflow_run_start(self) -> WorkflowExecution:
inputs = self._prepare_workflow_inputs()
execution_id = self._get_or_generate_execution_id()
execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
workflow_type=self._workflow_info.workflow_type,
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=inputs,
started_at=naive_utc_now(),
)
return self._save_and_cache_workflow_execution(execution)
def handle_workflow_run_success(
self,
*,
workflow_run_id: str,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
external_trace_id: str | None = None,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
self._update_workflow_execution_completion(
workflow_execution,
status=WorkflowExecutionStatus.SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
def handle_workflow_run_partial_success(
self,
*,
workflow_run_id: str,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
external_trace_id: str | None = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
self._update_workflow_execution_completion(
execution,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
exceptions_count=exceptions_count,
)
self._add_trace_task_if_needed(trace_manager, execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(execution)
return execution
def handle_workflow_run_failed(
self,
*,
workflow_run_id: str,
total_tokens: int,
total_steps: int,
status: WorkflowExecutionStatus,
error_message: str,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
exceptions_count: int = 0,
external_trace_id: str | None = None,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
now = naive_utc_now()
self._update_workflow_execution_completion(
workflow_execution,
status=status,
total_tokens=total_tokens,
total_steps=total_steps,
error_message=error_message,
exceptions_count=exceptions_count,
finished_at=now,
)
self._fail_running_node_executions(workflow_execution.id_, error_message, now)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
def handle_node_execution_start(
self,
*,
workflow_execution_id: str,
event: QueueNodeStartedEvent,
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RUNNING,
)
return self._save_and_cache_node_execution(domain_execution)
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
self._update_node_execution_completion(
domain_execution,
event=event,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
return domain_execution
def handle_workflow_node_execution_failed(
self,
*,
event: QueueNodeFailedEvent | QueueNodeExceptionEvent,
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
status = (
WorkflowNodeExecutionStatus.EXCEPTION
if isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.FAILED
)
self._update_node_execution_completion(
domain_execution,
event=event,
status=status,
error=event.error,
handle_special_values=True,
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
return domain_execution
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RETRY,
error=event.error,
created_at=event.start_at,
)
# Handle inputs and outputs
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = event.outputs
metadata = self._merge_event_metadata(event)
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
execution = self._save_and_cache_node_execution(domain_execution)
self._workflow_node_execution_repository.save_execution_data(execution)
return execution
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
# Check cache first
if id in self._workflow_execution_cache:
return self._workflow_execution_cache[id]
raise WorkflowRunNotFoundError(id)
def _prepare_workflow_inputs(self) -> dict[str, Any]:
"""Prepare workflow inputs by merging application inputs with system variables."""
inputs = {**self._application_generate_entity.inputs}
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name != SystemVariableKey.CONVERSATION_ID:
inputs[f"sys.{field_name}"] = value
return dict(WorkflowEntry.handle_special_values(inputs) or {})
def _get_or_generate_execution_id(self) -> str:
"""Get execution ID from system variables or generate a new one."""
if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
return str(self._workflow_system_variables.workflow_execution_id)
return str(uuidv7())
def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
"""Save workflow execution to repository and cache it."""
self._workflow_execution_repository.save(execution)
self._workflow_execution_cache[execution.id_] = execution
return execution
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
"""Save node execution to repository and cache it if it has an ID.
This does not persist the `inputs` / `process_data` / `outputs` fields of the execution model.
"""
self._workflow_node_execution_repository.save(execution)
if execution.node_execution_id:
self._node_execution_cache[execution.node_execution_id] = execution
return execution
def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
"""Get node execution from cache or raise error if not found."""
domain_execution = self._node_execution_cache.get(node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {node_execution_id}")
return domain_execution
def _update_workflow_execution_completion(
self,
execution: WorkflowExecution,
*,
status: WorkflowExecutionStatus,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
error_message: str | None = None,
exceptions_count: int = 0,
finished_at: datetime | None = None,
):
"""Update workflow execution with completion data."""
execution.status = status
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = finished_at or naive_utc_now()
execution.exceptions_count = exceptions_count
if error_message:
execution.error_message = error_message
def _add_trace_task_if_needed(
self,
trace_manager: TraceQueueManager | None,
workflow_execution: WorkflowExecution,
conversation_id: str | None,
external_trace_id: str | None,
):
"""Add trace task if trace manager is provided."""
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
external_trace_id=external_trace_id,
)
)
def _fail_running_node_executions(
self,
workflow_execution_id: str,
error_message: str,
now: datetime,
):
"""Fail all running node executions for a workflow."""
running_node_executions = [
node_exec
for node_exec in self._node_execution_cache.values()
if node_exec.workflow_execution_id == workflow_execution_id
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
]
for node_execution in running_node_executions:
if node_execution.node_execution_id:
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
self._workflow_node_execution_repository.save(node_execution)
def _create_node_execution_from_event(
self,
*,
workflow_execution: WorkflowExecution,
event: QueueNodeStartedEvent,
status: WorkflowNodeExecutionStatus,
error: str | None = None,
created_at: datetime | None = None,
) -> WorkflowNodeExecution:
"""Create a node execution from an event."""
now = naive_utc_now()
created_at = created_at or now
metadata = {
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = WorkflowNodeExecution(
id=event.node_execution_id,
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_title,
status=status,
metadata=metadata,
created_at=created_at,
error=error,
)
if status == WorkflowNodeExecutionStatus.RETRY:
domain_execution.finished_at = now
domain_execution.elapsed_time = (now - created_at).total_seconds()
return domain_execution
def _update_node_execution_completion(
self,
domain_execution: WorkflowNodeExecution,
*,
event: Union[
QueueNodeSucceededEvent,
QueueNodeFailedEvent,
QueueNodeExceptionEvent,
],
status: WorkflowNodeExecutionStatus,
error: str | None = None,
handle_special_values: bool = False,
):
"""Update node execution with completion data."""
finished_at = naive_utc_now()
elapsed_time = (finished_at - event.start_at).total_seconds()
# Process data
if handle_special_values:
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
else:
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs
# Convert metadata
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
if event.execution_metadata:
execution_metadata_dict.update(event.execution_metadata)
# Update domain model
domain_execution.status = status
domain_execution.update_from_mapping(
inputs=inputs,
process_data=process_data,
outputs=outputs,
metadata=execution_metadata_dict,
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
if error:
domain_execution.error = error
def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
"""Merge event metadata with origin metadata."""
origin_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
if event.execution_metadata:
execution_metadata_dict.update(event.execution_metadata)
return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata

View File

@ -9,7 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
@ -20,6 +20,7 @@ from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, Gra
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory