Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

# Conflicts:
#	api/core/memory/token_buffer_memory.py
#	api/core/rag/extractor/notion_extractor.py
#	api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
#	api/core/variables/variables.py
#	api/core/workflow/graph/graph.py
#	api/core/workflow/graph_engine/entities/event.py
#	api/services/dataset_service.py
#	web/app/components/app-sidebar/index.tsx
#	web/app/components/base/tag-management/selector.tsx
#	web/app/components/base/toast/index.tsx
#	web/app/components/datasets/create/website/index.tsx
#	web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx
#	web/app/components/workflow/header/version-history-button.tsx
#	web/app/components/workflow/hooks/use-inspect-vars-crud-common.ts
#	web/app/components/workflow/hooks/use-workflow-interactions.ts
#	web/app/components/workflow/panel/version-history-panel/index.tsx
#	web/service/base.ts
This commit is contained in:
jyong
2025-09-03 15:01:06 +08:00
572 changed files with 16030 additions and 7973 deletions

View File

@ -1,187 +0,0 @@
# Graph Engine
Queue-based workflow execution engine for parallel graph processing.
## Architecture
The engine uses a modular architecture with specialized packages:
### Core Components
- **Domain** (`domain/`) - Core models: ExecutionContext, GraphExecution, NodeExecution
- **Event Management** (`event_management/`) - Event handling, collection, and emission
- **State Management** (`state_management/`) - Thread-safe state tracking for nodes and edges
- **Error Handling** (`error_handling/`) - Strategy-based error recovery (retry, abort, fail-branch, default-value)
- **Graph Traversal** (`graph_traversal/`) - Node readiness, edge processing, branch handling
- **Command Processing** (`command_processing/`) - External command handling (abort, pause, resume)
- **Worker Management** (`worker_management/`) - Dynamic worker pool with auto-scaling
- **Orchestration** (`orchestration/`) - Main event loop and execution coordination
### Supporting Components
- **Output Registry** (`output_registry/`) - Thread-safe storage for node outputs
- **Response Coordinator** (`response_coordinator/`) - Ordered streaming of response nodes
- **Command Channels** (`command_channels/`) - Command transport (InMemory/Redis)
- **Layers** (`layers/`) - Pluggable middleware for extensions
## Architecture Diagram
```mermaid
classDiagram
class GraphEngine {
+run()
+add_layer()
}
class Domain {
ExecutionContext
GraphExecution
NodeExecution
}
class EventManagement {
EventHandlerRegistry
EventCollector
EventEmitter
}
class StateManagement {
NodeStateManager
EdgeStateManager
ExecutionTracker
}
class WorkerManagement {
WorkerPool
WorkerFactory
DynamicScaler
ActivityTracker
}
class GraphTraversal {
NodeReadinessChecker
EdgeProcessor
BranchHandler
SkipPropagator
}
class Orchestration {
Dispatcher
ExecutionCoordinator
}
class ErrorHandling {
ErrorHandler
RetryStrategy
AbortStrategy
FailBranchStrategy
}
class CommandProcessing {
CommandProcessor
AbortCommandHandler
}
class CommandChannels {
InMemoryChannel
RedisChannel
}
class OutputRegistry {
<<Storage>>
Scalar Values
Streaming Data
}
class ResponseCoordinator {
Session Management
Path Analysis
}
class Layers {
<<Plugin>>
DebugLoggingLayer
}
GraphEngine --> Orchestration : coordinates
GraphEngine --> Layers : extends
Orchestration --> EventManagement : processes events
Orchestration --> WorkerManagement : manages scaling
Orchestration --> CommandProcessing : checks commands
Orchestration --> StateManagement : monitors state
WorkerManagement --> StateManagement : consumes ready queue
WorkerManagement --> EventManagement : produces events
WorkerManagement --> Domain : executes nodes
EventManagement --> ErrorHandling : failed events
EventManagement --> GraphTraversal : success events
EventManagement --> ResponseCoordinator : stream events
EventManagement --> Layers : notifies
GraphTraversal --> StateManagement : updates states
GraphTraversal --> Domain : checks graph
CommandProcessing --> CommandChannels : fetches commands
CommandProcessing --> Domain : modifies execution
ErrorHandling --> Domain : handles failures
StateManagement --> Domain : tracks entities
ResponseCoordinator --> OutputRegistry : reads outputs
Domain --> OutputRegistry : writes outputs
```
## Package Relationships
### Core Dependencies
- **Orchestration** acts as the central coordinator, managing all subsystems
- **Domain** provides the core business entities used by all packages
- **EventManagement** serves as the communication backbone between components
- **StateManagement** maintains thread-safe state for the entire system
### Data Flow
1. **Commands** flow from CommandChannels → CommandProcessing → Domain
1. **Events** flow from Workers → EventHandlerRegistry → State updates
1. **Node outputs** flow from Workers → OutputRegistry → ResponseCoordinator
1. **Ready nodes** flow from GraphTraversal → StateManagement → WorkerManagement
### Extension Points
- **Layers** observe all events for monitoring, logging, and custom logic
- **ErrorHandling** strategies can be extended for custom failure recovery
- **CommandChannels** can be implemented for different transport mechanisms
## Execution Flow
1. **Initialization**: GraphEngine creates all subsystems with the workflow graph
1. **Node Discovery**: Traversal components identify ready nodes
1. **Worker Execution**: Workers pull from ready queue and execute nodes
1. **Event Processing**: Dispatcher routes events to appropriate handlers
1. **State Updates**: Managers track node/edge states for next steps
1. **Completion**: Coordinator detects when all nodes are done
## Usage
```python
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
# Create and run engine
engine = GraphEngine(
tenant_id="tenant_1",
app_id="app_1",
workflow_id="workflow_1",
graph=graph,
command_channel=InMemoryChannel(),
)
# Stream execution events
for event in engine.run():
handle_event(event)
```

View File

@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue.
"""
import json
from typing import TYPE_CHECKING, final
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
@ -87,7 +87,7 @@ class RedisChannel:
pipe.expire(self._key, self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict) -> GraphEngineCommand | None:
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
"""
Deserialize a command from dictionary data.

View File

@ -5,6 +5,8 @@ Command handler implementations.
import logging
from typing import final
from typing_extensions import override
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
from .command_processor import CommandHandler
@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
"""
Handle an abort command.

View File

@ -39,8 +39,8 @@ class CommandProcessor:
command_channel: Channel for receiving commands
graph_execution: Graph execution aggregate
"""
self.command_channel = command_channel
self.graph_execution = graph_execution
self._command_channel = command_channel
self._graph_execution = graph_execution
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
@ -56,7 +56,7 @@ class CommandProcessor:
def process_commands(self) -> None:
"""Check for and process any pending commands."""
try:
commands = self.command_channel.fetch_commands()
commands = self._command_channel.fetch_commands()
for command in commands:
self._handle_command(command)
except Exception as e:
@ -72,8 +72,8 @@ class CommandProcessor:
handler = self._handlers.get(type(command))
if handler:
try:
handler.handle(command, self.graph_execution)
except Exception as e:
handler.handle(command, self._graph_execution)
except Exception:
logger.exception("Error handling command %s", command.__class__.__name__)
else:
logger.warning("No handler registered for command: %s", command.__class__.__name__)

View File

@ -32,6 +32,8 @@ class AbortStrategy:
Returns:
None - signals abortion
"""
_ = graph
_ = retry_count
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
# Return None to signal that execution should stop

View File

@ -31,6 +31,7 @@ class DefaultValueStrategy:
Returns:
NodeRunExceptionEvent with default values
"""
_ = retry_count
node = graph.nodes[event.node_id]
outputs = {

View File

@ -31,6 +31,8 @@ class FailBranchStrategy:
Returns:
NodeRunExceptionEvent to continue via fail branch
"""
_ = graph
_ = retry_count
outputs = {
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,

View File

@ -5,12 +5,10 @@ This package handles event routing, collection, and emission for
workflow graph execution events.
"""
from .event_collector import EventCollector
from .event_emitter import EventEmitter
from .event_handlers import EventHandlerRegistry
from .event_handlers import EventHandler
from .event_manager import EventManager
__all__ = [
"EventCollector",
"EventEmitter",
"EventHandlerRegistry",
"EventHandler",
"EventManager",
]

View File

@ -1,58 +0,0 @@
"""
Event emitter for yielding events to external consumers.
"""
import threading
import time
from collections.abc import Generator
from typing import final
from core.workflow.graph_events import GraphEngineEvent
from .event_collector import EventCollector
@final
class EventEmitter:
"""
Emits collected events as a generator for external consumption.
This provides a generator interface for yielding events as they're
collected, with proper synchronization for multi-threaded access.
"""
def __init__(self, event_collector: EventCollector) -> None:
"""
Initialize the event emitter.
Args:
event_collector: The collector to emit events from
"""
self.event_collector = event_collector
self._execution_complete = threading.Event()
def mark_complete(self) -> None:
"""Mark execution as complete to stop the generator."""
self._execution_complete.set()
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
"""
Generator that yields events as they're collected.
Yields:
GraphEngineEvent instances as they're processed
"""
yielded_count = 0
while not self._execution_complete.is_set() or yielded_count < self.event_collector.event_count():
# Get new events since last yield
new_events = self.event_collector.get_new_events(yielded_count)
# Yield any new events
for event in new_events:
yield event
yielded_count += 1
# Small sleep to avoid busy waiting
if not self._execution_complete.is_set() and not new_events:
time.sleep(0.001)

View File

@ -10,6 +10,7 @@ from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
@ -31,15 +32,15 @@ from ..response_coordinator import ResponseStreamCoordinator
if TYPE_CHECKING:
from ..error_handling import ErrorHandler
from ..graph_traversal import BranchHandler, EdgeProcessor
from ..state_management import ExecutionTracker, NodeStateManager
from .event_collector import EventCollector
from ..graph_traversal import EdgeProcessor
from ..state_management import UnifiedStateManager
from .event_manager import EventManager
logger = logging.getLogger(__name__)
@final
class EventHandlerRegistry:
class EventHandler:
"""
Registry of event handlers for different event types.
@ -53,11 +54,9 @@ class EventHandlerRegistry:
graph_runtime_state: GraphRuntimeState,
graph_execution: GraphExecution,
response_coordinator: ResponseStreamCoordinator,
event_collector: "EventCollector",
branch_handler: "BranchHandler",
event_collector: "EventManager",
edge_processor: "EdgeProcessor",
node_state_manager: "NodeStateManager",
execution_tracker: "ExecutionTracker",
state_manager: "UnifiedStateManager",
error_handler: "ErrorHandler",
) -> None:
"""
@ -68,11 +67,9 @@ class EventHandlerRegistry:
graph_runtime_state: Runtime state with variable pool
graph_execution: Graph execution aggregate
response_coordinator: Response stream coordinator
event_collector: Event collector for collecting events
branch_handler: Branch handler for branch node processing
event_collector: Event manager for collecting events
edge_processor: Edge processor for edge traversal
node_state_manager: Node state manager
execution_tracker: Execution tracker
state_manager: Unified state manager
error_handler: Error handler
"""
self._graph = graph
@ -80,10 +77,8 @@ class EventHandlerRegistry:
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._event_collector = event_collector
self._branch_handler = branch_handler
self._edge_processor = edge_processor
self._node_state_manager = node_state_manager
self._execution_tracker = execution_tracker
self._state_manager = state_manager
self._error_handler = error_handler
def handle_event(self, event: GraphNodeEventBase) -> None:
@ -122,6 +117,7 @@ class EventHandlerRegistry:
NodeRunLoopNextEvent,
NodeRunLoopSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunAgentLogEvent,
),
):
# Iteration and loop events are collected directly
@ -187,7 +183,7 @@ class EventHandlerRegistry:
# Process edges and get ready nodes
node = self._graph.nodes[event.node_id]
if node.execution_type == NodeExecutionType.BRANCH:
ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion(
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
@ -199,11 +195,11 @@ class EventHandlerRegistry:
# Enqueue ready nodes
for node_id in ready_nodes:
self._node_state_manager.enqueue_node(node_id)
self._execution_tracker.add(node_id)
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update execution tracking
self._execution_tracker.remove(event.node_id)
self._state_manager.finish_execution(event.node_id)
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
@ -232,7 +228,7 @@ class EventHandlerRegistry:
# Abort execution
self._graph_execution.fail(RuntimeError(event.error))
self._event_collector.collect(event)
self._execution_tracker.remove(event.node_id)
self._state_manager.finish_execution(event.node_id)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
"""
@ -267,6 +263,8 @@ class EventHandlerRegistry:
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
"""Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in event.node_run_result.outputs.items():
if key == "answer":
existing = self._graph_runtime_state.outputs.get("answer", "")

View File

@ -1,8 +1,10 @@
"""
Event collector for buffering and managing events.
Unified event manager for collecting and emitting events.
"""
import threading
import time
from collections.abc import Generator
from typing import final
from core.workflow.graph_events import GraphEngineEvent
@ -23,7 +25,7 @@ class ReadWriteLock:
def acquire_read(self) -> None:
"""Acquire a read lock."""
self._read_ready.acquire()
_ = self._read_ready.acquire()
try:
self._readers += 1
finally:
@ -31,7 +33,7 @@ class ReadWriteLock:
def release_read(self) -> None:
"""Release a read lock."""
self._read_ready.acquire()
_ = self._read_ready.acquire()
try:
self._readers -= 1
if self._readers == 0:
@ -41,9 +43,9 @@ class ReadWriteLock:
def acquire_write(self) -> None:
"""Acquire a write lock."""
self._read_ready.acquire()
_ = self._read_ready.acquire()
while self._readers > 0:
self._read_ready.wait()
_ = self._read_ready.wait()
def release_write(self) -> None:
"""Release a write lock."""
@ -89,19 +91,21 @@ class WriteLockContext:
@final
class EventCollector:
class EventManager:
"""
Collects and buffers events for later retrieval.
Unified event manager that collects, buffers, and emits events.
This provides thread-safe event collection with support for
notifying layers about events as they're collected.
This class combines event collection with event emission, providing
thread-safe event management with support for notifying layers and
streaming events to external consumers.
"""
def __init__(self) -> None:
"""Initialize the event collector."""
"""Initialize the event manager."""
self._events: list[GraphEngineEvent] = []
self._lock = ReadWriteLock()
self._layers: list[Layer] = []
self._execution_complete = threading.Event()
def set_layers(self, layers: list[Layer]) -> None:
"""
@ -123,17 +127,7 @@ class EventCollector:
self._events.append(event)
self._notify_layers(event)
def get_events(self) -> list[GraphEngineEvent]:
"""
Get all collected events.
Returns:
List of collected events
"""
with self._lock.read_lock():
return list(self._events)
def get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
"""
Get new events starting from a specific index.
@ -146,7 +140,7 @@ class EventCollector:
with self._lock.read_lock():
return list(self._events[start_index:])
def event_count(self) -> int:
def _event_count(self) -> int:
"""
Get the current count of collected events.
@ -156,10 +150,31 @@ class EventCollector:
with self._lock.read_lock():
return len(self._events)
def clear(self) -> None:
"""Clear all collected events."""
with self._lock.write_lock():
self._events.clear()
def mark_complete(self) -> None:
"""Mark execution as complete to stop the event emission generator."""
self._execution_complete.set()
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
"""
Generator that yields events as they're collected.
Yields:
GraphEngineEvent instances as they're processed
"""
yielded_count = 0
while not self._execution_complete.is_set() or yielded_count < self._event_count():
# Get new events since last yield
new_events = self._get_new_events(yielded_count)
# Yield any new events
for event in new_events:
yield event
yielded_count += 1
# Small sleep to avoid busy waiting
if not self._execution_complete.is_set() and not new_events:
time.sleep(0.001)
def _notify_layers(self, event: GraphEngineEvent) -> None:
"""

View File

@ -13,7 +13,6 @@ from typing import final
from flask import Flask, current_app
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
@ -32,15 +31,14 @@ from .command_processing import AbortCommandHandler, CommandProcessor
from .domain import ExecutionContext, GraphExecution
from .entities.commands import AbortCommand
from .error_handling import ErrorHandler
from .event_management import EventCollector, EventEmitter, EventHandlerRegistry
from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator
from .event_management import EventHandler, EventManager
from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import Layer
from .orchestration import Dispatcher, ExecutionCoordinator
from .output_registry import OutputRegistry
from .protocols.command_channel import CommandChannel
from .response_coordinator import ResponseStreamCoordinator
from .state_management import EdgeStateManager, ExecutionTracker, NodeStateManager
from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, WorkerPool
from .state_management import UnifiedStateManager
from .worker_management import SimpleWorkerPool
logger = logging.getLogger(__name__)
@ -74,10 +72,11 @@ class GraphEngine:
scale_up_threshold: int | None = None,
scale_down_idle_time: float | None = None,
) -> None:
"""Initialize the graph engine with separated concerns."""
"""Initialize the graph engine with all subsystems and dependencies."""
# Create domain models
self.execution_context = ExecutionContext(
# === Domain Models ===
# Execution context encapsulates workflow execution metadata
self._execution_context = ExecutionContext(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
@ -89,167 +88,149 @@ class GraphEngine:
max_execution_time=max_execution_time,
)
self.graph_execution = GraphExecution(workflow_id=workflow_id)
# Graph execution tracks the overall execution state
self._graph_execution = GraphExecution(workflow_id=workflow_id)
# Store core dependencies
self.graph = graph
self.graph_config = graph_config
self.graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
# === Core Dependencies ===
# Graph structure and configuration
self._graph = graph
self._graph_config = graph_config
self._graph_runtime_state = graph_runtime_state
self._command_channel = command_channel
# Store worker management parameters
# === Worker Management Parameters ===
# Parameters for dynamic worker pool scaling
self._min_workers = min_workers
self._max_workers = max_workers
self._scale_up_threshold = scale_up_threshold
self._scale_down_idle_time = scale_down_idle_time
# Initialize queues
self.ready_queue: queue.Queue[str] = queue.Queue()
self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
# === Execution Queues ===
# Queue for nodes ready to execute
self._ready_queue: queue.Queue[str] = queue.Queue()
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
# Initialize subsystems
self._initialize_subsystems()
# === State Management ===
# Unified state manager handles all node state transitions and queue operations
self._state_manager = UnifiedStateManager(self._graph, self._ready_queue)
# Layers for extensibility
self._layers: list[Layer] = []
# Validate graph state consistency
self._validate_graph_state_consistency()
def _initialize_subsystems(self) -> None:
"""Initialize all subsystems with proper dependency injection."""
# State management
self.node_state_manager = NodeStateManager(self.graph, self.ready_queue)
self.edge_state_manager = EdgeStateManager(self.graph)
self.execution_tracker = ExecutionTracker()
# Response coordination
self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool)
self.response_coordinator = ResponseStreamCoordinator(registry=self.output_registry, graph=self.graph)
# Event management
self.event_collector = EventCollector()
self.event_emitter = EventEmitter(self.event_collector)
# Error handling
self.error_handler = ErrorHandler(self.graph, self.graph_execution)
# Graph traversal
self.node_readiness_checker = NodeReadinessChecker(self.graph)
self.edge_processor = EdgeProcessor(
graph=self.graph,
edge_state_manager=self.edge_state_manager,
node_state_manager=self.node_state_manager,
response_coordinator=self.response_coordinator,
)
self.skip_propagator = SkipPropagator(
graph=self.graph,
edge_state_manager=self.edge_state_manager,
node_state_manager=self.node_state_manager,
)
self.branch_handler = BranchHandler(
graph=self.graph,
edge_processor=self.edge_processor,
skip_propagator=self.skip_propagator,
edge_state_manager=self.edge_state_manager,
# === Response Coordination ===
# Coordinates response streaming from response nodes
self._response_coordinator = ResponseStreamCoordinator(
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
)
# Event handler registry with all dependencies
self.event_handler_registry = EventHandlerRegistry(
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
graph_execution=self.graph_execution,
response_coordinator=self.response_coordinator,
event_collector=self.event_collector,
branch_handler=self.branch_handler,
edge_processor=self.edge_processor,
node_state_manager=self.node_state_manager,
execution_tracker=self.execution_tracker,
error_handler=self.error_handler,
# === Event Management ===
# Event manager handles both collection and emission of events
self._event_manager = EventManager()
# === Error Handling ===
# Centralized error handler for graph execution errors
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
# === Graph Traversal Components ===
# Propagates skip status through the graph when conditions aren't met
self._skip_propagator = SkipPropagator(
graph=self._graph,
state_manager=self._state_manager,
)
# Command processing
self.command_processor = CommandProcessor(
command_channel=self.command_channel,
graph_execution=self.graph_execution,
)
self._setup_command_handlers()
# Worker management
self._setup_worker_management()
# Orchestration
self.execution_coordinator = ExecutionCoordinator(
graph_execution=self.graph_execution,
node_state_manager=self.node_state_manager,
execution_tracker=self.execution_tracker,
event_handler=self.event_handler_registry,
event_collector=self.event_collector,
command_processor=self.command_processor,
worker_pool=self._worker_pool,
# Processes edges to determine next nodes after execution
# Also handles conditional branching and route selection
self._edge_processor = EdgeProcessor(
graph=self._graph,
state_manager=self._state_manager,
response_coordinator=self._response_coordinator,
skip_propagator=self._skip_propagator,
)
self.dispatcher = Dispatcher(
event_queue=self.event_queue,
event_handler=self.event_handler_registry,
event_collector=self.event_collector,
execution_coordinator=self.execution_coordinator,
max_execution_time=self.execution_context.max_execution_time,
event_emitter=self.event_emitter,
# === Event Handler Registry ===
# Central registry for handling all node execution events
self._event_handler_registry = EventHandler(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_manager,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)
def _setup_command_handlers(self) -> None:
"""Configure command handlers."""
# Create handler instance that follows the protocol
# === Command Processing ===
# Processes external commands (e.g., abort requests)
self._command_processor = CommandProcessor(
command_channel=self._command_channel,
graph_execution=self._graph_execution,
)
# Register abort command handler
abort_handler = AbortCommandHandler()
self.command_processor.register_handler(
self._command_processor.register_handler(
AbortCommand,
abort_handler,
)
def _setup_worker_management(self) -> None:
"""Initialize worker management subsystem."""
# Capture context for workers
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
flask_app: Flask | None = None
try:
flask_app = current_app._get_current_object() # type: ignore
app = current_app._get_current_object() # type: ignore
if isinstance(app, Flask):
flask_app = app
except RuntimeError:
pass
# Capture context variables for worker threads
context_vars = contextvars.copy_context()
# Create worker management components
self._activity_tracker = ActivityTracker()
self._dynamic_scaler = DynamicScaler(
min_workers=(self._min_workers if self._min_workers is not None else dify_config.GRAPH_ENGINE_MIN_WORKERS),
max_workers=(self._max_workers if self._max_workers is not None else dify_config.GRAPH_ENGINE_MAX_WORKERS),
scale_up_threshold=(
self._scale_up_threshold
if self._scale_up_threshold is not None
else dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
),
scale_down_idle_time=(
self._scale_down_idle_time
if self._scale_down_idle_time is not None
else dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
),
# Create worker pool for parallel node execution
self._worker_pool = SimpleWorkerPool(
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
flask_app=flask_app,
context_vars=context_vars,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
scale_down_idle_time=self._scale_down_idle_time,
)
self._worker_factory = WorkerFactory(flask_app, context_vars)
self._worker_pool = WorkerPool(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
worker_factory=self._worker_factory,
dynamic_scaler=self._dynamic_scaler,
activity_tracker=self._activity_tracker,
# === Orchestration ===
# Coordinates the overall execution lifecycle
self._execution_coordinator = ExecutionCoordinator(
graph_execution=self._graph_execution,
state_manager=self._state_manager,
event_handler=self._event_handler_registry,
event_collector=self._event_manager,
command_processor=self._command_processor,
worker_pool=self._worker_pool,
)
# Dispatches events and manages execution flow
self._dispatcher = Dispatcher(
event_queue=self._event_queue,
event_handler=self._event_handler_registry,
event_collector=self._event_manager,
execution_coordinator=self._execution_coordinator,
max_execution_time=self._execution_context.max_execution_time,
event_emitter=self._event_manager,
)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[Layer] = []
# === Validation ===
# Ensure all nodes share the same GraphRuntimeState instance
self._validate_graph_state_consistency()
def _validate_graph_state_consistency(self) -> None:
"""Validate that all nodes share the same GraphRuntimeState."""
expected_state_id = id(self.graph_runtime_state)
for node in self.graph.nodes.values():
expected_state_id = id(self._graph_runtime_state)
for node in self._graph.nodes.values():
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
@ -270,7 +251,7 @@ class GraphEngine:
self._initialize_layers()
# Start execution
self.graph_execution.start()
self._graph_execution.start()
start_event = GraphRunStartedEvent()
yield start_event
@ -278,23 +259,23 @@ class GraphEngine:
self._start_execution()
# Yield events as they occur
yield from self.event_emitter.emit_events()
yield from self._event_manager.emit_events()
# Handle completion
if self.graph_execution.aborted:
if self._graph_execution.aborted:
abort_reason = "Workflow execution aborted by user command"
if self.graph_execution.error:
abort_reason = str(self.graph_execution.error)
if self._graph_execution.error:
abort_reason = str(self._graph_execution.error)
yield GraphRunAbortedEvent(
reason=abort_reason,
outputs=self.graph_runtime_state.outputs,
outputs=self._graph_runtime_state.outputs,
)
elif self.graph_execution.has_error:
if self.graph_execution.error:
raise self.graph_execution.error
elif self._graph_execution.has_error:
if self._graph_execution.error:
raise self._graph_execution.error
else:
yield GraphRunSucceededEvent(
outputs=self.graph_runtime_state.outputs,
outputs=self._graph_runtime_state.outputs,
)
except Exception as e:
@ -306,10 +287,10 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self.event_collector.set_layers(self._layers)
self._event_manager.set_layers(self._layers)
for layer in self._layers:
try:
layer.initialize(self.graph_runtime_state, self.command_channel)
layer.initialize(self._graph_runtime_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
@ -320,28 +301,25 @@ class GraphEngine:
def _start_execution(self) -> None:
"""Start execution subsystems."""
# Calculate initial worker count
initial_workers = self._dynamic_scaler.calculate_initial_workers(self.graph)
# Start worker pool
self._worker_pool.start(initial_workers)
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()
# Register response nodes
for node in self.graph.nodes.values():
for node in self._graph.nodes.values():
if node.execution_type == NodeExecutionType.RESPONSE:
self.response_coordinator.register(node.id)
self._response_coordinator.register(node.id)
# Enqueue root node
root_node = self.graph.root_node
self.node_state_manager.enqueue_node(root_node.id)
self.execution_tracker.add(root_node.id)
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
# Start dispatcher
self.dispatcher.start()
self._dispatcher.start()
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self.dispatcher.stop()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
@ -350,6 +328,12 @@ class GraphEngine:
for layer in self._layers:
try:
layer.on_graph_end(self.graph_execution.error)
layer.on_graph_end(self._graph_execution.error)
except Exception as e:
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
# Public property accessors for attributes that need external access
@property
def graph_runtime_state(self) -> GraphRuntimeState:
"""Get the graph runtime state."""
return self._graph_runtime_state

View File

@ -5,14 +5,10 @@ This package handles graph navigation, edge processing,
and skip propagation logic.
"""
from .branch_handler import BranchHandler
from .edge_processor import EdgeProcessor
from .node_readiness import NodeReadinessChecker
from .skip_propagator import SkipPropagator
__all__ = [
"BranchHandler",
"EdgeProcessor",
"NodeReadinessChecker",
"SkipPropagator",
]

View File

@ -1,87 +0,0 @@
"""
Branch node handling for graph traversal.
"""
from collections.abc import Sequence
from typing import final
from core.workflow.graph import Graph
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
from ..state_management import EdgeStateManager
from .edge_processor import EdgeProcessor
from .skip_propagator import SkipPropagator
@final
class BranchHandler:
"""
Handles branch node logic during graph traversal.
Branch nodes select one of multiple paths based on conditions,
requiring special handling for edge selection and skip propagation.
"""
def __init__(
self,
graph: Graph,
edge_processor: EdgeProcessor,
skip_propagator: SkipPropagator,
edge_state_manager: EdgeStateManager,
) -> None:
"""
Initialize the branch handler.
Args:
graph: The workflow graph
edge_processor: Processor for edges
skip_propagator: Propagator for skip states
edge_state_manager: Manager for edge states
"""
self.graph = graph
self.edge_processor = edge_processor
self.skip_propagator = skip_propagator
self.edge_state_manager = edge_state_manager
def handle_branch_completion(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Handle completion of a branch node.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected branch
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no branch was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
# Categorize edges into selected and unselected
_, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
# Skip all unselected paths
self.skip_propagator.skip_branch_paths(unselected_edges)
# Process selected edges and get ready nodes and streaming events
return self.edge_processor.process_node_success(node_id, selected_handle)
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
"""
Validate that a branch selection is valid.
Args:
node_id: The ID of the branch node
selected_handle: The handle to validate
Returns:
True if the selection is valid
"""
outgoing_edges = self.graph.get_outgoing_edges(node_id)
valid_handles = {edge.source_handle for edge in outgoing_edges}
return selected_handle in valid_handles

View File

@ -3,14 +3,17 @@ Edge processing logic for graph traversal.
"""
from collections.abc import Sequence
from typing import final
from typing import TYPE_CHECKING, final
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Edge, Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent
from ..response_coordinator import ResponseStreamCoordinator
from ..state_management import EdgeStateManager, NodeStateManager
from ..state_management import UnifiedStateManager
if TYPE_CHECKING:
from .skip_propagator import SkipPropagator
@final
@ -19,29 +22,30 @@ class EdgeProcessor:
Processes edges during graph execution.
This handles marking edges as taken or skipped, notifying
the response coordinator, and triggering downstream node execution.
the response coordinator, triggering downstream node execution,
and managing branch node logic.
"""
def __init__(
self,
graph: Graph,
edge_state_manager: EdgeStateManager,
node_state_manager: NodeStateManager,
state_manager: UnifiedStateManager,
response_coordinator: ResponseStreamCoordinator,
skip_propagator: "SkipPropagator",
) -> None:
"""
Initialize the edge processor.
Args:
graph: The workflow graph
edge_state_manager: Manager for edge states
node_state_manager: Manager for node states
state_manager: Unified state manager
response_coordinator: Response stream coordinator
skip_propagator: Propagator for skip states
"""
self.graph = graph
self.edge_state_manager = edge_state_manager
self.node_state_manager = node_state_manager
self.response_coordinator = response_coordinator
self._graph = graph
self._state_manager = state_manager
self._response_coordinator = response_coordinator
self._skip_propagator = skip_propagator
def process_node_success(
self, node_id: str, selected_handle: str | None = None
@ -56,7 +60,7 @@ class EdgeProcessor:
Returns:
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
"""
node = self.graph.nodes[node_id]
node = self._graph.nodes[node_id]
if node.execution_type == NodeExecutionType.BRANCH:
return self._process_branch_node_edges(node_id, selected_handle)
@ -75,7 +79,7 @@ class EdgeProcessor:
"""
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self.graph.get_outgoing_edges(node_id)
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
nodes, events = self._process_taken_edge(edge)
@ -107,7 +111,7 @@ class EdgeProcessor:
all_streaming_events: list[NodeRunStreamChunkEvent] = []
# Categorize edges
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Process unselected edges first (mark as skipped)
for edge in unselected_edges:
@ -132,14 +136,14 @@ class EdgeProcessor:
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
"""
# Mark edge as taken
self.edge_state_manager.mark_edge_taken(edge.id)
self._state_manager.mark_edge_taken(edge.id)
# Notify response coordinator and get streaming events
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready
ready_nodes: list[str] = []
if self.node_state_manager.is_node_ready(edge.head):
if self._state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)
return ready_nodes, streaming_events
@ -151,4 +155,47 @@ class EdgeProcessor:
Args:
edge: The edge to skip
"""
self.edge_state_manager.mark_edge_skipped(edge.id)
self._state_manager.mark_edge_skipped(edge.id)
def handle_branch_completion(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Handle completion of a branch node.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected branch
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no branch was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
# Categorize edges into selected and unselected
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Skip all unselected paths
self._skip_propagator.skip_branch_paths(unselected_edges)
# Process selected edges and get ready nodes and streaming events
return self.process_node_success(node_id, selected_handle)
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
"""
Validate that a branch selection is valid.
Args:
node_id: The ID of the branch node
selected_handle: The handle to validate
Returns:
True if the selection is valid
"""
outgoing_edges = self._graph.get_outgoing_edges(node_id)
valid_handles = {edge.source_handle for edge in outgoing_edges}
return selected_handle in valid_handles

View File

@ -1,86 +0,0 @@
"""
Node readiness checking for execution.
"""
from typing import final
from core.workflow.enums import NodeState
from core.workflow.graph import Graph
@final
class NodeReadinessChecker:
"""
Checks if nodes are ready for execution based on their dependencies.
A node is ready when its dependencies (incoming edges) have been
satisfied according to the graph's execution rules.
"""
def __init__(self, graph: Graph) -> None:
"""
Initialize the readiness checker.
Args:
graph: The workflow graph
"""
self.graph = graph
def is_node_ready(self, node_id: str) -> bool:
"""
Check if a node is ready to be executed.
A node is ready when:
- It has no incoming edges (root or isolated node), OR
- At least one incoming edge is TAKEN and none are UNKNOWN
Args:
node_id: The ID of the node to check
Returns:
True if the node is ready for execution
"""
incoming_edges = self.graph.get_incoming_edges(node_id)
# No dependencies means always ready
if not incoming_edges:
return True
# Check edge states
has_unknown = False
has_taken = False
for edge in incoming_edges:
if edge.state == NodeState.UNKNOWN:
has_unknown = True
break
elif edge.state == NodeState.TAKEN:
has_taken = True
# Not ready if any dependency is still unknown
if has_unknown:
return False
# Ready if at least one path is taken
return has_taken
def get_ready_downstream_nodes(self, from_node_id: str) -> list[str]:
"""
Get all downstream nodes that are ready after a node completes.
Args:
from_node_id: The ID of the completed node
Returns:
List of node IDs that are now ready
"""
ready_nodes: list[str] = []
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
for edge in outgoing_edges:
if edge.state == NodeState.TAKEN:
downstream_node_id = edge.head
if self.is_node_ready(downstream_node_id):
ready_nodes.append(downstream_node_id)
return ready_nodes

View File

@ -7,7 +7,7 @@ from typing import final
from core.workflow.graph import Edge, Graph
from ..state_management import EdgeStateManager, NodeStateManager
from ..state_management import UnifiedStateManager
@final
@ -22,20 +22,17 @@ class SkipPropagator:
def __init__(
self,
graph: Graph,
edge_state_manager: EdgeStateManager,
node_state_manager: NodeStateManager,
state_manager: UnifiedStateManager,
) -> None:
"""
Initialize the skip propagator.
Args:
graph: The workflow graph
edge_state_manager: Manager for edge states
node_state_manager: Manager for node states
state_manager: Unified state manager
"""
self.graph = graph
self.edge_state_manager = edge_state_manager
self.node_state_manager = node_state_manager
self._graph = graph
self._state_manager = state_manager
def propagate_skip_from_edge(self, edge_id: str) -> None:
"""
@ -49,11 +46,11 @@ class SkipPropagator:
Args:
edge_id: The ID of the skipped edge to start from
"""
downstream_node_id = self.graph.edges[edge_id].head
incoming_edges = self.graph.get_incoming_edges(downstream_node_id)
downstream_node_id = self._graph.edges[edge_id].head
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
# Analyze edge states
edge_states = self.edge_state_manager.analyze_edge_states(incoming_edges)
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
# Stop if there are unknown edges (not yet processed)
if edge_states["has_unknown"]:
@ -62,7 +59,7 @@ class SkipPropagator:
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Enqueue node
self.node_state_manager.enqueue_node(downstream_node_id)
self._state_manager.enqueue_node(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
@ -77,12 +74,12 @@ class SkipPropagator:
node_id: The ID of the node to skip
"""
# Mark node as skipped
self.node_state_manager.mark_node_skipped(node_id)
self._state_manager.mark_node_skipped(node_id)
# Mark all outgoing edges as skipped and propagate
outgoing_edges = self.graph.get_outgoing_edges(node_id)
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
self.edge_state_manager.mark_edge_skipped(edge.id)
self._state_manager.mark_edge_skipped(edge.id)
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)
@ -94,5 +91,5 @@ class SkipPropagator:
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:
self.edge_state_manager.mark_edge_skipped(edge.id)
self._state_manager.mark_edge_skipped(edge.id)
self.propagate_skip_from_edge(edge.id)

View File

@ -9,6 +9,8 @@ import logging
from collections.abc import Mapping
from typing import Any, final
from typing_extensions import override
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
@ -93,13 +95,14 @@ class DebugLoggingLayer(Layer):
if not data:
return "{}"
formatted_items = []
formatted_items: list[str] = []
for key, value in data.items():
formatted_value = self._truncate_value(value)
formatted_items.append(f" {key}: {formatted_value}")
return "{\n" + ",\n".join(formatted_items) + "\n}"
@override
def on_graph_start(self) -> None:
"""Log graph execution start."""
self.logger.info("=" * 80)
@ -112,7 +115,7 @@ class DebugLoggingLayer(Layer):
# Log inputs if available
if self.graph_runtime_state.variable_pool:
initial_vars = {}
initial_vars: dict[str, Any] = {}
# Access the variable dictionary directly
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
for var_key, var in variables.items():
@ -121,6 +124,7 @@ class DebugLoggingLayer(Layer):
if initial_vars:
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""Log individual events based on their type."""
event_class = event.__class__.__name__
@ -222,6 +226,7 @@ class DebugLoggingLayer(Layer):
# Log unknown events at debug level
self.logger.debug("Event: %s", event_class)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Log graph execution end with summary statistics."""
self.logger.info("=" * 80)

View File

@ -13,6 +13,8 @@ import time
from enum import Enum
from typing import final
from typing_extensions import override
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers import Layer
from core.workflow.graph_events import (
@ -63,6 +65,7 @@ class ExecutionLimitsLayer(Layer):
self._execution_ended = False
self._abort_sent = False # Track if abort command has been sent
@override
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self.start_time = time.time()
@ -73,6 +76,7 @@ class ExecutionLimitsLayer(Layer):
self.logger.debug("Execution limits monitoring started")
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
@ -95,6 +99,7 @@ class ExecutionLimitsLayer(Layer):
if self._reached_time_limitation():
self._send_abort_command(LimitType.TIME_LIMIT)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._execution_started and not self._execution_ended:

View File

@ -10,11 +10,11 @@ from typing import TYPE_CHECKING, final
from core.workflow.graph_events.base import GraphNodeEventBase
from ..event_management import EventCollector, EventEmitter
from ..event_management import EventManager
from .execution_coordinator import ExecutionCoordinator
if TYPE_CHECKING:
from ..event_management import EventHandlerRegistry
from ..event_management import EventHandler
logger = logging.getLogger(__name__)
@ -31,11 +31,11 @@ class Dispatcher:
def __init__(
self,
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandlerRegistry",
event_collector: EventCollector,
event_handler: "EventHandler",
event_collector: EventManager,
execution_coordinator: ExecutionCoordinator,
max_execution_time: int,
event_emitter: EventEmitter | None = None,
event_emitter: EventManager | None = None,
) -> None:
"""
Initialize the dispatcher.
@ -43,17 +43,17 @@ class Dispatcher:
Args:
event_queue: Queue of events from workers
event_handler: Event handler registry for processing events
event_collector: Event collector for collecting unhandled events
event_collector: Event manager for collecting unhandled events
execution_coordinator: Coordinator for execution flow
max_execution_time: Maximum execution time in seconds
event_emitter: Optional event emitter to signal completion
event_emitter: Optional event manager to signal completion
"""
self.event_queue = event_queue
self.event_handler = event_handler
self.event_collector = event_collector
self.execution_coordinator = execution_coordinator
self.max_execution_time = max_execution_time
self.event_emitter = event_emitter
self._event_queue = event_queue
self._event_handler = event_handler
self._event_collector = event_collector
self._execution_coordinator = execution_coordinator
self._max_execution_time = max_execution_time
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
@ -80,28 +80,28 @@ class Dispatcher:
try:
while not self._stop_event.is_set():
# Check for commands
self.execution_coordinator.check_commands()
self._execution_coordinator.check_commands()
# Check for scaling
self.execution_coordinator.check_scaling()
self._execution_coordinator.check_scaling()
# Process events
try:
event = self.event_queue.get(timeout=0.1)
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self.event_handler.handle_event(event)
self.event_queue.task_done()
self._event_handler.handle_event(event)
self._event_queue.task_done()
except queue.Empty:
# Check if execution is complete
if self.execution_coordinator.is_execution_complete():
if self._execution_coordinator.is_execution_complete():
break
except Exception as e:
logger.exception("Dispatcher error")
self.execution_coordinator.mark_failed(e)
self._execution_coordinator.mark_failed(e)
finally:
self.execution_coordinator.mark_complete()
self._execution_coordinator.mark_complete()
# Signal the event emitter that execution is complete
if self.event_emitter:
self.event_emitter.mark_complete()
if self._event_emitter:
self._event_emitter.mark_complete()

View File

@ -6,12 +6,12 @@ from typing import TYPE_CHECKING, final
from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..event_management import EventCollector
from ..state_management import ExecutionTracker, NodeStateManager
from ..worker_management import WorkerPool
from ..event_management import EventManager
from ..state_management import UnifiedStateManager
from ..worker_management import SimpleWorkerPool
if TYPE_CHECKING:
from ..event_management import EventHandlerRegistry
from ..event_management import EventHandler
@final
@ -26,42 +26,37 @@ class ExecutionCoordinator:
def __init__(
self,
graph_execution: GraphExecution,
node_state_manager: NodeStateManager,
execution_tracker: ExecutionTracker,
event_handler: "EventHandlerRegistry",
event_collector: EventCollector,
state_manager: UnifiedStateManager,
event_handler: "EventHandler",
event_collector: EventManager,
command_processor: CommandProcessor,
worker_pool: WorkerPool,
worker_pool: SimpleWorkerPool,
) -> None:
"""
Initialize the execution coordinator.
Args:
graph_execution: Graph execution aggregate
node_state_manager: Manager for node states
execution_tracker: Tracker for executing nodes
state_manager: Unified state manager
event_handler: Event handler registry for processing events
event_collector: Event collector for collecting events
event_collector: Event manager for collecting events
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self.graph_execution = graph_execution
self.node_state_manager = node_state_manager
self.execution_tracker = execution_tracker
self.event_handler = event_handler
self.event_collector = event_collector
self.command_processor = command_processor
self.worker_pool = worker_pool
self._graph_execution = graph_execution
self._state_manager = state_manager
self._event_handler = event_handler
self._event_collector = event_collector
self._command_processor = command_processor
self._worker_pool = worker_pool
def check_commands(self) -> None:
"""Process any pending commands."""
self.command_processor.process_commands()
self._command_processor.process_commands()
def check_scaling(self) -> None:
"""Check and perform worker scaling if needed."""
queue_depth = self.node_state_manager.ready_queue.qsize()
executing_count = self.execution_tracker.count()
self.worker_pool.check_scaling(queue_depth, executing_count)
self._worker_pool.check_and_scale()
def is_execution_complete(self) -> bool:
"""
@ -71,16 +66,16 @@ class ExecutionCoordinator:
True if execution is complete
"""
# Check if aborted or failed
if self.graph_execution.aborted or self.graph_execution.has_error:
if self._graph_execution.aborted or self._graph_execution.has_error:
return True
# Complete if no work remains
return self.node_state_manager.ready_queue.empty() and self.execution_tracker.is_empty()
return self._state_manager.is_execution_complete()
def mark_complete(self) -> None:
"""Mark execution as complete."""
if not self.graph_execution.completed:
self.graph_execution.complete()
if not self._graph_execution.completed:
self._graph_execution.complete()
def mark_failed(self, error: Exception) -> None:
"""
@ -89,4 +84,4 @@ class ExecutionCoordinator:
Args:
error: The error that caused failure
"""
self.graph_execution.fail(error)
self._graph_execution.fail(error)

View File

@ -1,10 +0,0 @@
"""
OutputRegistry - Thread-safe storage for node outputs (streams and scalars)
This component provides thread-safe storage and retrieval of node outputs,
supporting both scalar values and streaming chunks with proper state management.
"""
from .registry import OutputRegistry
__all__ = ["OutputRegistry"]

View File

@ -1,146 +0,0 @@
"""
Main OutputRegistry implementation.
This module contains the public OutputRegistry class that provides
thread-safe storage for node outputs.
"""
from collections.abc import Sequence
from threading import RLock
from typing import TYPE_CHECKING, Union, final
from core.variables import Segment
from core.workflow.entities.variable_pool import VariablePool
from .stream import Stream
if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent
@final
class OutputRegistry:
"""
Thread-safe registry for storing and retrieving node outputs.
Supports both scalar values and streaming chunks with proper state management.
All operations are thread-safe using internal locking.
"""
def __init__(self, variable_pool: VariablePool) -> None:
"""Initialize empty registry with thread-safe storage."""
self._lock = RLock()
self._scalars = variable_pool
self._streams: dict[tuple, Stream] = {}
def _selector_to_key(self, selector: Sequence[str]) -> tuple:
"""Convert selector list to tuple key for internal storage."""
return tuple(selector)
def set_scalar(self, selector: Sequence[str], value: Union[str, int, float, bool, dict, list]) -> None:
"""
Set a scalar value for the given selector.
Args:
selector: List of strings identifying the output location
value: The scalar value to store
"""
with self._lock:
self._scalars.add(selector, value)
def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
"""
Get a scalar value for the given selector.
Args:
selector: List of strings identifying the output location
Returns:
The stored Variable object, or None if not found
"""
with self._lock:
return self._scalars.get(selector)
def append_chunk(self, selector: Sequence[str], event: "NodeRunStreamChunkEvent") -> None:
"""
Append a NodeRunStreamChunkEvent to the stream for the given selector.
Args:
selector: List of strings identifying the stream location
event: The NodeRunStreamChunkEvent to append
Raises:
ValueError: If the stream is already closed
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
self._streams[key] = Stream()
try:
self._streams[key].append(event)
except ValueError:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.
Args:
selector: List of strings identifying the stream location
Returns:
The next event, or None if no unread events available
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return None
return self._streams[key].pop_next()
def has_unread(self, selector: Sequence[str]) -> bool:
"""
Check if the stream has unread events.
Args:
selector: List of strings identifying the stream location
Returns:
True if there are unread events, False otherwise
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return False
return self._streams[key].has_unread()
def close_stream(self, selector: Sequence[str]) -> None:
"""
Mark a stream as closed (no more chunks can be appended).
Args:
selector: List of strings identifying the stream location
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
self._streams[key] = Stream()
self._streams[key].close()
def stream_closed(self, selector: Sequence[str]) -> bool:
"""
Check if a stream is closed.
Args:
selector: List of strings identifying the stream location
Returns:
True if the stream is closed, False otherwise
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return False
return self._streams[key].is_closed

View File

@ -1,70 +0,0 @@
"""
Internal stream implementation for OutputRegistry.
This module contains the private Stream class used internally by OutputRegistry
to manage streaming data chunks.
"""
from typing import TYPE_CHECKING, final
if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent
@final
class Stream:
"""
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.
This class encapsulates stream-specific data and operations,
including event storage, read position tracking, and closed state.
Note: This is an internal class not exposed in the public API.
"""
def __init__(self) -> None:
"""Initialize an empty stream."""
self.events: list[NodeRunStreamChunkEvent] = []
self.read_position: int = 0
self.is_closed: bool = False
def append(self, event: "NodeRunStreamChunkEvent") -> None:
"""
Append a NodeRunStreamChunkEvent to the stream.
Args:
event: The NodeRunStreamChunkEvent to append
Raises:
ValueError: If the stream is already closed
"""
if self.is_closed:
raise ValueError("Cannot append to a closed stream")
self.events.append(event)
def pop_next(self) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.
Returns:
The next event, or None if no unread events available
"""
if self.read_position >= len(self.events):
return None
event = self.events[self.read_position]
self.read_position += 1
return event
def has_unread(self) -> bool:
"""
Check if the stream has unread events.
Returns:
True if there are unread events, False otherwise
"""
return self.read_position < len(self.events)
def close(self) -> None:
"""Mark the stream as closed (no more chunks can be appended)."""
self.is_closed = True

View File

@ -12,12 +12,12 @@ from threading import RLock
from typing import TypeAlias, final
from uuid import uuid4
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from ..output_registry import OutputRegistry
from .path import Path
from .session import ResponseSession
@ -36,19 +36,24 @@ class ResponseStreamCoordinator:
Ensures ordered streaming of responses based on upstream node outputs and constants.
"""
def __init__(self, registry: OutputRegistry, graph: "Graph") -> None:
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
"""
Initialize coordinator with output registry.
Initialize coordinator with variable pool.
Args:
registry: OutputRegistry instance for accessing node outputs
variable_pool: VariablePool instance for accessing node variables
graph: Graph instance for looking up node information
"""
self.registry = registry
self.graph = graph
self.active_session: ResponseSession | None = None
self.waiting_sessions: deque[ResponseSession] = deque()
self.lock = RLock()
self._variable_pool = variable_pool
self._graph = graph
self._active_session: ResponseSession | None = None
self._waiting_sessions: deque[ResponseSession] = deque()
self._lock = RLock()
# Internal stream management (replacing OutputRegistry)
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
self._stream_positions: dict[tuple[str, ...], int] = {}
self._closed_streams: set[tuple[str, ...]] = set()
# Track response nodes
self._response_nodes: set[NodeID] = set()
@ -63,7 +68,7 @@ class ResponseStreamCoordinator:
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
def register(self, response_node_id: NodeID) -> None:
with self.lock:
with self._lock:
self._response_nodes.add(response_node_id)
# Build and save paths map for this response node
@ -71,7 +76,7 @@ class ResponseStreamCoordinator:
self._paths_maps[response_node_id] = paths_map
# Create and store response session for this node
response_node = self.graph.nodes[response_node_id]
response_node = self._graph.nodes[response_node_id]
session = ResponseSession.from_node(response_node)
self._response_sessions[response_node_id] = session
@ -82,7 +87,7 @@ class ResponseStreamCoordinator:
node_id: The ID of the node
execution_id: The execution ID from NodeRunStartedEvent
"""
with self.lock:
with self._lock:
self._node_execution_ids[node_id] = execution_id
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
@ -94,7 +99,7 @@ class ResponseStreamCoordinator:
Returns:
The execution ID for the node
"""
with self.lock:
with self._lock:
if node_id not in self._node_execution_ids:
self._node_execution_ids[node_id] = str(uuid4())
return self._node_execution_ids[node_id]
@ -111,14 +116,14 @@ class ResponseStreamCoordinator:
List of Path objects, where each path contains branch edge IDs
"""
# Get root node ID
root_node_id = self.graph.root_node.id
root_node_id = self._graph.root_node.id
# If root is the response node, return empty path
if root_node_id == response_node_id:
return [Path()]
# Extract variable selectors from the response node's template
response_node = self.graph.nodes[response_node_id]
response_node = self._graph.nodes[response_node_id]
response_session = ResponseSession.from_node(response_node)
template = response_session.template
@ -144,7 +149,7 @@ class ResponseStreamCoordinator:
visited.add(current_node_id)
# Explore outgoing edges
outgoing_edges = self.graph.get_outgoing_edges(current_node_id)
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
for edge in outgoing_edges:
edge_id = edge.id
next_node_id = edge.head
@ -161,10 +166,10 @@ class ResponseStreamCoordinator:
# Step 2: For each complete path, filter edges based on node blocking behavior
filtered_paths: list[Path] = []
for path in all_complete_paths:
blocking_edges = []
blocking_edges: list[str] = []
for edge_id in path:
edge = self.graph.edges[edge_id]
source_node = self.graph.nodes[edge.tail]
edge = self._graph.edges[edge_id]
source_node = self._graph.nodes[edge.tail]
# Check if node is a branch/container (original behavior)
if source_node.execution_type in {
@ -194,7 +199,7 @@ class ResponseStreamCoordinator:
"""
events: list[NodeRunStreamChunkEvent] = []
with self.lock:
with self._lock:
# Check each response node in order
for response_node_id in self._response_nodes:
if response_node_id not in self._paths_maps:
@ -240,33 +245,32 @@ class ResponseStreamCoordinator:
# Remove from map to ensure it won't be activated again
del self._response_sessions[node_id]
if self.active_session is None:
self.active_session = session
if self._active_session is None:
self._active_session = session
# Try to flush immediately
events.extend(self.try_flush())
else:
# Queue the session if another is active
self.waiting_sessions.append(session)
self._waiting_sessions.append(session)
return events
def intercept_event(
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
) -> Sequence[NodeRunStreamChunkEvent]:
with self.lock:
with self._lock:
if isinstance(event, NodeRunStreamChunkEvent):
self.registry.append_chunk(event.selector, event)
self._append_stream_chunk(event.selector, event)
if event.is_final:
self.registry.close_stream(event.selector)
self._close_stream(event.selector)
return self.try_flush()
elif isinstance(event, NodeRunSucceededEvent):
else:
# Skip cause we share the same variable pool.
#
# for variable_name, variable_value in event.node_run_result.outputs.items():
# self.registry.set_scalar((event.node_id, variable_name), variable_value)
# self._variable_pool.add((event.node_id, variable_name), variable_value)
return self.try_flush()
return []
def _create_stream_chunk_event(
self,
@ -282,9 +286,9 @@ class ResponseStreamCoordinator:
active response node's information since these are not actual node IDs.
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self.graph.nodes and self.active_session:
if selector and selector[0] not in self._graph.nodes and self._active_session:
# Use the active response node for special selectors
response_node = self.graph.nodes[self.active_session.node_id]
response_node = self._graph.nodes[self._active_session.node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
@ -295,7 +299,7 @@ class ResponseStreamCoordinator:
)
# Standard case: selector refers to an actual node
node = self.graph.nodes[node_id]
node = self._graph.nodes[node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=node.id,
@ -318,21 +322,21 @@ class ResponseStreamCoordinator:
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
if self.active_session and source_selector_prefix not in self.graph.nodes:
if self._active_session and source_selector_prefix not in self._graph.nodes:
# Special selector - use active response node
output_node_id = self.active_session.node_id
output_node_id = self._active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
execution_id = self._get_or_create_execution_id(output_node_id)
# Stream all available chunks
while self.registry.has_unread(segment.selector):
if event := self.registry.pop_chunk(segment.selector):
while self._has_unread_stream(segment.selector):
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self.active_session and source_selector_prefix not in self.graph.nodes:
response_node = self.graph.nodes[self.active_session.node_id]
if self._active_session and source_selector_prefix not in self._graph.nodes:
response_node = self._graph.nodes[self._active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
@ -349,15 +353,15 @@ class ResponseStreamCoordinator:
events.append(event)
# Check if this is the last chunk by looking ahead
stream_closed = self.registry.stream_closed(segment.selector)
stream_closed = self._is_stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete
if stream_closed:
is_complete = True
elif value := self.registry.get_scalar(segment.selector):
elif value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
@ -374,13 +378,13 @@ class ResponseStreamCoordinator:
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self.active_session is not None
current_response_node = self.graph.nodes[self.active_session.node_id]
assert self._active_session is not None
current_response_node = self._graph.nodes[self._active_session.node_id]
# Use get_or_create_execution_id to ensure we have a consistent ID
execution_id = self._get_or_create_execution_id(current_response_node.id)
is_last_segment = self.active_session.index == len(self.active_session.template.segments) - 1
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
event = self._create_stream_chunk_event(
node_id=current_response_node.id,
execution_id=execution_id,
@ -391,29 +395,29 @@ class ResponseStreamCoordinator:
return [event]
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
with self.lock:
if not self.active_session:
with self._lock:
if not self._active_session:
return []
template = self.active_session.template
response_node_id = self.active_session.node_id
template = self._active_session.template
response_node_id = self._active_session.node_id
events: list[NodeRunStreamChunkEvent] = []
# Process segments sequentially from current index
while self.active_session.index < len(template.segments):
segment = template.segments[self.active_session.index]
while self._active_session.index < len(template.segments):
segment = template.segments[self._active_session.index]
if isinstance(segment, VariableSegment):
# Check if the source node for this variable is skipped
# Only check for actual nodes, not special selectors (sys, env, conversation)
source_selector_prefix = segment.selector[0] if segment.selector else ""
if source_selector_prefix in self.graph.nodes:
source_node = self.graph.nodes[source_selector_prefix]
if source_selector_prefix in self._graph.nodes:
source_node = self._graph.nodes[source_selector_prefix]
if source_node.state == NodeState.SKIPPED:
# Skip this variable segment if the source node is skipped
self.active_session.index += 1
self._active_session.index += 1
continue
segment_events, is_complete = self._process_variable_segment(segment)
@ -421,17 +425,17 @@ class ResponseStreamCoordinator:
# Only advance index if this variable segment is complete
if is_complete:
self.active_session.index += 1
self._active_session.index += 1
else:
# Wait for more data
break
elif isinstance(segment, TextSegment):
else:
segment_events = self._process_text_segment(segment)
events.extend(segment_events)
self.active_session.index += 1
self._active_session.index += 1
if self.active_session.is_complete():
if self._active_session.is_complete():
# End current session and get events from starting next session
next_session_events = self.end_session(response_node_id)
events.extend(next_session_events)
@ -449,18 +453,108 @@ class ResponseStreamCoordinator:
Returns:
List of events from starting the next session
"""
with self.lock:
with self._lock:
events: list[NodeRunStreamChunkEvent] = []
if self.active_session and self.active_session.node_id == node_id:
self.active_session = None
if self._active_session and self._active_session.node_id == node_id:
self._active_session = None
# Try to start next waiting session
if self.waiting_sessions:
next_session = self.waiting_sessions.popleft()
self.active_session = next_session
if self._waiting_sessions:
next_session = self._waiting_sessions.popleft()
self._active_session = next_session
# Immediately try to flush any available segments
events = self.try_flush()
return events
# ============= Internal Stream Management Methods =============
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
"""
Append a stream chunk to the internal buffer.
Args:
selector: List of strings identifying the stream location
event: The NodeRunStreamChunkEvent to append
Raises:
ValueError: If the stream is already closed
"""
key = tuple(selector)
if key in self._closed_streams:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
if key not in self._stream_buffers:
self._stream_buffers[key] = []
self._stream_positions[key] = 0
self._stream_buffers[key].append(event)
def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
"""
Pop the next unread stream chunk from the buffer.
Args:
selector: List of strings identifying the stream location
Returns:
The next event, or None if no unread events available
"""
key = tuple(selector)
if key not in self._stream_buffers:
return None
position = self._stream_positions.get(key, 0)
buffer = self._stream_buffers[key]
if position >= len(buffer):
return None
event = buffer[position]
self._stream_positions[key] = position + 1
return event
def _has_unread_stream(self, selector: Sequence[str]) -> bool:
"""
Check if the stream has unread events.
Args:
selector: List of strings identifying the stream location
Returns:
True if there are unread events, False otherwise
"""
key = tuple(selector)
if key not in self._stream_buffers:
return False
position = self._stream_positions.get(key, 0)
return position < len(self._stream_buffers[key])
def _close_stream(self, selector: Sequence[str]) -> None:
"""
Mark a stream as closed (no more chunks can be appended).
Args:
selector: List of strings identifying the stream location
"""
key = tuple(selector)
self._closed_streams.add(key)
def _is_stream_closed(self, selector: Sequence[str]) -> bool:
"""
Check if a stream is closed.
Args:
selector: List of strings identifying the stream location
Returns:
True if the stream is closed, False otherwise
"""
key = tuple(selector)
return key in self._closed_streams

View File

@ -5,12 +5,8 @@ This package manages node states, edge states, and execution tracking
during workflow graph execution.
"""
from .edge_state_manager import EdgeStateManager
from .execution_tracker import ExecutionTracker
from .node_state_manager import NodeStateManager
from .unified_state_manager import UnifiedStateManager
__all__ = [
"EdgeStateManager",
"ExecutionTracker",
"NodeStateManager",
"UnifiedStateManager",
]

View File

@ -1,114 +0,0 @@
"""
Manager for edge states during graph execution.
"""
import threading
from collections.abc import Sequence
from typing import TypedDict, final
from core.workflow.enums import NodeState
from core.workflow.graph import Edge, Graph
class EdgeStateAnalysis(TypedDict):
"""Analysis result for edge states."""
has_unknown: bool
has_taken: bool
all_skipped: bool
@final
class EdgeStateManager:
"""
Manages edge states and transitions during graph execution.
This handles edge state changes and provides analysis of edge
states for decision making during execution.
"""
def __init__(self, graph: Graph) -> None:
"""
Initialize the edge state manager.
Args:
graph: The workflow graph
"""
self.graph = graph
self._lock = threading.RLock()
def mark_edge_taken(self, edge_id: str) -> None:
"""
Mark an edge as TAKEN.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.TAKEN
def mark_edge_skipped(self, edge_id: str) -> None:
"""
Mark an edge as SKIPPED.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.SKIPPED
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
"""
Analyze the states of edges and return summary flags.
Args:
edges: List of edges to analyze
Returns:
Analysis result with state flags
"""
with self._lock:
states = {edge.state for edge in edges}
return EdgeStateAnalysis(
has_unknown=NodeState.UNKNOWN in states,
has_taken=NodeState.TAKEN in states,
all_skipped=states == {NodeState.SKIPPED} if states else True,
)
def get_edge_state(self, edge_id: str) -> NodeState:
"""
Get the current state of an edge.
Args:
edge_id: The ID of the edge
Returns:
The current edge state
"""
with self._lock:
return self.graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
Categorize branch edges into selected and unselected.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
A tuple of (selected_edges, unselected_edges)
"""
with self._lock:
outgoing_edges = self.graph.get_outgoing_edges(node_id)
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []
for edge in outgoing_edges:
if edge.source_handle == selected_handle:
selected_edges.append(edge)
else:
unselected_edges.append(edge)
return selected_edges, unselected_edges

View File

@ -1,89 +0,0 @@
"""
Tracker for currently executing nodes.
"""
import threading
from typing import final
@final
class ExecutionTracker:
"""
Tracks nodes that are currently being executed.
This replaces the ExecutingNodesManager with a cleaner interface
focused on tracking which nodes are in progress.
"""
def __init__(self) -> None:
"""Initialize the execution tracker."""
self._executing_nodes: set[str] = set()
self._lock = threading.RLock()
def add(self, node_id: str) -> None:
"""
Mark a node as executing.
Args:
node_id: The ID of the node starting execution
"""
with self._lock:
self._executing_nodes.add(node_id)
def remove(self, node_id: str) -> None:
"""
Mark a node as no longer executing.
Args:
node_id: The ID of the node finishing execution
"""
with self._lock:
self._executing_nodes.discard(node_id)
def is_executing(self, node_id: str) -> bool:
"""
Check if a node is currently executing.
Args:
node_id: The ID of the node to check
Returns:
True if the node is executing
"""
with self._lock:
return node_id in self._executing_nodes
def is_empty(self) -> bool:
"""
Check if no nodes are currently executing.
Returns:
True if no nodes are executing
"""
with self._lock:
return len(self._executing_nodes) == 0
def count(self) -> int:
"""
Get the count of currently executing nodes.
Returns:
Number of executing nodes
"""
with self._lock:
return len(self._executing_nodes)
def get_executing_nodes(self) -> set[str]:
"""
Get a copy of the set of executing node IDs.
Returns:
Set of node IDs currently executing
"""
with self._lock:
return self._executing_nodes.copy()
def clear(self) -> None:
"""Clear all executing nodes."""
with self._lock:
self._executing_nodes.clear()

View File

@ -1,97 +0,0 @@
"""
Manager for node states during graph execution.
"""
import queue
import threading
from typing import final
from core.workflow.enums import NodeState
from core.workflow.graph import Graph
@final
class NodeStateManager:
"""
Manages node states and the ready queue for execution.
This centralizes node state transitions and enqueueing logic,
ensuring thread-safe operations on node states.
"""
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
"""
Initialize the node state manager.
Args:
graph: The workflow graph
ready_queue: Queue for nodes ready to execute
"""
self.graph = graph
self.ready_queue = ready_queue
self._lock = threading.RLock()
def enqueue_node(self, node_id: str) -> None:
"""
Mark a node as TAKEN and add it to the ready queue.
This combines the state transition and enqueueing operations
that always occur together when preparing a node for execution.
Args:
node_id: The ID of the node to enqueue
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.TAKEN
self.ready_queue.put(node_id)
def mark_node_skipped(self, node_id: str) -> None:
"""
Mark a node as SKIPPED.
Args:
node_id: The ID of the node to skip
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.SKIPPED
def is_node_ready(self, node_id: str) -> bool:
"""
Check if a node is ready to be executed.
A node is ready when all its incoming edges from taken branches
have been satisfied.
Args:
node_id: The ID of the node to check
Returns:
True if the node is ready for execution
"""
with self._lock:
# Get all incoming edges to this node
incoming_edges = self.graph.get_incoming_edges(node_id)
# If no incoming edges, node is always ready
if not incoming_edges:
return True
# If any edge is UNKNOWN, node is not ready
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
return False
# Node is ready if at least one edge is TAKEN
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
def get_node_state(self, node_id: str) -> NodeState:
"""
Get the current state of a node.
Args:
node_id: The ID of the node
Returns:
The current node state
"""
with self._lock:
return self.graph.nodes[node_id].state

View File

@ -0,0 +1,304 @@
"""
Unified state manager that combines node, edge, and execution tracking.
This is a proposed simplification that merges NodeStateManager, EdgeStateManager,
and ExecutionTracker into a single cohesive class.
"""
import queue
import threading
from collections.abc import Sequence
from typing import TypedDict, final
from core.workflow.enums import NodeState
from core.workflow.graph import Edge, Graph
class EdgeStateAnalysis(TypedDict):
"""Analysis result for edge states."""
has_unknown: bool
has_taken: bool
all_skipped: bool
@final
class UnifiedStateManager:
"""
Unified manager for all graph state operations.
This class combines the responsibilities of:
- NodeStateManager: Node state transitions and ready queue
- EdgeStateManager: Edge state transitions and analysis
- ExecutionTracker: Tracking executing nodes
Benefits:
- Single lock for all state operations (reduced contention)
- Cohesive state management interface
- Simplified dependency injection
"""
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
"""
Initialize the unified state manager.
Args:
graph: The workflow graph
ready_queue: Queue for nodes ready to execute
"""
self._graph = graph
self._ready_queue = ready_queue
self._lock = threading.RLock()
# Execution tracking state
self._executing_nodes: set[str] = set()
# ============= Node State Operations =============
def enqueue_node(self, node_id: str) -> None:
"""
Mark a node as TAKEN and add it to the ready queue.
This combines the state transition and enqueueing operations
that always occur together when preparing a node for execution.
Args:
node_id: The ID of the node to enqueue
"""
with self._lock:
self._graph.nodes[node_id].state = NodeState.TAKEN
self._ready_queue.put(node_id)
def mark_node_skipped(self, node_id: str) -> None:
"""
Mark a node as SKIPPED.
Args:
node_id: The ID of the node to skip
"""
with self._lock:
self._graph.nodes[node_id].state = NodeState.SKIPPED
def is_node_ready(self, node_id: str) -> bool:
"""
Check if a node is ready to be executed.
A node is ready when all its incoming edges from taken branches
have been satisfied.
Args:
node_id: The ID of the node to check
Returns:
True if the node is ready for execution
"""
with self._lock:
# Get all incoming edges to this node
incoming_edges = self._graph.get_incoming_edges(node_id)
# If no incoming edges, node is always ready
if not incoming_edges:
return True
# If any edge is UNKNOWN, node is not ready
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
return False
# Node is ready if at least one edge is TAKEN
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
def get_node_state(self, node_id: str) -> NodeState:
"""
Get the current state of a node.
Args:
node_id: The ID of the node
Returns:
The current node state
"""
with self._lock:
return self._graph.nodes[node_id].state
# ============= Edge State Operations =============
def mark_edge_taken(self, edge_id: str) -> None:
"""
Mark an edge as TAKEN.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self._graph.edges[edge_id].state = NodeState.TAKEN
def mark_edge_skipped(self, edge_id: str) -> None:
"""
Mark an edge as SKIPPED.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self._graph.edges[edge_id].state = NodeState.SKIPPED
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
"""
Analyze the states of edges and return summary flags.
Args:
edges: List of edges to analyze
Returns:
Analysis result with state flags
"""
with self._lock:
states = {edge.state for edge in edges}
return EdgeStateAnalysis(
has_unknown=NodeState.UNKNOWN in states,
has_taken=NodeState.TAKEN in states,
all_skipped=states == {NodeState.SKIPPED} if states else True,
)
def get_edge_state(self, edge_id: str) -> NodeState:
"""
Get the current state of an edge.
Args:
edge_id: The ID of the edge
Returns:
The current edge state
"""
with self._lock:
return self._graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
Categorize branch edges into selected and unselected.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
A tuple of (selected_edges, unselected_edges)
"""
with self._lock:
outgoing_edges = self._graph.get_outgoing_edges(node_id)
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []
for edge in outgoing_edges:
if edge.source_handle == selected_handle:
selected_edges.append(edge)
else:
unselected_edges.append(edge)
return selected_edges, unselected_edges
# ============= Execution Tracking Operations =============
def start_execution(self, node_id: str) -> None:
"""
Mark a node as executing.
Args:
node_id: The ID of the node starting execution
"""
with self._lock:
self._executing_nodes.add(node_id)
def finish_execution(self, node_id: str) -> None:
"""
Mark a node as no longer executing.
Args:
node_id: The ID of the node finishing execution
"""
with self._lock:
self._executing_nodes.discard(node_id)
def is_executing(self, node_id: str) -> bool:
"""
Check if a node is currently executing.
Args:
node_id: The ID of the node to check
Returns:
True if the node is executing
"""
with self._lock:
return node_id in self._executing_nodes
def get_executing_count(self) -> int:
"""
Get the count of currently executing nodes.
Returns:
Number of executing nodes
"""
with self._lock:
return len(self._executing_nodes)
def get_executing_nodes(self) -> set[str]:
"""
Get a copy of the set of executing node IDs.
Returns:
Set of node IDs currently executing
"""
with self._lock:
return self._executing_nodes.copy()
def clear_executing(self) -> None:
"""Clear all executing nodes."""
with self._lock:
self._executing_nodes.clear()
# ============= Composite Operations =============
def is_execution_complete(self) -> bool:
"""
Check if graph execution is complete.
Execution is complete when:
- Ready queue is empty
- No nodes are executing
Returns:
True if execution is complete
"""
with self._lock:
return self._ready_queue.empty() and len(self._executing_nodes) == 0
def get_queue_depth(self) -> int:
"""
Get the current depth of the ready queue.
Returns:
Number of nodes in the ready queue
"""
return self._ready_queue.qsize()
def get_execution_stats(self) -> dict[str, int]:
"""
Get execution statistics.
Returns:
Dictionary with execution statistics
"""
with self._lock:
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
return {
"queue_depth": self._ready_queue.qsize(),
"executing": len(self._executing_nodes),
"taken_nodes": taken_nodes,
"skipped_nodes": skipped_nodes,
"unknown_nodes": unknown_nodes,
}

View File

@ -15,6 +15,7 @@ from typing import final
from uuid import uuid4
from flask import Flask
from typing_extensions import override
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
@ -58,21 +59,22 @@ class Worker(threading.Thread):
on_active_callback: Optional callback when worker becomes active
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.worker_id = worker_id
self.flask_app = flask_app
self.context_vars = context_vars
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self.on_idle_callback = on_idle_callback
self.on_active_callback = on_active_callback
self.last_task_time = time.time()
self._on_idle_callback = on_idle_callback
self._on_active_callback = on_active_callback
self._last_task_time = time.time()
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
@override
def run(self) -> None:
"""
Main worker loop.
@ -83,22 +85,22 @@ class Worker(threading.Thread):
while not self._stop_event.is_set():
# Try to get a node ID from the ready queue (with timeout)
try:
node_id = self.ready_queue.get(timeout=0.1)
node_id = self._ready_queue.get(timeout=0.1)
except queue.Empty:
# Notify that worker is idle
if self.on_idle_callback:
self.on_idle_callback(self.worker_id)
if self._on_idle_callback:
self._on_idle_callback(self._worker_id)
continue
# Notify that worker is active
if self.on_active_callback:
self.on_active_callback(self.worker_id)
if self._on_active_callback:
self._on_active_callback(self._worker_id)
self.last_task_time = time.time()
node = self.graph.nodes[node_id]
self._last_task_time = time.time()
node = self._graph.nodes[node_id]
try:
self._execute_node(node)
self.ready_queue.task_done()
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
@ -108,7 +110,7 @@ class Worker(threading.Thread):
error=str(e),
start_at=datetime.now(),
)
self.event_queue.put(error_event)
self._event_queue.put(error_event)
def _execute_node(self, node: Node) -> None:
"""
@ -118,19 +120,19 @@ class Worker(threading.Thread):
node: The node instance to execute
"""
# Execute the node with preserved context if Flask app is provided
if self.flask_app and self.context_vars:
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self.flask_app,
context_vars=self.context_vars,
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self.event_queue.put(event)
self._event_queue.put(event)
else:
# Execute without context preservation
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self.event_queue.put(event)
self._event_queue.put(event)

View File

@ -1,81 +0,0 @@
# Worker Management
Dynamic worker pool for node execution.
## Components
### WorkerPool
Manages worker thread lifecycle.
- `start/stop/wait()` - Control workers
- `scale_up/down()` - Adjust pool size
- `get_worker_count()` - Current count
### WorkerFactory
Creates workers with Flask context.
- `create_worker()` - Build with dependencies
- Preserves request context
### DynamicScaler
Determines scaling decisions.
- `min/max_workers` - Pool bounds
- `scale_up_threshold` - Queue trigger
- `should_scale_up/down()` - Check conditions
### ActivityTracker
Tracks worker activity.
- `track_activity(worker_id)` - Record activity
- `get_idle_workers(threshold)` - Find idle
- `get_active_count()` - Active count
## Usage
```python
scaler = DynamicScaler(
min_workers=2,
max_workers=10,
scale_up_threshold=5
)
pool = WorkerPool(
ready_queue=ready_queue,
worker_factory=factory,
dynamic_scaler=scaler
)
pool.start()
# Scale based on load
if scaler.should_scale_up(queue_size, active):
pool.scale_up()
pool.stop()
```
## Scaling Strategy
**Scale Up**: Queue size > threshold AND workers < max
**Scale Down**: Idle workers exist AND workers > min
## Parameters
- `min_workers` - Minimum pool size
- `max_workers` - Maximum pool size
- `scale_up_threshold` - Queue trigger
- `scale_down_threshold` - Idle seconds
## Flask Context
WorkerFactory preserves request context across threads:
```python
context_vars = {"request_id": request.id}
# Workers receive same context
```

View File

@ -5,14 +5,8 @@ This package manages the worker pool, including creation,
scaling, and activity tracking.
"""
from .activity_tracker import ActivityTracker
from .dynamic_scaler import DynamicScaler
from .worker_factory import WorkerFactory
from .worker_pool import WorkerPool
from .simple_worker_pool import SimpleWorkerPool
__all__ = [
"ActivityTracker",
"DynamicScaler",
"WorkerFactory",
"WorkerPool",
"SimpleWorkerPool",
]

View File

@ -1,76 +0,0 @@
"""
Activity tracker for monitoring worker activity.
"""
import threading
import time
from typing import final
@final
class ActivityTracker:
"""
Tracks worker activity for scaling decisions.
This monitors which workers are active or idle to support
dynamic scaling decisions.
"""
def __init__(self, idle_threshold: float = 30.0) -> None:
"""
Initialize the activity tracker.
Args:
idle_threshold: Seconds before a worker is considered idle
"""
self.idle_threshold = idle_threshold
self._worker_activity: dict[int, tuple[bool, float]] = {}
self._lock = threading.RLock()
def track_activity(self, worker_id: int, is_active: bool) -> None:
"""
Track worker activity state.
Args:
worker_id: ID of the worker
is_active: Whether the worker is active
"""
with self._lock:
self._worker_activity[worker_id] = (is_active, time.time())
def get_idle_workers(self) -> list[int]:
"""
Get list of workers that have been idle too long.
Returns:
List of idle worker IDs
"""
current_time = time.time()
idle_workers = []
with self._lock:
for worker_id, (is_active, last_change) in self._worker_activity.items():
if not is_active and (current_time - last_change) > self.idle_threshold:
idle_workers.append(worker_id)
return idle_workers
def remove_worker(self, worker_id: int) -> None:
"""
Remove a worker from tracking.
Args:
worker_id: ID of the worker to remove
"""
with self._lock:
self._worker_activity.pop(worker_id, None)
def get_active_count(self) -> int:
"""
Get count of currently active workers.
Returns:
Number of active workers
"""
with self._lock:
return sum(1 for is_active, _ in self._worker_activity.values() if is_active)

View File

@ -1,101 +0,0 @@
"""
Dynamic scaler for worker pool sizing.
"""
from typing import final
from core.workflow.graph import Graph
@final
class DynamicScaler:
"""
Manages dynamic scaling decisions for the worker pool.
This encapsulates the logic for when to scale up or down
based on workload and configuration.
"""
def __init__(
self,
min_workers: int = 2,
max_workers: int = 10,
scale_up_threshold: int = 5,
scale_down_idle_time: float = 30.0,
) -> None:
"""
Initialize the dynamic scaler.
Args:
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Idle time before scaling down
"""
self.min_workers = min_workers
self.max_workers = max_workers
self.scale_up_threshold = scale_up_threshold
self.scale_down_idle_time = scale_down_idle_time
def calculate_initial_workers(self, graph: Graph) -> int:
"""
Calculate initial worker count based on graph complexity.
Args:
graph: The workflow graph
Returns:
Initial number of workers to create
"""
node_count = len(graph.nodes)
# Simple heuristic: more nodes = more workers
if node_count < 10:
initial = self.min_workers
elif node_count < 50:
initial = min(4, self.max_workers)
elif node_count < 100:
initial = min(6, self.max_workers)
else:
initial = min(8, self.max_workers)
return max(self.min_workers, initial)
def should_scale_up(self, current_workers: int, queue_depth: int, executing_count: int) -> bool:
"""
Determine if scaling up is needed.
Args:
current_workers: Current number of workers
queue_depth: Number of nodes waiting
executing_count: Number of nodes executing
Returns:
True if should scale up
"""
if current_workers >= self.max_workers:
return False
# Scale up if queue is deep and workers are busy
if queue_depth > self.scale_up_threshold:
if executing_count >= current_workers * 0.8:
return True
return False
def should_scale_down(self, current_workers: int, idle_workers: list[int]) -> bool:
"""
Determine if scaling down is appropriate.
Args:
current_workers: Current number of workers
idle_workers: List of idle worker IDs
Returns:
True if should scale down
"""
if current_workers <= self.min_workers:
return False
# Scale down if we have idle workers
return len(idle_workers) > 0

View File

@ -0,0 +1,168 @@
"""
Simple worker pool that consolidates functionality.
This is a simpler implementation that merges WorkerPool, ActivityTracker,
DynamicScaler, and WorkerFactory into a single class.
"""
import queue
import threading
from typing import TYPE_CHECKING, final
from configs import dify_config
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..worker import Worker
if TYPE_CHECKING:
from contextvars import Context
from flask import Flask
@final
class SimpleWorkerPool:
"""
Simple worker pool with integrated management.
This class consolidates all worker management functionality into
a single, simpler implementation without excessive abstraction.
"""
def __init__(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
scale_down_idle_time: float | None = None,
) -> None:
"""
Initialize the simple worker pool.
Args:
ready_queue: Queue of nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Seconds before scaling down idle workers
"""
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
# Scaling parameters with defaults
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
# Worker management
self._workers: list[Worker] = []
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
def start(self, initial_count: int | None = None) -> None:
"""
Start the worker pool.
Args:
initial_count: Number of workers to start with (auto-calculated if None)
"""
with self._lock:
if self._running:
return
self._running = True
# Calculate initial worker count
if initial_count is None:
node_count = len(self._graph.nodes)
if node_count < 10:
initial_count = self._min_workers
elif node_count < 50:
initial_count = min(self._min_workers + 1, self._max_workers)
else:
initial_count = min(self._min_workers + 2, self._max_workers)
# Create initial workers
for _ in range(initial_count):
self._create_worker()
def stop(self) -> None:
"""Stop all workers in the pool."""
with self._lock:
self._running = False
# Stop all workers
for worker in self._workers:
worker.stop()
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=10.0)
self._workers.clear()
def _create_worker(self) -> None:
"""Create and start a new worker."""
worker_id = self._worker_counter
self._worker_counter += 1
worker = Worker(
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,
)
worker.start()
self._workers.append(worker)
def check_and_scale(self) -> None:
"""Check and perform scaling if needed."""
with self._lock:
if not self._running:
return
current_count = len(self._workers)
queue_depth = self._ready_queue.qsize()
# Simple scaling logic
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
self._create_worker()
def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self._workers)
def get_status(self) -> dict[str, int]:
"""
Get pool status information.
Returns:
Dictionary with status information
"""
with self._lock:
return {
"total_workers": len(self._workers),
"queue_depth": self._ready_queue.qsize(),
"min_workers": self._min_workers,
"max_workers": self._max_workers,
}

View File

@ -1,75 +0,0 @@
"""
Factory for creating worker instances.
"""
import contextvars
import queue
from collections.abc import Callable
from typing import final
from flask import Flask
from core.workflow.graph import Graph
from ..worker import Worker
@final
class WorkerFactory:
"""
Factory for creating worker instances with proper context.
This encapsulates worker creation logic and ensures all workers
are created with the necessary Flask and context variable setup.
"""
def __init__(
self,
flask_app: Flask | None,
context_vars: contextvars.Context,
) -> None:
"""
Initialize the worker factory.
Args:
flask_app: Flask application context
context_vars: Context variables to propagate
"""
self.flask_app = flask_app
self.context_vars = context_vars
self._next_worker_id = 0
def create_worker(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue,
graph: Graph,
on_idle_callback: Callable[[int], None] | None = None,
on_active_callback: Callable[[int], None] | None = None,
) -> Worker:
"""
Create a new worker instance.
Args:
ready_queue: Queue of nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
on_idle_callback: Callback when worker becomes idle
on_active_callback: Callback when worker becomes active
Returns:
Configured worker instance
"""
worker_id = self._next_worker_id
self._next_worker_id += 1
return Worker(
ready_queue=ready_queue,
event_queue=event_queue,
graph=graph,
worker_id=worker_id,
flask_app=self.flask_app,
context_vars=self.context_vars,
on_idle_callback=on_idle_callback,
on_active_callback=on_active_callback,
)

View File

@ -1,147 +0,0 @@
"""
Worker pool management.
"""
import queue
import threading
from typing import final
from core.workflow.graph import Graph
from ..worker import Worker
from .activity_tracker import ActivityTracker
from .dynamic_scaler import DynamicScaler
from .worker_factory import WorkerFactory
@final
class WorkerPool:
"""
Manages a pool of worker threads for executing nodes.
This provides dynamic scaling, activity tracking, and lifecycle
management for worker threads.
"""
def __init__(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue,
graph: Graph,
worker_factory: WorkerFactory,
dynamic_scaler: DynamicScaler,
activity_tracker: ActivityTracker,
) -> None:
"""
Initialize the worker pool.
Args:
ready_queue: Queue of nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
worker_factory: Factory for creating workers
dynamic_scaler: Scaler for dynamic sizing
activity_tracker: Tracker for worker activity
"""
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.worker_factory = worker_factory
self.dynamic_scaler = dynamic_scaler
self.activity_tracker = activity_tracker
self.workers: list[Worker] = []
self._lock = threading.RLock()
self._running = False
def start(self, initial_count: int) -> None:
"""
Start the worker pool with initial workers.
Args:
initial_count: Number of workers to start with
"""
with self._lock:
if self._running:
return
self._running = True
# Create initial workers
for _ in range(initial_count):
worker = self.worker_factory.create_worker(self.ready_queue, self.event_queue, self.graph)
worker.start()
self.workers.append(worker)
def stop(self) -> None:
"""Stop all workers in the pool."""
with self._lock:
self._running = False
# Stop all workers
for worker in self.workers:
worker.stop()
# Wait for workers to finish
for worker in self.workers:
if worker.is_alive():
worker.join(timeout=10.0)
self.workers.clear()
def scale_up(self) -> None:
"""Add a worker to the pool if allowed."""
with self._lock:
if not self._running:
return
if len(self.workers) >= self.dynamic_scaler.max_workers:
return
worker = self.worker_factory.create_worker(self.ready_queue, self.event_queue, self.graph)
worker.start()
self.workers.append(worker)
def scale_down(self, worker_ids: list[int]) -> None:
"""
Remove specific workers from the pool.
Args:
worker_ids: IDs of workers to remove
"""
with self._lock:
if not self._running:
return
if len(self.workers) <= self.dynamic_scaler.min_workers:
return
workers_to_remove = [w for w in self.workers if w.worker_id in worker_ids]
for worker in workers_to_remove:
worker.stop()
self.workers.remove(worker)
if worker.is_alive():
worker.join(timeout=1.0)
def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self.workers)
def check_scaling(self, queue_depth: int, executing_count: int) -> None:
"""
Check and perform scaling if needed.
Args:
queue_depth: Current queue depth
executing_count: Number of executing nodes
"""
current_count = self.get_worker_count()
if self.dynamic_scaler.should_scale_up(current_count, queue_depth, executing_count):
self.scale_up()
idle_workers = self.activity_tracker.get_idle_workers()
if idle_workers:
self.scale_down(idle_workers)