mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 07:58:02 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
This commit is contained in:
@ -6,10 +6,12 @@ within a single process. Each instance handles commands for one workflow executi
|
||||
"""
|
||||
|
||||
from queue import Queue
|
||||
from typing import final
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryChannel:
|
||||
"""
|
||||
In-memory command channel implementation using a thread-safe queue.
|
||||
|
||||
@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
|
||||
@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
|
||||
@final
|
||||
class RedisChannel:
|
||||
"""
|
||||
Redis-based command channel implementation for distributed systems.
|
||||
@ -86,7 +87,7 @@ class RedisChannel:
|
||||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict) -> Optional[GraphEngineCommand]:
|
||||
def _deserialize_command(self, data: dict) -> GraphEngineCommand | None:
|
||||
"""
|
||||
Deserialize a command from dictionary data.
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ Command handler implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
@ -11,6 +12,7 @@ from .command_processor import CommandHandler
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ Main command processor for handling external commands.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Protocol
|
||||
from typing import Protocol, final
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
@ -18,6 +18,7 @@ class CommandHandler(Protocol):
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
|
||||
|
||||
|
||||
@final
|
||||
class CommandProcessor:
|
||||
"""
|
||||
Processes external commands sent to the engine.
|
||||
|
||||
@ -3,7 +3,6 @@ GraphExecution aggregate root managing the overall graph execution state.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
@ -21,7 +20,7 @@ class GraphExecution:
|
||||
started: bool = False
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
error: Optional[Exception] = None
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
|
||||
|
||||
def start(self) -> None:
|
||||
|
||||
@ -3,7 +3,6 @@ NodeExecution entity representing a node's execution state.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
@ -20,8 +19,8 @@ class NodeExecution:
|
||||
node_id: str
|
||||
state: NodeState = NodeState.UNKNOWN
|
||||
retry_count: int = 0
|
||||
execution_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
execution_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
def mark_started(self, execution_id: str) -> None:
|
||||
"""Mark the node as started with an execution ID."""
|
||||
|
||||
@ -6,7 +6,7 @@ instance to control its execution flow.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -23,11 +23,11 @@ class GraphEngineCommand(BaseModel):
|
||||
"""Base class for all GraphEngine commands."""
|
||||
|
||||
command_type: CommandType = Field(..., description="Type of command")
|
||||
payload: Optional[dict[str, Any]] = Field(default=None, description="Optional command payload")
|
||||
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")
|
||||
|
||||
|
||||
class AbortCommand(GraphEngineCommand):
|
||||
"""Command to abort a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
|
||||
reason: Optional[str] = Field(default=None, description="Optional reason for abort")
|
||||
reason: str | None = Field(default=None, description="Optional reason for abort")
|
||||
|
||||
@ -8,7 +8,6 @@ the Strategy pattern for clean separation of concerns.
|
||||
from .abort_strategy import AbortStrategy
|
||||
from .default_value_strategy import DefaultValueStrategy
|
||||
from .error_handler import ErrorHandler
|
||||
from .error_strategy import ErrorStrategy
|
||||
from .fail_branch_strategy import FailBranchStrategy
|
||||
from .retry_strategy import RetryStrategy
|
||||
|
||||
@ -16,7 +15,6 @@ __all__ = [
|
||||
"AbortStrategy",
|
||||
"DefaultValueStrategy",
|
||||
"ErrorHandler",
|
||||
"ErrorStrategy",
|
||||
"FailBranchStrategy",
|
||||
"RetryStrategy",
|
||||
]
|
||||
|
||||
@ -3,7 +3,7 @@ Abort error strategy implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
@ -11,6 +11,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class AbortStrategy:
|
||||
"""
|
||||
Error strategy that aborts execution on failure.
|
||||
@ -19,7 +20,7 @@ class AbortStrategy:
|
||||
It stops the entire graph execution when a node fails.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle error by aborting execution.
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
Default value error strategy implementation.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
@final
|
||||
class DefaultValueStrategy:
|
||||
"""
|
||||
Error strategy that uses default values on failure.
|
||||
@ -18,7 +19,7 @@ class DefaultValueStrategy:
|
||||
predefined default output values.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle error by using default values.
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
Main error handler that coordinates error strategies.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum
|
||||
from core.workflow.graph import Graph
|
||||
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||
from ..domain import GraphExecution
|
||||
|
||||
|
||||
@final
|
||||
class ErrorHandler:
|
||||
"""
|
||||
Coordinates error handling strategies for node failures.
|
||||
@ -34,16 +35,16 @@ class ErrorHandler:
|
||||
graph: The workflow graph
|
||||
graph_execution: The graph execution state
|
||||
"""
|
||||
self.graph = graph
|
||||
self.graph_execution = graph_execution
|
||||
self._graph = graph
|
||||
self._graph_execution = graph_execution
|
||||
|
||||
# Initialize strategies
|
||||
self.abort_strategy = AbortStrategy()
|
||||
self.retry_strategy = RetryStrategy()
|
||||
self.fail_branch_strategy = FailBranchStrategy()
|
||||
self.default_value_strategy = DefaultValueStrategy()
|
||||
self._abort_strategy = AbortStrategy()
|
||||
self._retry_strategy = RetryStrategy()
|
||||
self._fail_branch_strategy = FailBranchStrategy()
|
||||
self._default_value_strategy = DefaultValueStrategy()
|
||||
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]:
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
@ -56,14 +57,14 @@ class ErrorHandler:
|
||||
Returns:
|
||||
Optional new event to process, or None to abort
|
||||
"""
|
||||
node = self.graph.nodes[event.node_id]
|
||||
node = self._graph.nodes[event.node_id]
|
||||
# Get retry count from NodeExecution
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
retry_count = node_execution.retry_count
|
||||
|
||||
# First check if retry is configured and not exhausted
|
||||
if node.retry and retry_count < node.retry_config.max_retries:
|
||||
result = self.retry_strategy.handle_error(event, self.graph, retry_count)
|
||||
result = self._retry_strategy.handle_error(event, self._graph, retry_count)
|
||||
if result:
|
||||
# Retry count will be incremented when NodeRunRetryEvent is handled
|
||||
return result
|
||||
@ -71,12 +72,10 @@ class ErrorHandler:
|
||||
# Apply configured error strategy
|
||||
strategy = node.error_strategy
|
||||
|
||||
if strategy is None:
|
||||
return self.abort_strategy.handle_error(event, self.graph, retry_count)
|
||||
elif strategy == ErrorStrategyEnum.FAIL_BRANCH:
|
||||
return self.fail_branch_strategy.handle_error(event, self.graph, retry_count)
|
||||
elif strategy == ErrorStrategyEnum.DEFAULT_VALUE:
|
||||
return self.default_value_strategy.handle_error(event, self.graph, retry_count)
|
||||
else:
|
||||
# Unknown strategy, default to abort
|
||||
return self.abort_strategy.handle_error(event, self.graph, retry_count)
|
||||
match strategy:
|
||||
case None:
|
||||
return self._abort_strategy.handle_error(event, self._graph, retry_count)
|
||||
case ErrorStrategyEnum.FAIL_BRANCH:
|
||||
return self._fail_branch_strategy.handle_error(event, self._graph, retry_count)
|
||||
case ErrorStrategyEnum.DEFAULT_VALUE:
|
||||
return self._default_value_strategy.handle_error(event, self._graph, retry_count)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
Fail branch error strategy implementation.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
@final
|
||||
class FailBranchStrategy:
|
||||
"""
|
||||
Error strategy that continues execution via a fail branch.
|
||||
@ -18,7 +19,7 @@ class FailBranchStrategy:
|
||||
through a designated fail-branch edge.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle error by taking the fail branch.
|
||||
|
||||
|
||||
@ -3,12 +3,13 @@ Retry error strategy implementation.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunRetryEvent
|
||||
|
||||
|
||||
@final
|
||||
class RetryStrategy:
|
||||
"""
|
||||
Error strategy that retries failed nodes.
|
||||
@ -17,7 +18,7 @@ class RetryStrategy:
|
||||
maximum number of retries with configurable intervals.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle error by retrying the node.
|
||||
|
||||
|
||||
@ -3,12 +3,92 @@ Event collector for buffering and managing events.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from ..layers.base import Layer
|
||||
|
||||
|
||||
@final
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock implementation that allows multiple concurrent readers
|
||||
but only one writer at a time.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._read_ready = threading.Condition(threading.RLock())
|
||||
self._readers = 0
|
||||
|
||||
def acquire_read(self) -> None:
|
||||
"""Acquire a read lock."""
|
||||
self._read_ready.acquire()
|
||||
try:
|
||||
self._readers += 1
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
self._read_ready.acquire()
|
||||
try:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
self._read_ready.notify_all()
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def acquire_write(self) -> None:
|
||||
"""Acquire a write lock."""
|
||||
self._read_ready.acquire()
|
||||
while self._readers > 0:
|
||||
self._read_ready.wait()
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release a write lock."""
|
||||
self._read_ready.release()
|
||||
|
||||
def read_lock(self) -> "ReadLockContext":
|
||||
"""Return a context manager for read locking."""
|
||||
return ReadLockContext(self)
|
||||
|
||||
def write_lock(self) -> "WriteLockContext":
|
||||
"""Return a context manager for write locking."""
|
||||
return WriteLockContext(self)
|
||||
|
||||
|
||||
@final
|
||||
class ReadLockContext:
|
||||
"""Context manager for read locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "ReadLockContext":
|
||||
self._lock.acquire_read()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_read()
|
||||
|
||||
|
||||
@final
|
||||
class WriteLockContext:
|
||||
"""Context manager for write locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "WriteLockContext":
|
||||
self._lock.acquire_write()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_write()
|
||||
|
||||
|
||||
@final
|
||||
class EventCollector:
|
||||
"""
|
||||
Collects and buffers events for later retrieval.
|
||||
@ -20,7 +100,7 @@ class EventCollector:
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the event collector."""
|
||||
self._events: list[GraphEngineEvent] = []
|
||||
self._lock = threading.Lock()
|
||||
self._lock = ReadWriteLock()
|
||||
self._layers: list[Layer] = []
|
||||
|
||||
def set_layers(self, layers: list[Layer]) -> None:
|
||||
@ -39,7 +119,7 @@ class EventCollector:
|
||||
Args:
|
||||
event: The event to collect
|
||||
"""
|
||||
with self._lock:
|
||||
with self._lock.write_lock():
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
@ -50,7 +130,7 @@ class EventCollector:
|
||||
Returns:
|
||||
List of collected events
|
||||
"""
|
||||
with self._lock:
|
||||
with self._lock.read_lock():
|
||||
return list(self._events)
|
||||
|
||||
def get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
@ -63,7 +143,7 @@ class EventCollector:
|
||||
Returns:
|
||||
List of new events
|
||||
"""
|
||||
with self._lock:
|
||||
with self._lock.read_lock():
|
||||
return list(self._events[start_index:])
|
||||
|
||||
def event_count(self) -> int:
|
||||
@ -73,12 +153,12 @@ class EventCollector:
|
||||
Returns:
|
||||
Number of collected events
|
||||
"""
|
||||
with self._lock:
|
||||
with self._lock.read_lock():
|
||||
return len(self._events)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all collected events."""
|
||||
with self._lock:
|
||||
with self._lock.write_lock():
|
||||
self._events.clear()
|
||||
|
||||
def _notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
|
||||
@ -5,12 +5,14 @@ Event emitter for yielding events to external consumers.
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from .event_collector import EventCollector
|
||||
|
||||
|
||||
@final
|
||||
class EventEmitter:
|
||||
"""
|
||||
Emits collected events as a generator for external consumption.
|
||||
|
||||
@ -3,7 +3,7 @@ Event handler implementations for different event types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
@ -38,6 +38,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class EventHandlerRegistry:
|
||||
"""
|
||||
Registry of event handlers for different event types.
|
||||
@ -52,12 +53,12 @@ class EventHandlerRegistry:
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_execution: GraphExecution,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
event_collector: Optional["EventCollector"] = None,
|
||||
branch_handler: Optional["BranchHandler"] = None,
|
||||
edge_processor: Optional["EdgeProcessor"] = None,
|
||||
node_state_manager: Optional["NodeStateManager"] = None,
|
||||
execution_tracker: Optional["ExecutionTracker"] = None,
|
||||
error_handler: Optional["ErrorHandler"] = None,
|
||||
event_collector: "EventCollector",
|
||||
branch_handler: "BranchHandler",
|
||||
edge_processor: "EdgeProcessor",
|
||||
node_state_manager: "NodeStateManager",
|
||||
execution_tracker: "ExecutionTracker",
|
||||
error_handler: "ErrorHandler",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the event handler registry.
|
||||
@ -67,23 +68,23 @@ class EventHandlerRegistry:
|
||||
graph_runtime_state: Runtime state with variable pool
|
||||
graph_execution: Graph execution aggregate
|
||||
response_coordinator: Response stream coordinator
|
||||
event_collector: Optional event collector for collecting events
|
||||
branch_handler: Optional branch handler for branch node processing
|
||||
edge_processor: Optional edge processor for edge traversal
|
||||
node_state_manager: Optional node state manager
|
||||
execution_tracker: Optional execution tracker
|
||||
error_handler: Optional error handler
|
||||
event_collector: Event collector for collecting events
|
||||
branch_handler: Branch handler for branch node processing
|
||||
edge_processor: Edge processor for edge traversal
|
||||
node_state_manager: Node state manager
|
||||
execution_tracker: Execution tracker
|
||||
error_handler: Error handler
|
||||
"""
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.graph_execution = graph_execution
|
||||
self.response_coordinator = response_coordinator
|
||||
self.event_collector = event_collector
|
||||
self.branch_handler = branch_handler
|
||||
self.edge_processor = edge_processor
|
||||
self.node_state_manager = node_state_manager
|
||||
self.execution_tracker = execution_tracker
|
||||
self.error_handler = error_handler
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._event_collector = event_collector
|
||||
self._branch_handler = branch_handler
|
||||
self._edge_processor = edge_processor
|
||||
self._node_state_manager = node_state_manager
|
||||
self._execution_tracker = execution_tracker
|
||||
self._error_handler = error_handler
|
||||
|
||||
def handle_event(self, event: GraphNodeEventBase) -> None:
|
||||
"""
|
||||
@ -93,9 +94,8 @@ class EventHandlerRegistry:
|
||||
event: The event to handle
|
||||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id):
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
|
||||
# Handle specific event types
|
||||
@ -125,12 +125,10 @@ class EventHandlerRegistry:
|
||||
),
|
||||
):
|
||||
# Iteration and loop events are collected directly
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
else:
|
||||
# Collect unhandled events
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
logger.warning("Unhandled event type: %s", type(event).__name__)
|
||||
|
||||
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
@ -141,15 +139,14 @@ class EventHandlerRegistry:
|
||||
event: The node started event
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self.response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
@ -159,12 +156,11 @@ class EventHandlerRegistry:
|
||||
event: The stream chunk event
|
||||
"""
|
||||
# Process with response coordinator
|
||||
streaming_events = list(self.response_coordinator.intercept_event(event))
|
||||
streaming_events = list(self._response_coordinator.intercept_event(event))
|
||||
|
||||
# Collect all events
|
||||
if self.event_collector:
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
@ -177,55 +173,44 @@ class EventHandlerRegistry:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = list(self.response_coordinator.intercept_event(event))
|
||||
if self.event_collector:
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
# Process edges and get ready nodes
|
||||
node = self.graph.nodes[event.node_id]
|
||||
node = self._graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
if self.branch_handler:
|
||||
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = [], []
|
||||
ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
if self.edge_processor:
|
||||
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = [], []
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
|
||||
# Collect streaming events from edge processing
|
||||
if self.event_collector:
|
||||
for edge_event in edge_streaming_events:
|
||||
self.event_collector.collect(edge_event)
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
if self.node_state_manager and self.execution_tracker:
|
||||
for node_id in ready_nodes:
|
||||
self.node_state_manager.enqueue_node(node_id)
|
||||
self.execution_tracker.add(node_id)
|
||||
for node_id in ready_nodes:
|
||||
self._node_state_manager.enqueue_node(node_id)
|
||||
self._execution_tracker.add(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
self._execution_tracker.remove(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event)
|
||||
|
||||
# Collect the event
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
@ -235,29 +220,19 @@ class EventHandlerRegistry:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
|
||||
if self.error_handler:
|
||||
result = self.error_handler.handle_node_failure(event)
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.handle_event(result)
|
||||
else:
|
||||
# Abort execution
|
||||
self.graph_execution.fail(RuntimeError(event.error))
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.handle_event(result)
|
||||
else:
|
||||
# Without error handler, just fail
|
||||
self.graph_execution.fail(RuntimeError(event.error))
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
# Abort execution
|
||||
self._graph_execution.fail(RuntimeError(event.error))
|
||||
self._event_collector.collect(event)
|
||||
self._execution_tracker.remove(event.node_id)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
@ -267,7 +242,7 @@ class EventHandlerRegistry:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch, so it's technically "succeeded"
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
||||
@ -277,7 +252,7 @@ class EventHandlerRegistry:
|
||||
Args:
|
||||
event: The node retry event
|
||||
"""
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
@ -288,16 +263,16 @@ class EventHandlerRegistry:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
if key == "answer":
|
||||
existing = self.graph_runtime_state.outputs.get("answer", "")
|
||||
existing = self._graph_runtime_state.outputs.get("answer", "")
|
||||
if existing:
|
||||
self.graph_runtime_state.outputs["answer"] = f"{existing}{value}"
|
||||
self._graph_runtime_state.outputs["answer"] = f"{existing}{value}"
|
||||
else:
|
||||
self.graph_runtime_state.outputs["answer"] = value
|
||||
self._graph_runtime_state.outputs["answer"] = value
|
||||
else:
|
||||
self.graph_runtime_state.outputs[key] = value
|
||||
self._graph_runtime_state.outputs[key] = value
|
||||
|
||||
@ -9,7 +9,7 @@ import contextvars
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import final
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
@ -20,6 +20,7 @@ from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
@ -44,6 +45,7 @@ from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, Wo
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngine:
|
||||
"""
|
||||
Queue-based graph execution engine.
|
||||
@ -62,7 +64,7 @@ class GraphEngine:
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_config: Mapping[str, object],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
@ -103,7 +105,7 @@ class GraphEngine:
|
||||
|
||||
# Initialize queues
|
||||
self.ready_queue: queue.Queue[str] = queue.Queue()
|
||||
self.event_queue: queue.Queue = queue.Queue()
|
||||
self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
# Initialize subsystems
|
||||
self._initialize_subsystems()
|
||||
@ -185,7 +187,7 @@ class GraphEngine:
|
||||
event_handler=self.event_handler_registry,
|
||||
event_collector=self.event_collector,
|
||||
command_processor=self.command_processor,
|
||||
worker_pool=self.worker_pool,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
self.dispatcher = Dispatcher(
|
||||
@ -209,7 +211,7 @@ class GraphEngine:
|
||||
def _setup_worker_management(self) -> None:
|
||||
"""Initialize worker management subsystem."""
|
||||
# Capture context for workers
|
||||
flask_app: Optional[Flask] = None
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
@ -218,8 +220,8 @@ class GraphEngine:
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Create worker management components
|
||||
self.activity_tracker = ActivityTracker()
|
||||
self.dynamic_scaler = DynamicScaler(
|
||||
self._activity_tracker = ActivityTracker()
|
||||
self._dynamic_scaler = DynamicScaler(
|
||||
min_workers=(self._min_workers if self._min_workers is not None else dify_config.GRAPH_ENGINE_MIN_WORKERS),
|
||||
max_workers=(self._max_workers if self._max_workers is not None else dify_config.GRAPH_ENGINE_MAX_WORKERS),
|
||||
scale_up_threshold=(
|
||||
@ -233,15 +235,15 @@ class GraphEngine:
|
||||
else dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
||||
),
|
||||
)
|
||||
self.worker_factory = WorkerFactory(flask_app, context_vars)
|
||||
self._worker_factory = WorkerFactory(flask_app, context_vars)
|
||||
|
||||
self.worker_pool = WorkerPool(
|
||||
self._worker_pool = WorkerPool(
|
||||
ready_queue=self.ready_queue,
|
||||
event_queue=self.event_queue,
|
||||
graph=self.graph,
|
||||
worker_factory=self.worker_factory,
|
||||
dynamic_scaler=self.dynamic_scaler,
|
||||
activity_tracker=self.activity_tracker,
|
||||
worker_factory=self._worker_factory,
|
||||
dynamic_scaler=self._dynamic_scaler,
|
||||
activity_tracker=self._activity_tracker,
|
||||
)
|
||||
|
||||
def _validate_graph_state_consistency(self) -> None:
|
||||
@ -319,10 +321,10 @@ class GraphEngine:
|
||||
def _start_execution(self) -> None:
|
||||
"""Start execution subsystems."""
|
||||
# Calculate initial worker count
|
||||
initial_workers = self.dynamic_scaler.calculate_initial_workers(self.graph)
|
||||
initial_workers = self._dynamic_scaler.calculate_initial_workers(self.graph)
|
||||
|
||||
# Start worker pool
|
||||
self.worker_pool.start(initial_workers)
|
||||
self._worker_pool.start(initial_workers)
|
||||
|
||||
# Register response nodes
|
||||
for node in self.graph.nodes.values():
|
||||
@ -340,7 +342,7 @@ class GraphEngine:
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self.dispatcher.stop()
|
||||
self.worker_pool.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
# Notify layers
|
||||
|
||||
@ -2,15 +2,18 @@
|
||||
Branch node handling for graph traversal.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
|
||||
|
||||
from ..state_management import EdgeStateManager
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
class BranchHandler:
|
||||
"""
|
||||
Handles branch node logic during graph traversal.
|
||||
@ -40,7 +43,9 @@ class BranchHandler:
|
||||
self.skip_propagator = skip_propagator
|
||||
self.edge_state_manager = edge_state_manager
|
||||
|
||||
def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]:
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
@ -58,10 +63,10 @@ class BranchHandler:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
_, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self.skip_propagator.skip_branch_paths(node_id, unselected_edges)
|
||||
self.skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.edge_processor.process_node_success(node_id, selected_handle)
|
||||
|
||||
@ -2,13 +2,18 @@
|
||||
Edge processing logic for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Edge, Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
|
||||
|
||||
@final
|
||||
class EdgeProcessor:
|
||||
"""
|
||||
Processes edges during graph execution.
|
||||
@ -38,7 +43,9 @@ class EdgeProcessor:
|
||||
self.node_state_manager = node_state_manager
|
||||
self.response_coordinator = response_coordinator
|
||||
|
||||
def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]:
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges after a node succeeds.
|
||||
|
||||
@ -56,7 +63,7 @@ class EdgeProcessor:
|
||||
else:
|
||||
return self._process_non_branch_node_edges(node_id)
|
||||
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]:
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for non-branch nodes (mark all as TAKEN).
|
||||
|
||||
@ -66,8 +73,8 @@ class EdgeProcessor:
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
"""
|
||||
ready_nodes = []
|
||||
all_streaming_events = []
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
@ -77,7 +84,9 @@ class EdgeProcessor:
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]:
|
||||
def _process_branch_node_edges(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for branch nodes.
|
||||
|
||||
@ -94,8 +103,8 @@ class EdgeProcessor:
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} did not select any edge")
|
||||
|
||||
ready_nodes = []
|
||||
all_streaming_events = []
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
@ -112,7 +121,7 @@ class EdgeProcessor:
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]:
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Mark edge as taken and check downstream node.
|
||||
|
||||
@ -129,11 +138,11 @@ class EdgeProcessor:
|
||||
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes = []
|
||||
ready_nodes: list[str] = []
|
||||
if self.node_state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, list(streaming_events)
|
||||
return ready_nodes, streaming_events
|
||||
|
||||
def _process_skipped_edge(self, edge: Edge) -> None:
|
||||
"""
|
||||
|
||||
@ -2,10 +2,13 @@
|
||||
Node readiness checking for execution.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class NodeReadinessChecker:
|
||||
"""
|
||||
Checks if nodes are ready for execution based on their dependencies.
|
||||
@ -71,7 +74,7 @@ class NodeReadinessChecker:
|
||||
Returns:
|
||||
List of node IDs that are now ready
|
||||
"""
|
||||
ready_nodes = []
|
||||
ready_nodes: list[str] = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
|
||||
@ -2,11 +2,15 @@
|
||||
Skip state propagation through the graph.
|
||||
"""
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
|
||||
|
||||
@final
|
||||
class SkipPropagator:
|
||||
"""
|
||||
Propagates skip states through the graph.
|
||||
@ -57,9 +61,8 @@ class SkipPropagator:
|
||||
|
||||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Check if node is ready and enqueue if so
|
||||
if self.node_state_manager.is_node_ready(downstream_node_id):
|
||||
self.node_state_manager.enqueue_node(downstream_node_id)
|
||||
# Enqueue node
|
||||
self.node_state_manager.enqueue_node(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
@ -83,12 +86,11 @@ class SkipPropagator:
|
||||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
def skip_branch_paths(self, node_id: str, unselected_edges: list) -> None:
|
||||
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
|
||||
"""
|
||||
Skip all paths from unselected branch edges.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
|
||||
@ -6,7 +6,6 @@ intercept and respond to GraphEngine events.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
@ -28,8 +27,8 @@ class Layer(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self.graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
self.command_channel: Optional[CommandChannel] = None
|
||||
self.graph_runtime_state: GraphRuntimeState | None = None
|
||||
self.command_channel: CommandChannel | None = None
|
||||
|
||||
def initialize(self, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
@ -73,7 +72,7 @@ class Layer(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_end(self, error: Optional[Exception]) -> None:
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ graph execution for debugging purposes.
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any, final
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
@ -34,6 +34,7 @@ from core.workflow.graph_events import (
|
||||
from .base import Layer
|
||||
|
||||
|
||||
@final
|
||||
class DebugLoggingLayer(Layer):
|
||||
"""
|
||||
A layer that provides comprehensive logging of GraphEngine execution.
|
||||
@ -221,7 +222,7 @@ class DebugLoggingLayer(Layer):
|
||||
# Log unknown events at debug level
|
||||
self.logger.debug("Event: %s", event_class)
|
||||
|
||||
def on_graph_end(self, error: Optional[Exception]) -> None:
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Log graph execution end with summary statistics."""
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ When limits are exceeded, the layer automatically aborts execution.
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers import Layer
|
||||
@ -29,6 +29,7 @@ class LimitType(Enum):
|
||||
TIME_LIMIT = "time_limit"
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionLimitsLayer(Layer):
|
||||
"""
|
||||
Layer that enforces execution limits for workflows.
|
||||
@ -53,7 +54,7 @@ class ExecutionLimitsLayer(Layer):
|
||||
self.max_time = max_time
|
||||
|
||||
# Runtime tracking
|
||||
self.start_time: Optional[float] = None
|
||||
self.start_time: float | None = None
|
||||
self.step_count = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
@ -94,7 +95,7 @@ class ExecutionLimitsLayer(Layer):
|
||||
if self._reached_time_limitation():
|
||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||
|
||||
def on_graph_end(self, error: Optional[Exception]) -> None:
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._execution_started and not self._execution_ended:
|
||||
self._execution_ended = True
|
||||
|
||||
@ -6,13 +6,14 @@ using the new Redis command channel, without requiring user permission checks.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngineManager:
|
||||
"""
|
||||
Manager for sending control commands to GraphEngine instances.
|
||||
@ -23,7 +24,7 @@ class GraphEngineManager:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def send_stop_command(task_id: str, reason: Optional[str] = None) -> None:
|
||||
def send_stop_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
|
||||
@ -6,7 +6,9 @@ import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.graph_events.base import GraphNodeEventBase
|
||||
|
||||
from ..event_management import EventCollector, EventEmitter
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
@ -17,6 +19,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class Dispatcher:
|
||||
"""
|
||||
Main dispatcher that processes events from the event queue.
|
||||
@ -27,12 +30,12 @@ class Dispatcher:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandlerRegistry",
|
||||
event_collector: EventCollector,
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
max_execution_time: int,
|
||||
event_emitter: Optional[EventEmitter] = None,
|
||||
event_emitter: EventEmitter | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dispatcher.
|
||||
@ -52,9 +55,9 @@ class Dispatcher:
|
||||
self.max_execution_time = max_execution_time
|
||||
self.event_emitter = event_emitter
|
||||
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._start_time: Optional[float] = None
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the dispatcher thread."""
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
Execution coordinator for managing overall workflow execution.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
||||
from ..event_management import EventHandlerRegistry
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionCoordinator:
|
||||
"""
|
||||
Coordinates overall execution flow between subsystems.
|
||||
|
||||
@ -7,7 +7,7 @@ thread-safe storage for node outputs.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Union, final
|
||||
|
||||
from core.variables import Segment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -18,6 +18,7 @@ if TYPE_CHECKING:
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
@final
|
||||
class OutputRegistry:
|
||||
"""
|
||||
Thread-safe registry for storing and retrieving node outputs.
|
||||
@ -47,7 +48,7 @@ class OutputRegistry:
|
||||
with self._lock:
|
||||
self._scalars.add(selector, value)
|
||||
|
||||
def get_scalar(self, selector: Sequence[str]) -> Optional["Segment"]:
|
||||
def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
|
||||
"""
|
||||
Get a scalar value for the given selector.
|
||||
|
||||
@ -81,7 +82,7 @@ class OutputRegistry:
|
||||
except ValueError:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
def pop_chunk(self, selector: Sequence[str]) -> Optional["NodeRunStreamChunkEvent"]:
|
||||
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
|
||||
@ -5,12 +5,13 @@ This module contains the private Stream class used internally by OutputRegistry
|
||||
to manage streaming data chunks.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
@final
|
||||
class Stream:
|
||||
"""
|
||||
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.
|
||||
@ -41,7 +42,7 @@ class Stream:
|
||||
raise ValueError("Cannot append to a closed stream")
|
||||
self.events.append(event)
|
||||
|
||||
def pop_next(self) -> Optional["NodeRunStreamChunkEvent"]:
|
||||
def pop_next(self) -> "NodeRunStreamChunkEvent | None":
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
Base error strategy protocol.
|
||||
"""
|
||||
|
||||
from typing import Optional, Protocol
|
||||
from typing import Protocol
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
@ -16,7 +16,7 @@ class ErrorStrategy(Protocol):
|
||||
node execution failures.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
@ -9,7 +9,7 @@ import logging
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import Optional, TypeAlias
|
||||
from typing import TypeAlias, final
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
@ -28,6 +28,7 @@ NodeID: TypeAlias = str
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
@final
|
||||
class ResponseStreamCoordinator:
|
||||
"""
|
||||
Manages response streaming sessions without relying on global state.
|
||||
@ -45,7 +46,7 @@ class ResponseStreamCoordinator:
|
||||
"""
|
||||
self.registry = registry
|
||||
self.graph = graph
|
||||
self.active_session: Optional[ResponseSession] = None
|
||||
self.active_session: ResponseSession | None = None
|
||||
self.waiting_sessions: deque[ResponseSession] = deque()
|
||||
self.lock = RLock()
|
||||
|
||||
|
||||
@ -3,7 +3,8 @@ Manager for edge states during graph execution.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import TypedDict
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Edge, Graph
|
||||
@ -17,6 +18,7 @@ class EdgeStateAnalysis(TypedDict):
|
||||
all_skipped: bool
|
||||
|
||||
|
||||
@final
|
||||
class EdgeStateManager:
|
||||
"""
|
||||
Manages edge states and transitions during graph execution.
|
||||
@ -87,7 +89,7 @@ class EdgeStateManager:
|
||||
with self._lock:
|
||||
return self.graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]:
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
@ -100,8 +102,8 @@ class EdgeStateManager:
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
selected_edges = []
|
||||
unselected_edges = []
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
|
||||
@ -3,8 +3,10 @@ Tracker for currently executing nodes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionTracker:
|
||||
"""
|
||||
Tracks nodes that are currently being executed.
|
||||
|
||||
@ -4,11 +4,13 @@ Manager for node states during graph execution.
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class NodeStateManager:
|
||||
"""
|
||||
Manages node states and the ready queue for execution.
|
||||
|
||||
@ -11,7 +11,7 @@ import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
@ -23,6 +23,7 @@ from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
|
||||
@final
|
||||
class Worker(threading.Thread):
|
||||
"""
|
||||
Worker thread that executes nodes from the ready queue.
|
||||
@ -38,10 +39,10 @@ class Worker(threading.Thread):
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
worker_id: int = 0,
|
||||
flask_app: Optional[Flask] = None,
|
||||
context_vars: Optional[contextvars.Context] = None,
|
||||
on_idle_callback: Optional[Callable[[int], None]] = None,
|
||||
on_active_callback: Optional[Callable[[int], None]] = None,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
on_idle_callback: Callable[[int], None] | None = None,
|
||||
on_active_callback: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
|
||||
@ -4,8 +4,10 @@ Activity tracker for monitoring worker activity.
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import final
|
||||
|
||||
|
||||
@final
|
||||
class ActivityTracker:
|
||||
"""
|
||||
Tracks worker activity for scaling decisions.
|
||||
|
||||
@ -2,9 +2,12 @@
|
||||
Dynamic scaler for worker pool sizing.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class DynamicScaler:
|
||||
"""
|
||||
Manages dynamic scaling decisions for the worker pool.
|
||||
|
||||
@ -5,7 +5,7 @@ Factory for creating worker instances.
|
||||
import contextvars
|
||||
import queue
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from flask import Flask
|
||||
|
||||
@ -14,6 +14,7 @@ from core.workflow.graph import Graph
|
||||
from ..worker import Worker
|
||||
|
||||
|
||||
@final
|
||||
class WorkerFactory:
|
||||
"""
|
||||
Factory for creating worker instances with proper context.
|
||||
@ -24,7 +25,7 @@ class WorkerFactory:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flask_app: Optional[Flask],
|
||||
flask_app: Flask | None,
|
||||
context_vars: contextvars.Context,
|
||||
) -> None:
|
||||
"""
|
||||
@ -43,8 +44,8 @@ class WorkerFactory:
|
||||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue,
|
||||
graph: Graph,
|
||||
on_idle_callback: Optional[Callable[[int], None]] = None,
|
||||
on_active_callback: Optional[Callable[[int], None]] = None,
|
||||
on_idle_callback: Callable[[int], None] | None = None,
|
||||
on_active_callback: Callable[[int], None] | None = None,
|
||||
) -> Worker:
|
||||
"""
|
||||
Create a new worker instance.
|
||||
|
||||
@ -4,6 +4,7 @@ Worker pool management.
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
@ -13,6 +14,7 @@ from .dynamic_scaler import DynamicScaler
|
||||
from .worker_factory import WorkerFactory
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
"""
|
||||
Manages a pool of worker threads for executing nodes.
|
||||
|
||||
Reference in New Issue
Block a user