Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

This commit is contained in:
jyong
2025-08-28 18:12:49 +08:00
71 changed files with 801 additions and 2326 deletions

View File

@ -6,10 +6,12 @@ within a single process. Each instance handles commands for one workflow executi
"""
from queue import Queue
from typing import final
from ..entities.commands import GraphEngineCommand
@final
class InMemoryChannel:
"""
In-memory command channel implementation using a thread-safe queue.

View File

@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue.
"""
import json
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
@ -15,6 +15,7 @@ if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@final
class RedisChannel:
"""
Redis-based command channel implementation for distributed systems.
@ -86,7 +87,7 @@ class RedisChannel:
pipe.expire(self._key, self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict) -> Optional[GraphEngineCommand]:
def _deserialize_command(self, data: dict) -> GraphEngineCommand | None:
"""
Deserialize a command from dictionary data.

View File

@ -3,6 +3,7 @@ Command handler implementations.
"""
import logging
from typing import final
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
@ -11,6 +12,7 @@ from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@final
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""

View File

@ -3,7 +3,7 @@ Main command processor for handling external commands.
"""
import logging
from typing import Protocol
from typing import Protocol, final
from ..domain.graph_execution import GraphExecution
from ..entities.commands import GraphEngineCommand
@ -18,6 +18,7 @@ class CommandHandler(Protocol):
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
@final
class CommandProcessor:
"""
Processes external commands sent to the engine.

View File

@ -3,7 +3,6 @@ GraphExecution aggregate root managing the overall graph execution state.
"""
from dataclasses import dataclass, field
from typing import Optional
from .node_execution import NodeExecution
@ -21,7 +20,7 @@ class GraphExecution:
started: bool = False
completed: bool = False
aborted: bool = False
error: Optional[Exception] = None
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
def start(self) -> None:

View File

@ -3,7 +3,6 @@ NodeExecution entity representing a node's execution state.
"""
from dataclasses import dataclass
from typing import Optional
from core.workflow.enums import NodeState
@ -20,8 +19,8 @@ class NodeExecution:
node_id: str
state: NodeState = NodeState.UNKNOWN
retry_count: int = 0
execution_id: Optional[str] = None
error: Optional[str] = None
execution_id: str | None = None
error: str | None = None
def mark_started(self, execution_id: str) -> None:
"""Mark the node as started with an execution ID."""

View File

@ -6,7 +6,7 @@ instance to control its execution flow.
"""
from enum import Enum
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -23,11 +23,11 @@ class GraphEngineCommand(BaseModel):
"""Base class for all GraphEngine commands."""
command_type: CommandType = Field(..., description="Type of command")
payload: Optional[dict[str, Any]] = Field(default=None, description="Optional command payload")
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")
class AbortCommand(GraphEngineCommand):
"""Command to abort a running workflow execution."""
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
reason: Optional[str] = Field(default=None, description="Optional reason for abort")
reason: str | None = Field(default=None, description="Optional reason for abort")

View File

@ -8,7 +8,6 @@ the Strategy pattern for clean separation of concerns.
from .abort_strategy import AbortStrategy
from .default_value_strategy import DefaultValueStrategy
from .error_handler import ErrorHandler
from .error_strategy import ErrorStrategy
from .fail_branch_strategy import FailBranchStrategy
from .retry_strategy import RetryStrategy
@ -16,7 +15,6 @@ __all__ = [
"AbortStrategy",
"DefaultValueStrategy",
"ErrorHandler",
"ErrorStrategy",
"FailBranchStrategy",
"RetryStrategy",
]

View File

@ -3,7 +3,7 @@ Abort error strategy implementation.
"""
import logging
from typing import Optional
from typing import final
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
@ -11,6 +11,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
logger = logging.getLogger(__name__)
@final
class AbortStrategy:
"""
Error strategy that aborts execution on failure.
@ -19,7 +20,7 @@ class AbortStrategy:
It stops the entire graph execution when a node fails.
"""
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
"""
Handle error by aborting execution.

View File

@ -2,7 +2,7 @@
Default value error strategy implementation.
"""
from typing import Optional
from typing import final
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent
from core.workflow.node_events import NodeRunResult
@final
class DefaultValueStrategy:
"""
Error strategy that uses default values on failure.
@ -18,7 +19,7 @@ class DefaultValueStrategy:
predefined default output values.
"""
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
"""
Handle error by using default values.

View File

@ -2,7 +2,7 @@
Main error handler that coordinates error strategies.
"""
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final
from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum
from core.workflow.graph import Graph
@ -17,6 +17,7 @@ if TYPE_CHECKING:
from ..domain import GraphExecution
@final
class ErrorHandler:
"""
Coordinates error handling strategies for node failures.
@ -34,16 +35,16 @@ class ErrorHandler:
graph: The workflow graph
graph_execution: The graph execution state
"""
self.graph = graph
self.graph_execution = graph_execution
self._graph = graph
self._graph_execution = graph_execution
# Initialize strategies
self.abort_strategy = AbortStrategy()
self.retry_strategy = RetryStrategy()
self.fail_branch_strategy = FailBranchStrategy()
self.default_value_strategy = DefaultValueStrategy()
self._abort_strategy = AbortStrategy()
self._retry_strategy = RetryStrategy()
self._fail_branch_strategy = FailBranchStrategy()
self._default_value_strategy = DefaultValueStrategy()
def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]:
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
"""
Handle a node failure event.
@ -56,14 +57,14 @@ class ErrorHandler:
Returns:
Optional new event to process, or None to abort
"""
node = self.graph.nodes[event.node_id]
node = self._graph.nodes[event.node_id]
# Get retry count from NodeExecution
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
retry_count = node_execution.retry_count
# First check if retry is configured and not exhausted
if node.retry and retry_count < node.retry_config.max_retries:
result = self.retry_strategy.handle_error(event, self.graph, retry_count)
result = self._retry_strategy.handle_error(event, self._graph, retry_count)
if result:
# Retry count will be incremented when NodeRunRetryEvent is handled
return result
@ -71,12 +72,10 @@ class ErrorHandler:
# Apply configured error strategy
strategy = node.error_strategy
if strategy is None:
return self.abort_strategy.handle_error(event, self.graph, retry_count)
elif strategy == ErrorStrategyEnum.FAIL_BRANCH:
return self.fail_branch_strategy.handle_error(event, self.graph, retry_count)
elif strategy == ErrorStrategyEnum.DEFAULT_VALUE:
return self.default_value_strategy.handle_error(event, self.graph, retry_count)
else:
# Unknown strategy, default to abort
return self.abort_strategy.handle_error(event, self.graph, retry_count)
match strategy:
case None:
return self._abort_strategy.handle_error(event, self._graph, retry_count)
case ErrorStrategyEnum.FAIL_BRANCH:
return self._fail_branch_strategy.handle_error(event, self._graph, retry_count)
case ErrorStrategyEnum.DEFAULT_VALUE:
return self._default_value_strategy.handle_error(event, self._graph, retry_count)

View File

@ -2,7 +2,7 @@
Fail branch error strategy implementation.
"""
from typing import Optional
from typing import final
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent
from core.workflow.node_events import NodeRunResult
@final
class FailBranchStrategy:
"""
Error strategy that continues execution via a fail branch.
@ -18,7 +19,7 @@ class FailBranchStrategy:
through a designated fail-branch edge.
"""
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
"""
Handle error by taking the fail branch.

View File

@ -3,12 +3,13 @@ Retry error strategy implementation.
"""
import time
from typing import Optional
from typing import final
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunRetryEvent
@final
class RetryStrategy:
"""
Error strategy that retries failed nodes.
@ -17,7 +18,7 @@ class RetryStrategy:
maximum number of retries with configurable intervals.
"""
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
"""
Handle error by retrying the node.

View File

@ -3,12 +3,92 @@ Event collector for buffering and managing events.
"""
import threading
from typing import final
from core.workflow.graph_events import GraphEngineEvent
from ..layers.base import Layer
@final
class ReadWriteLock:
"""
A read-write lock implementation that allows multiple concurrent readers
but only one writer at a time.
"""
def __init__(self) -> None:
self._read_ready = threading.Condition(threading.RLock())
self._readers = 0
def acquire_read(self) -> None:
"""Acquire a read lock."""
self._read_ready.acquire()
try:
self._readers += 1
finally:
self._read_ready.release()
def release_read(self) -> None:
"""Release a read lock."""
self._read_ready.acquire()
try:
self._readers -= 1
if self._readers == 0:
self._read_ready.notify_all()
finally:
self._read_ready.release()
def acquire_write(self) -> None:
"""Acquire a write lock."""
self._read_ready.acquire()
while self._readers > 0:
self._read_ready.wait()
def release_write(self) -> None:
"""Release a write lock."""
self._read_ready.release()
def read_lock(self) -> "ReadLockContext":
"""Return a context manager for read locking."""
return ReadLockContext(self)
def write_lock(self) -> "WriteLockContext":
"""Return a context manager for write locking."""
return WriteLockContext(self)
@final
class ReadLockContext:
"""Context manager for read locks."""
def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock
def __enter__(self) -> "ReadLockContext":
self._lock.acquire_read()
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_read()
@final
class WriteLockContext:
"""Context manager for write locks."""
def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock
def __enter__(self) -> "WriteLockContext":
self._lock.acquire_write()
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_write()
@final
class EventCollector:
"""
Collects and buffers events for later retrieval.
@ -20,7 +100,7 @@ class EventCollector:
def __init__(self) -> None:
"""Initialize the event collector."""
self._events: list[GraphEngineEvent] = []
self._lock = threading.Lock()
self._lock = ReadWriteLock()
self._layers: list[Layer] = []
def set_layers(self, layers: list[Layer]) -> None:
@ -39,7 +119,7 @@ class EventCollector:
Args:
event: The event to collect
"""
with self._lock:
with self._lock.write_lock():
self._events.append(event)
self._notify_layers(event)
@ -50,7 +130,7 @@ class EventCollector:
Returns:
List of collected events
"""
with self._lock:
with self._lock.read_lock():
return list(self._events)
def get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
@ -63,7 +143,7 @@ class EventCollector:
Returns:
List of new events
"""
with self._lock:
with self._lock.read_lock():
return list(self._events[start_index:])
def event_count(self) -> int:
@ -73,12 +153,12 @@ class EventCollector:
Returns:
Number of collected events
"""
with self._lock:
with self._lock.read_lock():
return len(self._events)
def clear(self) -> None:
"""Clear all collected events."""
with self._lock:
with self._lock.write_lock():
self._events.clear()
def _notify_layers(self, event: GraphEngineEvent) -> None:

View File

@ -5,12 +5,14 @@ Event emitter for yielding events to external consumers.
import threading
import time
from collections.abc import Generator
from typing import final
from core.workflow.graph_events import GraphEngineEvent
from .event_collector import EventCollector
@final
class EventEmitter:
"""
Emits collected events as a generator for external consumption.

View File

@ -3,7 +3,7 @@ Event handler implementations for different event types.
"""
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
@ -38,6 +38,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@final
class EventHandlerRegistry:
"""
Registry of event handlers for different event types.
@ -52,12 +53,12 @@ class EventHandlerRegistry:
graph_runtime_state: GraphRuntimeState,
graph_execution: GraphExecution,
response_coordinator: ResponseStreamCoordinator,
event_collector: Optional["EventCollector"] = None,
branch_handler: Optional["BranchHandler"] = None,
edge_processor: Optional["EdgeProcessor"] = None,
node_state_manager: Optional["NodeStateManager"] = None,
execution_tracker: Optional["ExecutionTracker"] = None,
error_handler: Optional["ErrorHandler"] = None,
event_collector: "EventCollector",
branch_handler: "BranchHandler",
edge_processor: "EdgeProcessor",
node_state_manager: "NodeStateManager",
execution_tracker: "ExecutionTracker",
error_handler: "ErrorHandler",
) -> None:
"""
Initialize the event handler registry.
@ -67,23 +68,23 @@ class EventHandlerRegistry:
graph_runtime_state: Runtime state with variable pool
graph_execution: Graph execution aggregate
response_coordinator: Response stream coordinator
event_collector: Optional event collector for collecting events
branch_handler: Optional branch handler for branch node processing
edge_processor: Optional edge processor for edge traversal
node_state_manager: Optional node state manager
execution_tracker: Optional execution tracker
error_handler: Optional error handler
event_collector: Event collector for collecting events
branch_handler: Branch handler for branch node processing
edge_processor: Edge processor for edge traversal
node_state_manager: Node state manager
execution_tracker: Execution tracker
error_handler: Error handler
"""
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.graph_execution = graph_execution
self.response_coordinator = response_coordinator
self.event_collector = event_collector
self.branch_handler = branch_handler
self.edge_processor = edge_processor
self.node_state_manager = node_state_manager
self.execution_tracker = execution_tracker
self.error_handler = error_handler
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._event_collector = event_collector
self._branch_handler = branch_handler
self._edge_processor = edge_processor
self._node_state_manager = node_state_manager
self._execution_tracker = execution_tracker
self._error_handler = error_handler
def handle_event(self, event: GraphNodeEventBase) -> None:
"""
@ -93,9 +94,8 @@ class EventHandlerRegistry:
event: The event to handle
"""
# Events in loops or iterations are always collected
if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id):
if self.event_collector:
self.event_collector.collect(event)
if event.in_loop_id or event.in_iteration_id:
self._event_collector.collect(event)
return
# Handle specific event types
@ -125,12 +125,10 @@ class EventHandlerRegistry:
),
):
# Iteration and loop events are collected directly
if self.event_collector:
self.event_collector.collect(event)
self._event_collector.collect(event)
else:
# Collect unhandled events
if self.event_collector:
self.event_collector.collect(event)
self._event_collector.collect(event)
logger.warning("Unhandled event type: %s", type(event).__name__)
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
@ -141,15 +139,14 @@ class EventHandlerRegistry:
event: The node started event
"""
# Track execution in domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_started(event.id)
# Track in response coordinator for stream ordering
self.response_coordinator.track_node_execution(event.node_id, event.id)
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
if self.event_collector:
self.event_collector.collect(event)
self._event_collector.collect(event)
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
"""
@ -159,12 +156,11 @@ class EventHandlerRegistry:
event: The stream chunk event
"""
# Process with response coordinator
streaming_events = list(self.response_coordinator.intercept_event(event))
streaming_events = list(self._response_coordinator.intercept_event(event))
# Collect all events
if self.event_collector:
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
for stream_event in streaming_events:
self._event_collector.collect(stream_event)
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
"""
@ -177,55 +173,44 @@ class EventHandlerRegistry:
event: The node succeeded event
"""
# Update domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
# Store outputs in variable pool
self._store_node_outputs(event)
# Forward to response coordinator and emit streaming events
streaming_events = list(self.response_coordinator.intercept_event(event))
if self.event_collector:
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
streaming_events = self._response_coordinator.intercept_event(event)
for stream_event in streaming_events:
self._event_collector.collect(stream_event)
# Process edges and get ready nodes
node = self.graph.nodes[event.node_id]
node = self._graph.nodes[event.node_id]
if node.execution_type == NodeExecutionType.BRANCH:
if self.branch_handler:
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
ready_nodes, edge_streaming_events = [], []
ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
if self.edge_processor:
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
else:
ready_nodes, edge_streaming_events = [], []
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
# Collect streaming events from edge processing
if self.event_collector:
for edge_event in edge_streaming_events:
self.event_collector.collect(edge_event)
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
# Enqueue ready nodes
if self.node_state_manager and self.execution_tracker:
for node_id in ready_nodes:
self.node_state_manager.enqueue_node(node_id)
self.execution_tracker.add(node_id)
for node_id in ready_nodes:
self._node_state_manager.enqueue_node(node_id)
self._execution_tracker.add(node_id)
# Update execution tracking
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
self._execution_tracker.remove(event.node_id)
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event)
# Collect the event
if self.event_collector:
self.event_collector.collect(event)
self._event_collector.collect(event)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
"""
@ -235,29 +220,19 @@ class EventHandlerRegistry:
event: The node failed event
"""
# Update domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
if self.error_handler:
result = self.error_handler.handle_node_failure(event)
result = self._error_handler.handle_node_failure(event)
if result:
# Process the resulting event (retry, exception, etc.)
self.handle_event(result)
else:
# Abort execution
self.graph_execution.fail(RuntimeError(event.error))
if self.event_collector:
self.event_collector.collect(event)
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
if result:
# Process the resulting event (retry, exception, etc.)
self.handle_event(result)
else:
# Without error handler, just fail
self.graph_execution.fail(RuntimeError(event.error))
if self.event_collector:
self.event_collector.collect(event)
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
# Abort execution
self._graph_execution.fail(RuntimeError(event.error))
self._event_collector.collect(event)
self._execution_tracker.remove(event.node_id)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
"""
@ -267,7 +242,7 @@ class EventHandlerRegistry:
event: The node exception event
"""
# Node continues via fail-branch, so it's technically "succeeded"
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
@ -277,7 +252,7 @@ class EventHandlerRegistry:
Args:
event: The node retry event
"""
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
@ -288,16 +263,16 @@ class EventHandlerRegistry:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in event.node_run_result.outputs.items():
self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
"""Update response outputs for response nodes."""
for key, value in event.node_run_result.outputs.items():
if key == "answer":
existing = self.graph_runtime_state.outputs.get("answer", "")
existing = self._graph_runtime_state.outputs.get("answer", "")
if existing:
self.graph_runtime_state.outputs["answer"] = f"{existing}{value}"
self._graph_runtime_state.outputs["answer"] = f"{existing}{value}"
else:
self.graph_runtime_state.outputs["answer"] = value
self._graph_runtime_state.outputs["answer"] = value
else:
self.graph_runtime_state.outputs[key] = value
self._graph_runtime_state.outputs[key] = value

View File

@ -9,7 +9,7 @@ import contextvars
import logging
import queue
from collections.abc import Generator, Mapping
from typing import Any, Optional
from typing import final
from flask import Flask, current_app
@ -20,6 +20,7 @@ from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphEngineEvent,
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
@ -44,6 +45,7 @@ from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, Wo
logger = logging.getLogger(__name__)
@final
class GraphEngine:
"""
Queue-based graph execution engine.
@ -62,7 +64,7 @@ class GraphEngine:
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
graph_config: Mapping[str, object],
graph_runtime_state: GraphRuntimeState,
max_execution_steps: int,
max_execution_time: int,
@ -103,7 +105,7 @@ class GraphEngine:
# Initialize queues
self.ready_queue: queue.Queue[str] = queue.Queue()
self.event_queue: queue.Queue = queue.Queue()
self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
# Initialize subsystems
self._initialize_subsystems()
@ -185,7 +187,7 @@ class GraphEngine:
event_handler=self.event_handler_registry,
event_collector=self.event_collector,
command_processor=self.command_processor,
worker_pool=self.worker_pool,
worker_pool=self._worker_pool,
)
self.dispatcher = Dispatcher(
@ -209,7 +211,7 @@ class GraphEngine:
def _setup_worker_management(self) -> None:
"""Initialize worker management subsystem."""
# Capture context for workers
flask_app: Optional[Flask] = None
flask_app: Flask | None = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
@ -218,8 +220,8 @@ class GraphEngine:
context_vars = contextvars.copy_context()
# Create worker management components
self.activity_tracker = ActivityTracker()
self.dynamic_scaler = DynamicScaler(
self._activity_tracker = ActivityTracker()
self._dynamic_scaler = DynamicScaler(
min_workers=(self._min_workers if self._min_workers is not None else dify_config.GRAPH_ENGINE_MIN_WORKERS),
max_workers=(self._max_workers if self._max_workers is not None else dify_config.GRAPH_ENGINE_MAX_WORKERS),
scale_up_threshold=(
@ -233,15 +235,15 @@ class GraphEngine:
else dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
),
)
self.worker_factory = WorkerFactory(flask_app, context_vars)
self._worker_factory = WorkerFactory(flask_app, context_vars)
self.worker_pool = WorkerPool(
self._worker_pool = WorkerPool(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
worker_factory=self.worker_factory,
dynamic_scaler=self.dynamic_scaler,
activity_tracker=self.activity_tracker,
worker_factory=self._worker_factory,
dynamic_scaler=self._dynamic_scaler,
activity_tracker=self._activity_tracker,
)
def _validate_graph_state_consistency(self) -> None:
@ -319,10 +321,10 @@ class GraphEngine:
def _start_execution(self) -> None:
"""Start execution subsystems."""
# Calculate initial worker count
initial_workers = self.dynamic_scaler.calculate_initial_workers(self.graph)
initial_workers = self._dynamic_scaler.calculate_initial_workers(self.graph)
# Start worker pool
self.worker_pool.start(initial_workers)
self._worker_pool.start(initial_workers)
# Register response nodes
for node in self.graph.nodes.values():
@ -340,7 +342,7 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self.dispatcher.stop()
self.worker_pool.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
# Notify layers

View File

@ -2,15 +2,18 @@
Branch node handling for graph traversal.
"""
from typing import Optional
from collections.abc import Sequence
from typing import final
from core.workflow.graph import Graph
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
from ..state_management import EdgeStateManager
from .edge_processor import EdgeProcessor
from .skip_propagator import SkipPropagator
@final
class BranchHandler:
"""
Handles branch node logic during graph traversal.
@ -40,7 +43,9 @@ class BranchHandler:
self.skip_propagator = skip_propagator
self.edge_state_manager = edge_state_manager
def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]:
def handle_branch_completion(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Handle completion of a branch node.
@ -58,10 +63,10 @@ class BranchHandler:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
# Categorize edges into selected and unselected
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
_, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
# Skip all unselected paths
self.skip_propagator.skip_branch_paths(node_id, unselected_edges)
self.skip_propagator.skip_branch_paths(unselected_edges)
# Process selected edges and get ready nodes and streaming events
return self.edge_processor.process_node_success(node_id, selected_handle)

View File

@ -2,13 +2,18 @@
Edge processing logic for graph traversal.
"""
from collections.abc import Sequence
from typing import final
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Edge, Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent
from ..response_coordinator import ResponseStreamCoordinator
from ..state_management import EdgeStateManager, NodeStateManager
@final
class EdgeProcessor:
"""
Processes edges during graph execution.
@ -38,7 +43,9 @@ class EdgeProcessor:
self.node_state_manager = node_state_manager
self.response_coordinator = response_coordinator
def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]:
def process_node_success(
self, node_id: str, selected_handle: str | None = None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges after a node succeeds.
@ -56,7 +63,7 @@ class EdgeProcessor:
else:
return self._process_non_branch_node_edges(node_id)
def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]:
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for non-branch nodes (mark all as TAKEN).
@ -66,8 +73,8 @@ class EdgeProcessor:
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
"""
ready_nodes = []
all_streaming_events = []
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self.graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
@ -77,7 +84,9 @@ class EdgeProcessor:
return ready_nodes, all_streaming_events
def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]:
def _process_branch_node_edges(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for branch nodes.
@ -94,8 +103,8 @@ class EdgeProcessor:
if not selected_handle:
raise ValueError(f"Branch node {node_id} did not select any edge")
ready_nodes = []
all_streaming_events = []
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
# Categorize edges
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
@ -112,7 +121,7 @@ class EdgeProcessor:
return ready_nodes, all_streaming_events
def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]:
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Mark edge as taken and check downstream node.
@ -129,11 +138,11 @@ class EdgeProcessor:
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready
ready_nodes = []
ready_nodes: list[str] = []
if self.node_state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)
return ready_nodes, list(streaming_events)
return ready_nodes, streaming_events
def _process_skipped_edge(self, edge: Edge) -> None:
"""

View File

@ -2,10 +2,13 @@
Node readiness checking for execution.
"""
from typing import final
from core.workflow.enums import NodeState
from core.workflow.graph import Graph
@final
class NodeReadinessChecker:
"""
Checks if nodes are ready for execution based on their dependencies.
@ -71,7 +74,7 @@ class NodeReadinessChecker:
Returns:
List of node IDs that are now ready
"""
ready_nodes = []
ready_nodes: list[str] = []
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
for edge in outgoing_edges:

View File

@ -2,11 +2,15 @@
Skip state propagation through the graph.
"""
from core.workflow.graph import Graph
from collections.abc import Sequence
from typing import final
from core.workflow.graph import Edge, Graph
from ..state_management import EdgeStateManager, NodeStateManager
@final
class SkipPropagator:
"""
Propagates skip states through the graph.
@ -57,9 +61,8 @@ class SkipPropagator:
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Check if node is ready and enqueue if so
if self.node_state_manager.is_node_ready(downstream_node_id):
self.node_state_manager.enqueue_node(downstream_node_id)
# Enqueue node
self.node_state_manager.enqueue_node(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
@ -83,12 +86,11 @@ class SkipPropagator:
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)
def skip_branch_paths(self, node_id: str, unselected_edges: list) -> None:
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
"""
Skip all paths from unselected branch edges.
Args:
node_id: The ID of the branch node
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:

View File

@ -6,7 +6,6 @@ intercept and respond to GraphEngine events.
"""
from abc import ABC, abstractmethod
from typing import Optional
from core.workflow.entities import GraphRuntimeState
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
@ -28,8 +27,8 @@ class Layer(ABC):
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: Optional[GraphRuntimeState] = None
self.command_channel: Optional[CommandChannel] = None
self.graph_runtime_state: GraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
def initialize(self, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel) -> None:
"""
@ -73,7 +72,7 @@ class Layer(ABC):
pass
@abstractmethod
def on_graph_end(self, error: Optional[Exception]) -> None:
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.

View File

@ -7,7 +7,7 @@ graph execution for debugging purposes.
import logging
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any, final
from core.workflow.graph_events import (
GraphEngineEvent,
@ -34,6 +34,7 @@ from core.workflow.graph_events import (
from .base import Layer
@final
class DebugLoggingLayer(Layer):
"""
A layer that provides comprehensive logging of GraphEngine execution.
@ -221,7 +222,7 @@ class DebugLoggingLayer(Layer):
# Log unknown events at debug level
self.logger.debug("Event: %s", event_class)
def on_graph_end(self, error: Optional[Exception]) -> None:
def on_graph_end(self, error: Exception | None) -> None:
"""Log graph execution end with summary statistics."""
self.logger.info("=" * 80)

View File

@ -11,7 +11,7 @@ When limits are exceeded, the layer automatically aborts execution.
import logging
import time
from enum import Enum
from typing import Optional
from typing import final
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers import Layer
@ -29,6 +29,7 @@ class LimitType(Enum):
TIME_LIMIT = "time_limit"
@final
class ExecutionLimitsLayer(Layer):
"""
Layer that enforces execution limits for workflows.
@ -53,7 +54,7 @@ class ExecutionLimitsLayer(Layer):
self.max_time = max_time
# Runtime tracking
self.start_time: Optional[float] = None
self.start_time: float | None = None
self.step_count = 0
self.logger = logging.getLogger(__name__)
@ -94,7 +95,7 @@ class ExecutionLimitsLayer(Layer):
if self._reached_time_limitation():
self._send_abort_command(LimitType.TIME_LIMIT)
def on_graph_end(self, error: Optional[Exception]) -> None:
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._execution_started and not self._execution_ended:
self._execution_ended = True

View File

@ -6,13 +6,14 @@ using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
from typing import Optional
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from extensions.ext_redis import redis_client
@final
class GraphEngineManager:
"""
Manager for sending control commands to GraphEngine instances.
@ -23,7 +24,7 @@ class GraphEngineManager:
"""
@staticmethod
def send_stop_command(task_id: str, reason: Optional[str] = None) -> None:
def send_stop_command(task_id: str, reason: str | None = None) -> None:
"""
Send a stop command to a running workflow.

View File

@ -6,7 +6,9 @@ import logging
import queue
import threading
import time
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final
from core.workflow.graph_events.base import GraphNodeEventBase
from ..event_management import EventCollector, EventEmitter
from .execution_coordinator import ExecutionCoordinator
@ -17,6 +19,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@final
class Dispatcher:
"""
Main dispatcher that processes events from the event queue.
@ -27,12 +30,12 @@ class Dispatcher:
def __init__(
self,
event_queue: queue.Queue,
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandlerRegistry",
event_collector: EventCollector,
execution_coordinator: ExecutionCoordinator,
max_execution_time: int,
event_emitter: Optional[EventEmitter] = None,
event_emitter: EventEmitter | None = None,
) -> None:
"""
Initialize the dispatcher.
@ -52,9 +55,9 @@ class Dispatcher:
self.max_execution_time = max_execution_time
self.event_emitter = event_emitter
self._thread: Optional[threading.Thread] = None
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._start_time: Optional[float] = None
self._start_time: float | None = None
def start(self) -> None:
"""Start the dispatcher thread."""

View File

@ -2,7 +2,7 @@
Execution coordinator for managing overall workflow execution.
"""
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, final
from ..command_processing import CommandProcessor
from ..domain import GraphExecution
@ -14,6 +14,7 @@ if TYPE_CHECKING:
from ..event_management import EventHandlerRegistry
@final
class ExecutionCoordinator:
"""
Coordinates overall execution flow between subsystems.

View File

@ -7,7 +7,7 @@ thread-safe storage for node outputs.
from collections.abc import Sequence
from threading import RLock
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Union, final
from core.variables import Segment
from core.workflow.entities.variable_pool import VariablePool
@ -18,6 +18,7 @@ if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent
@final
class OutputRegistry:
"""
Thread-safe registry for storing and retrieving node outputs.
@ -47,7 +48,7 @@ class OutputRegistry:
with self._lock:
self._scalars.add(selector, value)
def get_scalar(self, selector: Sequence[str]) -> Optional["Segment"]:
def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
"""
Get a scalar value for the given selector.
@ -81,7 +82,7 @@ class OutputRegistry:
except ValueError:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
def pop_chunk(self, selector: Sequence[str]) -> Optional["NodeRunStreamChunkEvent"]:
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.

View File

@ -5,12 +5,13 @@ This module contains the private Stream class used internally by OutputRegistry
to manage streaming data chunks.
"""
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final
if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent
@final
class Stream:
"""
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.
@ -41,7 +42,7 @@ class Stream:
raise ValueError("Cannot append to a closed stream")
self.events.append(event)
def pop_next(self) -> Optional["NodeRunStreamChunkEvent"]:
def pop_next(self) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.

View File

@ -2,7 +2,7 @@
Base error strategy protocol.
"""
from typing import Optional, Protocol
from typing import Protocol
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
@ -16,7 +16,7 @@ class ErrorStrategy(Protocol):
node execution failures.
"""
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
"""
Handle a node failure event.

View File

@ -9,7 +9,7 @@ import logging
from collections import deque
from collections.abc import Sequence
from threading import RLock
from typing import Optional, TypeAlias
from typing import TypeAlias, final
from uuid import uuid4
from core.workflow.enums import NodeExecutionType, NodeState
@ -28,6 +28,7 @@ NodeID: TypeAlias = str
EdgeID: TypeAlias = str
@final
class ResponseStreamCoordinator:
"""
Manages response streaming sessions without relying on global state.
@ -45,7 +46,7 @@ class ResponseStreamCoordinator:
"""
self.registry = registry
self.graph = graph
self.active_session: Optional[ResponseSession] = None
self.active_session: ResponseSession | None = None
self.waiting_sessions: deque[ResponseSession] = deque()
self.lock = RLock()

View File

@ -3,7 +3,8 @@ Manager for edge states during graph execution.
"""
import threading
from typing import TypedDict
from collections.abc import Sequence
from typing import TypedDict, final
from core.workflow.enums import NodeState
from core.workflow.graph import Edge, Graph
@ -17,6 +18,7 @@ class EdgeStateAnalysis(TypedDict):
all_skipped: bool
@final
class EdgeStateManager:
"""
Manages edge states and transitions during graph execution.
@ -87,7 +89,7 @@ class EdgeStateManager:
with self._lock:
return self.graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]:
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
Categorize branch edges into selected and unselected.
@ -100,8 +102,8 @@ class EdgeStateManager:
"""
with self._lock:
outgoing_edges = self.graph.get_outgoing_edges(node_id)
selected_edges = []
unselected_edges = []
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []
for edge in outgoing_edges:
if edge.source_handle == selected_handle:

View File

@ -3,8 +3,10 @@ Tracker for currently executing nodes.
"""
import threading
from typing import final
@final
class ExecutionTracker:
"""
Tracks nodes that are currently being executed.

View File

@ -4,11 +4,13 @@ Manager for node states during graph execution.
import queue
import threading
from typing import final
from core.workflow.enums import NodeState
from core.workflow.graph import Graph
@final
class NodeStateManager:
"""
Manages node states and the ready queue for execution.

View File

@ -11,7 +11,7 @@ import threading
import time
from collections.abc import Callable
from datetime import datetime
from typing import Optional
from typing import final
from uuid import uuid4
from flask import Flask
@ -23,6 +23,7 @@ from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
@final
class Worker(threading.Thread):
"""
Worker thread that executes nodes from the ready queue.
@ -38,10 +39,10 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
worker_id: int = 0,
flask_app: Optional[Flask] = None,
context_vars: Optional[contextvars.Context] = None,
on_idle_callback: Optional[Callable[[int], None]] = None,
on_active_callback: Optional[Callable[[int], None]] = None,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
on_idle_callback: Callable[[int], None] | None = None,
on_active_callback: Callable[[int], None] | None = None,
) -> None:
"""
Initialize worker thread.

View File

@ -4,8 +4,10 @@ Activity tracker for monitoring worker activity.
import threading
import time
from typing import final
@final
class ActivityTracker:
"""
Tracks worker activity for scaling decisions.

View File

@ -2,9 +2,12 @@
Dynamic scaler for worker pool sizing.
"""
from typing import final
from core.workflow.graph import Graph
@final
class DynamicScaler:
"""
Manages dynamic scaling decisions for the worker pool.

View File

@ -5,7 +5,7 @@ Factory for creating worker instances.
import contextvars
import queue
from collections.abc import Callable
from typing import Optional
from typing import final
from flask import Flask
@ -14,6 +14,7 @@ from core.workflow.graph import Graph
from ..worker import Worker
@final
class WorkerFactory:
"""
Factory for creating worker instances with proper context.
@ -24,7 +25,7 @@ class WorkerFactory:
def __init__(
self,
flask_app: Optional[Flask],
flask_app: Flask | None,
context_vars: contextvars.Context,
) -> None:
"""
@ -43,8 +44,8 @@ class WorkerFactory:
ready_queue: queue.Queue[str],
event_queue: queue.Queue,
graph: Graph,
on_idle_callback: Optional[Callable[[int], None]] = None,
on_active_callback: Optional[Callable[[int], None]] = None,
on_idle_callback: Callable[[int], None] | None = None,
on_active_callback: Callable[[int], None] | None = None,
) -> Worker:
"""
Create a new worker instance.

View File

@ -4,6 +4,7 @@ Worker pool management.
import queue
import threading
from typing import final
from core.workflow.graph import Graph
@ -13,6 +14,7 @@ from .dynamic_scaler import DynamicScaler
from .worker_factory import WorkerFactory
@final
class WorkerPool:
"""
Manages a pool of worker threads for executing nodes.