mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 01:46:14 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
# Conflicts: # api/core/memory/token_buffer_memory.py # api/core/rag/extractor/notion_extractor.py # api/core/repositories/sqlalchemy_workflow_node_execution_repository.py # api/core/variables/variables.py # api/core/workflow/graph/graph.py # api/core/workflow/graph_engine/entities/event.py # api/services/dataset_service.py # web/app/components/app-sidebar/index.tsx # web/app/components/base/tag-management/selector.tsx # web/app/components/base/toast/index.tsx # web/app/components/datasets/create/website/index.tsx # web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx # web/app/components/workflow/header/version-history-button.tsx # web/app/components/workflow/hooks/use-inspect-vars-crud-common.ts # web/app/components/workflow/hooks/use-workflow-interactions.ts # web/app/components/workflow/panel/version-history-panel/index.tsx # web/service/base.ts
This commit is contained in:
@ -1,187 +0,0 @@
|
||||
# Graph Engine
|
||||
|
||||
Queue-based workflow execution engine for parallel graph processing.
|
||||
|
||||
## Architecture
|
||||
|
||||
The engine uses a modular architecture with specialized packages:
|
||||
|
||||
### Core Components
|
||||
|
||||
- **Domain** (`domain/`) - Core models: ExecutionContext, GraphExecution, NodeExecution
|
||||
- **Event Management** (`event_management/`) - Event handling, collection, and emission
|
||||
- **State Management** (`state_management/`) - Thread-safe state tracking for nodes and edges
|
||||
- **Error Handling** (`error_handling/`) - Strategy-based error recovery (retry, abort, fail-branch, default-value)
|
||||
- **Graph Traversal** (`graph_traversal/`) - Node readiness, edge processing, branch handling
|
||||
- **Command Processing** (`command_processing/`) - External command handling (abort, pause, resume)
|
||||
- **Worker Management** (`worker_management/`) - Dynamic worker pool with auto-scaling
|
||||
- **Orchestration** (`orchestration/`) - Main event loop and execution coordination
|
||||
|
||||
### Supporting Components
|
||||
|
||||
- **Output Registry** (`output_registry/`) - Thread-safe storage for node outputs
|
||||
- **Response Coordinator** (`response_coordinator/`) - Ordered streaming of response nodes
|
||||
- **Command Channels** (`command_channels/`) - Command transport (InMemory/Redis)
|
||||
- **Layers** (`layers/`) - Pluggable middleware for extensions
|
||||
|
||||
## Architecture Diagram
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class GraphEngine {
|
||||
+run()
|
||||
+add_layer()
|
||||
}
|
||||
|
||||
class Domain {
|
||||
ExecutionContext
|
||||
GraphExecution
|
||||
NodeExecution
|
||||
}
|
||||
|
||||
class EventManagement {
|
||||
EventHandlerRegistry
|
||||
EventCollector
|
||||
EventEmitter
|
||||
}
|
||||
|
||||
class StateManagement {
|
||||
NodeStateManager
|
||||
EdgeStateManager
|
||||
ExecutionTracker
|
||||
}
|
||||
|
||||
class WorkerManagement {
|
||||
WorkerPool
|
||||
WorkerFactory
|
||||
DynamicScaler
|
||||
ActivityTracker
|
||||
}
|
||||
|
||||
class GraphTraversal {
|
||||
NodeReadinessChecker
|
||||
EdgeProcessor
|
||||
BranchHandler
|
||||
SkipPropagator
|
||||
}
|
||||
|
||||
class Orchestration {
|
||||
Dispatcher
|
||||
ExecutionCoordinator
|
||||
}
|
||||
|
||||
class ErrorHandling {
|
||||
ErrorHandler
|
||||
RetryStrategy
|
||||
AbortStrategy
|
||||
FailBranchStrategy
|
||||
}
|
||||
|
||||
class CommandProcessing {
|
||||
CommandProcessor
|
||||
AbortCommandHandler
|
||||
}
|
||||
|
||||
class CommandChannels {
|
||||
InMemoryChannel
|
||||
RedisChannel
|
||||
}
|
||||
|
||||
class OutputRegistry {
|
||||
<<Storage>>
|
||||
Scalar Values
|
||||
Streaming Data
|
||||
}
|
||||
|
||||
class ResponseCoordinator {
|
||||
Session Management
|
||||
Path Analysis
|
||||
}
|
||||
|
||||
class Layers {
|
||||
<<Plugin>>
|
||||
DebugLoggingLayer
|
||||
}
|
||||
|
||||
GraphEngine --> Orchestration : coordinates
|
||||
GraphEngine --> Layers : extends
|
||||
|
||||
Orchestration --> EventManagement : processes events
|
||||
Orchestration --> WorkerManagement : manages scaling
|
||||
Orchestration --> CommandProcessing : checks commands
|
||||
Orchestration --> StateManagement : monitors state
|
||||
|
||||
WorkerManagement --> StateManagement : consumes ready queue
|
||||
WorkerManagement --> EventManagement : produces events
|
||||
WorkerManagement --> Domain : executes nodes
|
||||
|
||||
EventManagement --> ErrorHandling : failed events
|
||||
EventManagement --> GraphTraversal : success events
|
||||
EventManagement --> ResponseCoordinator : stream events
|
||||
EventManagement --> Layers : notifies
|
||||
|
||||
GraphTraversal --> StateManagement : updates states
|
||||
GraphTraversal --> Domain : checks graph
|
||||
|
||||
CommandProcessing --> CommandChannels : fetches commands
|
||||
CommandProcessing --> Domain : modifies execution
|
||||
|
||||
ErrorHandling --> Domain : handles failures
|
||||
|
||||
StateManagement --> Domain : tracks entities
|
||||
|
||||
ResponseCoordinator --> OutputRegistry : reads outputs
|
||||
|
||||
Domain --> OutputRegistry : writes outputs
|
||||
```
|
||||
|
||||
## Package Relationships
|
||||
|
||||
### Core Dependencies
|
||||
|
||||
- **Orchestration** acts as the central coordinator, managing all subsystems
|
||||
- **Domain** provides the core business entities used by all packages
|
||||
- **EventManagement** serves as the communication backbone between components
|
||||
- **StateManagement** maintains thread-safe state for the entire system
|
||||
|
||||
### Data Flow
|
||||
|
||||
1. **Commands** flow from CommandChannels → CommandProcessing → Domain
|
||||
1. **Events** flow from Workers → EventHandlerRegistry → State updates
|
||||
1. **Node outputs** flow from Workers → OutputRegistry → ResponseCoordinator
|
||||
1. **Ready nodes** flow from GraphTraversal → StateManagement → WorkerManagement
|
||||
|
||||
### Extension Points
|
||||
|
||||
- **Layers** observe all events for monitoring, logging, and custom logic
|
||||
- **ErrorHandling** strategies can be extended for custom failure recovery
|
||||
- **CommandChannels** can be implemented for different transport mechanisms
|
||||
|
||||
## Execution Flow
|
||||
|
||||
1. **Initialization**: GraphEngine creates all subsystems with the workflow graph
|
||||
1. **Node Discovery**: Traversal components identify ready nodes
|
||||
1. **Worker Execution**: Workers pull from ready queue and execute nodes
|
||||
1. **Event Processing**: Dispatcher routes events to appropriate handlers
|
||||
1. **State Updates**: Managers track node/edge states for next steps
|
||||
1. **Completion**: Coordinator detects when all nodes are done
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
|
||||
# Create and run engine
|
||||
engine = GraphEngine(
|
||||
tenant_id="tenant_1",
|
||||
app_id="app_1",
|
||||
workflow_id="workflow_1",
|
||||
graph=graph,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Stream execution events
|
||||
for event in engine.run():
|
||||
handle_event(event)
|
||||
```
|
||||
@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, final
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
|
||||
@ -87,7 +87,7 @@ class RedisChannel:
|
||||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict) -> GraphEngineCommand | None:
|
||||
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
||||
"""
|
||||
Deserialize a command from dictionary data.
|
||||
|
||||
|
||||
@ -5,6 +5,8 @@ Command handler implementations.
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
from .command_processor import CommandHandler
|
||||
@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
"""
|
||||
Handle an abort command.
|
||||
|
||||
@ -39,8 +39,8 @@ class CommandProcessor:
|
||||
command_channel: Channel for receiving commands
|
||||
graph_execution: Graph execution aggregate
|
||||
"""
|
||||
self.command_channel = command_channel
|
||||
self.graph_execution = graph_execution
|
||||
self._command_channel = command_channel
|
||||
self._graph_execution = graph_execution
|
||||
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
|
||||
|
||||
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
|
||||
@ -56,7 +56,7 @@ class CommandProcessor:
|
||||
def process_commands(self) -> None:
|
||||
"""Check for and process any pending commands."""
|
||||
try:
|
||||
commands = self.command_channel.fetch_commands()
|
||||
commands = self._command_channel.fetch_commands()
|
||||
for command in commands:
|
||||
self._handle_command(command)
|
||||
except Exception as e:
|
||||
@ -72,8 +72,8 @@ class CommandProcessor:
|
||||
handler = self._handlers.get(type(command))
|
||||
if handler:
|
||||
try:
|
||||
handler.handle(command, self.graph_execution)
|
||||
except Exception as e:
|
||||
handler.handle(command, self._graph_execution)
|
||||
except Exception:
|
||||
logger.exception("Error handling command %s", command.__class__.__name__)
|
||||
else:
|
||||
logger.warning("No handler registered for command: %s", command.__class__.__name__)
|
||||
|
||||
@ -32,6 +32,8 @@ class AbortStrategy:
|
||||
Returns:
|
||||
None - signals abortion
|
||||
"""
|
||||
_ = graph
|
||||
_ = retry_count
|
||||
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
|
||||
|
||||
# Return None to signal that execution should stop
|
||||
|
||||
@ -31,6 +31,7 @@ class DefaultValueStrategy:
|
||||
Returns:
|
||||
NodeRunExceptionEvent with default values
|
||||
"""
|
||||
_ = retry_count
|
||||
node = graph.nodes[event.node_id]
|
||||
|
||||
outputs = {
|
||||
|
||||
@ -31,6 +31,8 @@ class FailBranchStrategy:
|
||||
Returns:
|
||||
NodeRunExceptionEvent to continue via fail branch
|
||||
"""
|
||||
_ = graph
|
||||
_ = retry_count
|
||||
outputs = {
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
|
||||
@ -5,12 +5,10 @@ This package handles event routing, collection, and emission for
|
||||
workflow graph execution events.
|
||||
"""
|
||||
|
||||
from .event_collector import EventCollector
|
||||
from .event_emitter import EventEmitter
|
||||
from .event_handlers import EventHandlerRegistry
|
||||
from .event_handlers import EventHandler
|
||||
from .event_manager import EventManager
|
||||
|
||||
__all__ = [
|
||||
"EventCollector",
|
||||
"EventEmitter",
|
||||
"EventHandlerRegistry",
|
||||
"EventHandler",
|
||||
"EventManager",
|
||||
]
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
"""
|
||||
Event emitter for yielding events to external consumers.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from .event_collector import EventCollector
|
||||
|
||||
|
||||
@final
|
||||
class EventEmitter:
|
||||
"""
|
||||
Emits collected events as a generator for external consumption.
|
||||
|
||||
This provides a generator interface for yielding events as they're
|
||||
collected, with proper synchronization for multi-threaded access.
|
||||
"""
|
||||
|
||||
def __init__(self, event_collector: EventCollector) -> None:
|
||||
"""
|
||||
Initialize the event emitter.
|
||||
|
||||
Args:
|
||||
event_collector: The collector to emit events from
|
||||
"""
|
||||
self.event_collector = event_collector
|
||||
self._execution_complete = threading.Event()
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete to stop the generator."""
|
||||
self._execution_complete.set()
|
||||
|
||||
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generator that yields events as they're collected.
|
||||
|
||||
Yields:
|
||||
GraphEngineEvent instances as they're processed
|
||||
"""
|
||||
yielded_count = 0
|
||||
|
||||
while not self._execution_complete.is_set() or yielded_count < self.event_collector.event_count():
|
||||
# Get new events since last yield
|
||||
new_events = self.event_collector.get_new_events(yielded_count)
|
||||
|
||||
# Yield any new events
|
||||
for event in new_events:
|
||||
yield event
|
||||
yielded_count += 1
|
||||
|
||||
# Small sleep to avoid busy waiting
|
||||
if not self._execution_complete.is_set() and not new_events:
|
||||
time.sleep(0.001)
|
||||
@ -10,6 +10,7 @@ from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
@ -31,15 +32,15 @@ from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..error_handling import ErrorHandler
|
||||
from ..graph_traversal import BranchHandler, EdgeProcessor
|
||||
from ..state_management import ExecutionTracker, NodeStateManager
|
||||
from .event_collector import EventCollector
|
||||
from ..graph_traversal import EdgeProcessor
|
||||
from ..state_management import UnifiedStateManager
|
||||
from .event_manager import EventManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class EventHandlerRegistry:
|
||||
class EventHandler:
|
||||
"""
|
||||
Registry of event handlers for different event types.
|
||||
|
||||
@ -53,11 +54,9 @@ class EventHandlerRegistry:
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_execution: GraphExecution,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
event_collector: "EventCollector",
|
||||
branch_handler: "BranchHandler",
|
||||
event_collector: "EventManager",
|
||||
edge_processor: "EdgeProcessor",
|
||||
node_state_manager: "NodeStateManager",
|
||||
execution_tracker: "ExecutionTracker",
|
||||
state_manager: "UnifiedStateManager",
|
||||
error_handler: "ErrorHandler",
|
||||
) -> None:
|
||||
"""
|
||||
@ -68,11 +67,9 @@ class EventHandlerRegistry:
|
||||
graph_runtime_state: Runtime state with variable pool
|
||||
graph_execution: Graph execution aggregate
|
||||
response_coordinator: Response stream coordinator
|
||||
event_collector: Event collector for collecting events
|
||||
branch_handler: Branch handler for branch node processing
|
||||
event_collector: Event manager for collecting events
|
||||
edge_processor: Edge processor for edge traversal
|
||||
node_state_manager: Node state manager
|
||||
execution_tracker: Execution tracker
|
||||
state_manager: Unified state manager
|
||||
error_handler: Error handler
|
||||
"""
|
||||
self._graph = graph
|
||||
@ -80,10 +77,8 @@ class EventHandlerRegistry:
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._event_collector = event_collector
|
||||
self._branch_handler = branch_handler
|
||||
self._edge_processor = edge_processor
|
||||
self._node_state_manager = node_state_manager
|
||||
self._execution_tracker = execution_tracker
|
||||
self._state_manager = state_manager
|
||||
self._error_handler = error_handler
|
||||
|
||||
def handle_event(self, event: GraphNodeEventBase) -> None:
|
||||
@ -122,6 +117,7 @@ class EventHandlerRegistry:
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
),
|
||||
):
|
||||
# Iteration and loop events are collected directly
|
||||
@ -187,7 +183,7 @@ class EventHandlerRegistry:
|
||||
# Process edges and get ready nodes
|
||||
node = self._graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion(
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
@ -199,11 +195,11 @@ class EventHandlerRegistry:
|
||||
|
||||
# Enqueue ready nodes
|
||||
for node_id in ready_nodes:
|
||||
self._node_state_manager.enqueue_node(node_id)
|
||||
self._execution_tracker.add(node_id)
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
self._execution_tracker.remove(event.node_id)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
@ -232,7 +228,7 @@ class EventHandlerRegistry:
|
||||
# Abort execution
|
||||
self._graph_execution.fail(RuntimeError(event.error))
|
||||
self._event_collector.collect(event)
|
||||
self._execution_tracker.remove(event.node_id)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
@ -267,6 +263,8 @@ class EventHandlerRegistry:
|
||||
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.outputs.get("answer", "")
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
"""
|
||||
Event collector for buffering and managing events.
|
||||
Unified event manager for collecting and emitting events.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
@ -23,7 +25,7 @@ class ReadWriteLock:
|
||||
|
||||
def acquire_read(self) -> None:
|
||||
"""Acquire a read lock."""
|
||||
self._read_ready.acquire()
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers += 1
|
||||
finally:
|
||||
@ -31,7 +33,7 @@ class ReadWriteLock:
|
||||
|
||||
def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
self._read_ready.acquire()
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
@ -41,9 +43,9 @@ class ReadWriteLock:
|
||||
|
||||
def acquire_write(self) -> None:
|
||||
"""Acquire a write lock."""
|
||||
self._read_ready.acquire()
|
||||
_ = self._read_ready.acquire()
|
||||
while self._readers > 0:
|
||||
self._read_ready.wait()
|
||||
_ = self._read_ready.wait()
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release a write lock."""
|
||||
@ -89,19 +91,21 @@ class WriteLockContext:
|
||||
|
||||
|
||||
@final
|
||||
class EventCollector:
|
||||
class EventManager:
|
||||
"""
|
||||
Collects and buffers events for later retrieval.
|
||||
Unified event manager that collects, buffers, and emits events.
|
||||
|
||||
This provides thread-safe event collection with support for
|
||||
notifying layers about events as they're collected.
|
||||
This class combines event collection with event emission, providing
|
||||
thread-safe event management with support for notifying layers and
|
||||
streaming events to external consumers.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the event collector."""
|
||||
"""Initialize the event manager."""
|
||||
self._events: list[GraphEngineEvent] = []
|
||||
self._lock = ReadWriteLock()
|
||||
self._layers: list[Layer] = []
|
||||
self._execution_complete = threading.Event()
|
||||
|
||||
def set_layers(self, layers: list[Layer]) -> None:
|
||||
"""
|
||||
@ -123,17 +127,7 @@ class EventCollector:
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
def get_events(self) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
Get all collected events.
|
||||
|
||||
Returns:
|
||||
List of collected events
|
||||
"""
|
||||
with self._lock.read_lock():
|
||||
return list(self._events)
|
||||
|
||||
def get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
Get new events starting from a specific index.
|
||||
|
||||
@ -146,7 +140,7 @@ class EventCollector:
|
||||
with self._lock.read_lock():
|
||||
return list(self._events[start_index:])
|
||||
|
||||
def event_count(self) -> int:
|
||||
def _event_count(self) -> int:
|
||||
"""
|
||||
Get the current count of collected events.
|
||||
|
||||
@ -156,10 +150,31 @@ class EventCollector:
|
||||
with self._lock.read_lock():
|
||||
return len(self._events)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all collected events."""
|
||||
with self._lock.write_lock():
|
||||
self._events.clear()
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete to stop the event emission generator."""
|
||||
self._execution_complete.set()
|
||||
|
||||
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generator that yields events as they're collected.
|
||||
|
||||
Yields:
|
||||
GraphEngineEvent instances as they're processed
|
||||
"""
|
||||
yielded_count = 0
|
||||
|
||||
while not self._execution_complete.is_set() or yielded_count < self._event_count():
|
||||
# Get new events since last yield
|
||||
new_events = self._get_new_events(yielded_count)
|
||||
|
||||
# Yield any new events
|
||||
for event in new_events:
|
||||
yield event
|
||||
yielded_count += 1
|
||||
|
||||
# Small sleep to avoid busy waiting
|
||||
if not self._execution_complete.is_set() and not new_events:
|
||||
time.sleep(0.001)
|
||||
|
||||
def _notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
@ -13,7 +13,6 @@ from typing import final
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
@ -32,15 +31,14 @@ from .command_processing import AbortCommandHandler, CommandProcessor
|
||||
from .domain import ExecutionContext, GraphExecution
|
||||
from .entities.commands import AbortCommand
|
||||
from .error_handling import ErrorHandler
|
||||
from .event_management import EventCollector, EventEmitter, EventHandlerRegistry
|
||||
from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import Layer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .output_registry import OutputRegistry
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .response_coordinator import ResponseStreamCoordinator
|
||||
from .state_management import EdgeStateManager, ExecutionTracker, NodeStateManager
|
||||
from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, WorkerPool
|
||||
from .state_management import UnifiedStateManager
|
||||
from .worker_management import SimpleWorkerPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -74,10 +72,11 @@ class GraphEngine:
|
||||
scale_up_threshold: int | None = None,
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with separated concerns."""
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
# Create domain models
|
||||
self.execution_context = ExecutionContext(
|
||||
# === Domain Models ===
|
||||
# Execution context encapsulates workflow execution metadata
|
||||
self._execution_context = ExecutionContext(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
@ -89,167 +88,149 @@ class GraphEngine:
|
||||
max_execution_time=max_execution_time,
|
||||
)
|
||||
|
||||
self.graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
|
||||
# Store core dependencies
|
||||
self.graph = graph
|
||||
self.graph_config = graph_config
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.command_channel = command_channel
|
||||
# === Core Dependencies ===
|
||||
# Graph structure and configuration
|
||||
self._graph = graph
|
||||
self._graph_config = graph_config
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._command_channel = command_channel
|
||||
|
||||
# Store worker management parameters
|
||||
# === Worker Management Parameters ===
|
||||
# Parameters for dynamic worker pool scaling
|
||||
self._min_workers = min_workers
|
||||
self._max_workers = max_workers
|
||||
self._scale_up_threshold = scale_up_threshold
|
||||
self._scale_down_idle_time = scale_down_idle_time
|
||||
|
||||
# Initialize queues
|
||||
self.ready_queue: queue.Queue[str] = queue.Queue()
|
||||
self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
# === Execution Queues ===
|
||||
# Queue for nodes ready to execute
|
||||
self._ready_queue: queue.Queue[str] = queue.Queue()
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
# Initialize subsystems
|
||||
self._initialize_subsystems()
|
||||
# === State Management ===
|
||||
# Unified state manager handles all node state transitions and queue operations
|
||||
self._state_manager = UnifiedStateManager(self._graph, self._ready_queue)
|
||||
|
||||
# Layers for extensibility
|
||||
self._layers: list[Layer] = []
|
||||
|
||||
# Validate graph state consistency
|
||||
self._validate_graph_state_consistency()
|
||||
|
||||
def _initialize_subsystems(self) -> None:
|
||||
"""Initialize all subsystems with proper dependency injection."""
|
||||
|
||||
# State management
|
||||
self.node_state_manager = NodeStateManager(self.graph, self.ready_queue)
|
||||
self.edge_state_manager = EdgeStateManager(self.graph)
|
||||
self.execution_tracker = ExecutionTracker()
|
||||
|
||||
# Response coordination
|
||||
self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool)
|
||||
self.response_coordinator = ResponseStreamCoordinator(registry=self.output_registry, graph=self.graph)
|
||||
|
||||
# Event management
|
||||
self.event_collector = EventCollector()
|
||||
self.event_emitter = EventEmitter(self.event_collector)
|
||||
|
||||
# Error handling
|
||||
self.error_handler = ErrorHandler(self.graph, self.graph_execution)
|
||||
|
||||
# Graph traversal
|
||||
self.node_readiness_checker = NodeReadinessChecker(self.graph)
|
||||
self.edge_processor = EdgeProcessor(
|
||||
graph=self.graph,
|
||||
edge_state_manager=self.edge_state_manager,
|
||||
node_state_manager=self.node_state_manager,
|
||||
response_coordinator=self.response_coordinator,
|
||||
)
|
||||
self.skip_propagator = SkipPropagator(
|
||||
graph=self.graph,
|
||||
edge_state_manager=self.edge_state_manager,
|
||||
node_state_manager=self.node_state_manager,
|
||||
)
|
||||
self.branch_handler = BranchHandler(
|
||||
graph=self.graph,
|
||||
edge_processor=self.edge_processor,
|
||||
skip_propagator=self.skip_propagator,
|
||||
edge_state_manager=self.edge_state_manager,
|
||||
# === Response Coordination ===
|
||||
# Coordinates response streaming from response nodes
|
||||
self._response_coordinator = ResponseStreamCoordinator(
|
||||
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
|
||||
)
|
||||
|
||||
# Event handler registry with all dependencies
|
||||
self.event_handler_registry = EventHandlerRegistry(
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
graph_execution=self.graph_execution,
|
||||
response_coordinator=self.response_coordinator,
|
||||
event_collector=self.event_collector,
|
||||
branch_handler=self.branch_handler,
|
||||
edge_processor=self.edge_processor,
|
||||
node_state_manager=self.node_state_manager,
|
||||
execution_tracker=self.execution_tracker,
|
||||
error_handler=self.error_handler,
|
||||
# === Event Management ===
|
||||
# Event manager handles both collection and emission of events
|
||||
self._event_manager = EventManager()
|
||||
|
||||
# === Error Handling ===
|
||||
# Centralized error handler for graph execution errors
|
||||
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
|
||||
|
||||
# === Graph Traversal Components ===
|
||||
# Propagates skip status through the graph when conditions aren't met
|
||||
self._skip_propagator = SkipPropagator(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
)
|
||||
|
||||
# Command processing
|
||||
self.command_processor = CommandProcessor(
|
||||
command_channel=self.command_channel,
|
||||
graph_execution=self.graph_execution,
|
||||
)
|
||||
self._setup_command_handlers()
|
||||
|
||||
# Worker management
|
||||
self._setup_worker_management()
|
||||
|
||||
# Orchestration
|
||||
self.execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self.graph_execution,
|
||||
node_state_manager=self.node_state_manager,
|
||||
execution_tracker=self.execution_tracker,
|
||||
event_handler=self.event_handler_registry,
|
||||
event_collector=self.event_collector,
|
||||
command_processor=self.command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
# Processes edges to determine next nodes after execution
|
||||
# Also handles conditional branching and route selection
|
||||
self._edge_processor = EdgeProcessor(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
response_coordinator=self._response_coordinator,
|
||||
skip_propagator=self._skip_propagator,
|
||||
)
|
||||
|
||||
self.dispatcher = Dispatcher(
|
||||
event_queue=self.event_queue,
|
||||
event_handler=self.event_handler_registry,
|
||||
event_collector=self.event_collector,
|
||||
execution_coordinator=self.execution_coordinator,
|
||||
max_execution_time=self.execution_context.max_execution_time,
|
||||
event_emitter=self.event_emitter,
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandler(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_manager,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
def _setup_command_handlers(self) -> None:
|
||||
"""Configure command handlers."""
|
||||
# Create handler instance that follows the protocol
|
||||
# === Command Processing ===
|
||||
# Processes external commands (e.g., abort requests)
|
||||
self._command_processor = CommandProcessor(
|
||||
command_channel=self._command_channel,
|
||||
graph_execution=self._graph_execution,
|
||||
)
|
||||
|
||||
# Register abort command handler
|
||||
abort_handler = AbortCommandHandler()
|
||||
self.command_processor.register_handler(
|
||||
self._command_processor.register_handler(
|
||||
AbortCommand,
|
||||
abort_handler,
|
||||
)
|
||||
|
||||
def _setup_worker_management(self) -> None:
|
||||
"""Initialize worker management subsystem."""
|
||||
# Capture context for workers
|
||||
# === Worker Pool Setup ===
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
app = current_app._get_current_object() # type: ignore
|
||||
if isinstance(app, Flask):
|
||||
flask_app = app
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Capture context variables for worker threads
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Create worker management components
|
||||
self._activity_tracker = ActivityTracker()
|
||||
self._dynamic_scaler = DynamicScaler(
|
||||
min_workers=(self._min_workers if self._min_workers is not None else dify_config.GRAPH_ENGINE_MIN_WORKERS),
|
||||
max_workers=(self._max_workers if self._max_workers is not None else dify_config.GRAPH_ENGINE_MAX_WORKERS),
|
||||
scale_up_threshold=(
|
||||
self._scale_up_threshold
|
||||
if self._scale_up_threshold is not None
|
||||
else dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
|
||||
),
|
||||
scale_down_idle_time=(
|
||||
self._scale_down_idle_time
|
||||
if self._scale_down_idle_time is not None
|
||||
else dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
||||
),
|
||||
# Create worker pool for parallel node execution
|
||||
self._worker_pool = SimpleWorkerPool(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
min_workers=self._min_workers,
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
scale_down_idle_time=self._scale_down_idle_time,
|
||||
)
|
||||
self._worker_factory = WorkerFactory(flask_app, context_vars)
|
||||
|
||||
self._worker_pool = WorkerPool(
|
||||
ready_queue=self.ready_queue,
|
||||
event_queue=self.event_queue,
|
||||
graph=self.graph,
|
||||
worker_factory=self._worker_factory,
|
||||
dynamic_scaler=self._dynamic_scaler,
|
||||
activity_tracker=self._activity_tracker,
|
||||
# === Orchestration ===
|
||||
# Coordinates the overall execution lifecycle
|
||||
self._execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self._graph_execution,
|
||||
state_manager=self._state_manager,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_manager,
|
||||
command_processor=self._command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
# Dispatches events and manages execution flow
|
||||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_manager,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
max_execution_time=self._execution_context.max_execution_time,
|
||||
event_emitter=self._event_manager,
|
||||
)
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[Layer] = []
|
||||
|
||||
# === Validation ===
|
||||
# Ensure all nodes share the same GraphRuntimeState instance
|
||||
self._validate_graph_state_consistency()
|
||||
|
||||
def _validate_graph_state_consistency(self) -> None:
|
||||
"""Validate that all nodes share the same GraphRuntimeState."""
|
||||
expected_state_id = id(self.graph_runtime_state)
|
||||
for node in self.graph.nodes.values():
|
||||
expected_state_id = id(self._graph_runtime_state)
|
||||
for node in self._graph.nodes.values():
|
||||
if id(node.graph_runtime_state) != expected_state_id:
|
||||
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
||||
|
||||
@ -270,7 +251,7 @@ class GraphEngine:
|
||||
self._initialize_layers()
|
||||
|
||||
# Start execution
|
||||
self.graph_execution.start()
|
||||
self._graph_execution.start()
|
||||
start_event = GraphRunStartedEvent()
|
||||
yield start_event
|
||||
|
||||
@ -278,23 +259,23 @@ class GraphEngine:
|
||||
self._start_execution()
|
||||
|
||||
# Yield events as they occur
|
||||
yield from self.event_emitter.emit_events()
|
||||
yield from self._event_manager.emit_events()
|
||||
|
||||
# Handle completion
|
||||
if self.graph_execution.aborted:
|
||||
if self._graph_execution.aborted:
|
||||
abort_reason = "Workflow execution aborted by user command"
|
||||
if self.graph_execution.error:
|
||||
abort_reason = str(self.graph_execution.error)
|
||||
if self._graph_execution.error:
|
||||
abort_reason = str(self._graph_execution.error)
|
||||
yield GraphRunAbortedEvent(
|
||||
reason=abort_reason,
|
||||
outputs=self.graph_runtime_state.outputs,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
elif self.graph_execution.has_error:
|
||||
if self.graph_execution.error:
|
||||
raise self.graph_execution.error
|
||||
elif self._graph_execution.has_error:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=self.graph_runtime_state.outputs,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -306,10 +287,10 @@ class GraphEngine:
|
||||
|
||||
def _initialize_layers(self) -> None:
|
||||
"""Initialize layers with context."""
|
||||
self.event_collector.set_layers(self._layers)
|
||||
self._event_manager.set_layers(self._layers)
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.initialize(self.graph_runtime_state, self.command_channel)
|
||||
layer.initialize(self._graph_runtime_state, self._command_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
|
||||
|
||||
@ -320,28 +301,25 @@ class GraphEngine:
|
||||
|
||||
def _start_execution(self) -> None:
|
||||
"""Start execution subsystems."""
|
||||
# Calculate initial worker count
|
||||
initial_workers = self._dynamic_scaler.calculate_initial_workers(self.graph)
|
||||
|
||||
# Start worker pool
|
||||
self._worker_pool.start(initial_workers)
|
||||
# Start worker pool (it calculates initial workers internally)
|
||||
self._worker_pool.start()
|
||||
|
||||
# Register response nodes
|
||||
for node in self.graph.nodes.values():
|
||||
for node in self._graph.nodes.values():
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self.response_coordinator.register(node.id)
|
||||
self._response_coordinator.register(node.id)
|
||||
|
||||
# Enqueue root node
|
||||
root_node = self.graph.root_node
|
||||
self.node_state_manager.enqueue_node(root_node.id)
|
||||
self.execution_tracker.add(root_node.id)
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
|
||||
# Start dispatcher
|
||||
self.dispatcher.start()
|
||||
self._dispatcher.start()
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self.dispatcher.stop()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
@ -350,6 +328,12 @@ class GraphEngine:
|
||||
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self.graph_execution.error)
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
|
||||
|
||||
# Public property accessors for attributes that need external access
|
||||
@property
|
||||
def graph_runtime_state(self) -> GraphRuntimeState:
|
||||
"""Get the graph runtime state."""
|
||||
return self._graph_runtime_state
|
||||
|
||||
@ -5,14 +5,10 @@ This package handles graph navigation, edge processing,
|
||||
and skip propagation logic.
|
||||
"""
|
||||
|
||||
from .branch_handler import BranchHandler
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .node_readiness import NodeReadinessChecker
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
__all__ = [
|
||||
"BranchHandler",
|
||||
"EdgeProcessor",
|
||||
"NodeReadinessChecker",
|
||||
"SkipPropagator",
|
||||
]
|
||||
|
||||
@ -1,87 +0,0 @@
|
||||
"""
|
||||
Branch node handling for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
|
||||
|
||||
from ..state_management import EdgeStateManager
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
class BranchHandler:
|
||||
"""
|
||||
Handles branch node logic during graph traversal.
|
||||
|
||||
Branch nodes select one of multiple paths based on conditions,
|
||||
requiring special handling for edge selection and skip propagation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_processor: EdgeProcessor,
|
||||
skip_propagator: SkipPropagator,
|
||||
edge_state_manager: EdgeStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the branch handler.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
edge_processor: Processor for edges
|
||||
skip_propagator: Propagator for skip states
|
||||
edge_state_manager: Manager for edge states
|
||||
"""
|
||||
self.graph = graph
|
||||
self.edge_processor = edge_processor
|
||||
self.skip_propagator = skip_propagator
|
||||
self.edge_state_manager = edge_state_manager
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected branch
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no branch was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self.skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.edge_processor.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
Validate that a branch selection is valid.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle to validate
|
||||
|
||||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
@ -3,14 +3,17 @@ Edge processing logic for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Edge, Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
from ..state_management import UnifiedStateManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
@ -19,29 +22,30 @@ class EdgeProcessor:
|
||||
Processes edges during graph execution.
|
||||
|
||||
This handles marking edges as taken or skipped, notifying
|
||||
the response coordinator, and triggering downstream node execution.
|
||||
the response coordinator, triggering downstream node execution,
|
||||
and managing branch node logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_state_manager: EdgeStateManager,
|
||||
node_state_manager: NodeStateManager,
|
||||
state_manager: UnifiedStateManager,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
skip_propagator: "SkipPropagator",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the edge processor.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
edge_state_manager: Manager for edge states
|
||||
node_state_manager: Manager for node states
|
||||
state_manager: Unified state manager
|
||||
response_coordinator: Response stream coordinator
|
||||
skip_propagator: Propagator for skip states
|
||||
"""
|
||||
self.graph = graph
|
||||
self.edge_state_manager = edge_state_manager
|
||||
self.node_state_manager = node_state_manager
|
||||
self.response_coordinator = response_coordinator
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
self._response_coordinator = response_coordinator
|
||||
self._skip_propagator = skip_propagator
|
||||
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
@ -56,7 +60,7 @@ class EdgeProcessor:
|
||||
Returns:
|
||||
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
|
||||
"""
|
||||
node = self.graph.nodes[node_id]
|
||||
node = self._graph.nodes[node_id]
|
||||
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
return self._process_branch_node_edges(node_id, selected_handle)
|
||||
@ -75,7 +79,7 @@ class EdgeProcessor:
|
||||
"""
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
@ -107,7 +111,7 @@ class EdgeProcessor:
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Process unselected edges first (mark as skipped)
|
||||
for edge in unselected_edges:
|
||||
@ -132,14 +136,14 @@ class EdgeProcessor:
|
||||
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
|
||||
"""
|
||||
# Mark edge as taken
|
||||
self.edge_state_manager.mark_edge_taken(edge.id)
|
||||
self._state_manager.mark_edge_taken(edge.id)
|
||||
|
||||
# Notify response coordinator and get streaming events
|
||||
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
|
||||
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes: list[str] = []
|
||||
if self.node_state_manager.is_node_ready(edge.head):
|
||||
if self._state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, streaming_events
|
||||
@ -151,4 +155,47 @@ class EdgeProcessor:
|
||||
Args:
|
||||
edge: The edge to skip
|
||||
"""
|
||||
self.edge_state_manager.mark_edge_skipped(edge.id)
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected branch
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no branch was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self._skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
Validate that a branch selection is valid.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle to validate
|
||||
|
||||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
|
||||
@ -1,86 +0,0 @@
|
||||
"""
|
||||
Node readiness checking for execution.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class NodeReadinessChecker:
|
||||
"""
|
||||
Checks if nodes are ready for execution based on their dependencies.
|
||||
|
||||
A node is ready when its dependencies (incoming edges) have been
|
||||
satisfied according to the graph's execution rules.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph) -> None:
|
||||
"""
|
||||
Initialize the readiness checker.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
"""
|
||||
self.graph = graph
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when:
|
||||
- It has no incoming edges (root or isolated node), OR
|
||||
- At least one incoming edge is TAKEN and none are UNKNOWN
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
incoming_edges = self.graph.get_incoming_edges(node_id)
|
||||
|
||||
# No dependencies means always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# Check edge states
|
||||
has_unknown = False
|
||||
has_taken = False
|
||||
|
||||
for edge in incoming_edges:
|
||||
if edge.state == NodeState.UNKNOWN:
|
||||
has_unknown = True
|
||||
break
|
||||
elif edge.state == NodeState.TAKEN:
|
||||
has_taken = True
|
||||
|
||||
# Not ready if any dependency is still unknown
|
||||
if has_unknown:
|
||||
return False
|
||||
|
||||
# Ready if at least one path is taken
|
||||
return has_taken
|
||||
|
||||
def get_ready_downstream_nodes(self, from_node_id: str) -> list[str]:
|
||||
"""
|
||||
Get all downstream nodes that are ready after a node completes.
|
||||
|
||||
Args:
|
||||
from_node_id: The ID of the completed node
|
||||
|
||||
Returns:
|
||||
List of node IDs that are now ready
|
||||
"""
|
||||
ready_nodes: list[str] = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.state == NodeState.TAKEN:
|
||||
downstream_node_id = edge.head
|
||||
if self.is_node_ready(downstream_node_id):
|
||||
ready_nodes.append(downstream_node_id)
|
||||
|
||||
return ready_nodes
|
||||
@ -7,7 +7,7 @@ from typing import final
|
||||
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
from ..state_management import UnifiedStateManager
|
||||
|
||||
|
||||
@final
|
||||
@ -22,20 +22,17 @@ class SkipPropagator:
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_state_manager: EdgeStateManager,
|
||||
node_state_manager: NodeStateManager,
|
||||
state_manager: UnifiedStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the skip propagator.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
edge_state_manager: Manager for edge states
|
||||
node_state_manager: Manager for node states
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self.graph = graph
|
||||
self.edge_state_manager = edge_state_manager
|
||||
self.node_state_manager = node_state_manager
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
|
||||
def propagate_skip_from_edge(self, edge_id: str) -> None:
|
||||
"""
|
||||
@ -49,11 +46,11 @@ class SkipPropagator:
|
||||
Args:
|
||||
edge_id: The ID of the skipped edge to start from
|
||||
"""
|
||||
downstream_node_id = self.graph.edges[edge_id].head
|
||||
incoming_edges = self.graph.get_incoming_edges(downstream_node_id)
|
||||
downstream_node_id = self._graph.edges[edge_id].head
|
||||
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
|
||||
|
||||
# Analyze edge states
|
||||
edge_states = self.edge_state_manager.analyze_edge_states(incoming_edges)
|
||||
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
|
||||
|
||||
# Stop if there are unknown edges (not yet processed)
|
||||
if edge_states["has_unknown"]:
|
||||
@ -62,7 +59,7 @@ class SkipPropagator:
|
||||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self.node_state_manager.enqueue_node(downstream_node_id)
|
||||
self._state_manager.enqueue_node(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
@ -77,12 +74,12 @@ class SkipPropagator:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
# Mark node as skipped
|
||||
self.node_state_manager.mark_node_skipped(node_id)
|
||||
self._state_manager.mark_node_skipped(node_id)
|
||||
|
||||
# Mark all outgoing edges as skipped and propagate
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
for edge in outgoing_edges:
|
||||
self.edge_state_manager.mark_edge_skipped(edge.id)
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
@ -94,5 +91,5 @@ class SkipPropagator:
|
||||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
self.edge_state_manager.mark_edge_skipped(edge.id)
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
@ -9,6 +9,8 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
@ -93,13 +95,14 @@ class DebugLoggingLayer(Layer):
|
||||
if not data:
|
||||
return "{}"
|
||||
|
||||
formatted_items = []
|
||||
formatted_items: list[str] = []
|
||||
for key, value in data.items():
|
||||
formatted_value = self._truncate_value(value)
|
||||
formatted_items.append(f" {key}: {formatted_value}")
|
||||
|
||||
return "{\n" + ",\n".join(formatted_items) + "\n}"
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Log graph execution start."""
|
||||
self.logger.info("=" * 80)
|
||||
@ -112,7 +115,7 @@ class DebugLoggingLayer(Layer):
|
||||
|
||||
# Log inputs if available
|
||||
if self.graph_runtime_state.variable_pool:
|
||||
initial_vars = {}
|
||||
initial_vars: dict[str, Any] = {}
|
||||
# Access the variable dictionary directly
|
||||
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
|
||||
for var_key, var in variables.items():
|
||||
@ -121,6 +124,7 @@ class DebugLoggingLayer(Layer):
|
||||
if initial_vars:
|
||||
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""Log individual events based on their type."""
|
||||
event_class = event.__class__.__name__
|
||||
@ -222,6 +226,7 @@ class DebugLoggingLayer(Layer):
|
||||
# Log unknown events at debug level
|
||||
self.logger.debug("Event: %s", event_class)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Log graph execution end with summary statistics."""
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
@ -13,6 +13,8 @@ import time
|
||||
from enum import Enum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers import Layer
|
||||
from core.workflow.graph_events import (
|
||||
@ -63,6 +65,7 @@ class ExecutionLimitsLayer(Layer):
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False # Track if abort command has been sent
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Called when graph execution starts."""
|
||||
self.start_time = time.time()
|
||||
@ -73,6 +76,7 @@ class ExecutionLimitsLayer(Layer):
|
||||
|
||||
self.logger.debug("Execution limits monitoring started")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
@ -95,6 +99,7 @@ class ExecutionLimitsLayer(Layer):
|
||||
if self._reached_time_limitation():
|
||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._execution_started and not self._execution_ended:
|
||||
|
||||
@ -10,11 +10,11 @@ from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.graph_events.base import GraphNodeEventBase
|
||||
|
||||
from ..event_management import EventCollector, EventEmitter
|
||||
from ..event_management import EventManager
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandlerRegistry
|
||||
from ..event_management import EventHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -31,11 +31,11 @@ class Dispatcher:
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandlerRegistry",
|
||||
event_collector: EventCollector,
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
max_execution_time: int,
|
||||
event_emitter: EventEmitter | None = None,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dispatcher.
|
||||
@ -43,17 +43,17 @@ class Dispatcher:
|
||||
Args:
|
||||
event_queue: Queue of events from workers
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event collector for collecting unhandled events
|
||||
event_collector: Event manager for collecting unhandled events
|
||||
execution_coordinator: Coordinator for execution flow
|
||||
max_execution_time: Maximum execution time in seconds
|
||||
event_emitter: Optional event emitter to signal completion
|
||||
event_emitter: Optional event manager to signal completion
|
||||
"""
|
||||
self.event_queue = event_queue
|
||||
self.event_handler = event_handler
|
||||
self.event_collector = event_collector
|
||||
self.execution_coordinator = execution_coordinator
|
||||
self.max_execution_time = max_execution_time
|
||||
self.event_emitter = event_emitter
|
||||
self._event_queue = event_queue
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._execution_coordinator = execution_coordinator
|
||||
self._max_execution_time = max_execution_time
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
@ -80,28 +80,28 @@ class Dispatcher:
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
# Check for commands
|
||||
self.execution_coordinator.check_commands()
|
||||
self._execution_coordinator.check_commands()
|
||||
|
||||
# Check for scaling
|
||||
self.execution_coordinator.check_scaling()
|
||||
self._execution_coordinator.check_scaling()
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self.event_queue.get(timeout=0.1)
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self.event_handler.handle_event(event)
|
||||
self.event_queue.task_done()
|
||||
self._event_handler.handle_event(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Check if execution is complete
|
||||
if self.execution_coordinator.is_execution_complete():
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
self.execution_coordinator.mark_failed(e)
|
||||
self._execution_coordinator.mark_failed(e)
|
||||
|
||||
finally:
|
||||
self.execution_coordinator.mark_complete()
|
||||
self._execution_coordinator.mark_complete()
|
||||
# Signal the event emitter that execution is complete
|
||||
if self.event_emitter:
|
||||
self.event_emitter.mark_complete()
|
||||
if self._event_emitter:
|
||||
self._event_emitter.mark_complete()
|
||||
|
||||
@ -6,12 +6,12 @@ from typing import TYPE_CHECKING, final
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..event_management import EventCollector
|
||||
from ..state_management import ExecutionTracker, NodeStateManager
|
||||
from ..worker_management import WorkerPool
|
||||
from ..event_management import EventManager
|
||||
from ..state_management import UnifiedStateManager
|
||||
from ..worker_management import SimpleWorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandlerRegistry
|
||||
from ..event_management import EventHandler
|
||||
|
||||
|
||||
@final
|
||||
@ -26,42 +26,37 @@ class ExecutionCoordinator:
|
||||
def __init__(
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
node_state_manager: NodeStateManager,
|
||||
execution_tracker: ExecutionTracker,
|
||||
event_handler: "EventHandlerRegistry",
|
||||
event_collector: EventCollector,
|
||||
state_manager: UnifiedStateManager,
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
command_processor: CommandProcessor,
|
||||
worker_pool: WorkerPool,
|
||||
worker_pool: SimpleWorkerPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the execution coordinator.
|
||||
|
||||
Args:
|
||||
graph_execution: Graph execution aggregate
|
||||
node_state_manager: Manager for node states
|
||||
execution_tracker: Tracker for executing nodes
|
||||
state_manager: Unified state manager
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event collector for collecting events
|
||||
event_collector: Event manager for collecting events
|
||||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self.graph_execution = graph_execution
|
||||
self.node_state_manager = node_state_manager
|
||||
self.execution_tracker = execution_tracker
|
||||
self.event_handler = event_handler
|
||||
self.event_collector = event_collector
|
||||
self.command_processor = command_processor
|
||||
self.worker_pool = worker_pool
|
||||
self._graph_execution = graph_execution
|
||||
self._state_manager = state_manager
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._command_processor = command_processor
|
||||
self._worker_pool = worker_pool
|
||||
|
||||
def check_commands(self) -> None:
|
||||
"""Process any pending commands."""
|
||||
self.command_processor.process_commands()
|
||||
self._command_processor.process_commands()
|
||||
|
||||
def check_scaling(self) -> None:
|
||||
"""Check and perform worker scaling if needed."""
|
||||
queue_depth = self.node_state_manager.ready_queue.qsize()
|
||||
executing_count = self.execution_tracker.count()
|
||||
self.worker_pool.check_scaling(queue_depth, executing_count)
|
||||
self._worker_pool.check_and_scale()
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
@ -71,16 +66,16 @@ class ExecutionCoordinator:
|
||||
True if execution is complete
|
||||
"""
|
||||
# Check if aborted or failed
|
||||
if self.graph_execution.aborted or self.graph_execution.has_error:
|
||||
if self._graph_execution.aborted or self._graph_execution.has_error:
|
||||
return True
|
||||
|
||||
# Complete if no work remains
|
||||
return self.node_state_manager.ready_queue.empty() and self.execution_tracker.is_empty()
|
||||
return self._state_manager.is_execution_complete()
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
if not self.graph_execution.completed:
|
||||
self.graph_execution.complete()
|
||||
if not self._graph_execution.completed:
|
||||
self._graph_execution.complete()
|
||||
|
||||
def mark_failed(self, error: Exception) -> None:
|
||||
"""
|
||||
@ -89,4 +84,4 @@ class ExecutionCoordinator:
|
||||
Args:
|
||||
error: The error that caused failure
|
||||
"""
|
||||
self.graph_execution.fail(error)
|
||||
self._graph_execution.fail(error)
|
||||
|
||||
@ -1,10 +0,0 @@
|
||||
"""
|
||||
OutputRegistry - Thread-safe storage for node outputs (streams and scalars)
|
||||
|
||||
This component provides thread-safe storage and retrieval of node outputs,
|
||||
supporting both scalar values and streaming chunks with proper state management.
|
||||
"""
|
||||
|
||||
from .registry import OutputRegistry
|
||||
|
||||
__all__ = ["OutputRegistry"]
|
||||
@ -1,146 +0,0 @@
|
||||
"""
|
||||
Main OutputRegistry implementation.
|
||||
|
||||
This module contains the public OutputRegistry class that provides
|
||||
thread-safe storage for node outputs.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import TYPE_CHECKING, Union, final
|
||||
|
||||
from core.variables import Segment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from .stream import Stream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
@final
|
||||
class OutputRegistry:
|
||||
"""
|
||||
Thread-safe registry for storing and retrieving node outputs.
|
||||
|
||||
Supports both scalar values and streaming chunks with proper state management.
|
||||
All operations are thread-safe using internal locking.
|
||||
"""
|
||||
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
"""Initialize empty registry with thread-safe storage."""
|
||||
self._lock = RLock()
|
||||
self._scalars = variable_pool
|
||||
self._streams: dict[tuple, Stream] = {}
|
||||
|
||||
def _selector_to_key(self, selector: Sequence[str]) -> tuple:
|
||||
"""Convert selector list to tuple key for internal storage."""
|
||||
return tuple(selector)
|
||||
|
||||
def set_scalar(self, selector: Sequence[str], value: Union[str, int, float, bool, dict, list]) -> None:
|
||||
"""
|
||||
Set a scalar value for the given selector.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the output location
|
||||
value: The scalar value to store
|
||||
"""
|
||||
with self._lock:
|
||||
self._scalars.add(selector, value)
|
||||
|
||||
def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
|
||||
"""
|
||||
Get a scalar value for the given selector.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the output location
|
||||
|
||||
Returns:
|
||||
The stored Variable object, or None if not found
|
||||
"""
|
||||
with self._lock:
|
||||
return self._scalars.get(selector)
|
||||
|
||||
def append_chunk(self, selector: Sequence[str], event: "NodeRunStreamChunkEvent") -> None:
|
||||
"""
|
||||
Append a NodeRunStreamChunkEvent to the stream for the given selector.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
self._streams[key] = Stream()
|
||||
|
||||
try:
|
||||
self._streams[key].append(event)
|
||||
except ValueError:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
return None
|
||||
|
||||
return self._streams[key].pop_next()
|
||||
|
||||
def has_unread(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
return False
|
||||
|
||||
return self._streams[key].has_unread()
|
||||
|
||||
def close_stream(self, selector: Sequence[str]) -> None:
|
||||
"""
|
||||
Mark a stream as closed (no more chunks can be appended).
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
self._streams[key] = Stream()
|
||||
self._streams[key].close()
|
||||
|
||||
def stream_closed(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if a stream is closed.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if the stream is closed, False otherwise
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
return False
|
||||
return self._streams[key].is_closed
|
||||
@ -1,70 +0,0 @@
|
||||
"""
|
||||
Internal stream implementation for OutputRegistry.
|
||||
|
||||
This module contains the private Stream class used internally by OutputRegistry
|
||||
to manage streaming data chunks.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
@final
|
||||
class Stream:
|
||||
"""
|
||||
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.
|
||||
|
||||
This class encapsulates stream-specific data and operations,
|
||||
including event storage, read position tracking, and closed state.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an empty stream."""
|
||||
self.events: list[NodeRunStreamChunkEvent] = []
|
||||
self.read_position: int = 0
|
||||
self.is_closed: bool = False
|
||||
|
||||
def append(self, event: "NodeRunStreamChunkEvent") -> None:
|
||||
"""
|
||||
Append a NodeRunStreamChunkEvent to the stream.
|
||||
|
||||
Args:
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
if self.is_closed:
|
||||
raise ValueError("Cannot append to a closed stream")
|
||||
self.events.append(event)
|
||||
|
||||
def pop_next(self) -> "NodeRunStreamChunkEvent | None":
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
if self.read_position >= len(self.events):
|
||||
return None
|
||||
|
||||
event = self.events[self.read_position]
|
||||
self.read_position += 1
|
||||
return event
|
||||
|
||||
def has_unread(self) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
return self.read_position < len(self.events)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Mark the stream as closed (no more chunks can be appended)."""
|
||||
self.is_closed = True
|
||||
@ -12,12 +12,12 @@ from threading import RLock
|
||||
from typing import TypeAlias, final
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
|
||||
from ..output_registry import OutputRegistry
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
|
||||
@ -36,19 +36,24 @@ class ResponseStreamCoordinator:
|
||||
Ensures ordered streaming of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
def __init__(self, registry: OutputRegistry, graph: "Graph") -> None:
|
||||
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
|
||||
"""
|
||||
Initialize coordinator with output registry.
|
||||
Initialize coordinator with variable pool.
|
||||
|
||||
Args:
|
||||
registry: OutputRegistry instance for accessing node outputs
|
||||
variable_pool: VariablePool instance for accessing node variables
|
||||
graph: Graph instance for looking up node information
|
||||
"""
|
||||
self.registry = registry
|
||||
self.graph = graph
|
||||
self.active_session: ResponseSession | None = None
|
||||
self.waiting_sessions: deque[ResponseSession] = deque()
|
||||
self.lock = RLock()
|
||||
self._variable_pool = variable_pool
|
||||
self._graph = graph
|
||||
self._active_session: ResponseSession | None = None
|
||||
self._waiting_sessions: deque[ResponseSession] = deque()
|
||||
self._lock = RLock()
|
||||
|
||||
# Internal stream management (replacing OutputRegistry)
|
||||
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
|
||||
self._stream_positions: dict[tuple[str, ...], int] = {}
|
||||
self._closed_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Track response nodes
|
||||
self._response_nodes: set[NodeID] = set()
|
||||
@ -63,7 +68,7 @@ class ResponseStreamCoordinator:
|
||||
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
|
||||
|
||||
def register(self, response_node_id: NodeID) -> None:
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
self._response_nodes.add(response_node_id)
|
||||
|
||||
# Build and save paths map for this response node
|
||||
@ -71,7 +76,7 @@ class ResponseStreamCoordinator:
|
||||
self._paths_maps[response_node_id] = paths_map
|
||||
|
||||
# Create and store response session for this node
|
||||
response_node = self.graph.nodes[response_node_id]
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
session = ResponseSession.from_node(response_node)
|
||||
self._response_sessions[response_node_id] = session
|
||||
|
||||
@ -82,7 +87,7 @@ class ResponseStreamCoordinator:
|
||||
node_id: The ID of the node
|
||||
execution_id: The execution ID from NodeRunStartedEvent
|
||||
"""
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
self._node_execution_ids[node_id] = execution_id
|
||||
|
||||
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
|
||||
@ -94,7 +99,7 @@ class ResponseStreamCoordinator:
|
||||
Returns:
|
||||
The execution ID for the node
|
||||
"""
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
if node_id not in self._node_execution_ids:
|
||||
self._node_execution_ids[node_id] = str(uuid4())
|
||||
return self._node_execution_ids[node_id]
|
||||
@ -111,14 +116,14 @@ class ResponseStreamCoordinator:
|
||||
List of Path objects, where each path contains branch edge IDs
|
||||
"""
|
||||
# Get root node ID
|
||||
root_node_id = self.graph.root_node.id
|
||||
root_node_id = self._graph.root_node.id
|
||||
|
||||
# If root is the response node, return empty path
|
||||
if root_node_id == response_node_id:
|
||||
return [Path()]
|
||||
|
||||
# Extract variable selectors from the response node's template
|
||||
response_node = self.graph.nodes[response_node_id]
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
response_session = ResponseSession.from_node(response_node)
|
||||
template = response_session.template
|
||||
|
||||
@ -144,7 +149,7 @@ class ResponseStreamCoordinator:
|
||||
visited.add(current_node_id)
|
||||
|
||||
# Explore outgoing edges
|
||||
outgoing_edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
|
||||
for edge in outgoing_edges:
|
||||
edge_id = edge.id
|
||||
next_node_id = edge.head
|
||||
@ -161,10 +166,10 @@ class ResponseStreamCoordinator:
|
||||
# Step 2: For each complete path, filter edges based on node blocking behavior
|
||||
filtered_paths: list[Path] = []
|
||||
for path in all_complete_paths:
|
||||
blocking_edges = []
|
||||
blocking_edges: list[str] = []
|
||||
for edge_id in path:
|
||||
edge = self.graph.edges[edge_id]
|
||||
source_node = self.graph.nodes[edge.tail]
|
||||
edge = self._graph.edges[edge_id]
|
||||
source_node = self._graph.nodes[edge.tail]
|
||||
|
||||
# Check if node is a branch/container (original behavior)
|
||||
if source_node.execution_type in {
|
||||
@ -194,7 +199,7 @@ class ResponseStreamCoordinator:
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
# Check each response node in order
|
||||
for response_node_id in self._response_nodes:
|
||||
if response_node_id not in self._paths_maps:
|
||||
@ -240,33 +245,32 @@ class ResponseStreamCoordinator:
|
||||
# Remove from map to ensure it won't be activated again
|
||||
del self._response_sessions[node_id]
|
||||
|
||||
if self.active_session is None:
|
||||
self.active_session = session
|
||||
if self._active_session is None:
|
||||
self._active_session = session
|
||||
|
||||
# Try to flush immediately
|
||||
events.extend(self.try_flush())
|
||||
else:
|
||||
# Queue the session if another is active
|
||||
self.waiting_sessions.append(session)
|
||||
self._waiting_sessions.append(session)
|
||||
|
||||
return events
|
||||
|
||||
def intercept_event(
|
||||
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
|
||||
) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.registry.append_chunk(event.selector, event)
|
||||
self._append_stream_chunk(event.selector, event)
|
||||
if event.is_final:
|
||||
self.registry.close_stream(event.selector)
|
||||
self._close_stream(event.selector)
|
||||
return self.try_flush()
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
else:
|
||||
# Skip cause we share the same variable pool.
|
||||
#
|
||||
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
# self.registry.set_scalar((event.node_id, variable_name), variable_value)
|
||||
# self._variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
return self.try_flush()
|
||||
return []
|
||||
|
||||
def _create_stream_chunk_event(
|
||||
self,
|
||||
@ -282,9 +286,9 @@ class ResponseStreamCoordinator:
|
||||
active response node's information since these are not actual node IDs.
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self.graph.nodes and self.active_session:
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
# Use the active response node for special selectors
|
||||
response_node = self.graph.nodes[self.active_session.node_id]
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
@ -295,7 +299,7 @@ class ResponseStreamCoordinator:
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
node = self.graph.nodes[node_id]
|
||||
node = self._graph.nodes[node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
@ -318,21 +322,21 @@ class ResponseStreamCoordinator:
|
||||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self.active_session and source_selector_prefix not in self.graph.nodes:
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self.active_session.node_id
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self.registry.has_unread(segment.selector):
|
||||
if event := self.registry.pop_chunk(segment.selector):
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self.active_session and source_selector_prefix not in self.graph.nodes:
|
||||
response_node = self.graph.nodes[self.active_session.node_id]
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
@ -349,15 +353,15 @@ class ResponseStreamCoordinator:
|
||||
events.append(event)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self.registry.stream_closed(segment.selector)
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
|
||||
elif value := self.registry.get_scalar(segment.selector):
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
@ -374,13 +378,13 @@ class ResponseStreamCoordinator:
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self.active_session is not None
|
||||
current_response_node = self.graph.nodes[self.active_session.node_id]
|
||||
assert self._active_session is not None
|
||||
current_response_node = self._graph.nodes[self._active_session.node_id]
|
||||
|
||||
# Use get_or_create_execution_id to ensure we have a consistent ID
|
||||
execution_id = self._get_or_create_execution_id(current_response_node.id)
|
||||
|
||||
is_last_segment = self.active_session.index == len(self.active_session.template.segments) - 1
|
||||
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
event = self._create_stream_chunk_event(
|
||||
node_id=current_response_node.id,
|
||||
execution_id=execution_id,
|
||||
@ -391,29 +395,29 @@ class ResponseStreamCoordinator:
|
||||
return [event]
|
||||
|
||||
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
|
||||
with self.lock:
|
||||
if not self.active_session:
|
||||
with self._lock:
|
||||
if not self._active_session:
|
||||
return []
|
||||
|
||||
template = self.active_session.template
|
||||
response_node_id = self.active_session.node_id
|
||||
template = self._active_session.template
|
||||
response_node_id = self._active_session.node_id
|
||||
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Process segments sequentially from current index
|
||||
while self.active_session.index < len(template.segments):
|
||||
segment = template.segments[self.active_session.index]
|
||||
while self._active_session.index < len(template.segments):
|
||||
segment = template.segments[self._active_session.index]
|
||||
|
||||
if isinstance(segment, VariableSegment):
|
||||
# Check if the source node for this variable is skipped
|
||||
# Only check for actual nodes, not special selectors (sys, env, conversation)
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
if source_selector_prefix in self.graph.nodes:
|
||||
source_node = self.graph.nodes[source_selector_prefix]
|
||||
if source_selector_prefix in self._graph.nodes:
|
||||
source_node = self._graph.nodes[source_selector_prefix]
|
||||
|
||||
if source_node.state == NodeState.SKIPPED:
|
||||
# Skip this variable segment if the source node is skipped
|
||||
self.active_session.index += 1
|
||||
self._active_session.index += 1
|
||||
continue
|
||||
|
||||
segment_events, is_complete = self._process_variable_segment(segment)
|
||||
@ -421,17 +425,17 @@ class ResponseStreamCoordinator:
|
||||
|
||||
# Only advance index if this variable segment is complete
|
||||
if is_complete:
|
||||
self.active_session.index += 1
|
||||
self._active_session.index += 1
|
||||
else:
|
||||
# Wait for more data
|
||||
break
|
||||
|
||||
elif isinstance(segment, TextSegment):
|
||||
else:
|
||||
segment_events = self._process_text_segment(segment)
|
||||
events.extend(segment_events)
|
||||
self.active_session.index += 1
|
||||
self._active_session.index += 1
|
||||
|
||||
if self.active_session.is_complete():
|
||||
if self._active_session.is_complete():
|
||||
# End current session and get events from starting next session
|
||||
next_session_events = self.end_session(response_node_id)
|
||||
events.extend(next_session_events)
|
||||
@ -449,18 +453,108 @@ class ResponseStreamCoordinator:
|
||||
Returns:
|
||||
List of events from starting the next session
|
||||
"""
|
||||
with self.lock:
|
||||
with self._lock:
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
if self.active_session and self.active_session.node_id == node_id:
|
||||
self.active_session = None
|
||||
if self._active_session and self._active_session.node_id == node_id:
|
||||
self._active_session = None
|
||||
|
||||
# Try to start next waiting session
|
||||
if self.waiting_sessions:
|
||||
next_session = self.waiting_sessions.popleft()
|
||||
self.active_session = next_session
|
||||
if self._waiting_sessions:
|
||||
next_session = self._waiting_sessions.popleft()
|
||||
self._active_session = next_session
|
||||
|
||||
# Immediately try to flush any available segments
|
||||
events = self.try_flush()
|
||||
|
||||
return events
|
||||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key in self._closed_streams:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
self._stream_buffers[key] = []
|
||||
self._stream_positions[key] = 0
|
||||
|
||||
self._stream_buffers[key].append(event)
|
||||
|
||||
def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
|
||||
"""
|
||||
Pop the next unread stream chunk from the buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return None
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
buffer = self._stream_buffers[key]
|
||||
|
||||
if position >= len(buffer):
|
||||
return None
|
||||
|
||||
event = buffer[position]
|
||||
self._stream_positions[key] = position + 1
|
||||
return event
|
||||
|
||||
def _has_unread_stream(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return False
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
return position < len(self._stream_buffers[key])
|
||||
|
||||
def _close_stream(self, selector: Sequence[str]) -> None:
|
||||
"""
|
||||
Mark a stream as closed (no more chunks can be appended).
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
"""
|
||||
key = tuple(selector)
|
||||
self._closed_streams.add(key)
|
||||
|
||||
def _is_stream_closed(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if a stream is closed.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if the stream is closed, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
return key in self._closed_streams
|
||||
|
||||
@ -5,12 +5,8 @@ This package manages node states, edge states, and execution tracking
|
||||
during workflow graph execution.
|
||||
"""
|
||||
|
||||
from .edge_state_manager import EdgeStateManager
|
||||
from .execution_tracker import ExecutionTracker
|
||||
from .node_state_manager import NodeStateManager
|
||||
from .unified_state_manager import UnifiedStateManager
|
||||
|
||||
__all__ = [
|
||||
"EdgeStateManager",
|
||||
"ExecutionTracker",
|
||||
"NodeStateManager",
|
||||
"UnifiedStateManager",
|
||||
]
|
||||
|
||||
@ -1,114 +0,0 @@
|
||||
"""
|
||||
Manager for edge states during graph execution.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
|
||||
class EdgeStateAnalysis(TypedDict):
|
||||
"""Analysis result for edge states."""
|
||||
|
||||
has_unknown: bool
|
||||
has_taken: bool
|
||||
all_skipped: bool
|
||||
|
||||
|
||||
@final
|
||||
class EdgeStateManager:
|
||||
"""
|
||||
Manages edge states and transitions during graph execution.
|
||||
|
||||
This handles edge state changes and provides analysis of edge
|
||||
states for decision making during execution.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph) -> None:
|
||||
"""
|
||||
Initialize the edge state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
"""
|
||||
self.graph = graph
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def mark_edge_taken(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as TAKEN.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.edges[edge_id].state = NodeState.TAKEN
|
||||
|
||||
def mark_edge_skipped(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as SKIPPED.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
|
||||
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
||||
"""
|
||||
Analyze the states of edges and return summary flags.
|
||||
|
||||
Args:
|
||||
edges: List of edges to analyze
|
||||
|
||||
Returns:
|
||||
Analysis result with state flags
|
||||
"""
|
||||
with self._lock:
|
||||
states = {edge.state for edge in edges}
|
||||
|
||||
return EdgeStateAnalysis(
|
||||
has_unknown=NodeState.UNKNOWN in states,
|
||||
has_taken=NodeState.TAKEN in states,
|
||||
all_skipped=states == {NodeState.SKIPPED} if states else True,
|
||||
)
|
||||
|
||||
def get_edge_state(self, edge_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of an edge.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge
|
||||
|
||||
Returns:
|
||||
The current edge state
|
||||
"""
|
||||
with self._lock:
|
||||
return self.graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
A tuple of (selected_edges, unselected_edges)
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
selected_edges.append(edge)
|
||||
else:
|
||||
unselected_edges.append(edge)
|
||||
|
||||
return selected_edges, unselected_edges
|
||||
@ -1,89 +0,0 @@
|
||||
"""
|
||||
Tracker for currently executing nodes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionTracker:
|
||||
"""
|
||||
Tracks nodes that are currently being executed.
|
||||
|
||||
This replaces the ExecutingNodesManager with a cleaner interface
|
||||
focused on tracking which nodes are in progress.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the execution tracker."""
|
||||
self._executing_nodes: set[str] = set()
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def add(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node starting execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.add(node_id)
|
||||
|
||||
def remove(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as no longer executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node finishing execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.discard(node_id)
|
||||
|
||||
def is_executing(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is currently executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is executing
|
||||
"""
|
||||
with self._lock:
|
||||
return node_id in self._executing_nodes
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if no nodes are currently executing.
|
||||
|
||||
Returns:
|
||||
True if no nodes are executing
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes) == 0
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Get the count of currently executing nodes.
|
||||
|
||||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
def get_executing_nodes(self) -> set[str]:
|
||||
"""
|
||||
Get a copy of the set of executing node IDs.
|
||||
|
||||
Returns:
|
||||
Set of node IDs currently executing
|
||||
"""
|
||||
with self._lock:
|
||||
return self._executing_nodes.copy()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all executing nodes."""
|
||||
with self._lock:
|
||||
self._executing_nodes.clear()
|
||||
@ -1,97 +0,0 @@
|
||||
"""
|
||||
Manager for node states during graph execution.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class NodeStateManager:
|
||||
"""
|
||||
Manages node states and the ready queue for execution.
|
||||
|
||||
This centralizes node state transitions and enqueueing logic,
|
||||
ensuring thread-safe operations on node states.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
|
||||
"""
|
||||
Initialize the node state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
ready_queue: Queue for nodes ready to execute
|
||||
"""
|
||||
self.graph = graph
|
||||
self.ready_queue = ready_queue
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def enqueue_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as TAKEN and add it to the ready queue.
|
||||
|
||||
This combines the state transition and enqueueing operations
|
||||
that always occur together when preparing a node for execution.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to enqueue
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self.ready_queue.put(node_id)
|
||||
|
||||
def mark_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as SKIPPED.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when all its incoming edges from taken branches
|
||||
have been satisfied.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
with self._lock:
|
||||
# Get all incoming edges to this node
|
||||
incoming_edges = self.graph.get_incoming_edges(node_id)
|
||||
|
||||
# If no incoming edges, node is always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# If any edge is UNKNOWN, node is not ready
|
||||
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
|
||||
return False
|
||||
|
||||
# Node is ready if at least one edge is TAKEN
|
||||
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
|
||||
|
||||
def get_node_state(self, node_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The current node state
|
||||
"""
|
||||
with self._lock:
|
||||
return self.graph.nodes[node_id].state
|
||||
@ -0,0 +1,304 @@
|
||||
"""
|
||||
Unified state manager that combines node, edge, and execution tracking.
|
||||
|
||||
This is a proposed simplification that merges NodeStateManager, EdgeStateManager,
|
||||
and ExecutionTracker into a single cohesive class.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
|
||||
class EdgeStateAnalysis(TypedDict):
|
||||
"""Analysis result for edge states."""
|
||||
|
||||
has_unknown: bool
|
||||
has_taken: bool
|
||||
all_skipped: bool
|
||||
|
||||
|
||||
@final
|
||||
class UnifiedStateManager:
|
||||
"""
|
||||
Unified manager for all graph state operations.
|
||||
|
||||
This class combines the responsibilities of:
|
||||
- NodeStateManager: Node state transitions and ready queue
|
||||
- EdgeStateManager: Edge state transitions and analysis
|
||||
- ExecutionTracker: Tracking executing nodes
|
||||
|
||||
Benefits:
|
||||
- Single lock for all state operations (reduced contention)
|
||||
- Cohesive state management interface
|
||||
- Simplified dependency injection
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
|
||||
"""
|
||||
Initialize the unified state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
ready_queue: Queue for nodes ready to execute
|
||||
"""
|
||||
self._graph = graph
|
||||
self._ready_queue = ready_queue
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Execution tracking state
|
||||
self._executing_nodes: set[str] = set()
|
||||
|
||||
# ============= Node State Operations =============
|
||||
|
||||
def enqueue_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as TAKEN and add it to the ready queue.
|
||||
|
||||
This combines the state transition and enqueueing operations
|
||||
that always occur together when preparing a node for execution.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to enqueue
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self._ready_queue.put(node_id)
|
||||
|
||||
def mark_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as SKIPPED.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when all its incoming edges from taken branches
|
||||
have been satisfied.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
with self._lock:
|
||||
# Get all incoming edges to this node
|
||||
incoming_edges = self._graph.get_incoming_edges(node_id)
|
||||
|
||||
# If no incoming edges, node is always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# If any edge is UNKNOWN, node is not ready
|
||||
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
|
||||
return False
|
||||
|
||||
# Node is ready if at least one edge is TAKEN
|
||||
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
|
||||
|
||||
def get_node_state(self, node_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The current node state
|
||||
"""
|
||||
with self._lock:
|
||||
return self._graph.nodes[node_id].state
|
||||
|
||||
# ============= Edge State Operations =============
|
||||
|
||||
def mark_edge_taken(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as TAKEN.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.edges[edge_id].state = NodeState.TAKEN
|
||||
|
||||
def mark_edge_skipped(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as SKIPPED.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
|
||||
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
||||
"""
|
||||
Analyze the states of edges and return summary flags.
|
||||
|
||||
Args:
|
||||
edges: List of edges to analyze
|
||||
|
||||
Returns:
|
||||
Analysis result with state flags
|
||||
"""
|
||||
with self._lock:
|
||||
states = {edge.state for edge in edges}
|
||||
|
||||
return EdgeStateAnalysis(
|
||||
has_unknown=NodeState.UNKNOWN in states,
|
||||
has_taken=NodeState.TAKEN in states,
|
||||
all_skipped=states == {NodeState.SKIPPED} if states else True,
|
||||
)
|
||||
|
||||
def get_edge_state(self, edge_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of an edge.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge
|
||||
|
||||
Returns:
|
||||
The current edge state
|
||||
"""
|
||||
with self._lock:
|
||||
return self._graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
A tuple of (selected_edges, unselected_edges)
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
selected_edges.append(edge)
|
||||
else:
|
||||
unselected_edges.append(edge)
|
||||
|
||||
return selected_edges, unselected_edges
|
||||
|
||||
# ============= Execution Tracking Operations =============
|
||||
|
||||
def start_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node starting execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.add(node_id)
|
||||
|
||||
def finish_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as no longer executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node finishing execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.discard(node_id)
|
||||
|
||||
def is_executing(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is currently executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is executing
|
||||
"""
|
||||
with self._lock:
|
||||
return node_id in self._executing_nodes
|
||||
|
||||
def get_executing_count(self) -> int:
|
||||
"""
|
||||
Get the count of currently executing nodes.
|
||||
|
||||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
def get_executing_nodes(self) -> set[str]:
|
||||
"""
|
||||
Get a copy of the set of executing node IDs.
|
||||
|
||||
Returns:
|
||||
Set of node IDs currently executing
|
||||
"""
|
||||
with self._lock:
|
||||
return self._executing_nodes.copy()
|
||||
|
||||
def clear_executing(self) -> None:
|
||||
"""Clear all executing nodes."""
|
||||
with self._lock:
|
||||
self._executing_nodes.clear()
|
||||
|
||||
# ============= Composite Operations =============
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
Check if graph execution is complete.
|
||||
|
||||
Execution is complete when:
|
||||
- Ready queue is empty
|
||||
- No nodes are executing
|
||||
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
with self._lock:
|
||||
return self._ready_queue.empty() and len(self._executing_nodes) == 0
|
||||
|
||||
def get_queue_depth(self) -> int:
|
||||
"""
|
||||
Get the current depth of the ready queue.
|
||||
|
||||
Returns:
|
||||
Number of nodes in the ready queue
|
||||
"""
|
||||
return self._ready_queue.qsize()
|
||||
|
||||
def get_execution_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get execution statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with execution statistics
|
||||
"""
|
||||
with self._lock:
|
||||
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
|
||||
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
|
||||
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
|
||||
|
||||
return {
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"executing": len(self._executing_nodes),
|
||||
"taken_nodes": taken_nodes,
|
||||
"skipped_nodes": skipped_nodes,
|
||||
"unknown_nodes": unknown_nodes,
|
||||
}
|
||||
@ -15,6 +15,7 @@ from typing import final
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
@ -58,21 +59,22 @@ class Worker(threading.Thread):
|
||||
on_active_callback: Optional callback when worker becomes active
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self.ready_queue = ready_queue
|
||||
self.event_queue = event_queue
|
||||
self.graph = graph
|
||||
self.worker_id = worker_id
|
||||
self.flask_app = flask_app
|
||||
self.context_vars = context_vars
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self.on_idle_callback = on_idle_callback
|
||||
self.on_active_callback = on_active_callback
|
||||
self.last_task_time = time.time()
|
||||
self._on_idle_callback = on_idle_callback
|
||||
self._on_active_callback = on_active_callback
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
@override
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Main worker loop.
|
||||
@ -83,22 +85,22 @@ class Worker(threading.Thread):
|
||||
while not self._stop_event.is_set():
|
||||
# Try to get a node ID from the ready queue (with timeout)
|
||||
try:
|
||||
node_id = self.ready_queue.get(timeout=0.1)
|
||||
node_id = self._ready_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
# Notify that worker is idle
|
||||
if self.on_idle_callback:
|
||||
self.on_idle_callback(self.worker_id)
|
||||
if self._on_idle_callback:
|
||||
self._on_idle_callback(self._worker_id)
|
||||
continue
|
||||
|
||||
# Notify that worker is active
|
||||
if self.on_active_callback:
|
||||
self.on_active_callback(self.worker_id)
|
||||
if self._on_active_callback:
|
||||
self._on_active_callback(self._worker_id)
|
||||
|
||||
self.last_task_time = time.time()
|
||||
node = self.graph.nodes[node_id]
|
||||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
self._execute_node(node)
|
||||
self.ready_queue.task_done()
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=str(uuid4()),
|
||||
@ -108,7 +110,7 @@ class Worker(threading.Thread):
|
||||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
)
|
||||
self.event_queue.put(error_event)
|
||||
self._event_queue.put(error_event)
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
@ -118,19 +120,19 @@ class Worker(threading.Thread):
|
||||
node: The node instance to execute
|
||||
"""
|
||||
# Execute the node with preserved context if Flask app is provided
|
||||
if self.flask_app and self.context_vars:
|
||||
if self._flask_app and self._context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self.flask_app,
|
||||
context_vars=self.context_vars,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
):
|
||||
# Execute the node
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self.event_queue.put(event)
|
||||
self._event_queue.put(event)
|
||||
else:
|
||||
# Execute without context preservation
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self.event_queue.put(event)
|
||||
self._event_queue.put(event)
|
||||
|
||||
@ -1,81 +0,0 @@
|
||||
# Worker Management
|
||||
|
||||
Dynamic worker pool for node execution.
|
||||
|
||||
## Components
|
||||
|
||||
### WorkerPool
|
||||
|
||||
Manages worker thread lifecycle.
|
||||
|
||||
- `start/stop/wait()` - Control workers
|
||||
- `scale_up/down()` - Adjust pool size
|
||||
- `get_worker_count()` - Current count
|
||||
|
||||
### WorkerFactory
|
||||
|
||||
Creates workers with Flask context.
|
||||
|
||||
- `create_worker()` - Build with dependencies
|
||||
- Preserves request context
|
||||
|
||||
### DynamicScaler
|
||||
|
||||
Determines scaling decisions.
|
||||
|
||||
- `min/max_workers` - Pool bounds
|
||||
- `scale_up_threshold` - Queue trigger
|
||||
- `should_scale_up/down()` - Check conditions
|
||||
|
||||
### ActivityTracker
|
||||
|
||||
Tracks worker activity.
|
||||
|
||||
- `track_activity(worker_id)` - Record activity
|
||||
- `get_idle_workers(threshold)` - Find idle
|
||||
- `get_active_count()` - Active count
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
scaler = DynamicScaler(
|
||||
min_workers=2,
|
||||
max_workers=10,
|
||||
scale_up_threshold=5
|
||||
)
|
||||
|
||||
pool = WorkerPool(
|
||||
ready_queue=ready_queue,
|
||||
worker_factory=factory,
|
||||
dynamic_scaler=scaler
|
||||
)
|
||||
|
||||
pool.start()
|
||||
|
||||
# Scale based on load
|
||||
if scaler.should_scale_up(queue_size, active):
|
||||
pool.scale_up()
|
||||
|
||||
pool.stop()
|
||||
```
|
||||
|
||||
## Scaling Strategy
|
||||
|
||||
**Scale Up**: Queue size > threshold AND workers < max
|
||||
**Scale Down**: Idle workers exist AND workers > min
|
||||
|
||||
## Parameters
|
||||
|
||||
- `min_workers` - Minimum pool size
|
||||
- `max_workers` - Maximum pool size
|
||||
- `scale_up_threshold` - Queue trigger
|
||||
- `scale_down_threshold` - Idle seconds
|
||||
|
||||
## Flask Context
|
||||
|
||||
WorkerFactory preserves request context across threads:
|
||||
|
||||
```python
|
||||
context_vars = {"request_id": request.id}
|
||||
# Workers receive same context
|
||||
```
|
||||
@ -5,14 +5,8 @@ This package manages the worker pool, including creation,
|
||||
scaling, and activity tracking.
|
||||
"""
|
||||
|
||||
from .activity_tracker import ActivityTracker
|
||||
from .dynamic_scaler import DynamicScaler
|
||||
from .worker_factory import WorkerFactory
|
||||
from .worker_pool import WorkerPool
|
||||
from .simple_worker_pool import SimpleWorkerPool
|
||||
|
||||
__all__ = [
|
||||
"ActivityTracker",
|
||||
"DynamicScaler",
|
||||
"WorkerFactory",
|
||||
"WorkerPool",
|
||||
"SimpleWorkerPool",
|
||||
]
|
||||
|
||||
@ -1,76 +0,0 @@
|
||||
"""
|
||||
Activity tracker for monitoring worker activity.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import final
|
||||
|
||||
|
||||
@final
|
||||
class ActivityTracker:
|
||||
"""
|
||||
Tracks worker activity for scaling decisions.
|
||||
|
||||
This monitors which workers are active or idle to support
|
||||
dynamic scaling decisions.
|
||||
"""
|
||||
|
||||
def __init__(self, idle_threshold: float = 30.0) -> None:
|
||||
"""
|
||||
Initialize the activity tracker.
|
||||
|
||||
Args:
|
||||
idle_threshold: Seconds before a worker is considered idle
|
||||
"""
|
||||
self.idle_threshold = idle_threshold
|
||||
self._worker_activity: dict[int, tuple[bool, float]] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def track_activity(self, worker_id: int, is_active: bool) -> None:
|
||||
"""
|
||||
Track worker activity state.
|
||||
|
||||
Args:
|
||||
worker_id: ID of the worker
|
||||
is_active: Whether the worker is active
|
||||
"""
|
||||
with self._lock:
|
||||
self._worker_activity[worker_id] = (is_active, time.time())
|
||||
|
||||
def get_idle_workers(self) -> list[int]:
|
||||
"""
|
||||
Get list of workers that have been idle too long.
|
||||
|
||||
Returns:
|
||||
List of idle worker IDs
|
||||
"""
|
||||
current_time = time.time()
|
||||
idle_workers = []
|
||||
|
||||
with self._lock:
|
||||
for worker_id, (is_active, last_change) in self._worker_activity.items():
|
||||
if not is_active and (current_time - last_change) > self.idle_threshold:
|
||||
idle_workers.append(worker_id)
|
||||
|
||||
return idle_workers
|
||||
|
||||
def remove_worker(self, worker_id: int) -> None:
|
||||
"""
|
||||
Remove a worker from tracking.
|
||||
|
||||
Args:
|
||||
worker_id: ID of the worker to remove
|
||||
"""
|
||||
with self._lock:
|
||||
self._worker_activity.pop(worker_id, None)
|
||||
|
||||
def get_active_count(self) -> int:
|
||||
"""
|
||||
Get count of currently active workers.
|
||||
|
||||
Returns:
|
||||
Number of active workers
|
||||
"""
|
||||
with self._lock:
|
||||
return sum(1 for is_active, _ in self._worker_activity.values() if is_active)
|
||||
@ -1,101 +0,0 @@
|
||||
"""
|
||||
Dynamic scaler for worker pool sizing.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class DynamicScaler:
|
||||
"""
|
||||
Manages dynamic scaling decisions for the worker pool.
|
||||
|
||||
This encapsulates the logic for when to scale up or down
|
||||
based on workload and configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_workers: int = 2,
|
||||
max_workers: int = 10,
|
||||
scale_up_threshold: int = 5,
|
||||
scale_down_idle_time: float = 30.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dynamic scaler.
|
||||
|
||||
Args:
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
scale_down_idle_time: Idle time before scaling down
|
||||
"""
|
||||
self.min_workers = min_workers
|
||||
self.max_workers = max_workers
|
||||
self.scale_up_threshold = scale_up_threshold
|
||||
self.scale_down_idle_time = scale_down_idle_time
|
||||
|
||||
def calculate_initial_workers(self, graph: Graph) -> int:
|
||||
"""
|
||||
Calculate initial worker count based on graph complexity.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
|
||||
Returns:
|
||||
Initial number of workers to create
|
||||
"""
|
||||
node_count = len(graph.nodes)
|
||||
|
||||
# Simple heuristic: more nodes = more workers
|
||||
if node_count < 10:
|
||||
initial = self.min_workers
|
||||
elif node_count < 50:
|
||||
initial = min(4, self.max_workers)
|
||||
elif node_count < 100:
|
||||
initial = min(6, self.max_workers)
|
||||
else:
|
||||
initial = min(8, self.max_workers)
|
||||
|
||||
return max(self.min_workers, initial)
|
||||
|
||||
def should_scale_up(self, current_workers: int, queue_depth: int, executing_count: int) -> bool:
|
||||
"""
|
||||
Determine if scaling up is needed.
|
||||
|
||||
Args:
|
||||
current_workers: Current number of workers
|
||||
queue_depth: Number of nodes waiting
|
||||
executing_count: Number of nodes executing
|
||||
|
||||
Returns:
|
||||
True if should scale up
|
||||
"""
|
||||
if current_workers >= self.max_workers:
|
||||
return False
|
||||
|
||||
# Scale up if queue is deep and workers are busy
|
||||
if queue_depth > self.scale_up_threshold:
|
||||
if executing_count >= current_workers * 0.8:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_scale_down(self, current_workers: int, idle_workers: list[int]) -> bool:
|
||||
"""
|
||||
Determine if scaling down is appropriate.
|
||||
|
||||
Args:
|
||||
current_workers: Current number of workers
|
||||
idle_workers: List of idle worker IDs
|
||||
|
||||
Returns:
|
||||
True if should scale down
|
||||
"""
|
||||
if current_workers <= self.min_workers:
|
||||
return False
|
||||
|
||||
# Scale down if we have idle workers
|
||||
return len(idle_workers) > 0
|
||||
@ -0,0 +1,168 @@
|
||||
"""
|
||||
Simple worker pool that consolidates functionality.
|
||||
|
||||
This is a simpler implementation that merges WorkerPool, ActivityTracker,
|
||||
DynamicScaler, and WorkerFactory into a single class.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
from ..worker import Worker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextvars import Context
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@final
|
||||
class SimpleWorkerPool:
|
||||
"""
|
||||
Simple worker pool with integrated management.
|
||||
|
||||
This class consolidates all worker management functionality into
|
||||
a single, simpler implementation without excessive abstraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the simple worker pool.
|
||||
|
||||
Args:
|
||||
ready_queue: Queue of nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
flask_app: Optional Flask app for context preservation
|
||||
context_vars: Optional context variables
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
scale_down_idle_time: Seconds before scaling down idle workers
|
||||
"""
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
|
||||
# Scaling parameters with defaults
|
||||
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
||||
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
|
||||
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
|
||||
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
||||
|
||||
# Worker management
|
||||
self._workers: list[Worker] = []
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
|
||||
def start(self, initial_count: int | None = None) -> None:
|
||||
"""
|
||||
Start the worker pool.
|
||||
|
||||
Args:
|
||||
initial_count: Number of workers to start with (auto-calculated if None)
|
||||
"""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Calculate initial worker count
|
||||
if initial_count is None:
|
||||
node_count = len(self._graph.nodes)
|
||||
if node_count < 10:
|
||||
initial_count = self._min_workers
|
||||
elif node_count < 50:
|
||||
initial_count = min(self._min_workers + 1, self._max_workers)
|
||||
else:
|
||||
initial_count = min(self._min_workers + 2, self._max_workers)
|
||||
|
||||
# Create initial workers
|
||||
for _ in range(initial_count):
|
||||
self._create_worker()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop all workers in the pool."""
|
||||
with self._lock:
|
||||
self._running = False
|
||||
|
||||
# Stop all workers
|
||||
for worker in self._workers:
|
||||
worker.stop()
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
|
||||
self._workers.clear()
|
||||
|
||||
def _create_worker(self) -> None:
|
||||
"""Create and start a new worker."""
|
||||
worker_id = self._worker_counter
|
||||
self._worker_counter += 1
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
self._workers.append(worker)
|
||||
|
||||
def check_and_scale(self) -> None:
|
||||
"""Check and perform scaling if needed."""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
current_count = len(self._workers)
|
||||
queue_depth = self._ready_queue.qsize()
|
||||
|
||||
# Simple scaling logic
|
||||
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
|
||||
self._create_worker()
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
"""Get current number of workers."""
|
||||
with self._lock:
|
||||
return len(self._workers)
|
||||
|
||||
def get_status(self) -> dict[str, int]:
|
||||
"""
|
||||
Get pool status information.
|
||||
|
||||
Returns:
|
||||
Dictionary with status information
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_workers": len(self._workers),
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"min_workers": self._min_workers,
|
||||
"max_workers": self._max_workers,
|
||||
}
|
||||
@ -1,75 +0,0 @@
|
||||
"""
|
||||
Factory for creating worker instances.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
from collections.abc import Callable
|
||||
from typing import final
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
from ..worker import Worker
|
||||
|
||||
|
||||
@final
|
||||
class WorkerFactory:
|
||||
"""
|
||||
Factory for creating worker instances with proper context.
|
||||
|
||||
This encapsulates worker creation logic and ensures all workers
|
||||
are created with the necessary Flask and context variable setup.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flask_app: Flask | None,
|
||||
context_vars: contextvars.Context,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the worker factory.
|
||||
|
||||
Args:
|
||||
flask_app: Flask application context
|
||||
context_vars: Context variables to propagate
|
||||
"""
|
||||
self.flask_app = flask_app
|
||||
self.context_vars = context_vars
|
||||
self._next_worker_id = 0
|
||||
|
||||
def create_worker(
|
||||
self,
|
||||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue,
|
||||
graph: Graph,
|
||||
on_idle_callback: Callable[[int], None] | None = None,
|
||||
on_active_callback: Callable[[int], None] | None = None,
|
||||
) -> Worker:
|
||||
"""
|
||||
Create a new worker instance.
|
||||
|
||||
Args:
|
||||
ready_queue: Queue of nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
on_idle_callback: Callback when worker becomes idle
|
||||
on_active_callback: Callback when worker becomes active
|
||||
|
||||
Returns:
|
||||
Configured worker instance
|
||||
"""
|
||||
worker_id = self._next_worker_id
|
||||
self._next_worker_id += 1
|
||||
|
||||
return Worker(
|
||||
ready_queue=ready_queue,
|
||||
event_queue=event_queue,
|
||||
graph=graph,
|
||||
worker_id=worker_id,
|
||||
flask_app=self.flask_app,
|
||||
context_vars=self.context_vars,
|
||||
on_idle_callback=on_idle_callback,
|
||||
on_active_callback=on_active_callback,
|
||||
)
|
||||
@ -1,147 +0,0 @@
|
||||
"""
|
||||
Worker pool management.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
from ..worker import Worker
|
||||
from .activity_tracker import ActivityTracker
|
||||
from .dynamic_scaler import DynamicScaler
|
||||
from .worker_factory import WorkerFactory
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
"""
|
||||
Manages a pool of worker threads for executing nodes.
|
||||
|
||||
This provides dynamic scaling, activity tracking, and lifecycle
|
||||
management for worker threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue,
|
||||
graph: Graph,
|
||||
worker_factory: WorkerFactory,
|
||||
dynamic_scaler: DynamicScaler,
|
||||
activity_tracker: ActivityTracker,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the worker pool.
|
||||
|
||||
Args:
|
||||
ready_queue: Queue of nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
worker_factory: Factory for creating workers
|
||||
dynamic_scaler: Scaler for dynamic sizing
|
||||
activity_tracker: Tracker for worker activity
|
||||
"""
|
||||
self.ready_queue = ready_queue
|
||||
self.event_queue = event_queue
|
||||
self.graph = graph
|
||||
self.worker_factory = worker_factory
|
||||
self.dynamic_scaler = dynamic_scaler
|
||||
self.activity_tracker = activity_tracker
|
||||
|
||||
self.workers: list[Worker] = []
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
|
||||
def start(self, initial_count: int) -> None:
|
||||
"""
|
||||
Start the worker pool with initial workers.
|
||||
|
||||
Args:
|
||||
initial_count: Number of workers to start with
|
||||
"""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Create initial workers
|
||||
for _ in range(initial_count):
|
||||
worker = self.worker_factory.create_worker(self.ready_queue, self.event_queue, self.graph)
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop all workers in the pool."""
|
||||
with self._lock:
|
||||
self._running = False
|
||||
|
||||
# Stop all workers
|
||||
for worker in self.workers:
|
||||
worker.stop()
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self.workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
|
||||
self.workers.clear()
|
||||
|
||||
def scale_up(self) -> None:
|
||||
"""Add a worker to the pool if allowed."""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
if len(self.workers) >= self.dynamic_scaler.max_workers:
|
||||
return
|
||||
|
||||
worker = self.worker_factory.create_worker(self.ready_queue, self.event_queue, self.graph)
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
|
||||
def scale_down(self, worker_ids: list[int]) -> None:
|
||||
"""
|
||||
Remove specific workers from the pool.
|
||||
|
||||
Args:
|
||||
worker_ids: IDs of workers to remove
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
if len(self.workers) <= self.dynamic_scaler.min_workers:
|
||||
return
|
||||
|
||||
workers_to_remove = [w for w in self.workers if w.worker_id in worker_ids]
|
||||
|
||||
for worker in workers_to_remove:
|
||||
worker.stop()
|
||||
self.workers.remove(worker)
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=1.0)
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
"""Get current number of workers."""
|
||||
with self._lock:
|
||||
return len(self.workers)
|
||||
|
||||
def check_scaling(self, queue_depth: int, executing_count: int) -> None:
|
||||
"""
|
||||
Check and perform scaling if needed.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
executing_count: Number of executing nodes
|
||||
"""
|
||||
current_count = self.get_worker_count()
|
||||
|
||||
if self.dynamic_scaler.should_scale_up(current_count, queue_depth, executing_count):
|
||||
self.scale_up()
|
||||
|
||||
idle_workers = self.activity_tracker.get_idle_workers()
|
||||
if idle_workers:
|
||||
self.scale_down(idle_workers)
|
||||
Reference in New Issue
Block a user