Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

# Conflicts:
#	api/commands.py
#	api/core/app/apps/common/workflow_response_converter.py
#	api/core/llm_generator/llm_generator.py
#	api/core/plugin/entities/plugin.py
#	api/core/plugin/impl/tool.py
#	api/core/rag/index_processor/index_processor_base.py
#	api/core/workflow/entities/workflow_execution.py
#	api/core/workflow/entities/workflow_node_execution.py
#	api/core/workflow/enums.py
#	api/core/workflow/graph_engine/entities/graph.py
#	api/core/workflow/graph_engine/graph_engine.py
#	api/core/workflow/nodes/enums.py
#	api/services/dataset_service.py
This commit is contained in:
jyong
2025-08-27 16:05:59 +08:00
385 changed files with 23289 additions and 11938 deletions

View File

@ -0,0 +1,187 @@
# Graph Engine
Queue-based workflow execution engine for parallel graph processing.
## Architecture
The engine uses a modular architecture with specialized packages:
### Core Components
- **Domain** (`domain/`) - Core models: ExecutionContext, GraphExecution, NodeExecution
- **Event Management** (`event_management/`) - Event handling, collection, and emission
- **State Management** (`state_management/`) - Thread-safe state tracking for nodes and edges
- **Error Handling** (`error_handling/`) - Strategy-based error recovery (retry, abort, fail-branch, default-value)
- **Graph Traversal** (`graph_traversal/`) - Node readiness, edge processing, branch handling
- **Command Processing** (`command_processing/`) - External command handling (abort, pause, resume)
- **Worker Management** (`worker_management/`) - Dynamic worker pool with auto-scaling
- **Orchestration** (`orchestration/`) - Main event loop and execution coordination
### Supporting Components
- **Output Registry** (`output_registry/`) - Thread-safe storage for node outputs
- **Response Coordinator** (`response_coordinator/`) - Ordered streaming of response nodes
- **Command Channels** (`command_channels/`) - Command transport (InMemory/Redis)
- **Layers** (`layers/`) - Pluggable middleware for extensions
## Architecture Diagram
```mermaid
classDiagram
class GraphEngine {
+run()
+add_layer()
}
class Domain {
ExecutionContext
GraphExecution
NodeExecution
}
class EventManagement {
EventHandlerRegistry
EventCollector
EventEmitter
}
class StateManagement {
NodeStateManager
EdgeStateManager
ExecutionTracker
}
class WorkerManagement {
WorkerPool
WorkerFactory
DynamicScaler
ActivityTracker
}
class GraphTraversal {
NodeReadinessChecker
EdgeProcessor
BranchHandler
SkipPropagator
}
class Orchestration {
Dispatcher
ExecutionCoordinator
}
class ErrorHandling {
ErrorHandler
RetryStrategy
AbortStrategy
FailBranchStrategy
}
class CommandProcessing {
CommandProcessor
AbortCommandHandler
}
class CommandChannels {
InMemoryChannel
RedisChannel
}
class OutputRegistry {
<<Storage>>
Scalar Values
Streaming Data
}
class ResponseCoordinator {
Session Management
Path Analysis
}
class Layers {
<<Plugin>>
DebugLoggingLayer
}
GraphEngine --> Orchestration : coordinates
GraphEngine --> Layers : extends
Orchestration --> EventManagement : processes events
Orchestration --> WorkerManagement : manages scaling
Orchestration --> CommandProcessing : checks commands
Orchestration --> StateManagement : monitors state
WorkerManagement --> StateManagement : consumes ready queue
WorkerManagement --> EventManagement : produces events
WorkerManagement --> Domain : executes nodes
EventManagement --> ErrorHandling : failed events
EventManagement --> GraphTraversal : success events
EventManagement --> ResponseCoordinator : stream events
EventManagement --> Layers : notifies
GraphTraversal --> StateManagement : updates states
GraphTraversal --> Domain : checks graph
CommandProcessing --> CommandChannels : fetches commands
CommandProcessing --> Domain : modifies execution
ErrorHandling --> Domain : handles failures
StateManagement --> Domain : tracks entities
ResponseCoordinator --> OutputRegistry : reads outputs
Domain --> OutputRegistry : writes outputs
```
## Package Relationships
### Core Dependencies
- **Orchestration** acts as the central coordinator, managing all subsystems
- **Domain** provides the core business entities used by all packages
- **EventManagement** serves as the communication backbone between components
- **StateManagement** maintains thread-safe state for the entire system
### Data Flow
1. **Commands** flow from CommandChannels → CommandProcessing → Domain
1. **Events** flow from Workers → EventHandlerRegistry → State updates
1. **Node outputs** flow from Workers → OutputRegistry → ResponseCoordinator
1. **Ready nodes** flow from GraphTraversal → StateManagement → WorkerManagement
### Extension Points
- **Layers** observe all events for monitoring, logging, and custom logic
- **ErrorHandling** strategies can be extended for custom failure recovery
- **CommandChannels** can be implemented for different transport mechanisms
## Execution Flow
1. **Initialization**: GraphEngine creates all subsystems with the workflow graph
1. **Node Discovery**: Traversal components identify ready nodes
1. **Worker Execution**: Workers pull from ready queue and execute nodes
1. **Event Processing**: Dispatcher routes events to appropriate handlers
1. **State Updates**: Managers track node/edge states for next steps
1. **Completion**: Coordinator detects when all nodes are done
## Usage
```python
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
# Create and run engine
engine = GraphEngine(
tenant_id="tenant_1",
app_id="app_1",
workflow_id="workflow_1",
graph=graph,
command_channel=InMemoryChannel(),
)
# Stream execution events
for event in engine.run():
handle_event(event)
```

View File

@ -1,4 +1,3 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
from .graph_engine import GraphEngine
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
__all__ = ["GraphEngine"]

View File

@ -0,0 +1,33 @@
# Command Channels
Channel implementations for external workflow control.
## Components
### InMemoryChannel
Thread-safe in-memory queue for single-process deployments.
- `fetch_commands()` - Get pending commands
- `send_command()` - Add command to queue
### RedisChannel
Redis-based queue for distributed deployments.
- `fetch_commands()` - Get commands with JSON deserialization
- `send_command()` - Store commands with TTL
## Usage
```python
# Local execution
channel = InMemoryChannel()
channel.send_command(AbortCommand(graph_id="workflow-123"))
# Distributed execution
redis_channel = RedisChannel(
redis_client=redis_client,
channel_key="workflow:123:commands"
)
```

View File

@ -0,0 +1,6 @@
"""Command channel implementations for GraphEngine."""
from .in_memory_channel import InMemoryChannel
from .redis_channel import RedisChannel
__all__ = ["InMemoryChannel", "RedisChannel"]

View File

@ -0,0 +1,51 @@
"""
In-memory implementation of CommandChannel for local/testing scenarios.
This implementation uses a thread-safe queue for command communication
within a single process. Each instance handles commands for one workflow execution.
"""
from queue import Queue
from ..entities.commands import GraphEngineCommand
class InMemoryChannel:
"""
In-memory command channel implementation using a thread-safe queue.
Each instance is dedicated to a single GraphEngine/workflow execution.
Suitable for local development, testing, and single-instance deployments.
"""
def __init__(self) -> None:
"""Initialize the in-memory channel with a single queue."""
self._queue: Queue[GraphEngineCommand] = Queue()
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch all pending commands from the queue.
Returns:
List of pending commands (drains the queue)
"""
commands: list[GraphEngineCommand] = []
# Drain all available commands from the queue
while not self._queue.empty():
try:
command = self._queue.get_nowait()
commands.append(command)
except Exception:
break
return commands
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to this channel's queue.
Args:
command: The command to send
"""
self._queue.put(command)

View File

@ -0,0 +1,109 @@
"""
Redis-based implementation of CommandChannel for distributed scenarios.
This implementation uses Redis lists for command queuing, supporting
multi-instance deployments and cross-server communication.
Each instance uses a unique key for its command queue.
"""
import json
from typing import TYPE_CHECKING, Optional
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
class RedisChannel:
"""
Redis-based command channel implementation for distributed systems.
Each instance uses a unique Redis key for its command queue.
Commands are JSON-serialized for transport.
"""
def __init__(
self,
redis_client: "RedisClientWrapper",
channel_key: str,
command_ttl: int = 3600,
) -> None:
"""
Initialize the Redis channel.
Args:
redis_client: Redis client instance
channel_key: Unique key for this channel's command queue
command_ttl: TTL for command keys in seconds (default: 3600)
"""
self._redis = redis_client
self._key = channel_key
self._command_ttl = command_ttl
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch all pending commands from Redis.
Returns:
List of pending commands (drains the Redis list)
"""
commands: list[GraphEngineCommand] = []
# Use pipeline for atomic operations
with self._redis.pipeline() as pipe:
# Get all commands and clear the list atomically
pipe.lrange(self._key, 0, -1)
pipe.delete(self._key)
results = pipe.execute()
# Parse commands from JSON
if results[0]:
for command_json in results[0]:
try:
command_data = json.loads(command_json)
command = self._deserialize_command(command_data)
if command:
commands.append(command)
except (json.JSONDecodeError, ValueError):
# Skip invalid commands
continue
return commands
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to Redis.
Args:
command: The command to send
"""
command_json = json.dumps(command.model_dump())
# Push to list and set expiry
with self._redis.pipeline() as pipe:
pipe.rpush(self._key, command_json)
pipe.expire(self._key, self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict) -> Optional[GraphEngineCommand]:
"""
Deserialize a command from dictionary data.
Args:
data: Command data dictionary
Returns:
Deserialized command or None if invalid
"""
try:
command_type = CommandType(data.get("command_type"))
if command_type == CommandType.ABORT:
return AbortCommand(**data)
else:
# For other command types, use base class
return GraphEngineCommand(**data)
except (ValueError, TypeError):
return None

View File

@ -0,0 +1,14 @@
"""
Command processing subsystem for graph engine.
This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
]

View File

@ -0,0 +1,27 @@
"""
Command handler implementations.
"""
import logging
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
"""
Handle an abort command.
Args:
command: The abort command
execution: Graph execution to abort
"""
assert isinstance(command, AbortCommand)
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
execution.abort(command.reason or "User requested abort")

View File

@ -0,0 +1,78 @@
"""
Main command processor for handling external commands.
"""
import logging
from typing import Protocol
from ..domain.graph_execution import GraphExecution
from ..entities.commands import GraphEngineCommand
from ..protocols.command_channel import CommandChannel
logger = logging.getLogger(__name__)
class CommandHandler(Protocol):
"""Protocol for command handlers."""
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
class CommandProcessor:
"""
Processes external commands sent to the engine.
This polls the command channel and dispatches commands to
appropriate handlers.
"""
def __init__(
self,
command_channel: CommandChannel,
graph_execution: GraphExecution,
) -> None:
"""
Initialize the command processor.
Args:
command_channel: Channel for receiving commands
graph_execution: Graph execution aggregate
"""
self.command_channel = command_channel
self.graph_execution = graph_execution
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
"""
Register a handler for a command type.
Args:
command_type: Type of command to handle
handler: Handler for the command
"""
self._handlers[command_type] = handler
def process_commands(self) -> None:
"""Check for and process any pending commands."""
try:
commands = self.command_channel.fetch_commands()
for command in commands:
self._handle_command(command)
except Exception as e:
logger.warning("Error processing commands: %s", e)
def _handle_command(self, command: GraphEngineCommand) -> None:
"""
Handle a single command.
Args:
command: The command to handle
"""
handler = self._handlers.get(type(command))
if handler:
try:
handler.handle(command, self.graph_execution)
except Exception as e:
logger.exception("Error handling command %s", command.__class__.__name__)
else:
logger.warning("No handler registered for command: %s", command.__class__.__name__)

View File

@ -1,25 +0,0 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class RunConditionHandler(ABC):
def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition
@abstractmethod
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
raise NotImplementedError

View File

@ -1,25 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
run_result = previous_route_node_state.node_run_result
if not run_result:
return False
if not run_result.edge_source_handle:
return False
return self.condition.branch_identify == run_result.edge_source_handle

View File

@ -1,27 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.conditions:
return True
# process condition
condition_processor = ConditionProcessor()
_, _, final_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool,
conditions=self.condition.conditions,
operator="and",
)
return final_result

View File

@ -1,25 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(
init_params: GraphInitParams, graph: Graph, run_condition: RunCondition
) -> RunConditionHandler:
"""
Get condition handler
:param init_params: init params
:param graph: graph
:param run_condition: run condition
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition)
else:
return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition)

View File

@ -0,0 +1,16 @@
"""
Domain models for graph engine.
This package contains the core domain entities, value objects, and aggregates
that represent the business concepts of workflow graph execution.
"""
from .execution_context import ExecutionContext
from .graph_execution import GraphExecution
from .node_execution import NodeExecution
__all__ = [
"ExecutionContext",
"GraphExecution",
"NodeExecution",
]

View File

@ -0,0 +1,37 @@
"""
ExecutionContext value object containing immutable execution parameters.
"""
from dataclasses import dataclass
from core.app.entities.app_invoke_entities import InvokeFrom
from models.enums import UserFrom
@dataclass(frozen=True)
class ExecutionContext:
"""
Immutable value object containing the context for a graph execution.
This encapsulates all the contextual information needed to execute a workflow,
keeping it separate from the mutable execution state.
"""
tenant_id: str
app_id: str
workflow_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
call_depth: int
max_execution_steps: int
max_execution_time: int
def __post_init__(self) -> None:
"""Validate execution context parameters."""
if self.call_depth < 0:
raise ValueError("Call depth must be non-negative")
if self.max_execution_steps <= 0:
raise ValueError("Max execution steps must be positive")
if self.max_execution_time <= 0:
raise ValueError("Max execution time must be positive")

View File

@ -0,0 +1,72 @@
"""
GraphExecution aggregate root managing the overall graph execution state.
"""
from dataclasses import dataclass, field
from typing import Optional
from .node_execution import NodeExecution
@dataclass
class GraphExecution:
"""
Aggregate root for graph execution.
This manages the overall execution state of a workflow graph,
coordinating between multiple node executions.
"""
workflow_id: str
started: bool = False
completed: bool = False
aborted: bool = False
error: Optional[Exception] = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
def start(self) -> None:
"""Mark the graph execution as started."""
if self.started:
raise RuntimeError("Graph execution already started")
self.started = True
def complete(self) -> None:
"""Mark the graph execution as completed."""
if not self.started:
raise RuntimeError("Cannot complete execution that hasn't started")
if self.completed:
raise RuntimeError("Graph execution already completed")
self.completed = True
def abort(self, reason: str) -> None:
"""Abort the graph execution."""
self.aborted = True
self.error = RuntimeError(f"Aborted: {reason}")
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
self.error = error
self.completed = True
def get_or_create_node_execution(self, node_id: str) -> NodeExecution:
"""Get or create a node execution entity."""
if node_id not in self.node_executions:
self.node_executions[node_id] = NodeExecution(node_id=node_id)
return self.node_executions[node_id]
@property
def is_running(self) -> bool:
"""Check if the execution is currently running."""
return self.started and not self.completed and not self.aborted
@property
def has_error(self) -> bool:
"""Check if the execution has encountered an error."""
return self.error is not None
@property
def error_message(self) -> str | None:
"""Get the error message if an error exists."""
if not self.error:
return None
return str(self.error)

View File

@ -0,0 +1,46 @@
"""
NodeExecution entity representing a node's execution state.
"""
from dataclasses import dataclass
from typing import Optional
from core.workflow.enums import NodeState
@dataclass
class NodeExecution:
"""
Entity representing the execution state of a single node.
This is a mutable entity that tracks the runtime state of a node
during graph execution.
"""
node_id: str
state: NodeState = NodeState.UNKNOWN
retry_count: int = 0
execution_id: Optional[str] = None
error: Optional[str] = None
def mark_started(self, execution_id: str) -> None:
"""Mark the node as started with an execution ID."""
self.state = NodeState.TAKEN
self.execution_id = execution_id
def mark_taken(self) -> None:
"""Mark the node as successfully completed."""
self.state = NodeState.TAKEN
self.error = None
def mark_failed(self, error: str) -> None:
"""Mark the node as failed with an error."""
self.error = error
def mark_skipped(self) -> None:
"""Mark the node as skipped."""
self.state = NodeState.SKIPPED
def increment_retry(self) -> None:
"""Increment the retry count for this node."""
self.retry_count += 1

View File

@ -1,6 +0,0 @@
from .graph import Graph
from .graph_init_params import GraphInitParams
from .graph_runtime_state import GraphRuntimeState
from .runtime_route_state import RuntimeRouteState
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

View File

@ -0,0 +1,33 @@
"""
GraphEngine command entities for external control.
This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field
class CommandType(str, Enum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = "abort"
PAUSE = "pause"
RESUME = "resume"
class GraphEngineCommand(BaseModel):
"""Base class for all GraphEngine commands."""
command_type: CommandType = Field(..., description="Type of command")
payload: Optional[dict[str, Any]] = Field(default=None, description="Optional command payload")
class AbortCommand(GraphEngineCommand):
"""Command to abort a running workflow execution."""
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
reason: Optional[str] = Field(default=None, description="Optional reason for abort")

View File

@ -1,277 +0,0 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class GraphEngineEvent(BaseModel):
pass
###########################################
# Graph Events
###########################################
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphRunStartedEvent(BaseGraphEvent):
pass
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: Optional[dict[str, Any]] = None
"""outputs"""
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: int = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
exceptions_count: int = Field(..., description="exception count")
outputs: Optional[dict[str, Any]] = None
###########################################
# Node Events
###########################################
class BaseNodeEvent(GraphEngineEvent):
id: str = Field(..., description="node execution id")
node_id: str = Field(..., description="node id")
node_type: NodeType = Field(..., description="node type")
node_data: BaseNodeData = Field(..., description="node data")
route_node_state: RouteNodeState = Field(..., description="route node state")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
"""predecessor node id"""
parallel_mode_run_id: Optional[str] = None
"""iteration node parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(BaseNodeEvent):
pass
class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunExceptionEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInLoopFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
start_at: datetime = Field(..., description="retry start time")
###########################################
# Parallel Branch Events
###########################################
class BaseParallelBranchEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
"""parallel id"""
parallel_start_node_id: str = Field(..., description="parallel start node id")
"""parallel start node id"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
error: str = Field(..., description="failed reason")
###########################################
# Iteration Events
###########################################
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration node execution id")
iteration_node_id: str = Field(..., description="iteration node id")
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
iteration_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
class IterationRunNextEvent(BaseIterationEvent):
index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = None
duration: Optional[float] = None
class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
iteration_duration_map: Optional[dict[str, float]] = None
class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")
###########################################
# Loop Events
###########################################
class BaseLoopEvent(GraphEngineEvent):
loop_id: str = Field(..., description="loop node execution id")
loop_node_id: str = Field(..., description="loop node id")
loop_node_type: NodeType = Field(..., description="node type, loop or loop")
loop_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""loop run in parallel mode run id"""
class LoopRunStartedEvent(BaseLoopEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
class LoopRunNextEvent(BaseLoopEvent):
index: int = Field(..., description="index")
pre_loop_output: Optional[Any] = None
duration: Optional[float] = None
class LoopRunSucceededEvent(BaseLoopEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
loop_duration_map: Optional[dict[str, float]] = None
class LoopRunFailedEvent(BaseLoopEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")
###########################################
# Agent Events
###########################################
class BaseAgentEvent(GraphEngineEvent):
pass
class AgentLogEvent(BaseAgentEvent):
id: str = Field(..., description="id")
label: str = Field(..., description="label")
node_execution_id: str = Field(..., description="node execution id")
parent_id: str | None = Field(..., description="parent id")
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
node_id: str = Field(..., description="agent node id")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent

View File

@ -1,721 +0,0 @@
import uuid
from collections import defaultdict
from collections.abc import Mapping
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from configs import dify_config
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
from core.workflow.nodes.end.entities import EndStreamParam
class GraphEdge(BaseModel):
source_node_id: str = Field(..., description="source node id")
target_node_id: str = Field(..., description="target node id")
run_condition: Optional[RunCondition] = None
"""run condition"""
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = None
"""parent parallel id"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
class Graph(BaseModel):
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=dict, description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict, description="graph edge mapping (source node id: edges)"
)
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict, description="reverse graph edge mapping (target node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict, description="graph parallel mapping (parallel id: parallel)"
)
node_parallel_mapping: dict[str, str] = Field(
default_factory=dict, description="graph node parallel mapping (node id: parallel id)"
)
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes")
end_stream_param: EndStreamParam = Field(..., description="end stream param")
@classmethod
def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph":
"""
Init graph
:param graph_config: graph config
:param root_node_id: root node id
:return: graph
"""
# edge configs
edge_configs = graph_config.get("edges")
if edge_configs is None:
edge_configs = []
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")
edge_configs = cast(list, edge_configs)
node_configs = cast(list, node_configs)
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
fail_branch_source_node_id = [
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
]
for edge_config in edge_configs:
source_node_id = edge_config.get("source")
if not source_node_id:
continue
if source_node_id not in edge_mapping:
edge_mapping[source_node_id] = []
target_node_id = edge_config.get("target")
if not target_node_id:
continue
if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = []
target_edge_ids.add(target_node_id)
# parse run condition
run_condition = None
if edge_config.get("sourceHandle"):
if (
edge_config.get("source") in fail_branch_source_node_id
and edge_config.get("sourceHandle") != "fail-branch"
):
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
elif edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
)
graph_edge = GraphEdge(
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
)
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id:
continue
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
all_node_id_config_mapping[node_id] = node_config
root_node_ids = [node_config.get("id") for node_config in root_node_configs]
# fetch root node
if not root_node_id:
# if no root node id, use the START type node as root node
root_node_id = next(
(
node_config.get("id")
for node_config in root_node_configs
if node_config.get("data", {}).get("type", "") == NodeType.START.value
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
),
None,
)
if not root_node_id or root_node_id not in root_node_ids:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Check whether it is connected to the previous node
cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping)
# fetch all node ids from root node
node_ids = [root_node_id]
cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id)
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
# init parallel mapping
parallel_mapping: dict[str, GraphParallel] = {}
node_parallel_mapping: dict[str, str] = {}
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=root_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
)
# Check if it exceeds N layers of parallel
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
parent_parallel_id=parallel.parent_parallel_id,
)
# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping
)
# init end stream param
end_stream_param = EndStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
node_parallel_mapping=node_parallel_mapping,
)
# init graph
graph = cls(
root_node_id=root_node_id,
node_ids=node_ids,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes,
end_stream_param=end_stream_param,
)
return graph
def add_extra_edge(
self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None
) -> None:
"""
Add extra edge to the graph
:param source_node_id: source node id
:param target_node_id: target node id
:param run_condition: run condition
"""
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
return
if source_node_id not in self.edge_mapping:
self.edge_mapping[source_node_id] = []
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
return
graph_edge = GraphEdge(
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
)
self.edge_mapping[source_node_id].append(graph_edge)
def get_leaf_node_ids(self) -> list[str]:
"""
Get leaf node ids of the graph
:return: leaf node ids
"""
leaf_node_ids = []
for node_id in self.node_ids:
if node_id not in self.edge_mapping or (
len(self.edge_mapping[node_id]) == 1
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id
):
leaf_node_ids.append(node_id)
return leaf_node_ids
@classmethod
def _recursively_add_node_ids(
cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str
) -> None:
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id
)
@classmethod
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None:
"""
Check whether it is connected to the previous node
"""
last_node_id = route[-1]
for graph_edge in edge_mapping.get(last_node_id, []):
if not graph_edge.target_node_id:
continue
if graph_edge.target_node_id in route:
raise ValueError(
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
)
new_route = route.copy()
new_route.append(graph_edge.target_node_id)
cls._check_connected_to_previous_node(
route=new_route,
edge_mapping=edge_mapping,
)
@classmethod
def _recursively_add_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str],
parent_parallel: Optional[GraphParallel] = None,
) -> None:
"""
Recursively add parallel ids
:param edge_mapping: edge mapping
:param start_node_id: start from node id
:param parallel_mapping: parallel mapping
:param node_parallel_mapping: node parallel mapping
:param parent_parallel: parent parallel
"""
target_node_edges = edge_mapping.get(start_node_id, [])
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = defaultdict(list)
condition_edge_mappings = defaultdict(list)
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
condition_edge_mappings[condition_hash].append(graph_edge)
for condition_hash, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
for graph_edge in graph_edges:
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
condition_parallels = {}
for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
# any target node id in node_parallel_mapping
parallel = None
if condition_parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
)
parallel_mapping[parallel.id] = parallel
condition_parallels[condition_hash] = parallel
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=condition_parallel_branch_node_ids,
)
# collect all branches node ids
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break
if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id
outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue
node_edges = edge_mapping.get(node_id)
if not node_edges:
continue
if len(node_edges) > 1:
continue
target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue
if (
(
node_parallel_mapping.get(target_node_id)
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
)
or (
parent_parallel
and parent_parallel.end_to_node_id
and target_node_id == parent_parallel.end_to_node_id
)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
if len(outside_parallel_target_node_ids) == 1:
if (
parent_parallel
and parent_parallel.end_to_node_id
and parallel.end_to_node_id == parent_parallel.end_to_node_id
):
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
if condition_edge_mappings:
for condition_hash, graph_edges in condition_edge_mappings.items():
for graph_edge in graph_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=condition_parallels.get(condition_hash),
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
@classmethod
def _get_current_parallel(
cls,
parallel_mapping: dict[str, GraphParallel],
graph_edge: GraphEdge,
parallel: Optional[GraphParallel] = None,
parent_parallel: Optional[GraphParallel] = None,
) -> Optional[GraphParallel]:
"""
Get current parallel
"""
current_parallel = None
if parallel:
current_parallel = parallel
elif parent_parallel:
if not parent_parallel.end_to_node_id or (
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
):
current_parallel = parent_parallel
else:
# fetch parent parallel's parent parallel
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
if parent_parallel_parent_parallel_id:
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
if parent_parallel_parent_parallel and (
not parent_parallel_parent_parallel.end_to_node_id
or (
parent_parallel_parent_parallel.end_to_node_id
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
)
):
current_parallel = parent_parallel_parent_parallel
return current_parallel
@classmethod
def _check_exceed_parallel_limit(
cls,
parallel_mapping: dict[str, GraphParallel],
level_limit: int,
parent_parallel_id: str,
current_level: int = 1,
) -> None:
"""
Check if it exceeds N layers of parallel
"""
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
return
current_level += 1
if current_level > level_limit:
raise ValueError(f"Exceeds {level_limit} layers of parallel")
if parent_parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=level_limit,
parent_parallel_id=parent_parallel.parent_parallel_id,
current_level=current_level,
)
@classmethod
def _recursively_add_parallel_node_ids(
cls,
branch_node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
merge_node_id: str,
start_node_id: str,
) -> None:
"""
Recursively add node ids
:param branch_node_ids: in branch node ids
:param edge_mapping: edge mapping
:param merge_node_id: merge node id
:param start_node_id: start node id
"""
for graph_edge in edge_mapping.get(start_node_id, []):
if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids:
branch_node_ids.append(graph_edge.target_node_id)
cls._recursively_add_parallel_node_ids(
branch_node_ids=branch_node_ids,
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=graph_edge.target_node_id,
)
@classmethod
def _fetch_all_node_ids_in_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
parallel_branch_node_ids: list[str],
) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
"""
routes_node_ids: dict[str, list[str]] = {}
for parallel_branch_node_id in parallel_branch_node_ids:
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
# fetch routes node ids
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=parallel_branch_node_id,
routes_node_ids=routes_node_ids[parallel_branch_node_id],
)
# fetch leaf node ids from routes node ids
leaf_node_ids: dict[str, list[str]] = {}
merge_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
if branch_node_id not in leaf_node_ids:
leaf_node_ids[branch_node_id] = []
leaf_node_ids[branch_node_id].append(node_id)
for branch_node_id2, inner_route2 in routes_node_ids.items():
if (
branch_node_id != branch_node_id2
and node_id in inner_route2
and len(reverse_edge_mapping.get(node_id, [])) > 1
and cls._is_node_in_routes(
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=node_id,
routes_node_ids=routes_node_ids,
)
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []
if branch_node_id2 not in merge_branch_node_ids[node_id]:
merge_branch_node_ids[node_id].append(branch_node_id2)
# sorted merge_branch_node_ids by branch_node_ids length desc
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
duplicate_end_node_ids = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
if (node_id, node_id2) not in duplicate_end_node_ids and (
node_id2,
node_id,
) not in duplicate_end_node_ids:
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
if len(branch_node_ids) <= 1:
continue
for branch_node_id in branch_node_ids:
if branch_node_id in branches_merge_node_ids:
continue
branches_merge_node_ids[branch_node_id] = node_id
in_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
in_branch_node_ids[branch_node_id] = []
if branch_node_id not in branches_merge_node_ids:
# all node ids in current branch is in this thread
in_branch_node_ids[branch_node_id].append(branch_node_id)
in_branch_node_ids[branch_node_id].extend(node_ids)
else:
merge_node_id = branches_merge_node_ids[branch_node_id]
if merge_node_id != branch_node_id:
in_branch_node_ids[branch_node_id].append(branch_node_id)
# fetch all node ids from branch_node_id and merge_node_id
cls._recursively_add_parallel_node_ids(
branch_node_ids=in_branch_node_ids[branch_node_id],
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=branch_node_id,
)
return in_branch_node_ids
@classmethod
def _recursively_fetch_routes(
cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str]
) -> None:
"""
Recursively fetch route
"""
if start_node_id not in edge_mapping:
return
for graph_edge in edge_mapping[start_node_id]:
# find next node ids
if graph_edge.target_node_id not in routes_node_ids:
routes_node_ids.append(graph_edge.target_node_id)
cls._recursively_fetch_routes(
edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids
)
@classmethod
def _is_node_in_routes(
cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]]
) -> bool:
"""
Recursively check if the node is in the routes
"""
if start_node_id not in reverse_edge_mapping:
return False
all_routes_node_ids = set()
parallel_start_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
all_routes_node_ids.update(node_ids)
if branch_node_id in reverse_edge_mapping:
for graph_edge in reverse_edge_mapping[branch_node_id]:
if graph_edge.source_node_id not in parallel_start_node_ids:
parallel_start_node_ids[graph_edge.source_node_id] = []
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
for _, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == set(routes_node_ids.keys()):
return True
return False
@classmethod
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
"""
is node2 after node1
"""
if node1_id not in edge_mapping:
return False
for graph_edge in edge_mapping[node1_id]:
if graph_edge.target_node_id == node2_id:
return True
if cls._is_node2_after_node1(
node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping
):
return True
return False

View File

@ -1,21 +0,0 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from models.enums import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth")

View File

@ -1,31 +0,0 @@
from typing import Any
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
start_at: float = Field(..., description="start time")
"""start time"""
total_tokens: int = 0
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
#
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
# after a serialization and deserialization round trip.
outputs: dict[str, Any] = Field(default_factory=dict)
node_run_steps: int = 0
"""node run steps"""
node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state"""

View File

@ -1,21 +0,0 @@
import hashlib
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
"""conditions to run the node, required when type is condition"""
@property
def hash(self) -> str:
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()

View File

@ -1,118 +0,0 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from libs.datetime_utils import naive_utc_now
class RouteNodeState(BaseModel):
class Status(Enum):
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
EXCEPTION = "exception"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
node_id: str
"""node id"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.RUNNING
"""node status"""
start_at: datetime
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
index: int = 1
def set_finished(self, run_result: NodeRunResult) -> None:
"""
Node finished
:param run_result: run result
"""
if self.status in {
RouteNodeState.Status.SUCCESS,
RouteNodeState.Status.FAILED,
RouteNodeState.Status.EXCEPTION,
}:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
self.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
self.status = RouteNodeState.Status.EXCEPTION
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result
self.finished_at = naive_utc_now()
class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field(
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
)
node_state_mapping: dict[str, RouteNodeState] = Field(
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
)
def create_node_state(self, node_id: str) -> RouteNodeState:
"""
Create node state
:param node_id: node id
"""
state = RouteNodeState(node_id=node_id, start_at=naive_utc_now())
self.node_state_mapping[state.id] = state
return state
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
"""
Add route to the graph state
:param source_node_state_id: source node state id
:param target_node_state_id: target node state id
"""
if source_node_state_id not in self.routes:
self.routes[source_node_state_id] = []
self.routes[source_node_state_id].append(target_node_state_id)
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]:
"""
Get routes with node state by source node id
:param source_node_state_id: source node state id
:return: routes with node state
"""
return [
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
]

View File

@ -0,0 +1,22 @@
"""
Error handling strategies for graph engine.
This package implements different error recovery strategies using
the Strategy pattern for clean separation of concerns.
"""
from .abort_strategy import AbortStrategy
from .default_value_strategy import DefaultValueStrategy
from .error_handler import ErrorHandler
from .error_strategy import ErrorStrategy
from .fail_branch_strategy import FailBranchStrategy
from .retry_strategy import RetryStrategy
__all__ = [
"AbortStrategy",
"DefaultValueStrategy",
"ErrorHandler",
"ErrorStrategy",
"FailBranchStrategy",
"RetryStrategy",
]

View File

@ -0,0 +1,37 @@
"""
Abort error strategy implementation.
"""
import logging
from typing import Optional
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
logger = logging.getLogger(__name__)
class AbortStrategy:
"""
Error strategy that aborts execution on failure.
This is the default strategy when no other strategy is specified.
It stops the entire graph execution when a node fails.
"""
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
"""
Handle error by aborting execution.
Args:
event: The failure event
graph: The workflow graph
retry_count: Current retry attempt count (unused)
Returns:
None - signals abortion
"""
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
# Return None to signal that execution should stop
return None

View File

@ -0,0 +1,56 @@
"""
Default value error strategy implementation.
"""
from typing import Optional
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent
from core.workflow.node_events import NodeRunResult
class DefaultValueStrategy:
"""
Error strategy that uses default values on failure.
This strategy allows nodes to fail gracefully by providing
predefined default output values.
"""
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
"""
Handle error by using default values.
Args:
event: The failure event
graph: The workflow graph
retry_count: Current retry attempt count (unused)
Returns:
NodeRunExceptionEvent with default values
"""
node = graph.nodes[event.node_id]
outputs = {
**node.default_value_dict,
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,
}
return NodeRunExceptionEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
metadata={
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategy.DEFAULT_VALUE,
},
),
error=event.error,
)

View File

@ -0,0 +1,82 @@
"""
Main error handler that coordinates error strategies.
"""
from typing import TYPE_CHECKING, Optional
from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from .abort_strategy import AbortStrategy
from .default_value_strategy import DefaultValueStrategy
from .fail_branch_strategy import FailBranchStrategy
from .retry_strategy import RetryStrategy
if TYPE_CHECKING:
from ..domain import GraphExecution
class ErrorHandler:
"""
Coordinates error handling strategies for node failures.
This acts as a facade for the various error strategies,
selecting and applying the appropriate strategy based on
node configuration.
"""
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
"""
Initialize the error handler.
Args:
graph: The workflow graph
graph_execution: The graph execution state
"""
self.graph = graph
self.graph_execution = graph_execution
# Initialize strategies
self.abort_strategy = AbortStrategy()
self.retry_strategy = RetryStrategy()
self.fail_branch_strategy = FailBranchStrategy()
self.default_value_strategy = DefaultValueStrategy()
def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]:
"""
Handle a node failure event.
Selects and applies the appropriate error strategy based on
the node's configuration.
Args:
event: The node failure event
Returns:
Optional new event to process, or None to abort
"""
node = self.graph.nodes[event.node_id]
# Get retry count from NodeExecution
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
retry_count = node_execution.retry_count
# First check if retry is configured and not exhausted
if node.retry and retry_count < node.retry_config.max_retries:
result = self.retry_strategy.handle_error(event, self.graph, retry_count)
if result:
# Retry count will be incremented when NodeRunRetryEvent is handled
return result
# Apply configured error strategy
strategy = node.error_strategy
if strategy is None:
return self.abort_strategy.handle_error(event, self.graph, retry_count)
elif strategy == ErrorStrategyEnum.FAIL_BRANCH:
return self.fail_branch_strategy.handle_error(event, self.graph, retry_count)
elif strategy == ErrorStrategyEnum.DEFAULT_VALUE:
return self.default_value_strategy.handle_error(event, self.graph, retry_count)
else:
# Unknown strategy, default to abort
return self.abort_strategy.handle_error(event, self.graph, retry_count)

View File

@ -0,0 +1,31 @@
"""
Base error strategy protocol.
"""
from typing import Optional, Protocol
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
class ErrorStrategy(Protocol):
"""
Protocol for error handling strategies.
Each strategy implements a different approach to handling
node execution failures.
"""
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
"""
Handle a node failure event.
Args:
event: The failure event
graph: The workflow graph
retry_count: Current retry attempt count
Returns:
Optional new event to process, or None to stop
"""
...

View File

@ -0,0 +1,54 @@
"""
Fail branch error strategy implementation.
"""
from typing import Optional
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent
from core.workflow.node_events import NodeRunResult
class FailBranchStrategy:
"""
Error strategy that continues execution via a fail branch.
This strategy converts failures to exceptions and routes execution
through a designated fail-branch edge.
"""
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
"""
Handle error by taking the fail branch.
Args:
event: The failure event
graph: The workflow graph
retry_count: Current retry attempt count (unused)
Returns:
NodeRunExceptionEvent to continue via fail branch
"""
outputs = {
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,
}
return NodeRunExceptionEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
edge_source_handle="fail-branch",
metadata={
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategy.FAIL_BRANCH,
},
),
error=event.error,
)

View File

@ -0,0 +1,51 @@
"""
Retry error strategy implementation.
"""
import time
from typing import Optional
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunRetryEvent
class RetryStrategy:
"""
Error strategy that retries failed nodes.
This strategy re-attempts node execution up to a configured
maximum number of retries with configurable intervals.
"""
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
"""
Handle error by retrying the node.
Args:
event: The failure event
graph: The workflow graph
retry_count: Current retry attempt count
Returns:
NodeRunRetryEvent if retry should occur, None otherwise
"""
node = graph.nodes[event.node_id]
# Check if we've exceeded max retries
if not node.retry or retry_count >= node.retry_config.max_retries:
return None
# Wait for retry interval
time.sleep(node.retry_config.retry_interval_seconds)
# Create retry event
return NodeRunRetryEvent(
id=event.id,
node_title=node.title,
node_id=event.node_id,
node_type=event.node_type,
node_run_result=event.node_run_result,
start_at=event.start_at,
error=event.error,
retry_index=retry_count + 1,
)

View File

@ -0,0 +1,16 @@
"""
Event management subsystem for graph engine.
This package handles event routing, collection, and emission for
workflow graph execution events.
"""
from .event_collector import EventCollector
from .event_emitter import EventEmitter
from .event_handlers import EventHandlerRegistry
__all__ = [
"EventCollector",
"EventEmitter",
"EventHandlerRegistry",
]

View File

@ -0,0 +1,98 @@
"""
Event collector for buffering and managing events.
"""
import threading
from core.workflow.graph_events import GraphEngineEvent
from ..layers.base import Layer
class EventCollector:
"""
Collects and buffers events for later retrieval.
This provides thread-safe event collection with support for
notifying layers about events as they're collected.
"""
def __init__(self) -> None:
"""Initialize the event collector."""
self._events: list[GraphEngineEvent] = []
self._lock = threading.Lock()
self._layers: list[Layer] = []
def set_layers(self, layers: list[Layer]) -> None:
"""
Set the layers to notify on event collection.
Args:
layers: List of layers to notify
"""
self._layers = layers
def collect(self, event: GraphEngineEvent) -> None:
"""
Thread-safe method to collect an event.
Args:
event: The event to collect
"""
with self._lock:
self._events.append(event)
self._notify_layers(event)
def get_events(self) -> list[GraphEngineEvent]:
"""
Get all collected events.
Returns:
List of collected events
"""
with self._lock:
return list(self._events)
def get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
"""
Get new events starting from a specific index.
Args:
start_index: The index to start from
Returns:
List of new events
"""
with self._lock:
return list(self._events[start_index:])
def event_count(self) -> int:
"""
Get the current count of collected events.
Returns:
Number of collected events
"""
with self._lock:
return len(self._events)
def clear(self) -> None:
"""Clear all collected events."""
with self._lock:
self._events.clear()
def _notify_layers(self, event: GraphEngineEvent) -> None:
"""
Notify all layers of an event.
Layer exceptions are caught and logged to prevent disrupting collection.
Args:
event: The event to send to layers
"""
for layer in self._layers:
try:
layer.on_event(event)
except Exception:
# Silently ignore layer errors during collection
pass

View File

@ -0,0 +1,56 @@
"""
Event emitter for yielding events to external consumers.
"""
import threading
import time
from collections.abc import Generator
from core.workflow.graph_events import GraphEngineEvent
from .event_collector import EventCollector
class EventEmitter:
"""
Emits collected events as a generator for external consumption.
This provides a generator interface for yielding events as they're
collected, with proper synchronization for multi-threaded access.
"""
def __init__(self, event_collector: EventCollector) -> None:
"""
Initialize the event emitter.
Args:
event_collector: The collector to emit events from
"""
self.event_collector = event_collector
self._execution_complete = threading.Event()
def mark_complete(self) -> None:
"""Mark execution as complete to stop the generator."""
self._execution_complete.set()
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
"""
Generator that yields events as they're collected.
Yields:
GraphEngineEvent instances as they're processed
"""
yielded_count = 0
while not self._execution_complete.is_set() or yielded_count < self.event_collector.event_count():
# Get new events since last yield
new_events = self.event_collector.get_new_events(yielded_count)
# Yield any new events
for event in new_events:
yield event
yielded_count += 1
# Small sleep to avoid busy waiting
if not self._execution_complete.is_set() and not new_events:
time.sleep(0.001)

View File

@ -0,0 +1,303 @@
"""
Event handler implementations for different event types.
"""
import logging
from typing import TYPE_CHECKING, Optional
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from ..domain.graph_execution import GraphExecution
from ..response_coordinator import ResponseStreamCoordinator
if TYPE_CHECKING:
from ..error_handling import ErrorHandler
from ..graph_traversal import BranchHandler, EdgeProcessor
from ..state_management import ExecutionTracker, NodeStateManager
from .event_collector import EventCollector
logger = logging.getLogger(__name__)
class EventHandlerRegistry:
"""
Registry of event handlers for different event types.
This centralizes the business logic for handling specific events,
keeping it separate from the routing and collection infrastructure.
"""
def __init__(
self,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
graph_execution: GraphExecution,
response_coordinator: ResponseStreamCoordinator,
event_collector: Optional["EventCollector"] = None,
branch_handler: Optional["BranchHandler"] = None,
edge_processor: Optional["EdgeProcessor"] = None,
node_state_manager: Optional["NodeStateManager"] = None,
execution_tracker: Optional["ExecutionTracker"] = None,
error_handler: Optional["ErrorHandler"] = None,
) -> None:
"""
Initialize the event handler registry.
Args:
graph: The workflow graph
graph_runtime_state: Runtime state with variable pool
graph_execution: Graph execution aggregate
response_coordinator: Response stream coordinator
event_collector: Optional event collector for collecting events
branch_handler: Optional branch handler for branch node processing
edge_processor: Optional edge processor for edge traversal
node_state_manager: Optional node state manager
execution_tracker: Optional execution tracker
error_handler: Optional error handler
"""
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.graph_execution = graph_execution
self.response_coordinator = response_coordinator
self.event_collector = event_collector
self.branch_handler = branch_handler
self.edge_processor = edge_processor
self.node_state_manager = node_state_manager
self.execution_tracker = execution_tracker
self.error_handler = error_handler
def handle_event(self, event: GraphNodeEventBase) -> None:
"""
Handle any node event by dispatching to the appropriate handler.
Args:
event: The event to handle
"""
# Events in loops or iterations are always collected
if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id):
if self.event_collector:
self.event_collector.collect(event)
return
# Handle specific event types
if isinstance(event, NodeRunStartedEvent):
self._handle_node_started(event)
elif isinstance(event, NodeRunStreamChunkEvent):
self._handle_stream_chunk(event)
elif isinstance(event, NodeRunSucceededEvent):
self._handle_node_succeeded(event)
elif isinstance(event, NodeRunFailedEvent):
self._handle_node_failed(event)
elif isinstance(event, NodeRunExceptionEvent):
self._handle_node_exception(event)
elif isinstance(event, NodeRunRetryEvent):
self._handle_node_retry(event)
elif isinstance(
event,
(
NodeRunIterationStartedEvent,
NodeRunIterationNextEvent,
NodeRunIterationSucceededEvent,
NodeRunIterationFailedEvent,
NodeRunLoopStartedEvent,
NodeRunLoopNextEvent,
NodeRunLoopSucceededEvent,
NodeRunLoopFailedEvent,
),
):
# Iteration and loop events are collected directly
if self.event_collector:
self.event_collector.collect(event)
else:
# Collect unhandled events
if self.event_collector:
self.event_collector.collect(event)
logger.warning("Unhandled event type: %s", type(event).__name__)
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
"""
Handle node started event.
Args:
event: The node started event
"""
# Track execution in domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_started(event.id)
# Track in response coordinator for stream ordering
self.response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
if self.event_collector:
self.event_collector.collect(event)
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
"""
Handle stream chunk event with full processing.
Args:
event: The stream chunk event
"""
# Process with response coordinator
streaming_events = list(self.response_coordinator.intercept_event(event))
# Collect all events
if self.event_collector:
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
"""
Handle node success by coordinating subsystems.
This method coordinates between different subsystems to process
node completion, handle edges, and trigger downstream execution.
Args:
event: The node succeeded event
"""
# Update domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
# Store outputs in variable pool
self._store_node_outputs(event)
# Forward to response coordinator and emit streaming events
streaming_events = list(self.response_coordinator.intercept_event(event))
if self.event_collector:
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
# Process edges and get ready nodes
node = self.graph.nodes[event.node_id]
if node.execution_type == NodeExecutionType.BRANCH:
if self.branch_handler:
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
ready_nodes, edge_streaming_events = [], []
else:
if self.edge_processor:
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
else:
ready_nodes, edge_streaming_events = [], []
# Collect streaming events from edge processing
if self.event_collector:
for edge_event in edge_streaming_events:
self.event_collector.collect(edge_event)
# Enqueue ready nodes
if self.node_state_manager and self.execution_tracker:
for node_id in ready_nodes:
self.node_state_manager.enqueue_node(node_id)
self.execution_tracker.add(node_id)
# Update execution tracking
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event)
# Collect the event
if self.event_collector:
self.event_collector.collect(event)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
"""
Handle node failure using error handler.
Args:
event: The node failed event
"""
# Update domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
if self.error_handler:
result = self.error_handler.handle_node_failure(event)
if result:
# Process the resulting event (retry, exception, etc.)
self.handle_event(result)
else:
# Abort execution
self.graph_execution.fail(RuntimeError(event.error))
if self.event_collector:
self.event_collector.collect(event)
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
else:
# Without error handler, just fail
self.graph_execution.fail(RuntimeError(event.error))
if self.event_collector:
self.event_collector.collect(event)
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
"""
Handle node exception event (fail-branch strategy).
Args:
event: The node exception event
"""
# Node continues via fail-branch, so it's technically "succeeded"
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
"""
Handle node retry event.
Args:
event: The node retry event
"""
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
"""
Store node outputs in the variable pool.
Args:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in event.node_run_result.outputs.items():
self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
"""Update response outputs for response nodes."""
for key, value in event.node_run_result.outputs.items():
if key == "answer":
existing = self.graph_runtime_state.outputs.get("answer", "")
if existing:
self.graph_runtime_state.outputs["answer"] = f"{existing}{value}"
else:
self.graph_runtime_state.outputs["answer"] = value
else:
self.graph_runtime_state.outputs[key] = value

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,18 @@
"""
Graph traversal subsystem for graph engine.
This package handles graph navigation, edge processing,
and skip propagation logic.
"""
from .branch_handler import BranchHandler
from .edge_processor import EdgeProcessor
from .node_readiness import NodeReadinessChecker
from .skip_propagator import SkipPropagator
__all__ = [
"BranchHandler",
"EdgeProcessor",
"NodeReadinessChecker",
"SkipPropagator",
]

View File

@ -0,0 +1,82 @@
"""
Branch node handling for graph traversal.
"""
from typing import Optional
from core.workflow.graph import Graph
from ..state_management import EdgeStateManager
from .edge_processor import EdgeProcessor
from .skip_propagator import SkipPropagator
class BranchHandler:
"""
Handles branch node logic during graph traversal.
Branch nodes select one of multiple paths based on conditions,
requiring special handling for edge selection and skip propagation.
"""
def __init__(
self,
graph: Graph,
edge_processor: EdgeProcessor,
skip_propagator: SkipPropagator,
edge_state_manager: EdgeStateManager,
) -> None:
"""
Initialize the branch handler.
Args:
graph: The workflow graph
edge_processor: Processor for edges
skip_propagator: Propagator for skip states
edge_state_manager: Manager for edge states
"""
self.graph = graph
self.edge_processor = edge_processor
self.skip_propagator = skip_propagator
self.edge_state_manager = edge_state_manager
def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]:
"""
Handle completion of a branch node.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected branch
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no branch was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
# Categorize edges into selected and unselected
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
# Skip all unselected paths
self.skip_propagator.skip_branch_paths(node_id, unselected_edges)
# Process selected edges and get ready nodes and streaming events
return self.edge_processor.process_node_success(node_id, selected_handle)
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
"""
Validate that a branch selection is valid.
Args:
node_id: The ID of the branch node
selected_handle: The handle to validate
Returns:
True if the selection is valid
"""
outgoing_edges = self.graph.get_outgoing_edges(node_id)
valid_handles = {edge.source_handle for edge in outgoing_edges}
return selected_handle in valid_handles

View File

@ -0,0 +1,145 @@
"""
Edge processing logic for graph traversal.
"""
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Edge, Graph
from ..response_coordinator import ResponseStreamCoordinator
from ..state_management import EdgeStateManager, NodeStateManager
class EdgeProcessor:
"""
Processes edges during graph execution.
This handles marking edges as taken or skipped, notifying
the response coordinator, and triggering downstream node execution.
"""
def __init__(
self,
graph: Graph,
edge_state_manager: EdgeStateManager,
node_state_manager: NodeStateManager,
response_coordinator: ResponseStreamCoordinator,
) -> None:
"""
Initialize the edge processor.
Args:
graph: The workflow graph
edge_state_manager: Manager for edge states
node_state_manager: Manager for node states
response_coordinator: Response stream coordinator
"""
self.graph = graph
self.edge_state_manager = edge_state_manager
self.node_state_manager = node_state_manager
self.response_coordinator = response_coordinator
def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]:
"""
Process edges after a node succeeds.
Args:
node_id: The ID of the succeeded node
selected_handle: For branch nodes, the selected edge handle
Returns:
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
"""
node = self.graph.nodes[node_id]
if node.execution_type == NodeExecutionType.BRANCH:
return self._process_branch_node_edges(node_id, selected_handle)
else:
return self._process_non_branch_node_edges(node_id)
def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]:
"""
Process edges for non-branch nodes (mark all as TAKEN).
Args:
node_id: The ID of the succeeded node
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
"""
ready_nodes = []
all_streaming_events = []
outgoing_edges = self.graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
nodes, events = self._process_taken_edge(edge)
ready_nodes.extend(nodes)
all_streaming_events.extend(events)
return ready_nodes, all_streaming_events
def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]:
"""
Process edges for branch nodes.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no edge was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} did not select any edge")
ready_nodes = []
all_streaming_events = []
# Categorize edges
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
# Process unselected edges first (mark as skipped)
for edge in unselected_edges:
self._process_skipped_edge(edge)
# Process selected edges
for edge in selected_edges:
nodes, events = self._process_taken_edge(edge)
ready_nodes.extend(nodes)
all_streaming_events.extend(events)
return ready_nodes, all_streaming_events
def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]:
"""
Mark edge as taken and check downstream node.
Args:
edge: The edge to process
Returns:
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
"""
# Mark edge as taken
self.edge_state_manager.mark_edge_taken(edge.id)
# Notify response coordinator and get streaming events
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready
ready_nodes = []
if self.node_state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)
return ready_nodes, list(streaming_events)
def _process_skipped_edge(self, edge: Edge) -> None:
"""
Mark edge as skipped.
Args:
edge: The edge to skip
"""
self.edge_state_manager.mark_edge_skipped(edge.id)

View File

@ -0,0 +1,83 @@
"""
Node readiness checking for execution.
"""
from core.workflow.enums import NodeState
from core.workflow.graph import Graph
class NodeReadinessChecker:
"""
Checks if nodes are ready for execution based on their dependencies.
A node is ready when its dependencies (incoming edges) have been
satisfied according to the graph's execution rules.
"""
def __init__(self, graph: Graph) -> None:
"""
Initialize the readiness checker.
Args:
graph: The workflow graph
"""
self.graph = graph
def is_node_ready(self, node_id: str) -> bool:
"""
Check if a node is ready to be executed.
A node is ready when:
- It has no incoming edges (root or isolated node), OR
- At least one incoming edge is TAKEN and none are UNKNOWN
Args:
node_id: The ID of the node to check
Returns:
True if the node is ready for execution
"""
incoming_edges = self.graph.get_incoming_edges(node_id)
# No dependencies means always ready
if not incoming_edges:
return True
# Check edge states
has_unknown = False
has_taken = False
for edge in incoming_edges:
if edge.state == NodeState.UNKNOWN:
has_unknown = True
break
elif edge.state == NodeState.TAKEN:
has_taken = True
# Not ready if any dependency is still unknown
if has_unknown:
return False
# Ready if at least one path is taken
return has_taken
def get_ready_downstream_nodes(self, from_node_id: str) -> list[str]:
"""
Get all downstream nodes that are ready after a node completes.
Args:
from_node_id: The ID of the completed node
Returns:
List of node IDs that are now ready
"""
ready_nodes = []
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
for edge in outgoing_edges:
if edge.state == NodeState.TAKEN:
downstream_node_id = edge.head
if self.is_node_ready(downstream_node_id):
ready_nodes.append(downstream_node_id)
return ready_nodes

View File

@ -0,0 +1,96 @@
"""
Skip state propagation through the graph.
"""
from core.workflow.graph import Graph
from ..state_management import EdgeStateManager, NodeStateManager
class SkipPropagator:
"""
Propagates skip states through the graph.
When a node is skipped, this ensures all downstream nodes
that depend solely on it are also skipped.
"""
def __init__(
self,
graph: Graph,
edge_state_manager: EdgeStateManager,
node_state_manager: NodeStateManager,
) -> None:
"""
Initialize the skip propagator.
Args:
graph: The workflow graph
edge_state_manager: Manager for edge states
node_state_manager: Manager for node states
"""
self.graph = graph
self.edge_state_manager = edge_state_manager
self.node_state_manager = node_state_manager
def propagate_skip_from_edge(self, edge_id: str) -> None:
"""
Recursively propagate skip state from a skipped edge.
Rules:
- If a node has any UNKNOWN incoming edges, stop processing
- If all incoming edges are SKIPPED, skip the node and its edges
- If any incoming edge is TAKEN, the node may still execute
Args:
edge_id: The ID of the skipped edge to start from
"""
downstream_node_id = self.graph.edges[edge_id].head
incoming_edges = self.graph.get_incoming_edges(downstream_node_id)
# Analyze edge states
edge_states = self.edge_state_manager.analyze_edge_states(incoming_edges)
# Stop if there are unknown edges (not yet processed)
if edge_states["has_unknown"]:
return
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Check if node is ready and enqueue if so
if self.node_state_manager.is_node_ready(downstream_node_id):
self.node_state_manager.enqueue_node(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
if edge_states["all_skipped"]:
self._propagate_skip_to_node(downstream_node_id)
def _propagate_skip_to_node(self, node_id: str) -> None:
"""
Mark a node and all its outgoing edges as skipped.
Args:
node_id: The ID of the node to skip
"""
# Mark node as skipped
self.node_state_manager.mark_node_skipped(node_id)
# Mark all outgoing edges as skipped and propagate
outgoing_edges = self.graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
self.edge_state_manager.mark_edge_skipped(edge.id)
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)
def skip_branch_paths(self, node_id: str, unselected_edges: list) -> None:
"""
Skip all paths from unselected branch edges.
Args:
node_id: The ID of the branch node
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:
self.edge_state_manager.mark_edge_skipped(edge.id)
self.propagate_skip_from_edge(edge.id)

View File

@ -0,0 +1,52 @@
# Layers
Pluggable middleware for engine extensions.
## Components
### Layer (base)
Abstract base class for layers.
- `initialize()` - Receive runtime context
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
### DebugLoggingLayer
Comprehensive execution logging.
- Configurable detail levels
- Tracks execution statistics
- Truncates long values
## Usage
```python
debug_layer = DebugLoggingLayer(
level="INFO",
include_outputs=True
)
engine = GraphEngine(graph)
engine.add_layer(debug_layer)
engine.run()
```
## Custom Layers
```python
class MetricsLayer(Layer):
def on_event(self, event):
if isinstance(event, NodeRunSucceededEvent):
self.metrics[event.node_id] = event.elapsed_time
```
## Configuration
**DebugLoggingLayer Options:**
- `level` - Log level (INFO, DEBUG, ERROR)
- `include_inputs/outputs` - Log data values
- `max_value_length` - Truncate long values

View File

@ -0,0 +1,16 @@
"""
Layer system for GraphEngine extensibility.
This module provides the layer infrastructure for extending GraphEngine functionality
with middleware-like components that can observe events and interact with execution.
"""
from .base import Layer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"Layer",
]

View File

@ -0,0 +1,86 @@
"""
Base layer class for GraphEngine extensions.
This module provides the abstract base class for implementing layers that can
intercept and respond to GraphEngine events.
"""
from abc import ABC, abstractmethod
from typing import Optional
from core.workflow.entities import GraphRuntimeState
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
class Layer(ABC):
"""
Abstract base class for GraphEngine layers.
Layers are middleware-like components that can:
- Observe all events emitted by the GraphEngine
- Access the graph runtime state
- Send commands to control execution
Subclasses should override the constructor to accept configuration parameters,
then implement the three lifecycle methods.
"""
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: Optional[GraphRuntimeState] = None
self.command_channel: Optional[CommandChannel] = None
def initialize(self, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine before execution starts to inject the runtime state
and command channel. This allows layers to access engine context and send
commands.
Args:
graph_runtime_state: The runtime state of the graph execution
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod
def on_graph_start(self) -> None:
"""
Called when graph execution starts.
This is called after the engine has been initialized but before any nodes
are executed. Layers can use this to set up resources or log start information.
"""
pass
@abstractmethod
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
This method receives all events generated during graph execution, including:
- Graph lifecycle events (start, success, failure)
- Node execution events (start, success, failure, retry)
- Stream events for response nodes
- Container events (iteration, loop)
Args:
event: The event emitted by the engine
"""
pass
@abstractmethod
def on_graph_end(self, error: Optional[Exception]) -> None:
"""
Called when graph execution ends.
This is called after all nodes have been executed or when execution is
aborted. Layers can use this to clean up resources or log final state.
Args:
error: The exception that caused execution to fail, or None if successful
"""
pass

View File

@ -0,0 +1,246 @@
"""
Debug logging layer for GraphEngine.
This module provides a layer that logs all events and state changes during
graph execution for debugging purposes.
"""
import logging
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from .base import Layer
class DebugLoggingLayer(Layer):
"""
A layer that provides comprehensive logging of GraphEngine execution.
This layer logs all events with configurable detail levels, helping developers
debug workflow execution and understand the flow of events.
"""
def __init__(
self,
level: str = "INFO",
include_inputs: bool = False,
include_outputs: bool = True,
include_process_data: bool = False,
logger_name: str = "GraphEngine.Debug",
max_value_length: int = 500,
) -> None:
"""
Initialize the debug logging layer.
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR)
include_inputs: Whether to log node input values
include_outputs: Whether to log node output values
include_process_data: Whether to log node process data
logger_name: Name of the logger to use
max_value_length: Maximum length of logged values (truncated if longer)
"""
super().__init__()
self.level = level
self.include_inputs = include_inputs
self.include_outputs = include_outputs
self.include_process_data = include_process_data
self.max_value_length = max_value_length
# Set up logger
self.logger = logging.getLogger(logger_name)
log_level = getattr(logging, level.upper(), logging.INFO)
self.logger.setLevel(log_level)
# Track execution stats
self.node_count = 0
self.success_count = 0
self.failure_count = 0
self.retry_count = 0
def _truncate_value(self, value: Any) -> str:
"""Truncate long values for logging."""
str_value = str(value)
if len(str_value) > self.max_value_length:
return str_value[: self.max_value_length] + "... (truncated)"
return str_value
def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str:
"""Format a dictionary or mapping for logging with truncation."""
if not data:
return "{}"
formatted_items = []
for key, value in data.items():
formatted_value = self._truncate_value(value)
formatted_items.append(f" {key}: {formatted_value}")
return "{\n" + ",\n".join(formatted_items) + "\n}"
def on_graph_start(self) -> None:
"""Log graph execution start."""
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
if self.graph_runtime_state:
# Log initial state
self.logger.info("Initial State:")
# Log inputs if available
if self.graph_runtime_state.variable_pool:
initial_vars = {}
# Access the variable dictionary directly
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
for var_key, var in variables.items():
initial_vars[f"{node_id}.{var_key}"] = str(var.value) if hasattr(var, "value") else str(var)
if initial_vars:
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
def on_event(self, event: GraphEngineEvent) -> None:
"""Log individual events based on their type."""
event_class = event.__class__.__name__
# Graph-level events
if isinstance(event, GraphRunStartedEvent):
self.logger.debug("Graph run started event")
elif isinstance(event, GraphRunSucceededEvent):
self.logger.info("✅ Graph run succeeded")
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0:
self.logger.error(" Total exceptions: %s", event.exceptions_count)
elif isinstance(event, GraphRunAbortedEvent):
self.logger.warning("⚠️ Graph run aborted: %s", event.reason)
if event.outputs:
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
# Node-level events
elif isinstance(event, NodeRunStartedEvent):
self.node_count += 1
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
if self.include_inputs and event.node_run_result.inputs:
self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs))
elif isinstance(event, NodeRunSucceededEvent):
self.success_count += 1
self.logger.info("✅ Node succeeded: %s", event.node_id)
if self.include_outputs and event.node_run_result.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs))
if self.include_process_data and event.node_run_result.process_data:
self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data))
elif isinstance(event, NodeRunFailedEvent):
self.failure_count += 1
self.logger.error("❌ Node failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
if event.node_run_result.error:
self.logger.error(" Details: %s", event.node_run_result.error)
elif isinstance(event, NodeRunExceptionEvent):
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
self.logger.warning(" Error: %s", event.error)
elif isinstance(event, NodeRunRetryEvent):
self.retry_count += 1
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
self.logger.warning(" Previous error: %s", event.error)
elif isinstance(event, NodeRunStreamChunkEvent):
# Log stream chunks at debug level to avoid spam
final_indicator = " (FINAL)" if event.is_final else ""
self.logger.debug(
"📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk)
)
# Iteration events
elif isinstance(event, NodeRunIterationStartedEvent):
self.logger.info("🔁 Iteration started: %s", event.node_id)
elif isinstance(event, NodeRunIterationNextEvent):
self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index)
elif isinstance(event, NodeRunIterationSucceededEvent):
self.logger.info("✅ Iteration succeeded: %s", event.node_id)
if self.include_outputs and event.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, NodeRunIterationFailedEvent):
self.logger.error("❌ Iteration failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
# Loop events
elif isinstance(event, NodeRunLoopStartedEvent):
self.logger.info("🔄 Loop started: %s", event.node_id)
elif isinstance(event, NodeRunLoopNextEvent):
self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index)
elif isinstance(event, NodeRunLoopSucceededEvent):
self.logger.info("✅ Loop succeeded: %s", event.node_id)
if self.include_outputs and event.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, NodeRunLoopFailedEvent):
self.logger.error("❌ Loop failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
else:
# Log unknown events at debug level
self.logger.debug("Event: %s", event_class)
def on_graph_end(self, error: Optional[Exception]) -> None:
"""Log graph execution end with summary statistics."""
self.logger.info("=" * 80)
if error:
self.logger.error("🔴 GRAPH EXECUTION FAILED")
self.logger.error(" Error: %s", error)
else:
self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY")
# Log execution statistics
self.logger.info("Execution Statistics:")
self.logger.info(" Total nodes executed: %s", self.node_count)
self.logger.info(" Successful nodes: %s", self.success_count)
self.logger.info(" Failed nodes: %s", self.failure_count)
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
if self.graph_runtime_state and self.include_outputs:
if self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)

View File

@ -0,0 +1,144 @@
"""
Execution limits layer for GraphEngine.
This layer monitors workflow execution to enforce limits on:
- Maximum execution steps
- Maximum execution time
When limits are exceeded, the layer automatically aborts execution.
"""
import logging
import time
from enum import Enum
from typing import Optional
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers import Layer
from core.workflow.graph_events import (
GraphEngineEvent,
NodeRunStartedEvent,
)
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
class LimitType(Enum):
"""Types of execution limits that can be exceeded."""
STEP_LIMIT = "step_limit"
TIME_LIMIT = "time_limit"
class ExecutionLimitsLayer(Layer):
"""
Layer that enforces execution limits for workflows.
Monitors:
- Step count: Tracks number of node executions
- Time limit: Monitors total execution time
Automatically aborts execution when limits are exceeded.
"""
def __init__(self, max_steps: int, max_time: int) -> None:
"""
Initialize the execution limits layer.
Args:
max_steps: Maximum number of execution steps allowed
max_time: Maximum execution time in seconds allowed
"""
super().__init__()
self.max_steps = max_steps
self.max_time = max_time
# Runtime tracking
self.start_time: Optional[float] = None
self.step_count = 0
self.logger = logging.getLogger(__name__)
# State tracking
self._execution_started = False
self._execution_ended = False
self._abort_sent = False # Track if abort command has been sent
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self.start_time = time.time()
self.step_count = 0
self._execution_started = True
self._execution_ended = False
self._abort_sent = False
self.logger.debug("Execution limits monitoring started")
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
Monitors execution progress and enforces limits.
"""
if not self._execution_started or self._execution_ended or self._abort_sent:
return
# Track step count for node execution events
if isinstance(event, NodeRunStartedEvent):
self.step_count += 1
self.logger.debug("Step %d started: %s", self.step_count, event.node_id)
# Check step limit when node execution completes
if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent):
if self._reached_step_limitation():
self._send_abort_command(LimitType.STEP_LIMIT)
if self._reached_time_limitation():
self._send_abort_command(LimitType.TIME_LIMIT)
def on_graph_end(self, error: Optional[Exception]) -> None:
"""Called when graph execution ends."""
if self._execution_started and not self._execution_ended:
self._execution_ended = True
if self.start_time:
total_time = time.time() - self.start_time
self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time)
def _reached_step_limitation(self) -> bool:
"""Check if step count limit has been exceeded."""
return self.step_count > self.max_steps
def _reached_time_limitation(self) -> bool:
"""Check if time limit has been exceeded."""
return self.start_time is not None and (time.time() - self.start_time) > self.max_time
def _send_abort_command(self, limit_type: LimitType) -> None:
"""
Send abort command due to limit violation.
Args:
limit_type: Type of limit exceeded
"""
if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent:
return
# Format detailed reason message
if limit_type == LimitType.STEP_LIMIT:
reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}"
elif limit_type == LimitType.TIME_LIMIT:
elapsed_time = time.time() - self.start_time if self.start_time else 0
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
self.logger.warning("Execution limit exceeded: %s", reason)
try:
# Send abort command to the engine
abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason)
self.command_channel.send_command(abort_command)
# Mark that abort has been sent to prevent duplicate commands
self._abort_sent = True
self.logger.debug("Abort command sent to engine")
except Exception:
self.logger.exception("Failed to send abort command: %s")

View File

@ -0,0 +1,49 @@
"""
GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
from typing import Optional
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from extensions.ext_redis import redis_client
class GraphEngineManager:
"""
Manager for sending control commands to GraphEngine instances.
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop, pause, and resume operations.
"""
@staticmethod
def send_stop_command(task_id: str, reason: Optional[str] = None) -> None:
"""
Send a stop command to a running workflow.
Args:
task_id: The task ID of the workflow to stop
reason: Optional reason for stopping (defaults to "User requested stop")
"""
if not task_id:
return
# Create Redis channel for this task
channel_key = f"workflow:{task_id}:commands"
channel = RedisChannel(redis_client, channel_key)
# Create and send abort command
abort_command = AbortCommand(reason=reason or "User requested stop")
try:
channel.send_command(abort_command)
except Exception:
# Silently fail if Redis is unavailable
# The legacy stop flag mechanism will still work
pass

View File

@ -0,0 +1,14 @@
"""
Orchestration subsystem for graph engine.
This package coordinates the overall execution flow between
different subsystems.
"""
from .dispatcher import Dispatcher
from .execution_coordinator import ExecutionCoordinator
__all__ = [
"Dispatcher",
"ExecutionCoordinator",
]

View File

@ -0,0 +1,104 @@
"""
Main dispatcher for processing events from workers.
"""
import logging
import queue
import threading
import time
from typing import TYPE_CHECKING, Optional
from ..event_management import EventCollector, EventEmitter
from .execution_coordinator import ExecutionCoordinator
if TYPE_CHECKING:
from ..event_management import EventHandlerRegistry
logger = logging.getLogger(__name__)
class Dispatcher:
"""
Main dispatcher that processes events from the event queue.
This runs in a separate thread and coordinates event processing
with timeout and completion detection.
"""
def __init__(
self,
event_queue: queue.Queue,
event_handler: "EventHandlerRegistry",
event_collector: EventCollector,
execution_coordinator: ExecutionCoordinator,
max_execution_time: int,
event_emitter: Optional[EventEmitter] = None,
) -> None:
"""
Initialize the dispatcher.
Args:
event_queue: Queue of events from workers
event_handler: Event handler registry for processing events
event_collector: Event collector for collecting unhandled events
execution_coordinator: Coordinator for execution flow
max_execution_time: Maximum execution time in seconds
event_emitter: Optional event emitter to signal completion
"""
self.event_queue = event_queue
self.event_handler = event_handler
self.event_collector = event_collector
self.execution_coordinator = execution_coordinator
self.max_execution_time = max_execution_time
self.event_emitter = event_emitter
self._thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._start_time: Optional[float] = None
def start(self) -> None:
"""Start the dispatcher thread."""
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
while not self._stop_event.is_set():
# Check for commands
self.execution_coordinator.check_commands()
# Check for scaling
self.execution_coordinator.check_scaling()
# Process events
try:
event = self.event_queue.get(timeout=0.1)
# Route to the event handler
self.event_handler.handle_event(event)
self.event_queue.task_done()
except queue.Empty:
# Check if execution is complete
if self.execution_coordinator.is_execution_complete():
break
except Exception as e:
logger.exception("Dispatcher error")
self.execution_coordinator.mark_failed(e)
finally:
self.execution_coordinator.mark_complete()
# Signal the event emitter that execution is complete
if self.event_emitter:
self.event_emitter.mark_complete()

View File

@ -0,0 +1,91 @@
"""
Execution coordinator for managing overall workflow execution.
"""
from typing import TYPE_CHECKING
from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..event_management import EventCollector
from ..state_management import ExecutionTracker, NodeStateManager
from ..worker_management import WorkerPool
if TYPE_CHECKING:
from ..event_management import EventHandlerRegistry
class ExecutionCoordinator:
"""
Coordinates overall execution flow between subsystems.
This provides high-level coordination methods used by the
dispatcher to manage execution state.
"""
def __init__(
self,
graph_execution: GraphExecution,
node_state_manager: NodeStateManager,
execution_tracker: ExecutionTracker,
event_handler: "EventHandlerRegistry",
event_collector: EventCollector,
command_processor: CommandProcessor,
worker_pool: WorkerPool,
) -> None:
"""
Initialize the execution coordinator.
Args:
graph_execution: Graph execution aggregate
node_state_manager: Manager for node states
execution_tracker: Tracker for executing nodes
event_handler: Event handler registry for processing events
event_collector: Event collector for collecting events
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self.graph_execution = graph_execution
self.node_state_manager = node_state_manager
self.execution_tracker = execution_tracker
self.event_handler = event_handler
self.event_collector = event_collector
self.command_processor = command_processor
self.worker_pool = worker_pool
def check_commands(self) -> None:
"""Process any pending commands."""
self.command_processor.process_commands()
def check_scaling(self) -> None:
"""Check and perform worker scaling if needed."""
queue_depth = self.node_state_manager.ready_queue.qsize()
executing_count = self.execution_tracker.count()
self.worker_pool.check_scaling(queue_depth, executing_count)
def is_execution_complete(self) -> bool:
"""
Check if execution is complete.
Returns:
True if execution is complete
"""
# Check if aborted or failed
if self.graph_execution.aborted or self.graph_execution.has_error:
return True
# Complete if no work remains
return self.node_state_manager.ready_queue.empty() and self.execution_tracker.is_empty()
def mark_complete(self) -> None:
"""Mark execution as complete."""
if not self.graph_execution.completed:
self.graph_execution.complete()
def mark_failed(self, error: Exception) -> None:
"""
Mark execution as failed.
Args:
error: The error that caused failure
"""
self.graph_execution.fail(error)

View File

@ -0,0 +1,10 @@
"""
OutputRegistry - Thread-safe storage for node outputs (streams and scalars)
This component provides thread-safe storage and retrieval of node outputs,
supporting both scalar values and streaming chunks with proper state management.
"""
from .registry import OutputRegistry
__all__ = ["OutputRegistry"]

View File

@ -0,0 +1,145 @@
"""
Main OutputRegistry implementation.
This module contains the public OutputRegistry class that provides
thread-safe storage for node outputs.
"""
from collections.abc import Sequence
from threading import RLock
from typing import TYPE_CHECKING, Optional, Union
from core.variables import Segment
from core.workflow.entities.variable_pool import VariablePool
from .stream import Stream
if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent
class OutputRegistry:
"""
Thread-safe registry for storing and retrieving node outputs.
Supports both scalar values and streaming chunks with proper state management.
All operations are thread-safe using internal locking.
"""
def __init__(self, variable_pool: VariablePool) -> None:
"""Initialize empty registry with thread-safe storage."""
self._lock = RLock()
self._scalars = variable_pool
self._streams: dict[tuple, Stream] = {}
def _selector_to_key(self, selector: Sequence[str]) -> tuple:
"""Convert selector list to tuple key for internal storage."""
return tuple(selector)
def set_scalar(self, selector: Sequence[str], value: Union[str, int, float, bool, dict, list]) -> None:
"""
Set a scalar value for the given selector.
Args:
selector: List of strings identifying the output location
value: The scalar value to store
"""
with self._lock:
self._scalars.add(selector, value)
def get_scalar(self, selector: Sequence[str]) -> Optional["Segment"]:
"""
Get a scalar value for the given selector.
Args:
selector: List of strings identifying the output location
Returns:
The stored Variable object, or None if not found
"""
with self._lock:
return self._scalars.get(selector)
def append_chunk(self, selector: Sequence[str], event: "NodeRunStreamChunkEvent") -> None:
"""
Append a NodeRunStreamChunkEvent to the stream for the given selector.
Args:
selector: List of strings identifying the stream location
event: The NodeRunStreamChunkEvent to append
Raises:
ValueError: If the stream is already closed
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
self._streams[key] = Stream()
try:
self._streams[key].append(event)
except ValueError:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
def pop_chunk(self, selector: Sequence[str]) -> Optional["NodeRunStreamChunkEvent"]:
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.
Args:
selector: List of strings identifying the stream location
Returns:
The next event, or None if no unread events available
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return None
return self._streams[key].pop_next()
def has_unread(self, selector: Sequence[str]) -> bool:
"""
Check if the stream has unread events.
Args:
selector: List of strings identifying the stream location
Returns:
True if there are unread events, False otherwise
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return False
return self._streams[key].has_unread()
def close_stream(self, selector: Sequence[str]) -> None:
"""
Mark a stream as closed (no more chunks can be appended).
Args:
selector: List of strings identifying the stream location
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
self._streams[key] = Stream()
self._streams[key].close()
def stream_closed(self, selector: Sequence[str]) -> bool:
"""
Check if a stream is closed.
Args:
selector: List of strings identifying the stream location
Returns:
True if the stream is closed, False otherwise
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return False
return self._streams[key].is_closed

View File

@ -0,0 +1,69 @@
"""
Internal stream implementation for OutputRegistry.
This module contains the private Stream class used internally by OutputRegistry
to manage streaming data chunks.
"""
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent
class Stream:
"""
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.
This class encapsulates stream-specific data and operations,
including event storage, read position tracking, and closed state.
Note: This is an internal class not exposed in the public API.
"""
def __init__(self) -> None:
"""Initialize an empty stream."""
self.events: list[NodeRunStreamChunkEvent] = []
self.read_position: int = 0
self.is_closed: bool = False
def append(self, event: "NodeRunStreamChunkEvent") -> None:
"""
Append a NodeRunStreamChunkEvent to the stream.
Args:
event: The NodeRunStreamChunkEvent to append
Raises:
ValueError: If the stream is already closed
"""
if self.is_closed:
raise ValueError("Cannot append to a closed stream")
self.events.append(event)
def pop_next(self) -> Optional["NodeRunStreamChunkEvent"]:
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.
Returns:
The next event, or None if no unread events available
"""
if self.read_position >= len(self.events):
return None
event = self.events[self.read_position]
self.read_position += 1
return event
def has_unread(self) -> bool:
"""
Check if the stream has unread events.
Returns:
True if there are unread events, False otherwise
"""
return self.read_position < len(self.events)
def close(self) -> None:
"""Mark the stream as closed (no more chunks can be appended)."""
self.is_closed = True

View File

@ -0,0 +1,41 @@
"""
CommandChannel protocol for GraphEngine command communication.
This protocol defines the interface for sending and receiving commands
to/from a GraphEngine instance, supporting both local and distributed scenarios.
"""
from typing import Protocol
from ..entities.commands import GraphEngineCommand
class CommandChannel(Protocol):
"""
Protocol for bidirectional command communication with GraphEngine.
Since each GraphEngine instance processes only one workflow execution,
this channel is dedicated to that single execution.
"""
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch pending commands for this GraphEngine instance.
Called by GraphEngine to poll for commands that need to be processed.
Returns:
List of pending commands (may be empty)
"""
...
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to be processed by this GraphEngine instance.
Called by external systems to send control commands to the running workflow.
Args:
command: The command to send
"""
...

View File

@ -0,0 +1,10 @@
"""
ResponseStreamCoordinator - Coordinates streaming output from response nodes
This component manages response streaming sessions and ensures ordered streaming
of responses based on upstream node outputs and constants.
"""
from .coordinator import ResponseStreamCoordinator
__all__ = ["ResponseStreamCoordinator"]

View File

@ -0,0 +1,465 @@
"""
Main ResponseStreamCoordinator implementation.
This module contains the public ResponseStreamCoordinator class that manages
response streaming sessions and ensures ordered streaming of responses.
"""
import logging
from collections import deque
from collections.abc import Sequence
from threading import RLock
from typing import Optional, TypeAlias
from uuid import uuid4
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from ..output_registry import OutputRegistry
from .path import Path
from .session import ResponseSession
logger = logging.getLogger(__name__)
# Type definitions
NodeID: TypeAlias = str
EdgeID: TypeAlias = str
class ResponseStreamCoordinator:
"""
Manages response streaming sessions without relying on global state.
Ensures ordered streaming of responses based on upstream node outputs and constants.
"""
def __init__(self, registry: OutputRegistry, graph: "Graph") -> None:
"""
Initialize coordinator with output registry.
Args:
registry: OutputRegistry instance for accessing node outputs
graph: Graph instance for looking up node information
"""
self.registry = registry
self.graph = graph
self.active_session: Optional[ResponseSession] = None
self.waiting_sessions: deque[ResponseSession] = deque()
self.lock = RLock()
# Track response nodes
self._response_nodes: set[NodeID] = set()
# Store paths for each response node
self._paths_maps: dict[NodeID, list[Path]] = {}
# Track node execution IDs and types for proper event forwarding
self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id
# Track response sessions to ensure only one per node
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
def register(self, response_node_id: NodeID) -> None:
with self.lock:
self._response_nodes.add(response_node_id)
# Build and save paths map for this response node
paths_map = self._build_paths_map(response_node_id)
self._paths_maps[response_node_id] = paths_map
# Create and store response session for this node
response_node = self.graph.nodes[response_node_id]
session = ResponseSession.from_node(response_node)
self._response_sessions[response_node_id] = session
def track_node_execution(self, node_id: NodeID, execution_id: str) -> None:
"""Track the execution ID for a node when it starts executing.
Args:
node_id: The ID of the node
execution_id: The execution ID from NodeRunStartedEvent
"""
with self.lock:
self._node_execution_ids[node_id] = execution_id
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
"""Get the execution ID for a node, creating one if it doesn't exist.
Args:
node_id: The ID of the node
Returns:
The execution ID for the node
"""
with self.lock:
if node_id not in self._node_execution_ids:
self._node_execution_ids[node_id] = str(uuid4())
return self._node_execution_ids[node_id]
def _build_paths_map(self, response_node_id: NodeID) -> list[Path]:
"""
Build a paths map for a response node by finding all paths from root node
to the response node, recording branch edges along each path.
Args:
response_node_id: ID of the response node to analyze
Returns:
List of Path objects, where each path contains branch edge IDs
"""
# Get root node ID
root_node_id = self.graph.root_node.id
# If root is the response node, return empty path
if root_node_id == response_node_id:
return [Path()]
# Extract variable selectors from the response node's template
response_node = self.graph.nodes[response_node_id]
response_session = ResponseSession.from_node(response_node)
template = response_session.template
# Collect all variable selectors from the template
variable_selectors: set[tuple[str, ...]] = set()
for segment in template.segments:
if isinstance(segment, VariableSegment):
variable_selectors.add(tuple(segment.selector[:2]))
# Step 1: Find all complete paths from root to response node
all_complete_paths: list[list[EdgeID]] = []
def find_paths(
current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID]
) -> None:
"""Recursively find all paths from current node to target node."""
if current_node_id == target_node_id:
# Found a complete path, store it
all_complete_paths.append(current_path.copy())
return
# Mark as visited to avoid cycles
visited.add(current_node_id)
# Explore outgoing edges
outgoing_edges = self.graph.get_outgoing_edges(current_node_id)
for edge in outgoing_edges:
edge_id = edge.id
next_node_id = edge.head
# Skip if already visited in this path
if next_node_id not in visited:
# Add edge to path and recurse
new_path = current_path + [edge_id]
find_paths(next_node_id, target_node_id, new_path, visited.copy())
# Start searching from root node
find_paths(root_node_id, response_node_id, [], set())
# Step 2: For each complete path, filter edges based on node blocking behavior
filtered_paths: list[Path] = []
for path in all_complete_paths:
blocking_edges = []
for edge_id in path:
edge = self.graph.edges[edge_id]
source_node = self.graph.nodes[edge.tail]
# Check if node is a branch/container (original behavior)
if source_node.execution_type in {
NodeExecutionType.BRANCH,
NodeExecutionType.CONTAINER,
} or source_node.blocks_variable_output(variable_selectors):
blocking_edges.append(edge_id)
# Keep the path even if it's empty
filtered_paths.append(Path(edges=blocking_edges))
return filtered_paths
def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]:
"""
Handle when an edge is taken (selected by a branch node).
This method updates the paths for all response nodes by removing
the taken edge. If any response node has an empty path after removal,
it means the node is now deterministically reachable and should start.
Args:
edge_id: The ID of the edge that was taken
Returns:
List of events to emit from starting new sessions
"""
events: list[NodeRunStreamChunkEvent] = []
with self.lock:
# Check each response node in order
for response_node_id in self._response_nodes:
if response_node_id not in self._paths_maps:
continue
paths = self._paths_maps[response_node_id]
has_reachable_path = False
# Update each path by removing the taken edge
for path in paths:
# Remove the taken edge from this path
path.remove_edge(edge_id)
# Check if this path is now empty (node is reachable)
if path.is_empty():
has_reachable_path = True
# If node is now reachable (has empty path), start/queue session
if has_reachable_path:
# Pass the node_id to the activation method
# The method will handle checking and removing from map
events.extend(self._active_or_queue_session(response_node_id))
return events
def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]:
"""
Start a session immediately if no active session, otherwise queue it.
Only activates sessions that exist in the _response_sessions map.
Args:
node_id: The ID of the response node to activate
Returns:
List of events from flush attempt if session started immediately
"""
events: list[NodeRunStreamChunkEvent] = []
# Get the session from our map (only activate if it exists)
session = self._response_sessions.get(node_id)
if not session:
return events
# Remove from map to ensure it won't be activated again
del self._response_sessions[node_id]
if self.active_session is None:
self.active_session = session
# Try to flush immediately
events.extend(self.try_flush())
else:
# Queue the session if another is active
self.waiting_sessions.append(session)
return events
def intercept_event(
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
) -> Sequence[NodeRunStreamChunkEvent]:
with self.lock:
if isinstance(event, NodeRunStreamChunkEvent):
self.registry.append_chunk(event.selector, event)
if event.is_final:
self.registry.close_stream(event.selector)
return self.try_flush()
elif isinstance(event, NodeRunSucceededEvent):
# Skip cause we share the same variable pool.
#
# for variable_name, variable_value in event.node_run_result.outputs.items():
# self.registry.set_scalar((event.node_id, variable_name), variable_value)
return self.try_flush()
return []
def _create_stream_chunk_event(
self,
node_id: str,
execution_id: str,
selector: Sequence[str],
chunk: str,
is_final: bool = False,
) -> NodeRunStreamChunkEvent:
"""Create a stream chunk event with consistent structure.
For selectors with special prefixes (sys, env, conversation), we use the
active response node's information since these are not actual node IDs.
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self.graph.nodes and self.active_session:
# Use the active response node for special selectors
response_node = self.graph.nodes[self.active_session.node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=selector,
chunk=chunk,
is_final=is_final,
)
# Standard case: selector refers to an actual node
node = self.graph.nodes[node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=node.id,
node_type=node.node_type,
selector=selector,
chunk=chunk,
is_final=is_final,
)
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
"""Process a variable segment. Returns (events, is_complete).
Handles both regular node selectors and special system selectors (sys, env, conversation).
For special selectors, we attribute the output to the active response node.
"""
events: list[NodeRunStreamChunkEvent] = []
source_selector_prefix = segment.selector[0] if segment.selector else ""
is_complete = False
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
if self.active_session and source_selector_prefix not in self.graph.nodes:
# Special selector - use active response node
output_node_id = self.active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
execution_id = self._get_or_create_execution_id(output_node_id)
# Stream all available chunks
while self.registry.has_unread(segment.selector):
if event := self.registry.pop_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self.active_session and source_selector_prefix not in self.graph.nodes:
response_node = self.graph.nodes[self.active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=event.selector, # Keep original selector
chunk=event.chunk,
is_final=event.is_final,
)
events.append(updated_event)
else:
# Regular node selector - use event as is
events.append(event)
# Check if this is the last chunk by looking ahead
stream_closed = self.registry.stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete
if stream_closed:
is_complete = True
elif value := self.registry.get_scalar(segment.selector):
# Process scalar value
is_last_segment = bool(
self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
)
)
is_complete = True
return events, is_complete
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self.active_session is not None
current_response_node = self.graph.nodes[self.active_session.node_id]
# Use get_or_create_execution_id to ensure we have a consistent ID
execution_id = self._get_or_create_execution_id(current_response_node.id)
is_last_segment = self.active_session.index == len(self.active_session.template.segments) - 1
event = self._create_stream_chunk_event(
node_id=current_response_node.id,
execution_id=execution_id,
selector=[current_response_node.id, "answer"], # FIXME(-LAN-)
chunk=segment.text,
is_final=is_last_segment,
)
return [event]
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
with self.lock:
if not self.active_session:
return []
template = self.active_session.template
response_node_id = self.active_session.node_id
events: list[NodeRunStreamChunkEvent] = []
# Process segments sequentially from current index
while self.active_session.index < len(template.segments):
segment = template.segments[self.active_session.index]
if isinstance(segment, VariableSegment):
# Check if the source node for this variable is skipped
# Only check for actual nodes, not special selectors (sys, env, conversation)
source_selector_prefix = segment.selector[0] if segment.selector else ""
if source_selector_prefix in self.graph.nodes:
source_node = self.graph.nodes[source_selector_prefix]
if source_node.state == NodeState.SKIPPED:
# Skip this variable segment if the source node is skipped
self.active_session.index += 1
continue
segment_events, is_complete = self._process_variable_segment(segment)
events.extend(segment_events)
# Only advance index if this variable segment is complete
if is_complete:
self.active_session.index += 1
else:
# Wait for more data
break
elif isinstance(segment, TextSegment):
segment_events = self._process_text_segment(segment)
events.extend(segment_events)
self.active_session.index += 1
if self.active_session.is_complete():
# End current session and get events from starting next session
next_session_events = self.end_session(response_node_id)
events.extend(next_session_events)
return events
def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]:
"""
End the active session for a response node.
Automatically starts the next waiting session if available.
Args:
node_id: ID of the response node ending its session
Returns:
List of events from starting the next session
"""
with self.lock:
events: list[NodeRunStreamChunkEvent] = []
if self.active_session and self.active_session.node_id == node_id:
self.active_session = None
# Try to start next waiting session
if self.waiting_sessions:
next_session = self.waiting_sessions.popleft()
self.active_session = next_session
# Immediately try to flush any available segments
events = self.try_flush()
return events

View File

@ -0,0 +1,35 @@
"""
Internal path representation for response coordinator.
This module contains the private Path class used internally by ResponseStreamCoordinator
to track execution paths to response nodes.
"""
from dataclasses import dataclass, field
from typing import TypeAlias
EdgeID: TypeAlias = str
@dataclass
class Path:
"""
Represents a path of branch edges that must be taken to reach a response node.
Note: This is an internal class not exposed in the public API.
"""
edges: list[EdgeID] = field(default_factory=list)
def contains_edge(self, edge_id: EdgeID) -> bool:
"""Check if this path contains the given edge."""
return edge_id in self.edges
def remove_edge(self, edge_id: EdgeID) -> None:
"""Remove the given edge from this path in place."""
if self.contains_edge(edge_id):
self.edges.remove(edge_id)
def is_empty(self) -> bool:
"""Check if the path has no edges (node is reachable)."""
return len(self.edges) == 0

View File

@ -0,0 +1,51 @@
"""
Internal response session management for response coordinator.
This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.end_node import EndNode
@dataclass
class ResponseSession:
"""
Represents an active response streaming session.
Note: This is an internal class not exposed in the public API.
"""
node_id: str
template: Template # Template object from the response node
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: Node) -> "ResponseSession":
"""
Create a ResponseSession from an AnswerNode or EndNode.
Args:
node: Must be either an AnswerNode or EndNode instance
Returns:
ResponseSession configured with the node's streaming template
Raises:
TypeError: If node is not an AnswerNode or EndNode
"""
if not isinstance(node, AnswerNode | EndNode):
raise TypeError
return cls(
node_id=node.id,
template=node.get_streaming_template(),
)
def is_complete(self) -> bool:
"""Check if all segments in the template have been processed."""
return self.index >= len(self.template.segments)

View File

@ -0,0 +1,16 @@
"""
State management subsystem for graph engine.
This package manages node states, edge states, and execution tracking
during workflow graph execution.
"""
from .edge_state_manager import EdgeStateManager
from .execution_tracker import ExecutionTracker
from .node_state_manager import NodeStateManager
__all__ = [
"EdgeStateManager",
"ExecutionTracker",
"NodeStateManager",
]

View File

@ -0,0 +1,112 @@
"""
Manager for edge states during graph execution.
"""
import threading
from typing import TypedDict
from core.workflow.enums import NodeState
from core.workflow.graph import Edge, Graph
class EdgeStateAnalysis(TypedDict):
"""Analysis result for edge states."""
has_unknown: bool
has_taken: bool
all_skipped: bool
class EdgeStateManager:
"""
Manages edge states and transitions during graph execution.
This handles edge state changes and provides analysis of edge
states for decision making during execution.
"""
def __init__(self, graph: Graph) -> None:
"""
Initialize the edge state manager.
Args:
graph: The workflow graph
"""
self.graph = graph
self._lock = threading.RLock()
def mark_edge_taken(self, edge_id: str) -> None:
"""
Mark an edge as TAKEN.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.TAKEN
def mark_edge_skipped(self, edge_id: str) -> None:
"""
Mark an edge as SKIPPED.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.SKIPPED
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
"""
Analyze the states of edges and return summary flags.
Args:
edges: List of edges to analyze
Returns:
Analysis result with state flags
"""
with self._lock:
states = {edge.state for edge in edges}
return EdgeStateAnalysis(
has_unknown=NodeState.UNKNOWN in states,
has_taken=NodeState.TAKEN in states,
all_skipped=states == {NodeState.SKIPPED} if states else True,
)
def get_edge_state(self, edge_id: str) -> NodeState:
"""
Get the current state of an edge.
Args:
edge_id: The ID of the edge
Returns:
The current edge state
"""
with self._lock:
return self.graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]:
"""
Categorize branch edges into selected and unselected.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
A tuple of (selected_edges, unselected_edges)
"""
with self._lock:
outgoing_edges = self.graph.get_outgoing_edges(node_id)
selected_edges = []
unselected_edges = []
for edge in outgoing_edges:
if edge.source_handle == selected_handle:
selected_edges.append(edge)
else:
unselected_edges.append(edge)
return selected_edges, unselected_edges

View File

@ -0,0 +1,87 @@
"""
Tracker for currently executing nodes.
"""
import threading
class ExecutionTracker:
"""
Tracks nodes that are currently being executed.
This replaces the ExecutingNodesManager with a cleaner interface
focused on tracking which nodes are in progress.
"""
def __init__(self) -> None:
"""Initialize the execution tracker."""
self._executing_nodes: set[str] = set()
self._lock = threading.RLock()
def add(self, node_id: str) -> None:
"""
Mark a node as executing.
Args:
node_id: The ID of the node starting execution
"""
with self._lock:
self._executing_nodes.add(node_id)
def remove(self, node_id: str) -> None:
"""
Mark a node as no longer executing.
Args:
node_id: The ID of the node finishing execution
"""
with self._lock:
self._executing_nodes.discard(node_id)
def is_executing(self, node_id: str) -> bool:
"""
Check if a node is currently executing.
Args:
node_id: The ID of the node to check
Returns:
True if the node is executing
"""
with self._lock:
return node_id in self._executing_nodes
def is_empty(self) -> bool:
"""
Check if no nodes are currently executing.
Returns:
True if no nodes are executing
"""
with self._lock:
return len(self._executing_nodes) == 0
def count(self) -> int:
"""
Get the count of currently executing nodes.
Returns:
Number of executing nodes
"""
with self._lock:
return len(self._executing_nodes)
def get_executing_nodes(self) -> set[str]:
"""
Get a copy of the set of executing node IDs.
Returns:
Set of node IDs currently executing
"""
with self._lock:
return self._executing_nodes.copy()
def clear(self) -> None:
"""Clear all executing nodes."""
with self._lock:
self._executing_nodes.clear()

View File

@ -0,0 +1,95 @@
"""
Manager for node states during graph execution.
"""
import queue
import threading
from core.workflow.enums import NodeState
from core.workflow.graph import Graph
class NodeStateManager:
"""
Manages node states and the ready queue for execution.
This centralizes node state transitions and enqueueing logic,
ensuring thread-safe operations on node states.
"""
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
"""
Initialize the node state manager.
Args:
graph: The workflow graph
ready_queue: Queue for nodes ready to execute
"""
self.graph = graph
self.ready_queue = ready_queue
self._lock = threading.RLock()
def enqueue_node(self, node_id: str) -> None:
"""
Mark a node as TAKEN and add it to the ready queue.
This combines the state transition and enqueueing operations
that always occur together when preparing a node for execution.
Args:
node_id: The ID of the node to enqueue
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.TAKEN
self.ready_queue.put(node_id)
def mark_node_skipped(self, node_id: str) -> None:
"""
Mark a node as SKIPPED.
Args:
node_id: The ID of the node to skip
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.SKIPPED
def is_node_ready(self, node_id: str) -> bool:
"""
Check if a node is ready to be executed.
A node is ready when all its incoming edges from taken branches
have been satisfied.
Args:
node_id: The ID of the node to check
Returns:
True if the node is ready for execution
"""
with self._lock:
# Get all incoming edges to this node
incoming_edges = self.graph.get_incoming_edges(node_id)
# If no incoming edges, node is always ready
if not incoming_edges:
return True
# If any edge is UNKNOWN, node is not ready
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
return False
# Node is ready if at least one edge is TAKEN
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
def get_node_state(self, node_id: str) -> NodeState:
"""
Get the current state of a node.
Args:
node_id: The ID of the node
Returns:
The current node state
"""
with self._lock:
return self.graph.nodes[node_id].state

View File

@ -0,0 +1,135 @@
"""
Worker - Thread implementation for queue-based node execution
Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
import contextvars
import queue
import threading
import time
from collections.abc import Callable
from datetime import datetime
from typing import Optional
from uuid import uuid4
from flask import Flask
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
class Worker(threading.Thread):
"""
Worker thread that executes nodes from the ready queue.
Workers continuously pull node IDs from the ready_queue, execute the
corresponding nodes, and push the resulting events to the event_queue
for the dispatcher to process.
"""
def __init__(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
worker_id: int = 0,
flask_app: Optional[Flask] = None,
context_vars: Optional[contextvars.Context] = None,
on_idle_callback: Optional[Callable[[int], None]] = None,
on_active_callback: Optional[Callable[[int], None]] = None,
) -> None:
"""
Initialize worker thread.
Args:
ready_queue: Queue containing node IDs ready for execution
event_queue: Queue for pushing execution events
graph: Graph containing nodes to execute
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
on_idle_callback: Optional callback when worker becomes idle
on_active_callback: Optional callback when worker becomes active
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.worker_id = worker_id
self.flask_app = flask_app
self.context_vars = context_vars
self._stop_event = threading.Event()
self.on_idle_callback = on_idle_callback
self.on_active_callback = on_active_callback
self.last_task_time = time.time()
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
def run(self) -> None:
"""
Main worker loop.
Continuously pulls node IDs from ready_queue, executes them,
and pushes events to event_queue until stopped.
"""
while not self._stop_event.is_set():
# Try to get a node ID from the ready queue (with timeout)
try:
node_id = self.ready_queue.get(timeout=0.1)
except queue.Empty:
# Notify that worker is idle
if self.on_idle_callback:
self.on_idle_callback(self.worker_id)
continue
# Notify that worker is active
if self.on_active_callback:
self.on_active_callback(self.worker_id)
self.last_task_time = time.time()
node = self.graph.nodes[node_id]
try:
self._execute_node(node)
self.ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
node_id="unknown",
node_type=NodeType.CODE,
in_iteration_id=None,
error=str(e),
start_at=datetime.now(),
)
self.event_queue.put(error_event)
def _execute_node(self, node: Node) -> None:
"""
Execute a single node and handle its events.
Args:
node: The node instance to execute
"""
# Execute the node with preserved context if Flask app is provided
if self.flask_app and self.context_vars:
with preserve_flask_contexts(
flask_app=self.flask_app,
context_vars=self.context_vars,
):
# Execute the node
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self.event_queue.put(event)
else:
# Execute without context preservation
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self.event_queue.put(event)

View File

@ -0,0 +1,81 @@
# Worker Management
Dynamic worker pool for node execution.
## Components
### WorkerPool
Manages worker thread lifecycle.
- `start/stop/wait()` - Control workers
- `scale_up/down()` - Adjust pool size
- `get_worker_count()` - Current count
### WorkerFactory
Creates workers with Flask context.
- `create_worker()` - Build with dependencies
- Preserves request context
### DynamicScaler
Determines scaling decisions.
- `min/max_workers` - Pool bounds
- `scale_up_threshold` - Queue trigger
- `should_scale_up/down()` - Check conditions
### ActivityTracker
Tracks worker activity.
- `track_activity(worker_id)` - Record activity
- `get_idle_workers(threshold)` - Find idle
- `get_active_count()` - Active count
## Usage
```python
scaler = DynamicScaler(
min_workers=2,
max_workers=10,
scale_up_threshold=5
)
pool = WorkerPool(
ready_queue=ready_queue,
worker_factory=factory,
dynamic_scaler=scaler
)
pool.start()
# Scale based on load
if scaler.should_scale_up(queue_size, active):
pool.scale_up()
pool.stop()
```
## Scaling Strategy
**Scale Up**: Queue size > threshold AND workers < max
**Scale Down**: Idle workers exist AND workers > min
## Parameters
- `min_workers` - Minimum pool size
- `max_workers` - Maximum pool size
- `scale_up_threshold` - Queue trigger
- `scale_down_threshold` - Idle seconds
## Flask Context
WorkerFactory preserves request context across threads:
```python
context_vars = {"request_id": request.id}
# Workers receive same context
```

View File

@ -0,0 +1,18 @@
"""
Worker management subsystem for graph engine.
This package manages the worker pool, including creation,
scaling, and activity tracking.
"""
from .activity_tracker import ActivityTracker
from .dynamic_scaler import DynamicScaler
from .worker_factory import WorkerFactory
from .worker_pool import WorkerPool
__all__ = [
"ActivityTracker",
"DynamicScaler",
"WorkerFactory",
"WorkerPool",
]

View File

@ -0,0 +1,74 @@
"""
Activity tracker for monitoring worker activity.
"""
import threading
import time
class ActivityTracker:
"""
Tracks worker activity for scaling decisions.
This monitors which workers are active or idle to support
dynamic scaling decisions.
"""
def __init__(self, idle_threshold: float = 30.0) -> None:
"""
Initialize the activity tracker.
Args:
idle_threshold: Seconds before a worker is considered idle
"""
self.idle_threshold = idle_threshold
self._worker_activity: dict[int, tuple[bool, float]] = {}
self._lock = threading.RLock()
def track_activity(self, worker_id: int, is_active: bool) -> None:
"""
Track worker activity state.
Args:
worker_id: ID of the worker
is_active: Whether the worker is active
"""
with self._lock:
self._worker_activity[worker_id] = (is_active, time.time())
def get_idle_workers(self) -> list[int]:
"""
Get list of workers that have been idle too long.
Returns:
List of idle worker IDs
"""
current_time = time.time()
idle_workers = []
with self._lock:
for worker_id, (is_active, last_change) in self._worker_activity.items():
if not is_active and (current_time - last_change) > self.idle_threshold:
idle_workers.append(worker_id)
return idle_workers
def remove_worker(self, worker_id: int) -> None:
"""
Remove a worker from tracking.
Args:
worker_id: ID of the worker to remove
"""
with self._lock:
self._worker_activity.pop(worker_id, None)
def get_active_count(self) -> int:
"""
Get count of currently active workers.
Returns:
Number of active workers
"""
with self._lock:
return sum(1 for is_active, _ in self._worker_activity.values() if is_active)

View File

@ -0,0 +1,98 @@
"""
Dynamic scaler for worker pool sizing.
"""
from core.workflow.graph import Graph
class DynamicScaler:
"""
Manages dynamic scaling decisions for the worker pool.
This encapsulates the logic for when to scale up or down
based on workload and configuration.
"""
def __init__(
self,
min_workers: int = 2,
max_workers: int = 10,
scale_up_threshold: int = 5,
scale_down_idle_time: float = 30.0,
) -> None:
"""
Initialize the dynamic scaler.
Args:
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Idle time before scaling down
"""
self.min_workers = min_workers
self.max_workers = max_workers
self.scale_up_threshold = scale_up_threshold
self.scale_down_idle_time = scale_down_idle_time
def calculate_initial_workers(self, graph: Graph) -> int:
"""
Calculate initial worker count based on graph complexity.
Args:
graph: The workflow graph
Returns:
Initial number of workers to create
"""
node_count = len(graph.nodes)
# Simple heuristic: more nodes = more workers
if node_count < 10:
initial = self.min_workers
elif node_count < 50:
initial = min(4, self.max_workers)
elif node_count < 100:
initial = min(6, self.max_workers)
else:
initial = min(8, self.max_workers)
return max(self.min_workers, initial)
def should_scale_up(self, current_workers: int, queue_depth: int, executing_count: int) -> bool:
"""
Determine if scaling up is needed.
Args:
current_workers: Current number of workers
queue_depth: Number of nodes waiting
executing_count: Number of nodes executing
Returns:
True if should scale up
"""
if current_workers >= self.max_workers:
return False
# Scale up if queue is deep and workers are busy
if queue_depth > self.scale_up_threshold:
if executing_count >= current_workers * 0.8:
return True
return False
def should_scale_down(self, current_workers: int, idle_workers: list[int]) -> bool:
"""
Determine if scaling down is appropriate.
Args:
current_workers: Current number of workers
idle_workers: List of idle worker IDs
Returns:
True if should scale down
"""
if current_workers <= self.min_workers:
return False
# Scale down if we have idle workers
return len(idle_workers) > 0

View File

@ -0,0 +1,74 @@
"""
Factory for creating worker instances.
"""
import contextvars
import queue
from collections.abc import Callable
from typing import Optional
from flask import Flask
from core.workflow.graph import Graph
from ..worker import Worker
class WorkerFactory:
"""
Factory for creating worker instances with proper context.
This encapsulates worker creation logic and ensures all workers
are created with the necessary Flask and context variable setup.
"""
def __init__(
self,
flask_app: Optional[Flask],
context_vars: contextvars.Context,
) -> None:
"""
Initialize the worker factory.
Args:
flask_app: Flask application context
context_vars: Context variables to propagate
"""
self.flask_app = flask_app
self.context_vars = context_vars
self._next_worker_id = 0
def create_worker(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue,
graph: Graph,
on_idle_callback: Optional[Callable[[int], None]] = None,
on_active_callback: Optional[Callable[[int], None]] = None,
) -> Worker:
"""
Create a new worker instance.
Args:
ready_queue: Queue of nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
on_idle_callback: Callback when worker becomes idle
on_active_callback: Callback when worker becomes active
Returns:
Configured worker instance
"""
worker_id = self._next_worker_id
self._next_worker_id += 1
return Worker(
ready_queue=ready_queue,
event_queue=event_queue,
graph=graph,
worker_id=worker_id,
flask_app=self.flask_app,
context_vars=self.context_vars,
on_idle_callback=on_idle_callback,
on_active_callback=on_active_callback,
)

View File

@ -0,0 +1,145 @@
"""
Worker pool management.
"""
import queue
import threading
from core.workflow.graph import Graph
from ..worker import Worker
from .activity_tracker import ActivityTracker
from .dynamic_scaler import DynamicScaler
from .worker_factory import WorkerFactory
class WorkerPool:
"""
Manages a pool of worker threads for executing nodes.
This provides dynamic scaling, activity tracking, and lifecycle
management for worker threads.
"""
def __init__(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue,
graph: Graph,
worker_factory: WorkerFactory,
dynamic_scaler: DynamicScaler,
activity_tracker: ActivityTracker,
) -> None:
"""
Initialize the worker pool.
Args:
ready_queue: Queue of nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
worker_factory: Factory for creating workers
dynamic_scaler: Scaler for dynamic sizing
activity_tracker: Tracker for worker activity
"""
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.worker_factory = worker_factory
self.dynamic_scaler = dynamic_scaler
self.activity_tracker = activity_tracker
self.workers: list[Worker] = []
self._lock = threading.RLock()
self._running = False
def start(self, initial_count: int) -> None:
"""
Start the worker pool with initial workers.
Args:
initial_count: Number of workers to start with
"""
with self._lock:
if self._running:
return
self._running = True
# Create initial workers
for _ in range(initial_count):
worker = self.worker_factory.create_worker(self.ready_queue, self.event_queue, self.graph)
worker.start()
self.workers.append(worker)
def stop(self) -> None:
"""Stop all workers in the pool."""
with self._lock:
self._running = False
# Stop all workers
for worker in self.workers:
worker.stop()
# Wait for workers to finish
for worker in self.workers:
if worker.is_alive():
worker.join(timeout=10.0)
self.workers.clear()
def scale_up(self) -> None:
"""Add a worker to the pool if allowed."""
with self._lock:
if not self._running:
return
if len(self.workers) >= self.dynamic_scaler.max_workers:
return
worker = self.worker_factory.create_worker(self.ready_queue, self.event_queue, self.graph)
worker.start()
self.workers.append(worker)
def scale_down(self, worker_ids: list[int]) -> None:
"""
Remove specific workers from the pool.
Args:
worker_ids: IDs of workers to remove
"""
with self._lock:
if not self._running:
return
if len(self.workers) <= self.dynamic_scaler.min_workers:
return
workers_to_remove = [w for w in self.workers if w.worker_id in worker_ids]
for worker in workers_to_remove:
worker.stop()
self.workers.remove(worker)
if worker.is_alive():
worker.join(timeout=1.0)
def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self.workers)
def check_scaling(self, queue_depth: int, executing_count: int) -> None:
"""
Check and perform scaling if needed.
Args:
queue_depth: Current queue depth
executing_count: Number of executing nodes
"""
current_count = self.get_worker_count()
if self.dynamic_scaler.should_scale_up(current_count, queue_depth, executing_count):
self.scale_up()
idle_workers = self.activity_tracker.get_idle_workers()
if idle_workers:
self.scale_down(idle_workers)