mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 23:18:05 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
# Conflicts: # api/commands.py # api/core/app/apps/common/workflow_response_converter.py # api/core/llm_generator/llm_generator.py # api/core/plugin/entities/plugin.py # api/core/plugin/impl/tool.py # api/core/rag/index_processor/index_processor_base.py # api/core/workflow/entities/workflow_execution.py # api/core/workflow/entities/workflow_node_execution.py # api/core/workflow/enums.py # api/core/workflow/graph_engine/entities/graph.py # api/core/workflow/graph_engine/graph_engine.py # api/core/workflow/nodes/enums.py # api/services/dataset_service.py
This commit is contained in:
187
api/core/workflow/graph_engine/README.md
Normal file
187
api/core/workflow/graph_engine/README.md
Normal file
@ -0,0 +1,187 @@
|
||||
# Graph Engine
|
||||
|
||||
Queue-based workflow execution engine for parallel graph processing.
|
||||
|
||||
## Architecture
|
||||
|
||||
The engine uses a modular architecture with specialized packages:
|
||||
|
||||
### Core Components
|
||||
|
||||
- **Domain** (`domain/`) - Core models: ExecutionContext, GraphExecution, NodeExecution
|
||||
- **Event Management** (`event_management/`) - Event handling, collection, and emission
|
||||
- **State Management** (`state_management/`) - Thread-safe state tracking for nodes and edges
|
||||
- **Error Handling** (`error_handling/`) - Strategy-based error recovery (retry, abort, fail-branch, default-value)
|
||||
- **Graph Traversal** (`graph_traversal/`) - Node readiness, edge processing, branch handling
|
||||
- **Command Processing** (`command_processing/`) - External command handling (abort, pause, resume)
|
||||
- **Worker Management** (`worker_management/`) - Dynamic worker pool with auto-scaling
|
||||
- **Orchestration** (`orchestration/`) - Main event loop and execution coordination
|
||||
|
||||
### Supporting Components
|
||||
|
||||
- **Output Registry** (`output_registry/`) - Thread-safe storage for node outputs
|
||||
- **Response Coordinator** (`response_coordinator/`) - Ordered streaming of response nodes
|
||||
- **Command Channels** (`command_channels/`) - Command transport (InMemory/Redis)
|
||||
- **Layers** (`layers/`) - Pluggable middleware for extensions
|
||||
|
||||
## Architecture Diagram
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class GraphEngine {
|
||||
+run()
|
||||
+add_layer()
|
||||
}
|
||||
|
||||
class Domain {
|
||||
ExecutionContext
|
||||
GraphExecution
|
||||
NodeExecution
|
||||
}
|
||||
|
||||
class EventManagement {
|
||||
EventHandlerRegistry
|
||||
EventCollector
|
||||
EventEmitter
|
||||
}
|
||||
|
||||
class StateManagement {
|
||||
NodeStateManager
|
||||
EdgeStateManager
|
||||
ExecutionTracker
|
||||
}
|
||||
|
||||
class WorkerManagement {
|
||||
WorkerPool
|
||||
WorkerFactory
|
||||
DynamicScaler
|
||||
ActivityTracker
|
||||
}
|
||||
|
||||
class GraphTraversal {
|
||||
NodeReadinessChecker
|
||||
EdgeProcessor
|
||||
BranchHandler
|
||||
SkipPropagator
|
||||
}
|
||||
|
||||
class Orchestration {
|
||||
Dispatcher
|
||||
ExecutionCoordinator
|
||||
}
|
||||
|
||||
class ErrorHandling {
|
||||
ErrorHandler
|
||||
RetryStrategy
|
||||
AbortStrategy
|
||||
FailBranchStrategy
|
||||
}
|
||||
|
||||
class CommandProcessing {
|
||||
CommandProcessor
|
||||
AbortCommandHandler
|
||||
}
|
||||
|
||||
class CommandChannels {
|
||||
InMemoryChannel
|
||||
RedisChannel
|
||||
}
|
||||
|
||||
class OutputRegistry {
|
||||
<<Storage>>
|
||||
Scalar Values
|
||||
Streaming Data
|
||||
}
|
||||
|
||||
class ResponseCoordinator {
|
||||
Session Management
|
||||
Path Analysis
|
||||
}
|
||||
|
||||
class Layers {
|
||||
<<Plugin>>
|
||||
DebugLoggingLayer
|
||||
}
|
||||
|
||||
GraphEngine --> Orchestration : coordinates
|
||||
GraphEngine --> Layers : extends
|
||||
|
||||
Orchestration --> EventManagement : processes events
|
||||
Orchestration --> WorkerManagement : manages scaling
|
||||
Orchestration --> CommandProcessing : checks commands
|
||||
Orchestration --> StateManagement : monitors state
|
||||
|
||||
WorkerManagement --> StateManagement : consumes ready queue
|
||||
WorkerManagement --> EventManagement : produces events
|
||||
WorkerManagement --> Domain : executes nodes
|
||||
|
||||
EventManagement --> ErrorHandling : failed events
|
||||
EventManagement --> GraphTraversal : success events
|
||||
EventManagement --> ResponseCoordinator : stream events
|
||||
EventManagement --> Layers : notifies
|
||||
|
||||
GraphTraversal --> StateManagement : updates states
|
||||
GraphTraversal --> Domain : checks graph
|
||||
|
||||
CommandProcessing --> CommandChannels : fetches commands
|
||||
CommandProcessing --> Domain : modifies execution
|
||||
|
||||
ErrorHandling --> Domain : handles failures
|
||||
|
||||
StateManagement --> Domain : tracks entities
|
||||
|
||||
ResponseCoordinator --> OutputRegistry : reads outputs
|
||||
|
||||
Domain --> OutputRegistry : writes outputs
|
||||
```
|
||||
|
||||
## Package Relationships
|
||||
|
||||
### Core Dependencies
|
||||
|
||||
- **Orchestration** acts as the central coordinator, managing all subsystems
|
||||
- **Domain** provides the core business entities used by all packages
|
||||
- **EventManagement** serves as the communication backbone between components
|
||||
- **StateManagement** maintains thread-safe state for the entire system
|
||||
|
||||
### Data Flow
|
||||
|
||||
1. **Commands** flow from CommandChannels → CommandProcessing → Domain
|
||||
1. **Events** flow from Workers → EventHandlerRegistry → State updates
|
||||
1. **Node outputs** flow from Workers → OutputRegistry → ResponseCoordinator
|
||||
1. **Ready nodes** flow from GraphTraversal → StateManagement → WorkerManagement
|
||||
|
||||
### Extension Points
|
||||
|
||||
- **Layers** observe all events for monitoring, logging, and custom logic
|
||||
- **ErrorHandling** strategies can be extended for custom failure recovery
|
||||
- **CommandChannels** can be implemented for different transport mechanisms
|
||||
|
||||
## Execution Flow
|
||||
|
||||
1. **Initialization**: GraphEngine creates all subsystems with the workflow graph
|
||||
1. **Node Discovery**: Traversal components identify ready nodes
|
||||
1. **Worker Execution**: Workers pull from ready queue and execute nodes
|
||||
1. **Event Processing**: Dispatcher routes events to appropriate handlers
|
||||
1. **State Updates**: Managers track node/edge states for next steps
|
||||
1. **Completion**: Coordinator detects when all nodes are done
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
|
||||
# Create and run engine
|
||||
engine = GraphEngine(
|
||||
tenant_id="tenant_1",
|
||||
app_id="app_1",
|
||||
workflow_id="workflow_1",
|
||||
graph=graph,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Stream execution events
|
||||
for event in engine.run():
|
||||
handle_event(event)
|
||||
```
|
||||
@ -1,4 +1,3 @@
|
||||
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
|
||||
from .graph_engine import GraphEngine
|
||||
|
||||
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
__all__ = ["GraphEngine"]
|
||||
|
||||
33
api/core/workflow/graph_engine/command_channels/README.md
Normal file
33
api/core/workflow/graph_engine/command_channels/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# Command Channels
|
||||
|
||||
Channel implementations for external workflow control.
|
||||
|
||||
## Components
|
||||
|
||||
### InMemoryChannel
|
||||
|
||||
Thread-safe in-memory queue for single-process deployments.
|
||||
|
||||
- `fetch_commands()` - Get pending commands
|
||||
- `send_command()` - Add command to queue
|
||||
|
||||
### RedisChannel
|
||||
|
||||
Redis-based queue for distributed deployments.
|
||||
|
||||
- `fetch_commands()` - Get commands with JSON deserialization
|
||||
- `send_command()` - Store commands with TTL
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# Local execution
|
||||
channel = InMemoryChannel()
|
||||
channel.send_command(AbortCommand(graph_id="workflow-123"))
|
||||
|
||||
# Distributed execution
|
||||
redis_channel = RedisChannel(
|
||||
redis_client=redis_client,
|
||||
channel_key="workflow:123:commands"
|
||||
)
|
||||
```
|
||||
@ -0,0 +1,6 @@
|
||||
"""Command channel implementations for GraphEngine."""
|
||||
|
||||
from .in_memory_channel import InMemoryChannel
|
||||
from .redis_channel import RedisChannel
|
||||
|
||||
__all__ = ["InMemoryChannel", "RedisChannel"]
|
||||
@ -0,0 +1,51 @@
|
||||
"""
|
||||
In-memory implementation of CommandChannel for local/testing scenarios.
|
||||
|
||||
This implementation uses a thread-safe queue for command communication
|
||||
within a single process. Each instance handles commands for one workflow execution.
|
||||
"""
|
||||
|
||||
from queue import Queue
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
class InMemoryChannel:
|
||||
"""
|
||||
In-memory command channel implementation using a thread-safe queue.
|
||||
|
||||
Each instance is dedicated to a single GraphEngine/workflow execution.
|
||||
Suitable for local development, testing, and single-instance deployments.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in-memory channel with a single queue."""
|
||||
self._queue: Queue[GraphEngineCommand] = Queue()
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from the queue.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the queue)
|
||||
"""
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Drain all available commands from the queue
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
command = self._queue.get_nowait()
|
||||
commands.append(command)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to this channel's queue.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
self._queue.put(command)
|
||||
109
api/core/workflow/graph_engine/command_channels/redis_channel.py
Normal file
109
api/core/workflow/graph_engine/command_channels/redis_channel.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""
|
||||
Redis-based implementation of CommandChannel for distributed scenarios.
|
||||
|
||||
This implementation uses Redis lists for command queuing, supporting
|
||||
multi-instance deployments and cross-server communication.
|
||||
Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
|
||||
class RedisChannel:
|
||||
"""
|
||||
Redis-based command channel implementation for distributed systems.
|
||||
|
||||
Each instance uses a unique Redis key for its command queue.
|
||||
Commands are JSON-serialized for transport.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: "RedisClientWrapper",
|
||||
channel_key: str,
|
||||
command_ttl: int = 3600,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Redis channel.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
channel_key: Unique key for this channel's command queue
|
||||
command_ttl: TTL for command keys in seconds (default: 3600)
|
||||
"""
|
||||
self._redis = redis_client
|
||||
self._key = channel_key
|
||||
self._command_ttl = command_ttl
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from Redis.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the Redis list)
|
||||
"""
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
with self._redis.pipeline() as pipe:
|
||||
# Get all commands and clear the list atomically
|
||||
pipe.lrange(self._key, 0, -1)
|
||||
pipe.delete(self._key)
|
||||
results = pipe.execute()
|
||||
|
||||
# Parse commands from JSON
|
||||
if results[0]:
|
||||
for command_json in results[0]:
|
||||
try:
|
||||
command_data = json.loads(command_json)
|
||||
command = self._deserialize_command(command_data)
|
||||
if command:
|
||||
commands.append(command)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Skip invalid commands
|
||||
continue
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to Redis.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
command_json = json.dumps(command.model_dump())
|
||||
|
||||
# Push to list and set expiry
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.rpush(self._key, command_json)
|
||||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict) -> Optional[GraphEngineCommand]:
|
||||
"""
|
||||
Deserialize a command from dictionary data.
|
||||
|
||||
Args:
|
||||
data: Command data dictionary
|
||||
|
||||
Returns:
|
||||
Deserialized command or None if invalid
|
||||
"""
|
||||
try:
|
||||
command_type = CommandType(data.get("command_type"))
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand(**data)
|
||||
else:
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand(**data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Command processing subsystem for graph engine.
|
||||
|
||||
This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
]
|
||||
@ -0,0 +1,27 @@
|
||||
"""
|
||||
Command handler implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
"""
|
||||
Handle an abort command.
|
||||
|
||||
Args:
|
||||
command: The abort command
|
||||
execution: Graph execution to abort
|
||||
"""
|
||||
assert isinstance(command, AbortCommand)
|
||||
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.abort(command.reason or "User requested abort")
|
||||
@ -0,0 +1,78 @@
|
||||
"""
|
||||
Main command processor for handling external commands.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Protocol
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
from ..protocols.command_channel import CommandChannel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommandHandler(Protocol):
|
||||
"""Protocol for command handlers."""
|
||||
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
|
||||
|
||||
|
||||
class CommandProcessor:
|
||||
"""
|
||||
Processes external commands sent to the engine.
|
||||
|
||||
This polls the command channel and dispatches commands to
|
||||
appropriate handlers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command_channel: CommandChannel,
|
||||
graph_execution: GraphExecution,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the command processor.
|
||||
|
||||
Args:
|
||||
command_channel: Channel for receiving commands
|
||||
graph_execution: Graph execution aggregate
|
||||
"""
|
||||
self.command_channel = command_channel
|
||||
self.graph_execution = graph_execution
|
||||
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
|
||||
|
||||
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
|
||||
"""
|
||||
Register a handler for a command type.
|
||||
|
||||
Args:
|
||||
command_type: Type of command to handle
|
||||
handler: Handler for the command
|
||||
"""
|
||||
self._handlers[command_type] = handler
|
||||
|
||||
def process_commands(self) -> None:
|
||||
"""Check for and process any pending commands."""
|
||||
try:
|
||||
commands = self.command_channel.fetch_commands()
|
||||
for command in commands:
|
||||
self._handle_command(command)
|
||||
except Exception as e:
|
||||
logger.warning("Error processing commands: %s", e)
|
||||
|
||||
def _handle_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Handle a single command.
|
||||
|
||||
Args:
|
||||
command: The command to handle
|
||||
"""
|
||||
handler = self._handlers.get(type(command))
|
||||
if handler:
|
||||
try:
|
||||
handler.handle(command, self.graph_execution)
|
||||
except Exception as e:
|
||||
logger.exception("Error handling command %s", command.__class__.__name__)
|
||||
else:
|
||||
logger.warning("No handler registered for command: %s", command.__class__.__name__)
|
||||
@ -1,25 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition):
|
||||
self.init_params = init_params
|
||||
self.graph = graph
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -1,25 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
run_result = previous_route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return False
|
||||
|
||||
if not run_result.edge_source_handle:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == run_result.edge_source_handle
|
||||
@ -1,27 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
return True
|
||||
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
_, _, final_result = condition_processor.process_conditions(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
conditions=self.condition.conditions,
|
||||
operator="and",
|
||||
)
|
||||
|
||||
return final_result
|
||||
@ -1,25 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class ConditionManager:
|
||||
@staticmethod
|
||||
def get_condition_handler(
|
||||
init_params: GraphInitParams, graph: Graph, run_condition: RunCondition
|
||||
) -> RunConditionHandler:
|
||||
"""
|
||||
Get condition handler
|
||||
|
||||
:param init_params: init params
|
||||
:param graph: graph
|
||||
:param run_condition: run condition
|
||||
:return: condition handler
|
||||
"""
|
||||
if run_condition.type == "branch_identify":
|
||||
return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition)
|
||||
else:
|
||||
return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition)
|
||||
16
api/core/workflow/graph_engine/domain/__init__.py
Normal file
16
api/core/workflow/graph_engine/domain/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Domain models for graph engine.
|
||||
|
||||
This package contains the core domain entities, value objects, and aggregates
|
||||
that represent the business concepts of workflow graph execution.
|
||||
"""
|
||||
|
||||
from .execution_context import ExecutionContext
|
||||
from .graph_execution import GraphExecution
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
__all__ = [
|
||||
"ExecutionContext",
|
||||
"GraphExecution",
|
||||
"NodeExecution",
|
||||
]
|
||||
37
api/core/workflow/graph_engine/domain/execution_context.py
Normal file
37
api/core/workflow/graph_engine/domain/execution_context.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""
|
||||
ExecutionContext value object containing immutable execution parameters.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Immutable value object containing the context for a graph execution.
|
||||
|
||||
This encapsulates all the contextual information needed to execute a workflow,
|
||||
keeping it separate from the mutable execution state.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
call_depth: int
|
||||
max_execution_steps: int
|
||||
max_execution_time: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate execution context parameters."""
|
||||
if self.call_depth < 0:
|
||||
raise ValueError("Call depth must be non-negative")
|
||||
if self.max_execution_steps <= 0:
|
||||
raise ValueError("Max execution steps must be positive")
|
||||
if self.max_execution_time <= 0:
|
||||
raise ValueError("Max execution time must be positive")
|
||||
72
api/core/workflow/graph_engine/domain/graph_execution.py
Normal file
72
api/core/workflow/graph_engine/domain/graph_execution.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""
|
||||
GraphExecution aggregate root managing the overall graph execution state.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExecution:
|
||||
"""
|
||||
Aggregate root for graph execution.
|
||||
|
||||
This manages the overall execution state of a workflow graph,
|
||||
coordinating between multiple node executions.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
started: bool = False
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
error: Optional[Exception] = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark the graph execution as started."""
|
||||
if self.started:
|
||||
raise RuntimeError("Graph execution already started")
|
||||
self.started = True
|
||||
|
||||
def complete(self) -> None:
|
||||
"""Mark the graph execution as completed."""
|
||||
if not self.started:
|
||||
raise RuntimeError("Cannot complete execution that hasn't started")
|
||||
if self.completed:
|
||||
raise RuntimeError("Graph execution already completed")
|
||||
self.completed = True
|
||||
|
||||
def abort(self, reason: str) -> None:
|
||||
"""Abort the graph execution."""
|
||||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
self.error = error
|
||||
self.completed = True
|
||||
|
||||
def get_or_create_node_execution(self, node_id: str) -> NodeExecution:
|
||||
"""Get or create a node execution entity."""
|
||||
if node_id not in self.node_executions:
|
||||
self.node_executions[node_id] = NodeExecution(node_id=node_id)
|
||||
return self.node_executions[node_id]
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the execution is currently running."""
|
||||
return self.started and not self.completed and not self.aborted
|
||||
|
||||
@property
|
||||
def has_error(self) -> bool:
|
||||
"""Check if the execution has encountered an error."""
|
||||
return self.error is not None
|
||||
|
||||
@property
|
||||
def error_message(self) -> str | None:
|
||||
"""Get the error message if an error exists."""
|
||||
if not self.error:
|
||||
return None
|
||||
return str(self.error)
|
||||
46
api/core/workflow/graph_engine/domain/node_execution.py
Normal file
46
api/core/workflow/graph_engine/domain/node_execution.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""
|
||||
NodeExecution entity representing a node's execution state.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeExecution:
|
||||
"""
|
||||
Entity representing the execution state of a single node.
|
||||
|
||||
This is a mutable entity that tracks the runtime state of a node
|
||||
during graph execution.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = NodeState.UNKNOWN
|
||||
retry_count: int = 0
|
||||
execution_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def mark_started(self, execution_id: str) -> None:
|
||||
"""Mark the node as started with an execution ID."""
|
||||
self.state = NodeState.TAKEN
|
||||
self.execution_id = execution_id
|
||||
|
||||
def mark_taken(self) -> None:
|
||||
"""Mark the node as successfully completed."""
|
||||
self.state = NodeState.TAKEN
|
||||
self.error = None
|
||||
|
||||
def mark_failed(self, error: str) -> None:
|
||||
"""Mark the node as failed with an error."""
|
||||
self.error = error
|
||||
|
||||
def mark_skipped(self) -> None:
|
||||
"""Mark the node as skipped."""
|
||||
self.state = NodeState.SKIPPED
|
||||
|
||||
def increment_retry(self) -> None:
|
||||
"""Increment the retry count for this node."""
|
||||
self.retry_count += 1
|
||||
@ -1,6 +0,0 @@
|
||||
from .graph import Graph
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .runtime_route_state import RuntimeRouteState
|
||||
|
||||
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
|
||||
33
api/core/workflow/graph_engine/entities/commands.py
Normal file
33
api/core/workflow/graph_engine/entities/commands.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""
|
||||
GraphEngine command entities for external control.
|
||||
|
||||
This module defines command types that can be sent to a running GraphEngine
|
||||
instance to control its execution flow.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CommandType(str, Enum):
|
||||
"""Types of commands that can be sent to GraphEngine."""
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
RESUME = "resume"
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
"""Base class for all GraphEngine commands."""
|
||||
|
||||
command_type: CommandType = Field(..., description="Type of command")
|
||||
payload: Optional[dict[str, Any]] = Field(default=None, description="Optional command payload")
|
||||
|
||||
|
||||
class AbortCommand(GraphEngineCommand):
|
||||
"""Command to abort a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
|
||||
reason: Optional[str] = Field(default=None, description="Optional reason for abort")
|
||||
@ -1,277 +0,0 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
###########################################
|
||||
# Graph Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
"""outputs"""
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: int = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
exceptions_count: int = Field(..., description="exception count")
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
###########################################
|
||||
# Node Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
id: str = Field(..., description="node execution id")
|
||||
node_id: str = Field(..., description="node id")
|
||||
node_type: NodeType = Field(..., description="node type")
|
||||
node_data: BaseNodeData = Field(..., description="node data")
|
||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
predecessor_node_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteration node parallel mode run id"""
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
pass
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeInLoopFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
start_at: datetime = Field(..., description="retry start time")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseParallelBranchEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
"""parallel id"""
|
||||
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
||||
"""parallel start node id"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Iteration Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseIterationEvent(GraphEngineEvent):
|
||||
iteration_id: str = Field(..., description="iteration node execution id")
|
||||
iteration_node_id: str = Field(..., description="iteration node id")
|
||||
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
|
||||
iteration_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class IterationRunNextEvent(BaseIterationEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Loop Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseLoopEvent(GraphEngineEvent):
|
||||
loop_id: str = Field(..., description="loop node execution id")
|
||||
loop_node_id: str = Field(..., description="loop node id")
|
||||
loop_node_type: NodeType = Field(..., description="node type, loop or loop")
|
||||
loop_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""loop run in parallel mode run id"""
|
||||
|
||||
|
||||
class LoopRunStartedEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class LoopRunNextEvent(BaseLoopEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Optional[Any] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class LoopRunSucceededEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
loop_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class LoopRunFailedEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Agent Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseAgentEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class AgentLogEvent(BaseAgentEvent):
|
||||
id: str = Field(..., description="id")
|
||||
label: str = Field(..., description="label")
|
||||
node_execution_id: str = Field(..., description="node execution id")
|
||||
parent_id: str | None = Field(..., description="parent id")
|
||||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
||||
node_id: str = Field(..., description="agent node id")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent
|
||||
@ -1,721 +0,0 @@
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
|
||||
from core.workflow.nodes.end.entities import EndStreamParam
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""run condition"""
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id"""
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=dict, description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict, description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict, description="reverse graph edge mapping (target node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict, description="graph parallel mapping (parallel id: parallel)"
|
||||
)
|
||||
node_parallel_mapping: dict[str, str] = Field(
|
||||
default_factory=dict, description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes")
|
||||
end_stream_param: EndStreamParam = Field(..., description="end stream param")
|
||||
|
||||
@classmethod
|
||||
def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
||||
:param graph_config: graph config
|
||||
:param root_node_id: root node id
|
||||
:return: graph
|
||||
"""
|
||||
# edge configs
|
||||
edge_configs = graph_config.get("edges")
|
||||
if edge_configs is None:
|
||||
edge_configs = []
|
||||
# node configs
|
||||
node_configs = graph_config.get("nodes")
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
node_configs = cast(list, node_configs)
|
||||
|
||||
# reorganize edges mapping
|
||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
fail_branch_source_node_id = [
|
||||
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
|
||||
]
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get("source")
|
||||
if not source_node_id:
|
||||
continue
|
||||
|
||||
if source_node_id not in edge_mapping:
|
||||
edge_mapping[source_node_id] = []
|
||||
|
||||
target_node_id = edge_config.get("target")
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get("sourceHandle"):
|
||||
if (
|
||||
edge_config.get("source") in fail_branch_source_node_id
|
||||
and edge_config.get("sourceHandle") != "fail-branch"
|
||||
):
|
||||
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
|
||||
elif edge_config.get("sourceHandle") != "source":
|
||||
run_condition = RunCondition(
|
||||
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
|
||||
)
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
|
||||
)
|
||||
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
if node_id not in target_edge_ids:
|
||||
root_node_configs.append(node_config)
|
||||
|
||||
all_node_id_config_mapping[node_id] = node_config
|
||||
|
||||
root_node_ids = [node_config.get("id") for node_config in root_node_configs]
|
||||
|
||||
# fetch root node
|
||||
if not root_node_id:
|
||||
# if no root node id, use the START type node as root node
|
||||
root_node_id = next(
|
||||
(
|
||||
node_config.get("id")
|
||||
for node_config in root_node_configs
|
||||
if node_config.get("data", {}).get("type", "") == NodeType.START.value
|
||||
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not root_node_id or root_node_id not in root_node_ids:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
|
||||
# Check whether it is connected to the previous node
|
||||
cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping)
|
||||
|
||||
# fetch all node ids from root node
|
||||
node_ids = [root_node_id]
|
||||
cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id)
|
||||
|
||||
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
|
||||
|
||||
# init parallel mapping
|
||||
parallel_mapping: dict[str, GraphParallel] = {}
|
||||
node_parallel_mapping: dict[str, str] = {}
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=root_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
)
|
||||
|
||||
# Check if it exceeds N layers of parallel
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
parent_parallel_id=parallel.parent_parallel_id,
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping
|
||||
)
|
||||
|
||||
# init end stream param
|
||||
end_stream_param = EndStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = cls(
|
||||
root_node_id=root_node_id,
|
||||
node_ids=node_ids,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
answer_stream_generate_routes=answer_stream_generate_routes,
|
||||
end_stream_param=end_stream_param,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
def add_extra_edge(
|
||||
self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None
|
||||
) -> None:
|
||||
"""
|
||||
Add extra edge to the graph
|
||||
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param run_condition: run condition
|
||||
"""
|
||||
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
|
||||
return
|
||||
|
||||
if source_node_id not in self.edge_mapping:
|
||||
self.edge_mapping[source_node_id] = []
|
||||
|
||||
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
|
||||
return
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
|
||||
)
|
||||
|
||||
self.edge_mapping[source_node_id].append(graph_edge)
|
||||
|
||||
def get_leaf_node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get leaf node ids of the graph
|
||||
|
||||
:return: leaf node ids
|
||||
"""
|
||||
leaf_node_ids = []
|
||||
for node_id in self.node_ids:
|
||||
if node_id not in self.edge_mapping or (
|
||||
len(self.edge_mapping[node_id]) == 1
|
||||
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id
|
||||
):
|
||||
leaf_node_ids.append(node_id)
|
||||
|
||||
return leaf_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(
|
||||
cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param node_ids: node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param node_id: node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(node_id, []):
|
||||
if graph_edge.target_node_id in node_ids:
|
||||
continue
|
||||
|
||||
node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None:
|
||||
"""
|
||||
Check whether it is connected to the previous node
|
||||
"""
|
||||
last_node_id = route[-1]
|
||||
|
||||
for graph_edge in edge_mapping.get(last_node_id, []):
|
||||
if not graph_edge.target_node_id:
|
||||
continue
|
||||
|
||||
if graph_edge.target_node_id in route:
|
||||
raise ValueError(
|
||||
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
|
||||
)
|
||||
|
||||
new_route = route.copy()
|
||||
new_route.append(graph_edge.target_node_id)
|
||||
cls._check_connected_to_previous_node(
|
||||
route=new_route,
|
||||
edge_mapping=edge_mapping,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
node_parallel_mapping: dict[str, str],
|
||||
parent_parallel: Optional[GraphParallel] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add parallel ids
|
||||
|
||||
:param edge_mapping: edge mapping
|
||||
:param start_node_id: start from node id
|
||||
:param parallel_mapping: parallel mapping
|
||||
:param node_parallel_mapping: node parallel mapping
|
||||
:param parent_parallel: parent parallel
|
||||
"""
|
||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||
parallel = None
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_branch_node_ids = defaultdict(list)
|
||||
condition_edge_mappings = defaultdict(list)
|
||||
for graph_edge in target_node_edges:
|
||||
if graph_edge.run_condition is None:
|
||||
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
|
||||
else:
|
||||
condition_hash = graph_edge.run_condition.hash
|
||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||
|
||||
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||
if len(graph_edges) > 1:
|
||||
for graph_edge in graph_edges:
|
||||
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
|
||||
|
||||
condition_parallels = {}
|
||||
for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
|
||||
# any target node id in node_parallel_mapping
|
||||
parallel = None
|
||||
if condition_parallel_branch_node_ids:
|
||||
parent_parallel_id = parent_parallel.id if parent_parallel else None
|
||||
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
condition_parallels[condition_hash] = parallel
|
||||
|
||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_branch_node_ids=condition_parallel_branch_node_ids,
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
parallel_node_ids = []
|
||||
for _, node_ids in in_branch_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
in_parent_parallel = True
|
||||
if parent_parallel_id:
|
||||
in_parent_parallel = False
|
||||
for parallel_node_id, parallel_id in node_parallel_mapping.items():
|
||||
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
|
||||
in_parent_parallel = True
|
||||
break
|
||||
|
||||
if in_parent_parallel:
|
||||
parallel_node_ids.append(node_id)
|
||||
node_parallel_mapping[node_id] = parallel.id
|
||||
|
||||
outside_parallel_target_node_ids = set()
|
||||
for node_id in parallel_node_ids:
|
||||
if node_id == parallel.start_from_node_id:
|
||||
continue
|
||||
|
||||
node_edges = edge_mapping.get(node_id)
|
||||
if not node_edges:
|
||||
continue
|
||||
|
||||
if len(node_edges) > 1:
|
||||
continue
|
||||
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if target_node_id in parallel_node_ids:
|
||||
continue
|
||||
|
||||
if parent_parallel_id:
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
continue
|
||||
|
||||
if (
|
||||
(
|
||||
node_parallel_mapping.get(target_node_id)
|
||||
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
|
||||
)
|
||||
or (
|
||||
parent_parallel
|
||||
and parent_parallel.end_to_node_id
|
||||
and target_node_id == parent_parallel.end_to_node_id
|
||||
)
|
||||
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
||||
):
|
||||
outside_parallel_target_node_ids.add(target_node_id)
|
||||
|
||||
if len(outside_parallel_target_node_ids) == 1:
|
||||
if (
|
||||
parent_parallel
|
||||
and parent_parallel.end_to_node_id
|
||||
and parallel.end_to_node_id == parent_parallel.end_to_node_id
|
||||
):
|
||||
parallel.end_to_node_id = None
|
||||
else:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
|
||||
if condition_edge_mappings:
|
||||
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||
for graph_edge in graph_edges:
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=condition_parallels.get(condition_hash),
|
||||
parent_parallel=parent_parallel,
|
||||
)
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel,
|
||||
)
|
||||
else:
|
||||
for graph_edge in target_node_edges:
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=parallel,
|
||||
parent_parallel=parent_parallel,
|
||||
)
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel,
|
||||
)
|
||||
else:
|
||||
for graph_edge in target_node_edges:
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=parallel,
|
||||
parent_parallel=parent_parallel,
|
||||
)
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_current_parallel(
|
||||
cls,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
graph_edge: GraphEdge,
|
||||
parallel: Optional[GraphParallel] = None,
|
||||
parent_parallel: Optional[GraphParallel] = None,
|
||||
) -> Optional[GraphParallel]:
|
||||
"""
|
||||
Get current parallel
|
||||
"""
|
||||
current_parallel = None
|
||||
if parallel:
|
||||
current_parallel = parallel
|
||||
elif parent_parallel:
|
||||
if not parent_parallel.end_to_node_id or (
|
||||
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
|
||||
):
|
||||
current_parallel = parent_parallel
|
||||
else:
|
||||
# fetch parent parallel's parent parallel
|
||||
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
|
||||
if parent_parallel_parent_parallel_id:
|
||||
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
|
||||
if parent_parallel_parent_parallel and (
|
||||
not parent_parallel_parent_parallel.end_to_node_id
|
||||
or (
|
||||
parent_parallel_parent_parallel.end_to_node_id
|
||||
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
|
||||
)
|
||||
):
|
||||
current_parallel = parent_parallel_parent_parallel
|
||||
|
||||
return current_parallel
|
||||
|
||||
@classmethod
|
||||
def _check_exceed_parallel_limit(
|
||||
cls,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
level_limit: int,
|
||||
parent_parallel_id: str,
|
||||
current_level: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Check if it exceeds N layers of parallel
|
||||
"""
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
return
|
||||
|
||||
current_level += 1
|
||||
if current_level > level_limit:
|
||||
raise ValueError(f"Exceeds {level_limit} layers of parallel")
|
||||
|
||||
if parent_parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=level_limit,
|
||||
parent_parallel_id=parent_parallel.parent_parallel_id,
|
||||
current_level=current_level,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallel_node_ids(
|
||||
cls,
|
||||
branch_node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param branch_node_ids: in branch node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param merge_node_id: merge node id
|
||||
:param start_node_id: start node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(start_node_id, []):
|
||||
if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids:
|
||||
branch_node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=branch_node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch_all_node_ids_in_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
parallel_branch_node_ids: list[str],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch all node ids in parallels
|
||||
"""
|
||||
routes_node_ids: dict[str, list[str]] = {}
|
||||
for parallel_branch_node_id in parallel_branch_node_ids:
|
||||
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
|
||||
|
||||
# fetch routes node ids
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=parallel_branch_node_id,
|
||||
routes_node_ids=routes_node_ids[parallel_branch_node_id],
|
||||
)
|
||||
|
||||
# fetch leaf node ids from routes node ids
|
||||
leaf_node_ids: dict[str, list[str]] = {}
|
||||
merge_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
|
||||
if branch_node_id not in leaf_node_ids:
|
||||
leaf_node_ids[branch_node_id] = []
|
||||
|
||||
leaf_node_ids[branch_node_id].append(node_id)
|
||||
|
||||
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||
if (
|
||||
branch_node_id != branch_node_id2
|
||||
and node_id in inner_route2
|
||||
and len(reverse_edge_mapping.get(node_id, [])) > 1
|
||||
and cls._is_node_in_routes(
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=node_id,
|
||||
routes_node_ids=routes_node_ids,
|
||||
)
|
||||
):
|
||||
if node_id not in merge_branch_node_ids:
|
||||
merge_branch_node_ids[node_id] = []
|
||||
|
||||
if branch_node_id2 not in merge_branch_node_ids[node_id]:
|
||||
merge_branch_node_ids[node_id].append(branch_node_id2)
|
||||
|
||||
# sorted merge_branch_node_ids by branch_node_ids length desc
|
||||
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
|
||||
|
||||
duplicate_end_node_ids = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
|
||||
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
|
||||
if (node_id, node_id2) not in duplicate_end_node_ids and (
|
||||
node_id2,
|
||||
node_id,
|
||||
) not in duplicate_end_node_ids:
|
||||
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
|
||||
|
||||
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
|
||||
# check which node is after
|
||||
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
|
||||
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id2]
|
||||
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
|
||||
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id]
|
||||
|
||||
branches_merge_node_ids: dict[str, str] = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
if len(branch_node_ids) <= 1:
|
||||
continue
|
||||
|
||||
for branch_node_id in branch_node_ids:
|
||||
if branch_node_id in branches_merge_node_ids:
|
||||
continue
|
||||
|
||||
branches_merge_node_ids[branch_node_id] = node_id
|
||||
|
||||
in_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
in_branch_node_ids[branch_node_id] = []
|
||||
if branch_node_id not in branches_merge_node_ids:
|
||||
# all node ids in current branch is in this thread
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
in_branch_node_ids[branch_node_id].extend(node_ids)
|
||||
else:
|
||||
merge_node_id = branches_merge_node_ids[branch_node_id]
|
||||
if merge_node_id != branch_node_id:
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
|
||||
# fetch all node ids from branch_node_id and merge_node_id
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=branch_node_id,
|
||||
)
|
||||
|
||||
return in_branch_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_fetch_routes(
|
||||
cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
Recursively fetch route
|
||||
"""
|
||||
if start_node_id not in edge_mapping:
|
||||
return
|
||||
|
||||
for graph_edge in edge_mapping[start_node_id]:
|
||||
# find next node ids
|
||||
if graph_edge.target_node_id not in routes_node_ids:
|
||||
routes_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_node_in_routes(
|
||||
cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]]
|
||||
) -> bool:
|
||||
"""
|
||||
Recursively check if the node is in the routes
|
||||
"""
|
||||
if start_node_id not in reverse_edge_mapping:
|
||||
return False
|
||||
|
||||
all_routes_node_ids = set()
|
||||
parallel_start_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
all_routes_node_ids.update(node_ids)
|
||||
|
||||
if branch_node_id in reverse_edge_mapping:
|
||||
for graph_edge in reverse_edge_mapping[branch_node_id]:
|
||||
if graph_edge.source_node_id not in parallel_start_node_ids:
|
||||
parallel_start_node_ids[graph_edge.source_node_id] = []
|
||||
|
||||
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
|
||||
|
||||
for _, branch_node_ids in parallel_start_node_ids.items():
|
||||
if set(branch_node_ids) == set(routes_node_ids.keys()):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
|
||||
"""
|
||||
is node2 after node1
|
||||
"""
|
||||
if node1_id not in edge_mapping:
|
||||
return False
|
||||
|
||||
for graph_edge in edge_mapping[node1_id]:
|
||||
if graph_edge.target_node_id == node2_id:
|
||||
return True
|
||||
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@ -1,21 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
@ -1,31 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
"""variable pool"""
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
"""start time"""
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
|
||||
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
|
||||
#
|
||||
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
|
||||
# after a serialization and deserialization round trip.
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
||||
"""node run state"""
|
||||
@ -1,21 +0,0 @@
|
||||
import hashlib
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: Optional[str] = None
|
||||
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||
|
||||
conditions: Optional[list[Condition]] = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
||||
@ -1,118 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
class Status(Enum):
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
EXCEPTION = "exception"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""node state id"""
|
||||
|
||||
node_id: str
|
||||
"""node id"""
|
||||
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
"""node run result"""
|
||||
|
||||
status: Status = Status.RUNNING
|
||||
"""node status"""
|
||||
|
||||
start_at: datetime
|
||||
"""start time"""
|
||||
|
||||
paused_at: Optional[datetime] = None
|
||||
"""paused time"""
|
||||
|
||||
finished_at: Optional[datetime] = None
|
||||
"""finished time"""
|
||||
|
||||
failed_reason: Optional[str] = None
|
||||
"""failed reason"""
|
||||
|
||||
paused_by: Optional[str] = None
|
||||
"""paused by"""
|
||||
|
||||
index: int = 1
|
||||
|
||||
def set_finished(self, run_result: NodeRunResult) -> None:
|
||||
"""
|
||||
Node finished
|
||||
|
||||
:param run_result: run result
|
||||
"""
|
||||
if self.status in {
|
||||
RouteNodeState.Status.SUCCESS,
|
||||
RouteNodeState.Status.FAILED,
|
||||
RouteNodeState.Status.EXCEPTION,
|
||||
}:
|
||||
raise Exception(f"Route state {self.id} already finished")
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
self.status = RouteNodeState.Status.SUCCESS
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
self.status = RouteNodeState.Status.FAILED
|
||||
self.failed_reason = run_result.error
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||
self.status = RouteNodeState.Status.EXCEPTION
|
||||
self.failed_reason = run_result.error
|
||||
else:
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
self.node_run_result = run_result
|
||||
self.finished_at = naive_utc_now()
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
routes: dict[str, list[str]] = Field(
|
||||
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
|
||||
)
|
||||
|
||||
node_state_mapping: dict[str, RouteNodeState] = Field(
|
||||
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
|
||||
)
|
||||
|
||||
def create_node_state(self, node_id: str) -> RouteNodeState:
|
||||
"""
|
||||
Create node state
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
state = RouteNodeState(node_id=node_id, start_at=naive_utc_now())
|
||||
self.node_state_mapping[state.id] = state
|
||||
return state
|
||||
|
||||
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
|
||||
"""
|
||||
Add route to the graph state
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:param target_node_state_id: target node state id
|
||||
"""
|
||||
if source_node_state_id not in self.routes:
|
||||
self.routes[source_node_state_id] = []
|
||||
|
||||
self.routes[source_node_state_id].append(target_node_state_id)
|
||||
|
||||
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]:
|
||||
"""
|
||||
Get routes with node state by source node id
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:return: routes with node state
|
||||
"""
|
||||
return [
|
||||
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
|
||||
]
|
||||
22
api/core/workflow/graph_engine/error_handling/__init__.py
Normal file
22
api/core/workflow/graph_engine/error_handling/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""
|
||||
Error handling strategies for graph engine.
|
||||
|
||||
This package implements different error recovery strategies using
|
||||
the Strategy pattern for clean separation of concerns.
|
||||
"""
|
||||
|
||||
from .abort_strategy import AbortStrategy
|
||||
from .default_value_strategy import DefaultValueStrategy
|
||||
from .error_handler import ErrorHandler
|
||||
from .error_strategy import ErrorStrategy
|
||||
from .fail_branch_strategy import FailBranchStrategy
|
||||
from .retry_strategy import RetryStrategy
|
||||
|
||||
__all__ = [
|
||||
"AbortStrategy",
|
||||
"DefaultValueStrategy",
|
||||
"ErrorHandler",
|
||||
"ErrorStrategy",
|
||||
"FailBranchStrategy",
|
||||
"RetryStrategy",
|
||||
]
|
||||
@ -0,0 +1,37 @@
|
||||
"""
|
||||
Abort error strategy implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbortStrategy:
|
||||
"""
|
||||
Error strategy that aborts execution on failure.
|
||||
|
||||
This is the default strategy when no other strategy is specified.
|
||||
It stops the entire graph execution when a node fails.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
"""
|
||||
Handle error by aborting execution.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
graph: The workflow graph
|
||||
retry_count: Current retry attempt count (unused)
|
||||
|
||||
Returns:
|
||||
None - signals abortion
|
||||
"""
|
||||
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
|
||||
|
||||
# Return None to signal that execution should stop
|
||||
return None
|
||||
@ -0,0 +1,56 @@
|
||||
"""
|
||||
Default value error strategy implementation.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
class DefaultValueStrategy:
|
||||
"""
|
||||
Error strategy that uses default values on failure.
|
||||
|
||||
This strategy allows nodes to fail gracefully by providing
|
||||
predefined default output values.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
"""
|
||||
Handle error by using default values.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
graph: The workflow graph
|
||||
retry_count: Current retry attempt count (unused)
|
||||
|
||||
Returns:
|
||||
NodeRunExceptionEvent with default values
|
||||
"""
|
||||
node = graph.nodes[event.node_id]
|
||||
|
||||
outputs = {
|
||||
**node.default_value_dict,
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
}
|
||||
|
||||
return NodeRunExceptionEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategy.DEFAULT_VALUE,
|
||||
},
|
||||
),
|
||||
error=event.error,
|
||||
)
|
||||
@ -0,0 +1,82 @@
|
||||
"""
|
||||
Main error handler that coordinates error strategies.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
|
||||
from .abort_strategy import AbortStrategy
|
||||
from .default_value_strategy import DefaultValueStrategy
|
||||
from .fail_branch_strategy import FailBranchStrategy
|
||||
from .retry_strategy import RetryStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..domain import GraphExecution
|
||||
|
||||
|
||||
class ErrorHandler:
|
||||
"""
|
||||
Coordinates error handling strategies for node failures.
|
||||
|
||||
This acts as a facade for the various error strategies,
|
||||
selecting and applying the appropriate strategy based on
|
||||
node configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
|
||||
"""
|
||||
Initialize the error handler.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
graph_execution: The graph execution state
|
||||
"""
|
||||
self.graph = graph
|
||||
self.graph_execution = graph_execution
|
||||
|
||||
# Initialize strategies
|
||||
self.abort_strategy = AbortStrategy()
|
||||
self.retry_strategy = RetryStrategy()
|
||||
self.fail_branch_strategy = FailBranchStrategy()
|
||||
self.default_value_strategy = DefaultValueStrategy()
|
||||
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
Selects and applies the appropriate error strategy based on
|
||||
the node's configuration.
|
||||
|
||||
Args:
|
||||
event: The node failure event
|
||||
|
||||
Returns:
|
||||
Optional new event to process, or None to abort
|
||||
"""
|
||||
node = self.graph.nodes[event.node_id]
|
||||
# Get retry count from NodeExecution
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
retry_count = node_execution.retry_count
|
||||
|
||||
# First check if retry is configured and not exhausted
|
||||
if node.retry and retry_count < node.retry_config.max_retries:
|
||||
result = self.retry_strategy.handle_error(event, self.graph, retry_count)
|
||||
if result:
|
||||
# Retry count will be incremented when NodeRunRetryEvent is handled
|
||||
return result
|
||||
|
||||
# Apply configured error strategy
|
||||
strategy = node.error_strategy
|
||||
|
||||
if strategy is None:
|
||||
return self.abort_strategy.handle_error(event, self.graph, retry_count)
|
||||
elif strategy == ErrorStrategyEnum.FAIL_BRANCH:
|
||||
return self.fail_branch_strategy.handle_error(event, self.graph, retry_count)
|
||||
elif strategy == ErrorStrategyEnum.DEFAULT_VALUE:
|
||||
return self.default_value_strategy.handle_error(event, self.graph, retry_count)
|
||||
else:
|
||||
# Unknown strategy, default to abort
|
||||
return self.abort_strategy.handle_error(event, self.graph, retry_count)
|
||||
@ -0,0 +1,31 @@
|
||||
"""
|
||||
Base error strategy protocol.
|
||||
"""
|
||||
|
||||
from typing import Optional, Protocol
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
|
||||
|
||||
class ErrorStrategy(Protocol):
|
||||
"""
|
||||
Protocol for error handling strategies.
|
||||
|
||||
Each strategy implements a different approach to handling
|
||||
node execution failures.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
graph: The workflow graph
|
||||
retry_count: Current retry attempt count
|
||||
|
||||
Returns:
|
||||
Optional new event to process, or None to stop
|
||||
"""
|
||||
...
|
||||
@ -0,0 +1,54 @@
|
||||
"""
|
||||
Fail branch error strategy implementation.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
class FailBranchStrategy:
|
||||
"""
|
||||
Error strategy that continues execution via a fail branch.
|
||||
|
||||
This strategy converts failures to exceptions and routes execution
|
||||
through a designated fail-branch edge.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
"""
|
||||
Handle error by taking the fail branch.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
graph: The workflow graph
|
||||
retry_count: Current retry attempt count (unused)
|
||||
|
||||
Returns:
|
||||
NodeRunExceptionEvent to continue via fail branch
|
||||
"""
|
||||
outputs = {
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
}
|
||||
|
||||
return NodeRunExceptionEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
edge_source_handle="fail-branch",
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategy.FAIL_BRANCH,
|
||||
},
|
||||
),
|
||||
error=event.error,
|
||||
)
|
||||
@ -0,0 +1,51 @@
|
||||
"""
|
||||
Retry error strategy implementation.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunRetryEvent
|
||||
|
||||
|
||||
class RetryStrategy:
|
||||
"""
|
||||
Error strategy that retries failed nodes.
|
||||
|
||||
This strategy re-attempts node execution up to a configured
|
||||
maximum number of retries with configurable intervals.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
"""
|
||||
Handle error by retrying the node.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
graph: The workflow graph
|
||||
retry_count: Current retry attempt count
|
||||
|
||||
Returns:
|
||||
NodeRunRetryEvent if retry should occur, None otherwise
|
||||
"""
|
||||
node = graph.nodes[event.node_id]
|
||||
|
||||
# Check if we've exceeded max retries
|
||||
if not node.retry or retry_count >= node.retry_config.max_retries:
|
||||
return None
|
||||
|
||||
# Wait for retry interval
|
||||
time.sleep(node.retry_config.retry_interval_seconds)
|
||||
|
||||
# Create retry event
|
||||
return NodeRunRetryEvent(
|
||||
id=event.id,
|
||||
node_title=node.title,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_run_result=event.node_run_result,
|
||||
start_at=event.start_at,
|
||||
error=event.error,
|
||||
retry_index=retry_count + 1,
|
||||
)
|
||||
16
api/core/workflow/graph_engine/event_management/__init__.py
Normal file
16
api/core/workflow/graph_engine/event_management/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Event management subsystem for graph engine.
|
||||
|
||||
This package handles event routing, collection, and emission for
|
||||
workflow graph execution events.
|
||||
"""
|
||||
|
||||
from .event_collector import EventCollector
|
||||
from .event_emitter import EventEmitter
|
||||
from .event_handlers import EventHandlerRegistry
|
||||
|
||||
__all__ = [
|
||||
"EventCollector",
|
||||
"EventEmitter",
|
||||
"EventHandlerRegistry",
|
||||
]
|
||||
@ -0,0 +1,98 @@
|
||||
"""
|
||||
Event collector for buffering and managing events.
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from ..layers.base import Layer
|
||||
|
||||
|
||||
class EventCollector:
|
||||
"""
|
||||
Collects and buffers events for later retrieval.
|
||||
|
||||
This provides thread-safe event collection with support for
|
||||
notifying layers about events as they're collected.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the event collector."""
|
||||
self._events: list[GraphEngineEvent] = []
|
||||
self._lock = threading.Lock()
|
||||
self._layers: list[Layer] = []
|
||||
|
||||
def set_layers(self, layers: list[Layer]) -> None:
|
||||
"""
|
||||
Set the layers to notify on event collection.
|
||||
|
||||
Args:
|
||||
layers: List of layers to notify
|
||||
"""
|
||||
self._layers = layers
|
||||
|
||||
def collect(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Thread-safe method to collect an event.
|
||||
|
||||
Args:
|
||||
event: The event to collect
|
||||
"""
|
||||
with self._lock:
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
def get_events(self) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
Get all collected events.
|
||||
|
||||
Returns:
|
||||
List of collected events
|
||||
"""
|
||||
with self._lock:
|
||||
return list(self._events)
|
||||
|
||||
def get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
Get new events starting from a specific index.
|
||||
|
||||
Args:
|
||||
start_index: The index to start from
|
||||
|
||||
Returns:
|
||||
List of new events
|
||||
"""
|
||||
with self._lock:
|
||||
return list(self._events[start_index:])
|
||||
|
||||
def event_count(self) -> int:
|
||||
"""
|
||||
Get the current count of collected events.
|
||||
|
||||
Returns:
|
||||
Number of collected events
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._events)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all collected events."""
|
||||
with self._lock:
|
||||
self._events.clear()
|
||||
|
||||
def _notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Notify all layers of an event.
|
||||
|
||||
Layer exceptions are caught and logged to prevent disrupting collection.
|
||||
|
||||
Args:
|
||||
event: The event to send to layers
|
||||
"""
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_event(event)
|
||||
except Exception:
|
||||
# Silently ignore layer errors during collection
|
||||
pass
|
||||
@ -0,0 +1,56 @@
|
||||
"""
|
||||
Event emitter for yielding events to external consumers.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from .event_collector import EventCollector
|
||||
|
||||
|
||||
class EventEmitter:
|
||||
"""
|
||||
Emits collected events as a generator for external consumption.
|
||||
|
||||
This provides a generator interface for yielding events as they're
|
||||
collected, with proper synchronization for multi-threaded access.
|
||||
"""
|
||||
|
||||
def __init__(self, event_collector: EventCollector) -> None:
|
||||
"""
|
||||
Initialize the event emitter.
|
||||
|
||||
Args:
|
||||
event_collector: The collector to emit events from
|
||||
"""
|
||||
self.event_collector = event_collector
|
||||
self._execution_complete = threading.Event()
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete to stop the generator."""
|
||||
self._execution_complete.set()
|
||||
|
||||
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generator that yields events as they're collected.
|
||||
|
||||
Yields:
|
||||
GraphEngineEvent instances as they're processed
|
||||
"""
|
||||
yielded_count = 0
|
||||
|
||||
while not self._execution_complete.is_set() or yielded_count < self.event_collector.event_count():
|
||||
# Get new events since last yield
|
||||
new_events = self.event_collector.get_new_events(yielded_count)
|
||||
|
||||
# Yield any new events
|
||||
for event in new_events:
|
||||
yield event
|
||||
yielded_count += 1
|
||||
|
||||
# Small sleep to avoid busy waiting
|
||||
if not self._execution_complete.is_set() and not new_events:
|
||||
time.sleep(0.001)
|
||||
@ -0,0 +1,303 @@
|
||||
"""
|
||||
Event handler implementations for different event types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..error_handling import ErrorHandler
|
||||
from ..graph_traversal import BranchHandler, EdgeProcessor
|
||||
from ..state_management import ExecutionTracker, NodeStateManager
|
||||
from .event_collector import EventCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventHandlerRegistry:
|
||||
"""
|
||||
Registry of event handlers for different event types.
|
||||
|
||||
This centralizes the business logic for handling specific events,
|
||||
keeping it separate from the routing and collection infrastructure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_execution: GraphExecution,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
event_collector: Optional["EventCollector"] = None,
|
||||
branch_handler: Optional["BranchHandler"] = None,
|
||||
edge_processor: Optional["EdgeProcessor"] = None,
|
||||
node_state_manager: Optional["NodeStateManager"] = None,
|
||||
execution_tracker: Optional["ExecutionTracker"] = None,
|
||||
error_handler: Optional["ErrorHandler"] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the event handler registry.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
graph_runtime_state: Runtime state with variable pool
|
||||
graph_execution: Graph execution aggregate
|
||||
response_coordinator: Response stream coordinator
|
||||
event_collector: Optional event collector for collecting events
|
||||
branch_handler: Optional branch handler for branch node processing
|
||||
edge_processor: Optional edge processor for edge traversal
|
||||
node_state_manager: Optional node state manager
|
||||
execution_tracker: Optional execution tracker
|
||||
error_handler: Optional error handler
|
||||
"""
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.graph_execution = graph_execution
|
||||
self.response_coordinator = response_coordinator
|
||||
self.event_collector = event_collector
|
||||
self.branch_handler = branch_handler
|
||||
self.edge_processor = edge_processor
|
||||
self.node_state_manager = node_state_manager
|
||||
self.execution_tracker = execution_tracker
|
||||
self.error_handler = error_handler
|
||||
|
||||
def handle_event(self, event: GraphNodeEventBase) -> None:
|
||||
"""
|
||||
Handle any node event by dispatching to the appropriate handler.
|
||||
|
||||
Args:
|
||||
event: The event to handle
|
||||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id):
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
return
|
||||
|
||||
# Handle specific event types
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._handle_stream_chunk(event)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self._handle_node_succeeded(event)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._handle_node_failed(event)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
self._handle_node_exception(event)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self._handle_node_retry(event)
|
||||
elif isinstance(
|
||||
event,
|
||||
(
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
),
|
||||
):
|
||||
# Iteration and loop events are collected directly
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
else:
|
||||
# Collect unhandled events
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
logger.warning("Unhandled event type: %s", type(event).__name__)
|
||||
|
||||
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Handle node started event.
|
||||
|
||||
Args:
|
||||
event: The node started event
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self.response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
|
||||
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Handle stream chunk event with full processing.
|
||||
|
||||
Args:
|
||||
event: The stream chunk event
|
||||
"""
|
||||
# Process with response coordinator
|
||||
streaming_events = list(self.response_coordinator.intercept_event(event))
|
||||
|
||||
# Collect all events
|
||||
if self.event_collector:
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Handle node success by coordinating subsystems.
|
||||
|
||||
This method coordinates between different subsystems to process
|
||||
node completion, handle edges, and trigger downstream execution.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = list(self.response_coordinator.intercept_event(event))
|
||||
if self.event_collector:
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
|
||||
# Process edges and get ready nodes
|
||||
node = self.graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
if self.branch_handler:
|
||||
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = [], []
|
||||
else:
|
||||
if self.edge_processor:
|
||||
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = [], []
|
||||
|
||||
# Collect streaming events from edge processing
|
||||
if self.event_collector:
|
||||
for edge_event in edge_streaming_events:
|
||||
self.event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
if self.node_state_manager and self.execution_tracker:
|
||||
for node_id in ready_nodes:
|
||||
self.node_state_manager.enqueue_node(node_id)
|
||||
self.execution_tracker.add(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event)
|
||||
|
||||
# Collect the event
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Handle node failure using error handler.
|
||||
|
||||
Args:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
|
||||
if self.error_handler:
|
||||
result = self.error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.handle_event(result)
|
||||
else:
|
||||
# Abort execution
|
||||
self.graph_execution.fail(RuntimeError(event.error))
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
else:
|
||||
# Without error handler, just fail
|
||||
self.graph_execution.fail(RuntimeError(event.error))
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
Handle node exception event (fail-branch strategy).
|
||||
|
||||
Args:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch, so it's technically "succeeded"
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
||||
"""
|
||||
Handle node retry event.
|
||||
|
||||
Args:
|
||||
event: The node retry event
|
||||
"""
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
if key == "answer":
|
||||
existing = self.graph_runtime_state.outputs.get("answer", "")
|
||||
if existing:
|
||||
self.graph_runtime_state.outputs["answer"] = f"{existing}{value}"
|
||||
else:
|
||||
self.graph_runtime_state.outputs["answer"] = value
|
||||
else:
|
||||
self.graph_runtime_state.outputs[key] = value
|
||||
File diff suppressed because it is too large
Load Diff
18
api/core/workflow/graph_engine/graph_traversal/__init__.py
Normal file
18
api/core/workflow/graph_engine/graph_traversal/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
"""
|
||||
Graph traversal subsystem for graph engine.
|
||||
|
||||
This package handles graph navigation, edge processing,
|
||||
and skip propagation logic.
|
||||
"""
|
||||
|
||||
from .branch_handler import BranchHandler
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .node_readiness import NodeReadinessChecker
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
__all__ = [
|
||||
"BranchHandler",
|
||||
"EdgeProcessor",
|
||||
"NodeReadinessChecker",
|
||||
"SkipPropagator",
|
||||
]
|
||||
@ -0,0 +1,82 @@
|
||||
"""
|
||||
Branch node handling for graph traversal.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
from ..state_management import EdgeStateManager
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
class BranchHandler:
|
||||
"""
|
||||
Handles branch node logic during graph traversal.
|
||||
|
||||
Branch nodes select one of multiple paths based on conditions,
|
||||
requiring special handling for edge selection and skip propagation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_processor: EdgeProcessor,
|
||||
skip_propagator: SkipPropagator,
|
||||
edge_state_manager: EdgeStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the branch handler.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
edge_processor: Processor for edges
|
||||
skip_propagator: Propagator for skip states
|
||||
edge_state_manager: Manager for edge states
|
||||
"""
|
||||
self.graph = graph
|
||||
self.edge_processor = edge_processor
|
||||
self.skip_propagator = skip_propagator
|
||||
self.edge_state_manager = edge_state_manager
|
||||
|
||||
def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected branch
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no branch was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self.skip_propagator.skip_branch_paths(node_id, unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.edge_processor.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
Validate that a branch selection is valid.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle to validate
|
||||
|
||||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
145
api/core/workflow/graph_engine/graph_traversal/edge_processor.py
Normal file
145
api/core/workflow/graph_engine/graph_traversal/edge_processor.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""
|
||||
Edge processing logic for graph traversal.
|
||||
"""
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
|
||||
|
||||
class EdgeProcessor:
|
||||
"""
|
||||
Processes edges during graph execution.
|
||||
|
||||
This handles marking edges as taken or skipped, notifying
|
||||
the response coordinator, and triggering downstream node execution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_state_manager: EdgeStateManager,
|
||||
node_state_manager: NodeStateManager,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the edge processor.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
edge_state_manager: Manager for edge states
|
||||
node_state_manager: Manager for node states
|
||||
response_coordinator: Response stream coordinator
|
||||
"""
|
||||
self.graph = graph
|
||||
self.edge_state_manager = edge_state_manager
|
||||
self.node_state_manager = node_state_manager
|
||||
self.response_coordinator = response_coordinator
|
||||
|
||||
def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]:
|
||||
"""
|
||||
Process edges after a node succeeds.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
selected_handle: For branch nodes, the selected edge handle
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
|
||||
"""
|
||||
node = self.graph.nodes[node_id]
|
||||
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
return self._process_branch_node_edges(node_id, selected_handle)
|
||||
else:
|
||||
return self._process_non_branch_node_edges(node_id)
|
||||
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]:
|
||||
"""
|
||||
Process edges for non-branch nodes (mark all as TAKEN).
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
"""
|
||||
ready_nodes = []
|
||||
all_streaming_events = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]:
|
||||
"""
|
||||
Process edges for branch nodes.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no edge was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} did not select any edge")
|
||||
|
||||
ready_nodes = []
|
||||
all_streaming_events = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Process unselected edges first (mark as skipped)
|
||||
for edge in unselected_edges:
|
||||
self._process_skipped_edge(edge)
|
||||
|
||||
# Process selected edges
|
||||
for edge in selected_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]:
|
||||
"""
|
||||
Mark edge as taken and check downstream node.
|
||||
|
||||
Args:
|
||||
edge: The edge to process
|
||||
|
||||
Returns:
|
||||
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
|
||||
"""
|
||||
# Mark edge as taken
|
||||
self.edge_state_manager.mark_edge_taken(edge.id)
|
||||
|
||||
# Notify response coordinator and get streaming events
|
||||
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes = []
|
||||
if self.node_state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, list(streaming_events)
|
||||
|
||||
def _process_skipped_edge(self, edge: Edge) -> None:
|
||||
"""
|
||||
Mark edge as skipped.
|
||||
|
||||
Args:
|
||||
edge: The edge to skip
|
||||
"""
|
||||
self.edge_state_manager.mark_edge_skipped(edge.id)
|
||||
@ -0,0 +1,83 @@
|
||||
"""
|
||||
Node readiness checking for execution.
|
||||
"""
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
class NodeReadinessChecker:
|
||||
"""
|
||||
Checks if nodes are ready for execution based on their dependencies.
|
||||
|
||||
A node is ready when its dependencies (incoming edges) have been
|
||||
satisfied according to the graph's execution rules.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph) -> None:
|
||||
"""
|
||||
Initialize the readiness checker.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
"""
|
||||
self.graph = graph
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when:
|
||||
- It has no incoming edges (root or isolated node), OR
|
||||
- At least one incoming edge is TAKEN and none are UNKNOWN
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
incoming_edges = self.graph.get_incoming_edges(node_id)
|
||||
|
||||
# No dependencies means always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# Check edge states
|
||||
has_unknown = False
|
||||
has_taken = False
|
||||
|
||||
for edge in incoming_edges:
|
||||
if edge.state == NodeState.UNKNOWN:
|
||||
has_unknown = True
|
||||
break
|
||||
elif edge.state == NodeState.TAKEN:
|
||||
has_taken = True
|
||||
|
||||
# Not ready if any dependency is still unknown
|
||||
if has_unknown:
|
||||
return False
|
||||
|
||||
# Ready if at least one path is taken
|
||||
return has_taken
|
||||
|
||||
def get_ready_downstream_nodes(self, from_node_id: str) -> list[str]:
|
||||
"""
|
||||
Get all downstream nodes that are ready after a node completes.
|
||||
|
||||
Args:
|
||||
from_node_id: The ID of the completed node
|
||||
|
||||
Returns:
|
||||
List of node IDs that are now ready
|
||||
"""
|
||||
ready_nodes = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.state == NodeState.TAKEN:
|
||||
downstream_node_id = edge.head
|
||||
if self.is_node_ready(downstream_node_id):
|
||||
ready_nodes.append(downstream_node_id)
|
||||
|
||||
return ready_nodes
|
||||
@ -0,0 +1,96 @@
|
||||
"""
|
||||
Skip state propagation through the graph.
|
||||
"""
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
|
||||
|
||||
class SkipPropagator:
|
||||
"""
|
||||
Propagates skip states through the graph.
|
||||
|
||||
When a node is skipped, this ensures all downstream nodes
|
||||
that depend solely on it are also skipped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_state_manager: EdgeStateManager,
|
||||
node_state_manager: NodeStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the skip propagator.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
edge_state_manager: Manager for edge states
|
||||
node_state_manager: Manager for node states
|
||||
"""
|
||||
self.graph = graph
|
||||
self.edge_state_manager = edge_state_manager
|
||||
self.node_state_manager = node_state_manager
|
||||
|
||||
def propagate_skip_from_edge(self, edge_id: str) -> None:
|
||||
"""
|
||||
Recursively propagate skip state from a skipped edge.
|
||||
|
||||
Rules:
|
||||
- If a node has any UNKNOWN incoming edges, stop processing
|
||||
- If all incoming edges are SKIPPED, skip the node and its edges
|
||||
- If any incoming edge is TAKEN, the node may still execute
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the skipped edge to start from
|
||||
"""
|
||||
downstream_node_id = self.graph.edges[edge_id].head
|
||||
incoming_edges = self.graph.get_incoming_edges(downstream_node_id)
|
||||
|
||||
# Analyze edge states
|
||||
edge_states = self.edge_state_manager.analyze_edge_states(incoming_edges)
|
||||
|
||||
# Stop if there are unknown edges (not yet processed)
|
||||
if edge_states["has_unknown"]:
|
||||
return
|
||||
|
||||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Check if node is ready and enqueue if so
|
||||
if self.node_state_manager.is_node_ready(downstream_node_id):
|
||||
self.node_state_manager.enqueue_node(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
if edge_states["all_skipped"]:
|
||||
self._propagate_skip_to_node(downstream_node_id)
|
||||
|
||||
def _propagate_skip_to_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node and all its outgoing edges as skipped.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
# Mark node as skipped
|
||||
self.node_state_manager.mark_node_skipped(node_id)
|
||||
|
||||
# Mark all outgoing edges as skipped and propagate
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
for edge in outgoing_edges:
|
||||
self.edge_state_manager.mark_edge_skipped(edge.id)
|
||||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
def skip_branch_paths(self, node_id: str, unselected_edges: list) -> None:
|
||||
"""
|
||||
Skip all paths from unselected branch edges.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
self.edge_state_manager.mark_edge_skipped(edge.id)
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
52
api/core/workflow/graph_engine/layers/README.md
Normal file
52
api/core/workflow/graph_engine/layers/README.md
Normal file
@ -0,0 +1,52 @@
|
||||
# Layers
|
||||
|
||||
Pluggable middleware for engine extensions.
|
||||
|
||||
## Components
|
||||
|
||||
### Layer (base)
|
||||
|
||||
Abstract base class for layers.
|
||||
|
||||
- `initialize()` - Receive runtime context
|
||||
- `on_graph_start()` - Execution start hook
|
||||
- `on_event()` - Process all events
|
||||
- `on_graph_end()` - Execution end hook
|
||||
|
||||
### DebugLoggingLayer
|
||||
|
||||
Comprehensive execution logging.
|
||||
|
||||
- Configurable detail levels
|
||||
- Tracks execution statistics
|
||||
- Truncates long values
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
debug_layer = DebugLoggingLayer(
|
||||
level="INFO",
|
||||
include_outputs=True
|
||||
)
|
||||
|
||||
engine = GraphEngine(graph)
|
||||
engine.add_layer(debug_layer)
|
||||
engine.run()
|
||||
```
|
||||
|
||||
## Custom Layers
|
||||
|
||||
```python
|
||||
class MetricsLayer(Layer):
|
||||
def on_event(self, event):
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self.metrics[event.node_id] = event.elapsed_time
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
**DebugLoggingLayer Options:**
|
||||
|
||||
- `level` - Log level (INFO, DEBUG, ERROR)
|
||||
- `include_inputs/outputs` - Log data values
|
||||
- `max_value_length` - Truncate long values
|
||||
16
api/core/workflow/graph_engine/layers/__init__.py
Normal file
16
api/core/workflow/graph_engine/layers/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Layer system for GraphEngine extensibility.
|
||||
|
||||
This module provides the layer infrastructure for extending GraphEngine functionality
|
||||
with middleware-like components that can observe events and interact with execution.
|
||||
"""
|
||||
|
||||
from .base import Layer
|
||||
from .debug_logging import DebugLoggingLayer
|
||||
from .execution_limits import ExecutionLimitsLayer
|
||||
|
||||
__all__ = [
|
||||
"DebugLoggingLayer",
|
||||
"ExecutionLimitsLayer",
|
||||
"Layer",
|
||||
]
|
||||
86
api/core/workflow/graph_engine/layers/base.py
Normal file
86
api/core/workflow/graph_engine/layers/base.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""
|
||||
Base layer class for GraphEngine extensions.
|
||||
|
||||
This module provides the abstract base class for implementing layers that can
|
||||
intercept and respond to GraphEngine events.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
|
||||
class Layer(ABC):
|
||||
"""
|
||||
Abstract base class for GraphEngine layers.
|
||||
|
||||
Layers are middleware-like components that can:
|
||||
- Observe all events emitted by the GraphEngine
|
||||
- Access the graph runtime state
|
||||
- Send commands to control execution
|
||||
|
||||
Subclasses should override the constructor to accept configuration parameters,
|
||||
then implement the three lifecycle methods.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self.graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
self.command_channel: Optional[CommandChannel] = None
|
||||
|
||||
def initialize(self, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
Initialize the layer with engine dependencies.
|
||||
|
||||
Called by GraphEngine before execution starts to inject the runtime state
|
||||
and command channel. This allows layers to access engine context and send
|
||||
commands.
|
||||
|
||||
Args:
|
||||
graph_runtime_state: The runtime state of the graph execution
|
||||
command_channel: Channel for sending commands to the engine
|
||||
"""
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.command_channel = command_channel
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
|
||||
This is called after the engine has been initialized but before any nodes
|
||||
are executed. Layers can use this to set up resources or log start information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
This method receives all events generated during graph execution, including:
|
||||
- Graph lifecycle events (start, success, failure)
|
||||
- Node execution events (start, success, failure, retry)
|
||||
- Stream events for response nodes
|
||||
- Container events (iteration, loop)
|
||||
|
||||
Args:
|
||||
event: The event emitted by the engine
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_end(self, error: Optional[Exception]) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
This is called after all nodes have been executed or when execution is
|
||||
aborted. Layers can use this to clean up resources or log final state.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
246
api/core/workflow/graph_engine/layers/debug_logging.py
Normal file
246
api/core/workflow/graph_engine/layers/debug_logging.py
Normal file
@ -0,0 +1,246 @@
|
||||
"""
|
||||
Debug logging layer for GraphEngine.
|
||||
|
||||
This module provides a layer that logs all events and state changes during
|
||||
graph execution for debugging purposes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .base import Layer
|
||||
|
||||
|
||||
class DebugLoggingLayer(Layer):
|
||||
"""
|
||||
A layer that provides comprehensive logging of GraphEngine execution.
|
||||
|
||||
This layer logs all events with configurable detail levels, helping developers
|
||||
debug workflow execution and understand the flow of events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
level: str = "INFO",
|
||||
include_inputs: bool = False,
|
||||
include_outputs: bool = True,
|
||||
include_process_data: bool = False,
|
||||
logger_name: str = "GraphEngine.Debug",
|
||||
max_value_length: int = 500,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the debug logging layer.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
include_inputs: Whether to log node input values
|
||||
include_outputs: Whether to log node output values
|
||||
include_process_data: Whether to log node process data
|
||||
logger_name: Name of the logger to use
|
||||
max_value_length: Maximum length of logged values (truncated if longer)
|
||||
"""
|
||||
super().__init__()
|
||||
self.level = level
|
||||
self.include_inputs = include_inputs
|
||||
self.include_outputs = include_outputs
|
||||
self.include_process_data = include_process_data
|
||||
self.max_value_length = max_value_length
|
||||
|
||||
# Set up logger
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
log_level = getattr(logging, level.upper(), logging.INFO)
|
||||
self.logger.setLevel(log_level)
|
||||
|
||||
# Track execution stats
|
||||
self.node_count = 0
|
||||
self.success_count = 0
|
||||
self.failure_count = 0
|
||||
self.retry_count = 0
|
||||
|
||||
def _truncate_value(self, value: Any) -> str:
|
||||
"""Truncate long values for logging."""
|
||||
str_value = str(value)
|
||||
if len(str_value) > self.max_value_length:
|
||||
return str_value[: self.max_value_length] + "... (truncated)"
|
||||
return str_value
|
||||
|
||||
def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str:
|
||||
"""Format a dictionary or mapping for logging with truncation."""
|
||||
if not data:
|
||||
return "{}"
|
||||
|
||||
formatted_items = []
|
||||
for key, value in data.items():
|
||||
formatted_value = self._truncate_value(value)
|
||||
formatted_items.append(f" {key}: {formatted_value}")
|
||||
|
||||
return "{\n" + ",\n".join(formatted_items) + "\n}"
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
"""Log graph execution start."""
|
||||
self.logger.info("=" * 80)
|
||||
self.logger.info("🚀 GRAPH EXECUTION STARTED")
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if self.graph_runtime_state:
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
|
||||
# Log inputs if available
|
||||
if self.graph_runtime_state.variable_pool:
|
||||
initial_vars = {}
|
||||
# Access the variable dictionary directly
|
||||
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
|
||||
for var_key, var in variables.items():
|
||||
initial_vars[f"{node_id}.{var_key}"] = str(var.value) if hasattr(var, "value") else str(var)
|
||||
|
||||
if initial_vars:
|
||||
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""Log individual events based on their type."""
|
||||
event_class = event.__class__.__name__
|
||||
|
||||
# Graph-level events
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.logger.debug("Graph run started event")
|
||||
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.logger.info("✅ Graph run succeeded")
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.error(" Total exceptions: %s", event.exceptions_count)
|
||||
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self.logger.warning("⚠️ Graph run aborted: %s", event.reason)
|
||||
if event.outputs:
|
||||
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
# Node-level events
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.node_count += 1
|
||||
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
|
||||
|
||||
if self.include_inputs and event.node_run_result.inputs:
|
||||
self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs))
|
||||
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.success_count += 1
|
||||
self.logger.info("✅ Node succeeded: %s", event.node_id)
|
||||
|
||||
if self.include_outputs and event.node_run_result.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs))
|
||||
|
||||
if self.include_process_data and event.node_run_result.process_data:
|
||||
self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data))
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.failure_count += 1
|
||||
self.logger.error("❌ Node failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
if event.node_run_result.error:
|
||||
self.logger.error(" Details: %s", event.node_run_result.error)
|
||||
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
|
||||
self.logger.warning(" Error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
# Log stream chunks at debug level to avoid spam
|
||||
final_indicator = " (FINAL)" if event.is_final else ""
|
||||
self.logger.debug(
|
||||
"📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk)
|
||||
)
|
||||
|
||||
# Iteration events
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
self.logger.info("🔁 Iteration started: %s", event.node_id)
|
||||
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index)
|
||||
|
||||
elif isinstance(event, NodeRunIterationSucceededEvent):
|
||||
self.logger.info("✅ Iteration succeeded: %s", event.node_id)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, NodeRunIterationFailedEvent):
|
||||
self.logger.error("❌ Iteration failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
# Loop events
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
self.logger.info("🔄 Loop started: %s", event.node_id)
|
||||
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index)
|
||||
|
||||
elif isinstance(event, NodeRunLoopSucceededEvent):
|
||||
self.logger.info("✅ Loop succeeded: %s", event.node_id)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, NodeRunLoopFailedEvent):
|
||||
self.logger.error("❌ Loop failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
else:
|
||||
# Log unknown events at debug level
|
||||
self.logger.debug("Event: %s", event_class)
|
||||
|
||||
def on_graph_end(self, error: Optional[Exception]) -> None:
|
||||
"""Log graph execution end with summary statistics."""
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if error:
|
||||
self.logger.error("🔴 GRAPH EXECUTION FAILED")
|
||||
self.logger.error(" Error: %s", error)
|
||||
else:
|
||||
self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY")
|
||||
|
||||
# Log execution statistics
|
||||
self.logger.info("Execution Statistics:")
|
||||
self.logger.info(" Total nodes executed: %s", self.node_count)
|
||||
self.logger.info(" Successful nodes: %s", self.success_count)
|
||||
self.logger.info(" Failed nodes: %s", self.failure_count)
|
||||
self.logger.info(" Node retries: %s", self.retry_count)
|
||||
|
||||
# Log final state if available
|
||||
if self.graph_runtime_state and self.include_outputs:
|
||||
if self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
|
||||
self.logger.info("=" * 80)
|
||||
144
api/core/workflow/graph_engine/layers/execution_limits.py
Normal file
144
api/core/workflow/graph_engine/layers/execution_limits.py
Normal file
@ -0,0 +1,144 @@
|
||||
"""
|
||||
Execution limits layer for GraphEngine.
|
||||
|
||||
This layer monitors workflow execution to enforce limits on:
|
||||
- Maximum execution steps
|
||||
- Maximum execution time
|
||||
|
||||
When limits are exceeded, the layer automatically aborts execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers import Layer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
||||
|
||||
class LimitType(Enum):
|
||||
"""Types of execution limits that can be exceeded."""
|
||||
|
||||
STEP_LIMIT = "step_limit"
|
||||
TIME_LIMIT = "time_limit"
|
||||
|
||||
|
||||
class ExecutionLimitsLayer(Layer):
|
||||
"""
|
||||
Layer that enforces execution limits for workflows.
|
||||
|
||||
Monitors:
|
||||
- Step count: Tracks number of node executions
|
||||
- Time limit: Monitors total execution time
|
||||
|
||||
Automatically aborts execution when limits are exceeded.
|
||||
"""
|
||||
|
||||
def __init__(self, max_steps: int, max_time: int) -> None:
|
||||
"""
|
||||
Initialize the execution limits layer.
|
||||
|
||||
Args:
|
||||
max_steps: Maximum number of execution steps allowed
|
||||
max_time: Maximum execution time in seconds allowed
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
self.max_time = max_time
|
||||
|
||||
# Runtime tracking
|
||||
self.start_time: Optional[float] = None
|
||||
self.step_count = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# State tracking
|
||||
self._execution_started = False
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False # Track if abort command has been sent
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
"""Called when graph execution starts."""
|
||||
self.start_time = time.time()
|
||||
self.step_count = 0
|
||||
self._execution_started = True
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False
|
||||
|
||||
self.logger.debug("Execution limits monitoring started")
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
Monitors execution progress and enforces limits.
|
||||
"""
|
||||
if not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Track step count for node execution events
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self.step_count += 1
|
||||
self.logger.debug("Step %d started: %s", self.step_count, event.node_id)
|
||||
|
||||
# Check step limit when node execution completes
|
||||
if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent):
|
||||
if self._reached_step_limitation():
|
||||
self._send_abort_command(LimitType.STEP_LIMIT)
|
||||
|
||||
if self._reached_time_limitation():
|
||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||
|
||||
def on_graph_end(self, error: Optional[Exception]) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._execution_started and not self._execution_ended:
|
||||
self._execution_ended = True
|
||||
|
||||
if self.start_time:
|
||||
total_time = time.time() - self.start_time
|
||||
self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time)
|
||||
|
||||
def _reached_step_limitation(self) -> bool:
|
||||
"""Check if step count limit has been exceeded."""
|
||||
return self.step_count > self.max_steps
|
||||
|
||||
def _reached_time_limitation(self) -> bool:
|
||||
"""Check if time limit has been exceeded."""
|
||||
return self.start_time is not None and (time.time() - self.start_time) > self.max_time
|
||||
|
||||
def _send_abort_command(self, limit_type: LimitType) -> None:
|
||||
"""
|
||||
Send abort command due to limit violation.
|
||||
|
||||
Args:
|
||||
limit_type: Type of limit exceeded
|
||||
"""
|
||||
if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Format detailed reason message
|
||||
if limit_type == LimitType.STEP_LIMIT:
|
||||
reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}"
|
||||
elif limit_type == LimitType.TIME_LIMIT:
|
||||
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
||||
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
||||
|
||||
self.logger.warning("Execution limit exceeded: %s", reason)
|
||||
|
||||
try:
|
||||
# Send abort command to the engine
|
||||
abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason)
|
||||
self.command_channel.send_command(abort_command)
|
||||
|
||||
# Mark that abort has been sent to prevent duplicate commands
|
||||
self._abort_sent = True
|
||||
|
||||
self.logger.debug("Abort command sent to engine")
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Failed to send abort command: %s")
|
||||
49
api/core/workflow/graph_engine/manager.py
Normal file
49
api/core/workflow/graph_engine/manager.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""
|
||||
GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class GraphEngineManager:
|
||||
"""
|
||||
Manager for sending control commands to GraphEngine instances.
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def send_stop_command(task_id: str, reason: Optional[str] = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
Args:
|
||||
task_id: The task ID of the workflow to stop
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
# Create Redis channel for this task
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
# Create and send abort command
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
|
||||
try:
|
||||
channel.send_command(abort_command)
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
# The legacy stop flag mechanism will still work
|
||||
pass
|
||||
14
api/core/workflow/graph_engine/orchestration/__init__.py
Normal file
14
api/core/workflow/graph_engine/orchestration/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Orchestration subsystem for graph engine.
|
||||
|
||||
This package coordinates the overall execution flow between
|
||||
different subsystems.
|
||||
"""
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
__all__ = [
|
||||
"Dispatcher",
|
||||
"ExecutionCoordinator",
|
||||
]
|
||||
104
api/core/workflow/graph_engine/orchestration/dispatcher.py
Normal file
104
api/core/workflow/graph_engine/orchestration/dispatcher.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""
|
||||
Main dispatcher for processing events from workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ..event_management import EventCollector, EventEmitter
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandlerRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
"""
|
||||
Main dispatcher that processes events from the event queue.
|
||||
|
||||
This runs in a separate thread and coordinates event processing
|
||||
with timeout and completion detection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue,
|
||||
event_handler: "EventHandlerRegistry",
|
||||
event_collector: EventCollector,
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
max_execution_time: int,
|
||||
event_emitter: Optional[EventEmitter] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dispatcher.
|
||||
|
||||
Args:
|
||||
event_queue: Queue of events from workers
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event collector for collecting unhandled events
|
||||
execution_coordinator: Coordinator for execution flow
|
||||
max_execution_time: Maximum execution time in seconds
|
||||
event_emitter: Optional event emitter to signal completion
|
||||
"""
|
||||
self.event_queue = event_queue
|
||||
self.event_handler = event_handler
|
||||
self.event_collector = event_collector
|
||||
self.execution_coordinator = execution_coordinator
|
||||
self.max_execution_time = max_execution_time
|
||||
self.event_emitter = event_emitter
|
||||
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
self._start_time: Optional[float] = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the dispatcher thread."""
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
# Check for commands
|
||||
self.execution_coordinator.check_commands()
|
||||
|
||||
# Check for scaling
|
||||
self.execution_coordinator.check_scaling()
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self.event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self.event_handler.handle_event(event)
|
||||
self.event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Check if execution is complete
|
||||
if self.execution_coordinator.is_execution_complete():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
self.execution_coordinator.mark_failed(e)
|
||||
|
||||
finally:
|
||||
self.execution_coordinator.mark_complete()
|
||||
# Signal the event emitter that execution is complete
|
||||
if self.event_emitter:
|
||||
self.event_emitter.mark_complete()
|
||||
@ -0,0 +1,91 @@
|
||||
"""
|
||||
Execution coordinator for managing overall workflow execution.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..event_management import EventCollector
|
||||
from ..state_management import ExecutionTracker, NodeStateManager
|
||||
from ..worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandlerRegistry
|
||||
|
||||
|
||||
class ExecutionCoordinator:
|
||||
"""
|
||||
Coordinates overall execution flow between subsystems.
|
||||
|
||||
This provides high-level coordination methods used by the
|
||||
dispatcher to manage execution state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
node_state_manager: NodeStateManager,
|
||||
execution_tracker: ExecutionTracker,
|
||||
event_handler: "EventHandlerRegistry",
|
||||
event_collector: EventCollector,
|
||||
command_processor: CommandProcessor,
|
||||
worker_pool: WorkerPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the execution coordinator.
|
||||
|
||||
Args:
|
||||
graph_execution: Graph execution aggregate
|
||||
node_state_manager: Manager for node states
|
||||
execution_tracker: Tracker for executing nodes
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event collector for collecting events
|
||||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self.graph_execution = graph_execution
|
||||
self.node_state_manager = node_state_manager
|
||||
self.execution_tracker = execution_tracker
|
||||
self.event_handler = event_handler
|
||||
self.event_collector = event_collector
|
||||
self.command_processor = command_processor
|
||||
self.worker_pool = worker_pool
|
||||
|
||||
def check_commands(self) -> None:
|
||||
"""Process any pending commands."""
|
||||
self.command_processor.process_commands()
|
||||
|
||||
def check_scaling(self) -> None:
|
||||
"""Check and perform worker scaling if needed."""
|
||||
queue_depth = self.node_state_manager.ready_queue.qsize()
|
||||
executing_count = self.execution_tracker.count()
|
||||
self.worker_pool.check_scaling(queue_depth, executing_count)
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
Check if execution is complete.
|
||||
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
# Check if aborted or failed
|
||||
if self.graph_execution.aborted or self.graph_execution.has_error:
|
||||
return True
|
||||
|
||||
# Complete if no work remains
|
||||
return self.node_state_manager.ready_queue.empty() and self.execution_tracker.is_empty()
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
if not self.graph_execution.completed:
|
||||
self.graph_execution.complete()
|
||||
|
||||
def mark_failed(self, error: Exception) -> None:
|
||||
"""
|
||||
Mark execution as failed.
|
||||
|
||||
Args:
|
||||
error: The error that caused failure
|
||||
"""
|
||||
self.graph_execution.fail(error)
|
||||
10
api/core/workflow/graph_engine/output_registry/__init__.py
Normal file
10
api/core/workflow/graph_engine/output_registry/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""
|
||||
OutputRegistry - Thread-safe storage for node outputs (streams and scalars)
|
||||
|
||||
This component provides thread-safe storage and retrieval of node outputs,
|
||||
supporting both scalar values and streaming chunks with proper state management.
|
||||
"""
|
||||
|
||||
from .registry import OutputRegistry
|
||||
|
||||
__all__ = ["OutputRegistry"]
|
||||
145
api/core/workflow/graph_engine/output_registry/registry.py
Normal file
145
api/core/workflow/graph_engine/output_registry/registry.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""
|
||||
Main OutputRegistry implementation.
|
||||
|
||||
This module contains the public OutputRegistry class that provides
|
||||
thread-safe storage for node outputs.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from core.variables import Segment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from .stream import Stream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
class OutputRegistry:
|
||||
"""
|
||||
Thread-safe registry for storing and retrieving node outputs.
|
||||
|
||||
Supports both scalar values and streaming chunks with proper state management.
|
||||
All operations are thread-safe using internal locking.
|
||||
"""
|
||||
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
"""Initialize empty registry with thread-safe storage."""
|
||||
self._lock = RLock()
|
||||
self._scalars = variable_pool
|
||||
self._streams: dict[tuple, Stream] = {}
|
||||
|
||||
def _selector_to_key(self, selector: Sequence[str]) -> tuple:
|
||||
"""Convert selector list to tuple key for internal storage."""
|
||||
return tuple(selector)
|
||||
|
||||
def set_scalar(self, selector: Sequence[str], value: Union[str, int, float, bool, dict, list]) -> None:
|
||||
"""
|
||||
Set a scalar value for the given selector.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the output location
|
||||
value: The scalar value to store
|
||||
"""
|
||||
with self._lock:
|
||||
self._scalars.add(selector, value)
|
||||
|
||||
def get_scalar(self, selector: Sequence[str]) -> Optional["Segment"]:
|
||||
"""
|
||||
Get a scalar value for the given selector.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the output location
|
||||
|
||||
Returns:
|
||||
The stored Variable object, or None if not found
|
||||
"""
|
||||
with self._lock:
|
||||
return self._scalars.get(selector)
|
||||
|
||||
def append_chunk(self, selector: Sequence[str], event: "NodeRunStreamChunkEvent") -> None:
|
||||
"""
|
||||
Append a NodeRunStreamChunkEvent to the stream for the given selector.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
self._streams[key] = Stream()
|
||||
|
||||
try:
|
||||
self._streams[key].append(event)
|
||||
except ValueError:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
def pop_chunk(self, selector: Sequence[str]) -> Optional["NodeRunStreamChunkEvent"]:
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
return None
|
||||
|
||||
return self._streams[key].pop_next()
|
||||
|
||||
def has_unread(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
return False
|
||||
|
||||
return self._streams[key].has_unread()
|
||||
|
||||
def close_stream(self, selector: Sequence[str]) -> None:
|
||||
"""
|
||||
Mark a stream as closed (no more chunks can be appended).
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
self._streams[key] = Stream()
|
||||
self._streams[key].close()
|
||||
|
||||
def stream_closed(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if a stream is closed.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if the stream is closed, False otherwise
|
||||
"""
|
||||
key = self._selector_to_key(selector)
|
||||
with self._lock:
|
||||
if key not in self._streams:
|
||||
return False
|
||||
return self._streams[key].is_closed
|
||||
69
api/core/workflow/graph_engine/output_registry/stream.py
Normal file
69
api/core/workflow/graph_engine/output_registry/stream.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""
|
||||
Internal stream implementation for OutputRegistry.
|
||||
|
||||
This module contains the private Stream class used internally by OutputRegistry
|
||||
to manage streaming data chunks.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
class Stream:
|
||||
"""
|
||||
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.
|
||||
|
||||
This class encapsulates stream-specific data and operations,
|
||||
including event storage, read position tracking, and closed state.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an empty stream."""
|
||||
self.events: list[NodeRunStreamChunkEvent] = []
|
||||
self.read_position: int = 0
|
||||
self.is_closed: bool = False
|
||||
|
||||
def append(self, event: "NodeRunStreamChunkEvent") -> None:
|
||||
"""
|
||||
Append a NodeRunStreamChunkEvent to the stream.
|
||||
|
||||
Args:
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
if self.is_closed:
|
||||
raise ValueError("Cannot append to a closed stream")
|
||||
self.events.append(event)
|
||||
|
||||
def pop_next(self) -> Optional["NodeRunStreamChunkEvent"]:
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
if self.read_position >= len(self.events):
|
||||
return None
|
||||
|
||||
event = self.events[self.read_position]
|
||||
self.read_position += 1
|
||||
return event
|
||||
|
||||
def has_unread(self) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
return self.read_position < len(self.events)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Mark the stream as closed (no more chunks can be appended)."""
|
||||
self.is_closed = True
|
||||
41
api/core/workflow/graph_engine/protocols/command_channel.py
Normal file
41
api/core/workflow/graph_engine/protocols/command_channel.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""
|
||||
CommandChannel protocol for GraphEngine command communication.
|
||||
|
||||
This protocol defines the interface for sending and receiving commands
|
||||
to/from a GraphEngine instance, supporting both local and distributed scenarios.
|
||||
"""
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
class CommandChannel(Protocol):
|
||||
"""
|
||||
Protocol for bidirectional command communication with GraphEngine.
|
||||
|
||||
Since each GraphEngine instance processes only one workflow execution,
|
||||
this channel is dedicated to that single execution.
|
||||
"""
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch pending commands for this GraphEngine instance.
|
||||
|
||||
Called by GraphEngine to poll for commands that need to be processed.
|
||||
|
||||
Returns:
|
||||
List of pending commands (may be empty)
|
||||
"""
|
||||
...
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to be processed by this GraphEngine instance.
|
||||
|
||||
Called by external systems to send control commands to the running workflow.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
...
|
||||
@ -0,0 +1,10 @@
|
||||
"""
|
||||
ResponseStreamCoordinator - Coordinates streaming output from response nodes
|
||||
|
||||
This component manages response streaming sessions and ensures ordered streaming
|
||||
of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
from .coordinator import ResponseStreamCoordinator
|
||||
|
||||
__all__ = ["ResponseStreamCoordinator"]
|
||||
@ -0,0 +1,465 @@
|
||||
"""
|
||||
Main ResponseStreamCoordinator implementation.
|
||||
|
||||
This module contains the public ResponseStreamCoordinator class that manages
|
||||
response streaming sessions and ensures ordered streaming of responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import Optional, TypeAlias
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
|
||||
from ..output_registry import OutputRegistry
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type definitions
|
||||
NodeID: TypeAlias = str
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
class ResponseStreamCoordinator:
|
||||
"""
|
||||
Manages response streaming sessions without relying on global state.
|
||||
|
||||
Ensures ordered streaming of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
def __init__(self, registry: OutputRegistry, graph: "Graph") -> None:
|
||||
"""
|
||||
Initialize coordinator with output registry.
|
||||
|
||||
Args:
|
||||
registry: OutputRegistry instance for accessing node outputs
|
||||
graph: Graph instance for looking up node information
|
||||
"""
|
||||
self.registry = registry
|
||||
self.graph = graph
|
||||
self.active_session: Optional[ResponseSession] = None
|
||||
self.waiting_sessions: deque[ResponseSession] = deque()
|
||||
self.lock = RLock()
|
||||
|
||||
# Track response nodes
|
||||
self._response_nodes: set[NodeID] = set()
|
||||
|
||||
# Store paths for each response node
|
||||
self._paths_maps: dict[NodeID, list[Path]] = {}
|
||||
|
||||
# Track node execution IDs and types for proper event forwarding
|
||||
self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id
|
||||
|
||||
# Track response sessions to ensure only one per node
|
||||
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
|
||||
|
||||
def register(self, response_node_id: NodeID) -> None:
|
||||
with self.lock:
|
||||
self._response_nodes.add(response_node_id)
|
||||
|
||||
# Build and save paths map for this response node
|
||||
paths_map = self._build_paths_map(response_node_id)
|
||||
self._paths_maps[response_node_id] = paths_map
|
||||
|
||||
# Create and store response session for this node
|
||||
response_node = self.graph.nodes[response_node_id]
|
||||
session = ResponseSession.from_node(response_node)
|
||||
self._response_sessions[response_node_id] = session
|
||||
|
||||
def track_node_execution(self, node_id: NodeID, execution_id: str) -> None:
|
||||
"""Track the execution ID for a node when it starts executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
execution_id: The execution ID from NodeRunStartedEvent
|
||||
"""
|
||||
with self.lock:
|
||||
self._node_execution_ids[node_id] = execution_id
|
||||
|
||||
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
|
||||
"""Get the execution ID for a node, creating one if it doesn't exist.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The execution ID for the node
|
||||
"""
|
||||
with self.lock:
|
||||
if node_id not in self._node_execution_ids:
|
||||
self._node_execution_ids[node_id] = str(uuid4())
|
||||
return self._node_execution_ids[node_id]
|
||||
|
||||
def _build_paths_map(self, response_node_id: NodeID) -> list[Path]:
|
||||
"""
|
||||
Build a paths map for a response node by finding all paths from root node
|
||||
to the response node, recording branch edges along each path.
|
||||
|
||||
Args:
|
||||
response_node_id: ID of the response node to analyze
|
||||
|
||||
Returns:
|
||||
List of Path objects, where each path contains branch edge IDs
|
||||
"""
|
||||
# Get root node ID
|
||||
root_node_id = self.graph.root_node.id
|
||||
|
||||
# If root is the response node, return empty path
|
||||
if root_node_id == response_node_id:
|
||||
return [Path()]
|
||||
|
||||
# Extract variable selectors from the response node's template
|
||||
response_node = self.graph.nodes[response_node_id]
|
||||
response_session = ResponseSession.from_node(response_node)
|
||||
template = response_session.template
|
||||
|
||||
# Collect all variable selectors from the template
|
||||
variable_selectors: set[tuple[str, ...]] = set()
|
||||
for segment in template.segments:
|
||||
if isinstance(segment, VariableSegment):
|
||||
variable_selectors.add(tuple(segment.selector[:2]))
|
||||
|
||||
# Step 1: Find all complete paths from root to response node
|
||||
all_complete_paths: list[list[EdgeID]] = []
|
||||
|
||||
def find_paths(
|
||||
current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID]
|
||||
) -> None:
|
||||
"""Recursively find all paths from current node to target node."""
|
||||
if current_node_id == target_node_id:
|
||||
# Found a complete path, store it
|
||||
all_complete_paths.append(current_path.copy())
|
||||
return
|
||||
|
||||
# Mark as visited to avoid cycles
|
||||
visited.add(current_node_id)
|
||||
|
||||
# Explore outgoing edges
|
||||
outgoing_edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
for edge in outgoing_edges:
|
||||
edge_id = edge.id
|
||||
next_node_id = edge.head
|
||||
|
||||
# Skip if already visited in this path
|
||||
if next_node_id not in visited:
|
||||
# Add edge to path and recurse
|
||||
new_path = current_path + [edge_id]
|
||||
find_paths(next_node_id, target_node_id, new_path, visited.copy())
|
||||
|
||||
# Start searching from root node
|
||||
find_paths(root_node_id, response_node_id, [], set())
|
||||
|
||||
# Step 2: For each complete path, filter edges based on node blocking behavior
|
||||
filtered_paths: list[Path] = []
|
||||
for path in all_complete_paths:
|
||||
blocking_edges = []
|
||||
for edge_id in path:
|
||||
edge = self.graph.edges[edge_id]
|
||||
source_node = self.graph.nodes[edge.tail]
|
||||
|
||||
# Check if node is a branch/container (original behavior)
|
||||
if source_node.execution_type in {
|
||||
NodeExecutionType.BRANCH,
|
||||
NodeExecutionType.CONTAINER,
|
||||
} or source_node.blocks_variable_output(variable_selectors):
|
||||
blocking_edges.append(edge_id)
|
||||
|
||||
# Keep the path even if it's empty
|
||||
filtered_paths.append(Path(edges=blocking_edges))
|
||||
|
||||
return filtered_paths
|
||||
|
||||
def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
Handle when an edge is taken (selected by a branch node).
|
||||
|
||||
This method updates the paths for all response nodes by removing
|
||||
the taken edge. If any response node has an empty path after removal,
|
||||
it means the node is now deterministically reachable and should start.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge that was taken
|
||||
|
||||
Returns:
|
||||
List of events to emit from starting new sessions
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
with self.lock:
|
||||
# Check each response node in order
|
||||
for response_node_id in self._response_nodes:
|
||||
if response_node_id not in self._paths_maps:
|
||||
continue
|
||||
|
||||
paths = self._paths_maps[response_node_id]
|
||||
has_reachable_path = False
|
||||
|
||||
# Update each path by removing the taken edge
|
||||
for path in paths:
|
||||
# Remove the taken edge from this path
|
||||
path.remove_edge(edge_id)
|
||||
|
||||
# Check if this path is now empty (node is reachable)
|
||||
if path.is_empty():
|
||||
has_reachable_path = True
|
||||
|
||||
# If node is now reachable (has empty path), start/queue session
|
||||
if has_reachable_path:
|
||||
# Pass the node_id to the activation method
|
||||
# The method will handle checking and removing from map
|
||||
events.extend(self._active_or_queue_session(response_node_id))
|
||||
return events
|
||||
|
||||
def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
Start a session immediately if no active session, otherwise queue it.
|
||||
Only activates sessions that exist in the _response_sessions map.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the response node to activate
|
||||
|
||||
Returns:
|
||||
List of events from flush attempt if session started immediately
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Get the session from our map (only activate if it exists)
|
||||
session = self._response_sessions.get(node_id)
|
||||
if not session:
|
||||
return events
|
||||
|
||||
# Remove from map to ensure it won't be activated again
|
||||
del self._response_sessions[node_id]
|
||||
|
||||
if self.active_session is None:
|
||||
self.active_session = session
|
||||
|
||||
# Try to flush immediately
|
||||
events.extend(self.try_flush())
|
||||
else:
|
||||
# Queue the session if another is active
|
||||
self.waiting_sessions.append(session)
|
||||
|
||||
return events
|
||||
|
||||
def intercept_event(
|
||||
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
|
||||
) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
with self.lock:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.registry.append_chunk(event.selector, event)
|
||||
if event.is_final:
|
||||
self.registry.close_stream(event.selector)
|
||||
return self.try_flush()
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
# Skip cause we share the same variable pool.
|
||||
#
|
||||
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
# self.registry.set_scalar((event.node_id, variable_name), variable_value)
|
||||
return self.try_flush()
|
||||
return []
|
||||
|
||||
def _create_stream_chunk_event(
|
||||
self,
|
||||
node_id: str,
|
||||
execution_id: str,
|
||||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
For selectors with special prefixes (sys, env, conversation), we use the
|
||||
active response node's information since these are not actual node IDs.
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self.graph.nodes and self.active_session:
|
||||
# Use the active response node for special selectors
|
||||
response_node = self.graph.nodes[self.active_session.node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
node = self.graph.nodes[node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
"""Process a variable segment. Returns (events, is_complete).
|
||||
|
||||
Handles both regular node selectors and special system selectors (sys, env, conversation).
|
||||
For special selectors, we attribute the output to the active response node.
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
is_complete = False
|
||||
|
||||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self.active_session and source_selector_prefix not in self.graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self.active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self.registry.has_unread(segment.selector):
|
||||
if event := self.registry.pop_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self.active_session and source_selector_prefix not in self.graph.nodes:
|
||||
response_node = self.graph.nodes[self.active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector, # Keep original selector
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
# Regular node selector - use event as is
|
||||
events.append(event)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self.registry.stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
|
||||
elif value := self.registry.get_scalar(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
|
||||
return events, is_complete
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self.active_session is not None
|
||||
current_response_node = self.graph.nodes[self.active_session.node_id]
|
||||
|
||||
# Use get_or_create_execution_id to ensure we have a consistent ID
|
||||
execution_id = self._get_or_create_execution_id(current_response_node.id)
|
||||
|
||||
is_last_segment = self.active_session.index == len(self.active_session.template.segments) - 1
|
||||
event = self._create_stream_chunk_event(
|
||||
node_id=current_response_node.id,
|
||||
execution_id=execution_id,
|
||||
selector=[current_response_node.id, "answer"], # FIXME(-LAN-)
|
||||
chunk=segment.text,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
return [event]
|
||||
|
||||
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
|
||||
with self.lock:
|
||||
if not self.active_session:
|
||||
return []
|
||||
|
||||
template = self.active_session.template
|
||||
response_node_id = self.active_session.node_id
|
||||
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Process segments sequentially from current index
|
||||
while self.active_session.index < len(template.segments):
|
||||
segment = template.segments[self.active_session.index]
|
||||
|
||||
if isinstance(segment, VariableSegment):
|
||||
# Check if the source node for this variable is skipped
|
||||
# Only check for actual nodes, not special selectors (sys, env, conversation)
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
if source_selector_prefix in self.graph.nodes:
|
||||
source_node = self.graph.nodes[source_selector_prefix]
|
||||
|
||||
if source_node.state == NodeState.SKIPPED:
|
||||
# Skip this variable segment if the source node is skipped
|
||||
self.active_session.index += 1
|
||||
continue
|
||||
|
||||
segment_events, is_complete = self._process_variable_segment(segment)
|
||||
events.extend(segment_events)
|
||||
|
||||
# Only advance index if this variable segment is complete
|
||||
if is_complete:
|
||||
self.active_session.index += 1
|
||||
else:
|
||||
# Wait for more data
|
||||
break
|
||||
|
||||
elif isinstance(segment, TextSegment):
|
||||
segment_events = self._process_text_segment(segment)
|
||||
events.extend(segment_events)
|
||||
self.active_session.index += 1
|
||||
|
||||
if self.active_session.is_complete():
|
||||
# End current session and get events from starting next session
|
||||
next_session_events = self.end_session(response_node_id)
|
||||
events.extend(next_session_events)
|
||||
|
||||
return events
|
||||
|
||||
def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
End the active session for a response node.
|
||||
Automatically starts the next waiting session if available.
|
||||
|
||||
Args:
|
||||
node_id: ID of the response node ending its session
|
||||
|
||||
Returns:
|
||||
List of events from starting the next session
|
||||
"""
|
||||
with self.lock:
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
if self.active_session and self.active_session.node_id == node_id:
|
||||
self.active_session = None
|
||||
|
||||
# Try to start next waiting session
|
||||
if self.waiting_sessions:
|
||||
next_session = self.waiting_sessions.popleft()
|
||||
self.active_session = next_session
|
||||
|
||||
# Immediately try to flush any available segments
|
||||
events = self.try_flush()
|
||||
|
||||
return events
|
||||
35
api/core/workflow/graph_engine/response_coordinator/path.py
Normal file
35
api/core/workflow/graph_engine/response_coordinator/path.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
Internal path representation for response coordinator.
|
||||
|
||||
This module contains the private Path class used internally by ResponseStreamCoordinator
|
||||
to track execution paths to response nodes.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeAlias
|
||||
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Path:
|
||||
"""
|
||||
Represents a path of branch edges that must be taken to reach a response node.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
edges: list[EdgeID] = field(default_factory=list)
|
||||
|
||||
def contains_edge(self, edge_id: EdgeID) -> bool:
|
||||
"""Check if this path contains the given edge."""
|
||||
return edge_id in self.edges
|
||||
|
||||
def remove_edge(self, edge_id: EdgeID) -> None:
|
||||
"""Remove the given edge from this path in place."""
|
||||
if self.contains_edge(edge_id):
|
||||
self.edges.remove(edge_id)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the path has no edges (node is reachable)."""
|
||||
return len(self.edges) == 0
|
||||
@ -0,0 +1,51 @@
|
||||
"""
|
||||
Internal response session management for response coordinator.
|
||||
|
||||
This module contains the private ResponseSession class used internally
|
||||
by ResponseStreamCoordinator to manage streaming sessions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseSession:
|
||||
"""
|
||||
Represents an active response streaming session.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
template: Template # Template object from the response node
|
||||
index: int = 0 # Current position in the template segments
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: Node) -> "ResponseSession":
|
||||
"""
|
||||
Create a ResponseSession from an AnswerNode or EndNode.
|
||||
|
||||
Args:
|
||||
node: Must be either an AnswerNode or EndNode instance
|
||||
|
||||
Returns:
|
||||
ResponseSession configured with the node's streaming template
|
||||
|
||||
Raises:
|
||||
TypeError: If node is not an AnswerNode or EndNode
|
||||
"""
|
||||
if not isinstance(node, AnswerNode | EndNode):
|
||||
raise TypeError
|
||||
return cls(
|
||||
node_id=node.id,
|
||||
template=node.get_streaming_template(),
|
||||
)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if all segments in the template have been processed."""
|
||||
return self.index >= len(self.template.segments)
|
||||
16
api/core/workflow/graph_engine/state_management/__init__.py
Normal file
16
api/core/workflow/graph_engine/state_management/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
State management subsystem for graph engine.
|
||||
|
||||
This package manages node states, edge states, and execution tracking
|
||||
during workflow graph execution.
|
||||
"""
|
||||
|
||||
from .edge_state_manager import EdgeStateManager
|
||||
from .execution_tracker import ExecutionTracker
|
||||
from .node_state_manager import NodeStateManager
|
||||
|
||||
__all__ = [
|
||||
"EdgeStateManager",
|
||||
"ExecutionTracker",
|
||||
"NodeStateManager",
|
||||
]
|
||||
@ -0,0 +1,112 @@
|
||||
"""
|
||||
Manager for edge states during graph execution.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import TypedDict
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
|
||||
class EdgeStateAnalysis(TypedDict):
|
||||
"""Analysis result for edge states."""
|
||||
|
||||
has_unknown: bool
|
||||
has_taken: bool
|
||||
all_skipped: bool
|
||||
|
||||
|
||||
class EdgeStateManager:
|
||||
"""
|
||||
Manages edge states and transitions during graph execution.
|
||||
|
||||
This handles edge state changes and provides analysis of edge
|
||||
states for decision making during execution.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph) -> None:
|
||||
"""
|
||||
Initialize the edge state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
"""
|
||||
self.graph = graph
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def mark_edge_taken(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as TAKEN.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.edges[edge_id].state = NodeState.TAKEN
|
||||
|
||||
def mark_edge_skipped(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as SKIPPED.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
|
||||
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
||||
"""
|
||||
Analyze the states of edges and return summary flags.
|
||||
|
||||
Args:
|
||||
edges: List of edges to analyze
|
||||
|
||||
Returns:
|
||||
Analysis result with state flags
|
||||
"""
|
||||
with self._lock:
|
||||
states = {edge.state for edge in edges}
|
||||
|
||||
return EdgeStateAnalysis(
|
||||
has_unknown=NodeState.UNKNOWN in states,
|
||||
has_taken=NodeState.TAKEN in states,
|
||||
all_skipped=states == {NodeState.SKIPPED} if states else True,
|
||||
)
|
||||
|
||||
def get_edge_state(self, edge_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of an edge.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge
|
||||
|
||||
Returns:
|
||||
The current edge state
|
||||
"""
|
||||
with self._lock:
|
||||
return self.graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
A tuple of (selected_edges, unselected_edges)
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
selected_edges = []
|
||||
unselected_edges = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
selected_edges.append(edge)
|
||||
else:
|
||||
unselected_edges.append(edge)
|
||||
|
||||
return selected_edges, unselected_edges
|
||||
@ -0,0 +1,87 @@
|
||||
"""
|
||||
Tracker for currently executing nodes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
||||
|
||||
class ExecutionTracker:
|
||||
"""
|
||||
Tracks nodes that are currently being executed.
|
||||
|
||||
This replaces the ExecutingNodesManager with a cleaner interface
|
||||
focused on tracking which nodes are in progress.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the execution tracker."""
|
||||
self._executing_nodes: set[str] = set()
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def add(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node starting execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.add(node_id)
|
||||
|
||||
def remove(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as no longer executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node finishing execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.discard(node_id)
|
||||
|
||||
def is_executing(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is currently executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is executing
|
||||
"""
|
||||
with self._lock:
|
||||
return node_id in self._executing_nodes
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if no nodes are currently executing.
|
||||
|
||||
Returns:
|
||||
True if no nodes are executing
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes) == 0
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Get the count of currently executing nodes.
|
||||
|
||||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
def get_executing_nodes(self) -> set[str]:
|
||||
"""
|
||||
Get a copy of the set of executing node IDs.
|
||||
|
||||
Returns:
|
||||
Set of node IDs currently executing
|
||||
"""
|
||||
with self._lock:
|
||||
return self._executing_nodes.copy()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all executing nodes."""
|
||||
with self._lock:
|
||||
self._executing_nodes.clear()
|
||||
@ -0,0 +1,95 @@
|
||||
"""
|
||||
Manager for node states during graph execution.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
class NodeStateManager:
|
||||
"""
|
||||
Manages node states and the ready queue for execution.
|
||||
|
||||
This centralizes node state transitions and enqueueing logic,
|
||||
ensuring thread-safe operations on node states.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
|
||||
"""
|
||||
Initialize the node state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
ready_queue: Queue for nodes ready to execute
|
||||
"""
|
||||
self.graph = graph
|
||||
self.ready_queue = ready_queue
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def enqueue_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as TAKEN and add it to the ready queue.
|
||||
|
||||
This combines the state transition and enqueueing operations
|
||||
that always occur together when preparing a node for execution.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to enqueue
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self.ready_queue.put(node_id)
|
||||
|
||||
def mark_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as SKIPPED.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when all its incoming edges from taken branches
|
||||
have been satisfied.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
with self._lock:
|
||||
# Get all incoming edges to this node
|
||||
incoming_edges = self.graph.get_incoming_edges(node_id)
|
||||
|
||||
# If no incoming edges, node is always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# If any edge is UNKNOWN, node is not ready
|
||||
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
|
||||
return False
|
||||
|
||||
# Node is ready if at least one edge is TAKEN
|
||||
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
|
||||
|
||||
def get_node_state(self, node_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The current node state
|
||||
"""
|
||||
with self._lock:
|
||||
return self.graph.nodes[node_id].state
|
||||
135
api/core/workflow/graph_engine/worker.py
Normal file
135
api/core/workflow/graph_engine/worker.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""
|
||||
Worker - Thread implementation for queue-based node execution
|
||||
|
||||
Workers pull node IDs from the ready_queue, execute nodes, and push events
|
||||
to the event_queue for the dispatcher to process.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
|
||||
class Worker(threading.Thread):
|
||||
"""
|
||||
Worker thread that executes nodes from the ready queue.
|
||||
|
||||
Workers continuously pull node IDs from the ready_queue, execute the
|
||||
corresponding nodes, and push the resulting events to the event_queue
|
||||
for the dispatcher to process.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
worker_id: int = 0,
|
||||
flask_app: Optional[Flask] = None,
|
||||
context_vars: Optional[contextvars.Context] = None,
|
||||
on_idle_callback: Optional[Callable[[int], None]] = None,
|
||||
on_active_callback: Optional[Callable[[int], None]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
|
||||
Args:
|
||||
ready_queue: Queue containing node IDs ready for execution
|
||||
event_queue: Queue for pushing execution events
|
||||
graph: Graph containing nodes to execute
|
||||
worker_id: Unique identifier for this worker
|
||||
flask_app: Optional Flask application for context preservation
|
||||
context_vars: Optional context variables to preserve in worker thread
|
||||
on_idle_callback: Optional callback when worker becomes idle
|
||||
on_active_callback: Optional callback when worker becomes active
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self.ready_queue = ready_queue
|
||||
self.event_queue = event_queue
|
||||
self.graph = graph
|
||||
self.worker_id = worker_id
|
||||
self.flask_app = flask_app
|
||||
self.context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self.on_idle_callback = on_idle_callback
|
||||
self.on_active_callback = on_active_callback
|
||||
self.last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Main worker loop.
|
||||
|
||||
Continuously pulls node IDs from ready_queue, executes them,
|
||||
and pushes events to event_queue until stopped.
|
||||
"""
|
||||
while not self._stop_event.is_set():
|
||||
# Try to get a node ID from the ready queue (with timeout)
|
||||
try:
|
||||
node_id = self.ready_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
# Notify that worker is idle
|
||||
if self.on_idle_callback:
|
||||
self.on_idle_callback(self.worker_id)
|
||||
continue
|
||||
|
||||
# Notify that worker is active
|
||||
if self.on_active_callback:
|
||||
self.on_active_callback(self.worker_id)
|
||||
|
||||
self.last_task_time = time.time()
|
||||
node = self.graph.nodes[node_id]
|
||||
try:
|
||||
self._execute_node(node)
|
||||
self.ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=str(uuid4()),
|
||||
node_id="unknown",
|
||||
node_type=NodeType.CODE,
|
||||
in_iteration_id=None,
|
||||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
)
|
||||
self.event_queue.put(error_event)
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
Execute a single node and handle its events.
|
||||
|
||||
Args:
|
||||
node: The node instance to execute
|
||||
"""
|
||||
# Execute the node with preserved context if Flask app is provided
|
||||
if self.flask_app and self.context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self.flask_app,
|
||||
context_vars=self.context_vars,
|
||||
):
|
||||
# Execute the node
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self.event_queue.put(event)
|
||||
else:
|
||||
# Execute without context preservation
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self.event_queue.put(event)
|
||||
81
api/core/workflow/graph_engine/worker_management/README.md
Normal file
81
api/core/workflow/graph_engine/worker_management/README.md
Normal file
@ -0,0 +1,81 @@
|
||||
# Worker Management
|
||||
|
||||
Dynamic worker pool for node execution.
|
||||
|
||||
## Components
|
||||
|
||||
### WorkerPool
|
||||
|
||||
Manages worker thread lifecycle.
|
||||
|
||||
- `start/stop/wait()` - Control workers
|
||||
- `scale_up/down()` - Adjust pool size
|
||||
- `get_worker_count()` - Current count
|
||||
|
||||
### WorkerFactory
|
||||
|
||||
Creates workers with Flask context.
|
||||
|
||||
- `create_worker()` - Build with dependencies
|
||||
- Preserves request context
|
||||
|
||||
### DynamicScaler
|
||||
|
||||
Determines scaling decisions.
|
||||
|
||||
- `min/max_workers` - Pool bounds
|
||||
- `scale_up_threshold` - Queue trigger
|
||||
- `should_scale_up/down()` - Check conditions
|
||||
|
||||
### ActivityTracker
|
||||
|
||||
Tracks worker activity.
|
||||
|
||||
- `track_activity(worker_id)` - Record activity
|
||||
- `get_idle_workers(threshold)` - Find idle
|
||||
- `get_active_count()` - Active count
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
scaler = DynamicScaler(
|
||||
min_workers=2,
|
||||
max_workers=10,
|
||||
scale_up_threshold=5
|
||||
)
|
||||
|
||||
pool = WorkerPool(
|
||||
ready_queue=ready_queue,
|
||||
worker_factory=factory,
|
||||
dynamic_scaler=scaler
|
||||
)
|
||||
|
||||
pool.start()
|
||||
|
||||
# Scale based on load
|
||||
if scaler.should_scale_up(queue_size, active):
|
||||
pool.scale_up()
|
||||
|
||||
pool.stop()
|
||||
```
|
||||
|
||||
## Scaling Strategy
|
||||
|
||||
**Scale Up**: Queue size > threshold AND workers < max
|
||||
**Scale Down**: Idle workers exist AND workers > min
|
||||
|
||||
## Parameters
|
||||
|
||||
- `min_workers` - Minimum pool size
|
||||
- `max_workers` - Maximum pool size
|
||||
- `scale_up_threshold` - Queue trigger
|
||||
- `scale_down_threshold` - Idle seconds
|
||||
|
||||
## Flask Context
|
||||
|
||||
WorkerFactory preserves request context across threads:
|
||||
|
||||
```python
|
||||
context_vars = {"request_id": request.id}
|
||||
# Workers receive same context
|
||||
```
|
||||
18
api/core/workflow/graph_engine/worker_management/__init__.py
Normal file
18
api/core/workflow/graph_engine/worker_management/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
"""
|
||||
Worker management subsystem for graph engine.
|
||||
|
||||
This package manages the worker pool, including creation,
|
||||
scaling, and activity tracking.
|
||||
"""
|
||||
|
||||
from .activity_tracker import ActivityTracker
|
||||
from .dynamic_scaler import DynamicScaler
|
||||
from .worker_factory import WorkerFactory
|
||||
from .worker_pool import WorkerPool
|
||||
|
||||
__all__ = [
|
||||
"ActivityTracker",
|
||||
"DynamicScaler",
|
||||
"WorkerFactory",
|
||||
"WorkerPool",
|
||||
]
|
||||
@ -0,0 +1,74 @@
|
||||
"""
|
||||
Activity tracker for monitoring worker activity.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class ActivityTracker:
|
||||
"""
|
||||
Tracks worker activity for scaling decisions.
|
||||
|
||||
This monitors which workers are active or idle to support
|
||||
dynamic scaling decisions.
|
||||
"""
|
||||
|
||||
def __init__(self, idle_threshold: float = 30.0) -> None:
|
||||
"""
|
||||
Initialize the activity tracker.
|
||||
|
||||
Args:
|
||||
idle_threshold: Seconds before a worker is considered idle
|
||||
"""
|
||||
self.idle_threshold = idle_threshold
|
||||
self._worker_activity: dict[int, tuple[bool, float]] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def track_activity(self, worker_id: int, is_active: bool) -> None:
|
||||
"""
|
||||
Track worker activity state.
|
||||
|
||||
Args:
|
||||
worker_id: ID of the worker
|
||||
is_active: Whether the worker is active
|
||||
"""
|
||||
with self._lock:
|
||||
self._worker_activity[worker_id] = (is_active, time.time())
|
||||
|
||||
def get_idle_workers(self) -> list[int]:
|
||||
"""
|
||||
Get list of workers that have been idle too long.
|
||||
|
||||
Returns:
|
||||
List of idle worker IDs
|
||||
"""
|
||||
current_time = time.time()
|
||||
idle_workers = []
|
||||
|
||||
with self._lock:
|
||||
for worker_id, (is_active, last_change) in self._worker_activity.items():
|
||||
if not is_active and (current_time - last_change) > self.idle_threshold:
|
||||
idle_workers.append(worker_id)
|
||||
|
||||
return idle_workers
|
||||
|
||||
def remove_worker(self, worker_id: int) -> None:
|
||||
"""
|
||||
Remove a worker from tracking.
|
||||
|
||||
Args:
|
||||
worker_id: ID of the worker to remove
|
||||
"""
|
||||
with self._lock:
|
||||
self._worker_activity.pop(worker_id, None)
|
||||
|
||||
def get_active_count(self) -> int:
|
||||
"""
|
||||
Get count of currently active workers.
|
||||
|
||||
Returns:
|
||||
Number of active workers
|
||||
"""
|
||||
with self._lock:
|
||||
return sum(1 for is_active, _ in self._worker_activity.values() if is_active)
|
||||
@ -0,0 +1,98 @@
|
||||
"""
|
||||
Dynamic scaler for worker pool sizing.
|
||||
"""
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
class DynamicScaler:
|
||||
"""
|
||||
Manages dynamic scaling decisions for the worker pool.
|
||||
|
||||
This encapsulates the logic for when to scale up or down
|
||||
based on workload and configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_workers: int = 2,
|
||||
max_workers: int = 10,
|
||||
scale_up_threshold: int = 5,
|
||||
scale_down_idle_time: float = 30.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dynamic scaler.
|
||||
|
||||
Args:
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
scale_down_idle_time: Idle time before scaling down
|
||||
"""
|
||||
self.min_workers = min_workers
|
||||
self.max_workers = max_workers
|
||||
self.scale_up_threshold = scale_up_threshold
|
||||
self.scale_down_idle_time = scale_down_idle_time
|
||||
|
||||
def calculate_initial_workers(self, graph: Graph) -> int:
|
||||
"""
|
||||
Calculate initial worker count based on graph complexity.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
|
||||
Returns:
|
||||
Initial number of workers to create
|
||||
"""
|
||||
node_count = len(graph.nodes)
|
||||
|
||||
# Simple heuristic: more nodes = more workers
|
||||
if node_count < 10:
|
||||
initial = self.min_workers
|
||||
elif node_count < 50:
|
||||
initial = min(4, self.max_workers)
|
||||
elif node_count < 100:
|
||||
initial = min(6, self.max_workers)
|
||||
else:
|
||||
initial = min(8, self.max_workers)
|
||||
|
||||
return max(self.min_workers, initial)
|
||||
|
||||
def should_scale_up(self, current_workers: int, queue_depth: int, executing_count: int) -> bool:
|
||||
"""
|
||||
Determine if scaling up is needed.
|
||||
|
||||
Args:
|
||||
current_workers: Current number of workers
|
||||
queue_depth: Number of nodes waiting
|
||||
executing_count: Number of nodes executing
|
||||
|
||||
Returns:
|
||||
True if should scale up
|
||||
"""
|
||||
if current_workers >= self.max_workers:
|
||||
return False
|
||||
|
||||
# Scale up if queue is deep and workers are busy
|
||||
if queue_depth > self.scale_up_threshold:
|
||||
if executing_count >= current_workers * 0.8:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_scale_down(self, current_workers: int, idle_workers: list[int]) -> bool:
|
||||
"""
|
||||
Determine if scaling down is appropriate.
|
||||
|
||||
Args:
|
||||
current_workers: Current number of workers
|
||||
idle_workers: List of idle worker IDs
|
||||
|
||||
Returns:
|
||||
True if should scale down
|
||||
"""
|
||||
if current_workers <= self.min_workers:
|
||||
return False
|
||||
|
||||
# Scale down if we have idle workers
|
||||
return len(idle_workers) > 0
|
||||
@ -0,0 +1,74 @@
|
||||
"""
|
||||
Factory for creating worker instances.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
from ..worker import Worker
|
||||
|
||||
|
||||
class WorkerFactory:
|
||||
"""
|
||||
Factory for creating worker instances with proper context.
|
||||
|
||||
This encapsulates worker creation logic and ensures all workers
|
||||
are created with the necessary Flask and context variable setup.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flask_app: Optional[Flask],
|
||||
context_vars: contextvars.Context,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the worker factory.
|
||||
|
||||
Args:
|
||||
flask_app: Flask application context
|
||||
context_vars: Context variables to propagate
|
||||
"""
|
||||
self.flask_app = flask_app
|
||||
self.context_vars = context_vars
|
||||
self._next_worker_id = 0
|
||||
|
||||
def create_worker(
|
||||
self,
|
||||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue,
|
||||
graph: Graph,
|
||||
on_idle_callback: Optional[Callable[[int], None]] = None,
|
||||
on_active_callback: Optional[Callable[[int], None]] = None,
|
||||
) -> Worker:
|
||||
"""
|
||||
Create a new worker instance.
|
||||
|
||||
Args:
|
||||
ready_queue: Queue of nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
on_idle_callback: Callback when worker becomes idle
|
||||
on_active_callback: Callback when worker becomes active
|
||||
|
||||
Returns:
|
||||
Configured worker instance
|
||||
"""
|
||||
worker_id = self._next_worker_id
|
||||
self._next_worker_id += 1
|
||||
|
||||
return Worker(
|
||||
ready_queue=ready_queue,
|
||||
event_queue=event_queue,
|
||||
graph=graph,
|
||||
worker_id=worker_id,
|
||||
flask_app=self.flask_app,
|
||||
context_vars=self.context_vars,
|
||||
on_idle_callback=on_idle_callback,
|
||||
on_active_callback=on_active_callback,
|
||||
)
|
||||
145
api/core/workflow/graph_engine/worker_management/worker_pool.py
Normal file
145
api/core/workflow/graph_engine/worker_management/worker_pool.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""
|
||||
Worker pool management.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
from ..worker import Worker
|
||||
from .activity_tracker import ActivityTracker
|
||||
from .dynamic_scaler import DynamicScaler
|
||||
from .worker_factory import WorkerFactory
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
"""
|
||||
Manages a pool of worker threads for executing nodes.
|
||||
|
||||
This provides dynamic scaling, activity tracking, and lifecycle
|
||||
management for worker threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue,
|
||||
graph: Graph,
|
||||
worker_factory: WorkerFactory,
|
||||
dynamic_scaler: DynamicScaler,
|
||||
activity_tracker: ActivityTracker,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the worker pool.
|
||||
|
||||
Args:
|
||||
ready_queue: Queue of nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
worker_factory: Factory for creating workers
|
||||
dynamic_scaler: Scaler for dynamic sizing
|
||||
activity_tracker: Tracker for worker activity
|
||||
"""
|
||||
self.ready_queue = ready_queue
|
||||
self.event_queue = event_queue
|
||||
self.graph = graph
|
||||
self.worker_factory = worker_factory
|
||||
self.dynamic_scaler = dynamic_scaler
|
||||
self.activity_tracker = activity_tracker
|
||||
|
||||
self.workers: list[Worker] = []
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
|
||||
def start(self, initial_count: int) -> None:
|
||||
"""
|
||||
Start the worker pool with initial workers.
|
||||
|
||||
Args:
|
||||
initial_count: Number of workers to start with
|
||||
"""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Create initial workers
|
||||
for _ in range(initial_count):
|
||||
worker = self.worker_factory.create_worker(self.ready_queue, self.event_queue, self.graph)
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop all workers in the pool."""
|
||||
with self._lock:
|
||||
self._running = False
|
||||
|
||||
# Stop all workers
|
||||
for worker in self.workers:
|
||||
worker.stop()
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self.workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
|
||||
self.workers.clear()
|
||||
|
||||
def scale_up(self) -> None:
|
||||
"""Add a worker to the pool if allowed."""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
if len(self.workers) >= self.dynamic_scaler.max_workers:
|
||||
return
|
||||
|
||||
worker = self.worker_factory.create_worker(self.ready_queue, self.event_queue, self.graph)
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
|
||||
def scale_down(self, worker_ids: list[int]) -> None:
|
||||
"""
|
||||
Remove specific workers from the pool.
|
||||
|
||||
Args:
|
||||
worker_ids: IDs of workers to remove
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
if len(self.workers) <= self.dynamic_scaler.min_workers:
|
||||
return
|
||||
|
||||
workers_to_remove = [w for w in self.workers if w.worker_id in worker_ids]
|
||||
|
||||
for worker in workers_to_remove:
|
||||
worker.stop()
|
||||
self.workers.remove(worker)
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=1.0)
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
"""Get current number of workers."""
|
||||
with self._lock:
|
||||
return len(self.workers)
|
||||
|
||||
def check_scaling(self, queue_depth: int, executing_count: int) -> None:
|
||||
"""
|
||||
Check and perform scaling if needed.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
executing_count: Number of executing nodes
|
||||
"""
|
||||
current_count = self.get_worker_count()
|
||||
|
||||
if self.dynamic_scaler.should_scale_up(current_count, queue_depth, executing_count):
|
||||
self.scale_up()
|
||||
|
||||
idle_workers = self.activity_tracker.get_idle_workers()
|
||||
if idle_workers:
|
||||
self.scale_down(idle_workers)
|
||||
Reference in New Issue
Block a user