mirror of
https://github.com/langgenius/dify.git
synced 2026-02-26 04:27:41 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
132
api/core/workflow/README.md
Normal file
132
api/core/workflow/README.md
Normal file
@ -0,0 +1,132 @@
|
||||
# Workflow
|
||||
|
||||
## Project Overview
|
||||
|
||||
This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
The graph engine follows a layered architecture with strict dependency rules:
|
||||
|
||||
1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution
|
||||
|
||||
- **Manager** - External control interface for stop/pause/resume commands
|
||||
- **Worker** - Node execution runtime
|
||||
- **Command Processing** - Handles control commands (abort, pause, resume)
|
||||
- **Event Management** - Event propagation and layer notifications
|
||||
- **Graph Traversal** - Edge processing and skip propagation
|
||||
- **Response Coordinator** - Path tracking and session management
|
||||
- **Layers** - Pluggable middleware (debug logging, execution limits)
|
||||
- **Command Channels** - Communication channels (InMemory, Redis)
|
||||
|
||||
1. **Graph** (`graph/`) - Graph structure and runtime state
|
||||
|
||||
- **Graph Template** - Workflow definition
|
||||
- **Edge** - Node connections with conditions
|
||||
- **Runtime State Protocol** - State management interface
|
||||
|
||||
1. **Nodes** (`nodes/`) - Node implementations
|
||||
|
||||
- **Base** - Abstract node classes and variable parsing
|
||||
- **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc.
|
||||
|
||||
1. **Events** (`node_events/`) - Event system
|
||||
|
||||
- **Base** - Event protocols
|
||||
- **Node Events** - Node lifecycle events
|
||||
|
||||
1. **Entities** (`entities/`) - Domain models
|
||||
|
||||
- **Variable Pool** - Variable storage
|
||||
- **Graph Init Params** - Initialization configuration
|
||||
|
||||
## Key Design Patterns
|
||||
|
||||
### Command Channel Pattern
|
||||
|
||||
External workflow control via Redis or in-memory channels:
|
||||
|
||||
```python
|
||||
# Send stop command to running workflow
|
||||
channel = RedisChannel(redis_client, f"workflow:{task_id}:commands")
|
||||
channel.send_command(AbortCommand(reason="User requested"))
|
||||
```
|
||||
|
||||
### Layer System
|
||||
|
||||
Extensible middleware for cross-cutting concerns:
|
||||
|
||||
```python
|
||||
engine = GraphEngine(graph)
|
||||
engine.add_layer(DebugLoggingLayer(level="INFO"))
|
||||
engine.add_layer(ExecutionLimitsLayer(max_nodes=100))
|
||||
```
|
||||
|
||||
### Event-Driven Architecture
|
||||
|
||||
All node executions emit events for monitoring and integration:
|
||||
|
||||
- `NodeRunStartedEvent` - Node execution begins
|
||||
- `NodeRunSucceededEvent` - Node completes successfully
|
||||
- `NodeRunFailedEvent` - Node encounters error
|
||||
- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle
|
||||
|
||||
### Variable Pool
|
||||
|
||||
Centralized variable storage with namespace isolation:
|
||||
|
||||
```python
|
||||
# Variables scoped by node_id
|
||||
pool.add(["node1", "output"], value)
|
||||
result = pool.get(["node1", "output"])
|
||||
```
|
||||
|
||||
## Import Architecture Rules
|
||||
|
||||
The codebase enforces strict layering via import-linter:
|
||||
|
||||
1. **Workflow Layers** (top to bottom):
|
||||
|
||||
- graph_engine → graph_events → graph → nodes → node_events → entities
|
||||
|
||||
1. **Graph Engine Internal Layers**:
|
||||
|
||||
- orchestration → command_processing → event_management → graph_traversal → domain
|
||||
|
||||
1. **Domain Isolation**:
|
||||
|
||||
- Domain models cannot import from infrastructure layers
|
||||
|
||||
1. **Command Channel Independence**:
|
||||
|
||||
- InMemory and Redis channels must remain independent
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Adding a New Node Type
|
||||
|
||||
1. Create node class in `nodes/<node_type>/`
|
||||
1. Inherit from `BaseNode` or appropriate base class
|
||||
1. Implement `_run()` method
|
||||
1. Register in `nodes/node_mapping.py`
|
||||
1. Add tests in `tests/unit_tests/core/workflow/nodes/`
|
||||
|
||||
### Implementing a Custom Layer
|
||||
|
||||
1. Create class inheriting from `Layer` base
|
||||
1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()`
|
||||
1. Add to engine via `engine.add_layer()`
|
||||
|
||||
### Debugging Workflow Execution
|
||||
|
||||
Enable debug logging layer:
|
||||
|
||||
```python
|
||||
debug_layer = DebugLoggingLayer(
|
||||
level="DEBUG",
|
||||
include_inputs=True,
|
||||
include_outputs=True
|
||||
)
|
||||
```
|
||||
@ -1,7 +0,0 @@
|
||||
from .base_workflow_callback import WorkflowCallback
|
||||
from .workflow_logging_callback import WorkflowLoggingCallback
|
||||
|
||||
__all__ = [
|
||||
"WorkflowCallback",
|
||||
"WorkflowLoggingCallback",
|
||||
]
|
||||
@ -1,12 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
||||
|
||||
|
||||
class WorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Published event
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -1,263 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .base_workflow_callback import WorkflowCallback
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
def __init__(self) -> None:
|
||||
self.current_node_id: Optional[str] = None
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color="green")
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(event=event)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.on_workflow_node_execute_succeeded(event=event)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.on_workflow_node_execute_failed(event=event)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.on_node_text_chunk(event=event)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self.on_workflow_parallel_started(event=event)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
self.on_workflow_parallel_completed(event=event)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self.on_workflow_iteration_started(event=event)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self.on_workflow_iteration_next(event=event)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(event=event)
|
||||
elif isinstance(event, LoopRunStartedEvent):
|
||||
self.on_workflow_loop_started(event=event)
|
||||
elif isinstance(event, LoopRunNextEvent):
|
||||
self.on_workflow_loop_next(event=event)
|
||||
elif isinstance(event, LoopRunSucceededEvent | LoopRunFailedEvent):
|
||||
self.on_workflow_loop_completed(event=event)
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
|
||||
|
||||
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="yellow")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="yellow")
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color="green")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="green")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="green")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="green")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: "
|
||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
|
||||
color="green",
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
self.print_text("\n[NodeRunFailedEvent]", color="red")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="red")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="red")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="red")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Error: {node_run_result.error}", color="red")
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: "
|
||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||
self.current_node_id = route_node_state.node_id
|
||||
self.print_text("\n[NodeRunStreamChunkEvent]")
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||
|
||||
node_run_result = route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
|
||||
)
|
||||
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
|
||||
if event.in_loop_id:
|
||||
self.print_text(f"Loop ID: {event.in_loop_id}", color="blue")
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
color = "blue"
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
color = "red"
|
||||
|
||||
self.print_text(
|
||||
"\n[ParallelBranchRunSucceededEvent]"
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent)
|
||||
else "\n[ParallelBranchRunFailedEvent]",
|
||||
color=color,
|
||||
)
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
||||
if event.in_loop_id:
|
||||
self.print_text(f"Loop ID: {event.in_loop_id}", color=color)
|
||||
|
||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[IterationRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[IterationRunNextEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
self.print_text(f"Iteration Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text(
|
||||
"\n[IterationRunSucceededEvent]"
|
||||
if isinstance(event, IterationRunSucceededEvent)
|
||||
else "\n[IterationRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_started(self, event: LoopRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish loop started
|
||||
"""
|
||||
self.print_text("\n[LoopRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
|
||||
"""
|
||||
Publish loop next
|
||||
"""
|
||||
self.print_text("\n[LoopRunNextEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
self.print_text(f"Loop Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
|
||||
"""
|
||||
Publish loop completed
|
||||
"""
|
||||
self.print_text(
|
||||
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(f"{text_to_print}", end=end)
|
||||
|
||||
def _get_colored_text(self, text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
@ -1,3 +1,4 @@
|
||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||
|
||||
@ -20,7 +20,7 @@ class ConversationVariableUpdater(Protocol):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, conversation_id: str, variable: "Variable") -> None:
|
||||
def update(self, conversation_id: str, variable: "Variable"):
|
||||
"""
|
||||
Updates the value of the specified conversation variable in the underlying storage.
|
||||
|
||||
|
||||
@ -0,0 +1,18 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .run_condition import RunCondition
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"RunCondition",
|
||||
"VariablePool",
|
||||
"VariableValue",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
8
api/core/workflow/entities/agent.py
Normal file
8
api/core/workflow/entities/agent.py
Normal file
@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentNodeStrategyInit(BaseModel):
|
||||
"""Agent node strategy initialization data."""
|
||||
|
||||
name: str
|
||||
icon: str | None = None
|
||||
@ -3,19 +3,18 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
user_from: str = Field(
|
||||
..., description="user from, account or end-user"
|
||||
) # Should be UserFrom enum: 'account' | 'end-user'
|
||||
invoke_from: str = Field(
|
||||
..., description="invoke from, service-api, web-app, explore or debugger"
|
||||
) # Should be InvokeFrom enum: 'service-api' | 'web-app' | 'explore' | 'debugger'
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
160
api/core/workflow/entities/graph_runtime_state.py
Normal file
160
api/core/workflow/entities/graph_runtime_state.py
Normal file
@ -0,0 +1,160 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
# Private attributes to prevent direct modification
|
||||
_variable_pool: VariablePool = PrivateAttr()
|
||||
_start_at: float = PrivateAttr()
|
||||
_total_tokens: int = PrivateAttr(default=0)
|
||||
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
||||
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
|
||||
_node_run_steps: int = PrivateAttr(default=0)
|
||||
_ready_queue_json: str = PrivateAttr()
|
||||
_graph_execution_json: str = PrivateAttr()
|
||||
_response_coordinator_json: str = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
start_at: float,
|
||||
total_tokens: int = 0,
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, object] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_json: str = "",
|
||||
graph_execution_json: str = "",
|
||||
response_coordinator_json: str = "",
|
||||
**kwargs: object,
|
||||
):
|
||||
"""Initialize the GraphRuntimeState with validation."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Initialize private attributes with validation
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
self._start_at = start_at
|
||||
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
if llm_usage is None:
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
self._llm_usage = llm_usage
|
||||
|
||||
if outputs is None:
|
||||
outputs = {}
|
||||
self._outputs = deepcopy(outputs)
|
||||
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
self._ready_queue_json = ready_queue_json
|
||||
self._graph_execution_json = graph_execution_json
|
||||
self._response_coordinator_json = response_coordinator_json
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
"""Get the variable pool."""
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time."""
|
||||
return self._start_at
|
||||
|
||||
@start_at.setter
|
||||
def start_at(self, value: float) -> None:
|
||||
"""Set the start time."""
|
||||
self._start_at = value
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count."""
|
||||
return self._total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: int):
|
||||
"""Set the total tokens count."""
|
||||
if value < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = value
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get the LLM usage info."""
|
||||
# Return a copy to prevent external modification
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
@llm_usage.setter
|
||||
def llm_usage(self, value: LLMUsage):
|
||||
"""Set the LLM usage info."""
|
||||
self._llm_usage = value.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, object]:
|
||||
"""Get a copy of the outputs dictionary."""
|
||||
return deepcopy(self._outputs)
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, value: dict[str, object]) -> None:
|
||||
"""Set the outputs dictionary."""
|
||||
self._outputs = deepcopy(value)
|
||||
|
||||
def set_output(self, key: str, value: object) -> None:
|
||||
"""Set a single output value."""
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
"""Get a single output value."""
|
||||
return deepcopy(self._outputs.get(key, default))
|
||||
|
||||
def update_outputs(self, updates: dict[str, object]) -> None:
|
||||
"""Update multiple output values."""
|
||||
for key, value in updates.items():
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count."""
|
||||
return self._node_run_steps
|
||||
|
||||
@node_run_steps.setter
|
||||
def node_run_steps(self, value: int) -> None:
|
||||
"""Set the node run steps count."""
|
||||
if value < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = value
|
||||
|
||||
def increment_node_run_steps(self) -> None:
|
||||
"""Increment the node run steps by 1."""
|
||||
self._node_run_steps += 1
|
||||
|
||||
def add_tokens(self, tokens: int) -> None:
|
||||
"""Add tokens to the total count."""
|
||||
if tokens < 0:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
@property
|
||||
def ready_queue_json(self) -> str:
|
||||
"""Get a copy of the ready queue state."""
|
||||
return self._ready_queue_json
|
||||
|
||||
@property
|
||||
def graph_execution_json(self) -> str:
|
||||
"""Get a copy of the serialized graph execution state."""
|
||||
return self._graph_execution_json
|
||||
|
||||
@property
|
||||
def response_coordinator_json(self) -> str:
|
||||
"""Get a copy of the serialized response coordinator state."""
|
||||
return self._response_coordinator_json
|
||||
@ -1,34 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
"""
|
||||
Node Run Result.
|
||||
"""
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[Mapping[str, Any]] = None # process data
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
error_type: Optional[str] = None # error type if status is failed
|
||||
|
||||
# single step node run retry
|
||||
retry_index: int = 0
|
||||
|
||||
|
||||
class AgentNodeStrategyInit(BaseModel):
|
||||
name: str
|
||||
icon: str | None = None
|
||||
@ -1,5 +1,5 @@
|
||||
import hashlib
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -10,10 +10,10 @@ class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: Optional[str] = None
|
||||
branch_identify: str | None = None
|
||||
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||
|
||||
conditions: Optional[list[Condition]] = None
|
||||
conditions: list[Condition] | None = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
@property
|
||||
@ -1,12 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
Variable Selector.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_selector: Sequence[str]
|
||||
@ -9,12 +9,17 @@ from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, ObjectSegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from factories import variable_factory
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
@ -40,14 +45,18 @@ class VariablePool(BaseModel):
|
||||
)
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list,
|
||||
default_factory=list[VariableUnion],
|
||||
)
|
||||
conversation_variables: Sequence[VariableUnion] = Field(
|
||||
description="Conversation variables.",
|
||||
default_factory=list[VariableUnion],
|
||||
)
|
||||
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
|
||||
description="RAG pipeline variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /) -> None:
|
||||
def model_post_init(self, context: Any, /):
|
||||
# Create a mapping from field names to SystemVariableKey enum values
|
||||
self._add_system_variables(self.system_variables)
|
||||
# Add environment variables to the variable pool
|
||||
@ -56,8 +65,18 @@ class VariablePool(BaseModel):
|
||||
# Add conversation variables to the variable pool
|
||||
for var in self.conversation_variables:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add rag pipeline variables to the variable pool
|
||||
if self.rag_pipeline_variables:
|
||||
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
|
||||
for rag_var in self.rag_pipeline_variables:
|
||||
node_id = rag_var.variable.belong_to_node_id
|
||||
key = rag_var.variable.variable
|
||||
value = rag_var.value
|
||||
rag_pipeline_variables_map[node_id][key] = value
|
||||
for key, value in rag_pipeline_variables_map.items():
|
||||
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
def add(self, selector: Sequence[str], value: Any, /):
|
||||
"""
|
||||
Add a variable to the variable pool.
|
||||
|
||||
@ -161,11 +180,11 @@ class VariablePool(BaseModel):
|
||||
# Return result as Segment
|
||||
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
|
||||
|
||||
def _extract_value(self, obj: Any) -> Any:
|
||||
def _extract_value(self, obj: Any):
|
||||
"""Extract the actual value from an ObjectSegment."""
|
||||
return obj.value if isinstance(obj, ObjectSegment) else obj
|
||||
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str):
|
||||
"""Get a nested attribute from a dictionary-like object."""
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
@ -191,7 +210,7 @@ class VariablePool(BaseModel):
|
||||
|
||||
def convert_template(self, template: str, /):
|
||||
parts = VARIABLE_PATTERN.split(template)
|
||||
segments = []
|
||||
segments: list[Segment] = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and (variable := self.get(part.split("."))):
|
||||
segments.append(variable)
|
||||
|
||||
@ -7,31 +7,14 @@ implementation details like tenant_id, app_id, etc.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum for domain layer
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
STOPPED = "stopped"
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
|
||||
class WorkflowExecution(BaseModel):
|
||||
"""
|
||||
Domain model for workflow execution based on WorkflowRun but without
|
||||
@ -45,7 +28,7 @@ class WorkflowExecution(BaseModel):
|
||||
graph: Mapping[str, Any] = Field(...)
|
||||
|
||||
inputs: Mapping[str, Any] = Field(...)
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
|
||||
status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
|
||||
error_message: str = Field(default="")
|
||||
@ -54,7 +37,7 @@ class WorkflowExecution(BaseModel):
|
||||
exceptions_count: int = Field(default=0)
|
||||
|
||||
started_at: datetime = Field(...)
|
||||
finished_at: Optional[datetime] = None
|
||||
finished_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
|
||||
@ -8,50 +8,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
"""
|
||||
Node Execution Status Enum.
|
||||
"""
|
||||
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
RETRY = "retry"
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class WorkflowNodeExecution(BaseModel):
|
||||
@ -78,42 +39,95 @@ class WorkflowNodeExecution(BaseModel):
|
||||
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
|
||||
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
|
||||
# In most scenarios, `id` should be used as the primary identifier.
|
||||
node_execution_id: Optional[str] = None
|
||||
node_execution_id: str | None = None
|
||||
workflow_id: str # ID of the workflow this node belongs to
|
||||
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging)
|
||||
# --------- Core identification fields ends ---------
|
||||
|
||||
# Execution positioning and flow
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
predecessor_node_id: Optional[str] = None # ID of the node that executed before this one
|
||||
predecessor_node_id: str | None = None # ID of the node that executed before this one
|
||||
node_id: str # ID of the node being executed
|
||||
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
|
||||
title: str # Display title of the node
|
||||
|
||||
# Execution data
|
||||
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
|
||||
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
|
||||
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
|
||||
# The `inputs` and `outputs` fields hold the full content
|
||||
inputs: Mapping[str, Any] | None = None # Input variables used by this node
|
||||
process_data: Mapping[str, Any] | None = None # Intermediate processing data
|
||||
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
|
||||
|
||||
# Execution state
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
|
||||
error: Optional[str] = None # Error message if execution failed
|
||||
error: str | None = None # Error message if execution failed
|
||||
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
|
||||
|
||||
# Additional metadata
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.)
|
||||
|
||||
# Timing information
|
||||
created_at: datetime # When execution started
|
||||
finished_at: Optional[datetime] = None # When execution completed
|
||||
finished_at: datetime | None = None # When execution completed
|
||||
|
||||
_truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None)
|
||||
_truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None)
|
||||
_truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None)
|
||||
|
||||
def get_truncated_inputs(self) -> Mapping[str, Any] | None:
|
||||
return self._truncated_inputs
|
||||
|
||||
def get_truncated_outputs(self) -> Mapping[str, Any] | None:
|
||||
return self._truncated_outputs
|
||||
|
||||
def get_truncated_process_data(self) -> Mapping[str, Any] | None:
|
||||
return self._truncated_process_data
|
||||
|
||||
def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None):
|
||||
self._truncated_inputs = truncated_inputs
|
||||
|
||||
def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None):
|
||||
self._truncated_outputs = truncated_outputs
|
||||
|
||||
def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None):
|
||||
self._truncated_process_data = truncated_process_data
|
||||
|
||||
def get_response_inputs(self) -> Mapping[str, Any] | None:
|
||||
inputs = self.get_truncated_inputs()
|
||||
if inputs:
|
||||
return inputs
|
||||
return self.inputs
|
||||
|
||||
@property
|
||||
def inputs_truncated(self):
|
||||
return self._truncated_inputs is not None
|
||||
|
||||
@property
|
||||
def outputs_truncated(self):
|
||||
return self._truncated_outputs is not None
|
||||
|
||||
@property
|
||||
def process_data_truncated(self):
|
||||
return self._truncated_process_data is not None
|
||||
|
||||
def get_response_outputs(self) -> Mapping[str, Any] | None:
|
||||
outputs = self.get_truncated_outputs()
|
||||
if outputs is not None:
|
||||
return outputs
|
||||
return self.outputs
|
||||
|
||||
def get_response_process_data(self) -> Mapping[str, Any] | None:
|
||||
process_data = self.get_truncated_process_data()
|
||||
if process_data is not None:
|
||||
return process_data
|
||||
return self.process_data
|
||||
|
||||
def update_from_mapping(
|
||||
self,
|
||||
inputs: Optional[Mapping[str, Any]] = None,
|
||||
process_data: Optional[Mapping[str, Any]] = None,
|
||||
outputs: Optional[Mapping[str, Any]] = None,
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
|
||||
) -> None:
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Update the model from mappings.
|
||||
|
||||
|
||||
@ -1,4 +1,12 @@
|
||||
from enum import StrEnum
|
||||
from enum import Enum, StrEnum
|
||||
|
||||
|
||||
class NodeState(Enum):
|
||||
"""State of a node or edge during workflow execution."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
TAKEN = "taken"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class SystemVariableKey(StrEnum):
|
||||
@ -14,3 +22,120 @@ class SystemVariableKey(StrEnum):
|
||||
APP_ID = "app_id"
|
||||
WORKFLOW_ID = "workflow_id"
|
||||
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
||||
# RAG Pipeline
|
||||
DOCUMENT_ID = "document_id"
|
||||
ORIGINAL_DOCUMENT_ID = "original_document_id"
|
||||
BATCH = "batch"
|
||||
DATASET_ID = "dataset_id"
|
||||
DATASOURCE_TYPE = "datasource_type"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
INVOKE_FROM = "invoke_from"
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
KNOWLEDGE_INDEX = "knowledge-index"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TEMPLATE_TRANSFORM = "template-transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
DATASOURCE = "datasource"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
LOOP_START = "loop-start"
|
||||
LOOP_END = "loop-end"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
TRIGGER_WEBHOOK = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||
TRIGGER_PLUGIN = "trigger-plugin"
|
||||
|
||||
|
||||
class NodeExecutionType(StrEnum):
|
||||
"""Node execution type classification."""
|
||||
|
||||
EXECUTABLE = "executable" # Regular nodes that execute and produce outputs
|
||||
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
|
||||
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
|
||||
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
|
||||
ROOT = "root" # Nodes that can serve as execution entry points
|
||||
|
||||
|
||||
class ErrorStrategy(StrEnum):
|
||||
FAIL_BRANCH = "fail-branch"
|
||||
DEFAULT_VALUE = "default-value"
|
||||
|
||||
|
||||
class FailBranchSourceHandle(StrEnum):
|
||||
FAILED = "fail-branch"
|
||||
SUCCESS = "success-branch"
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum for domain layer
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
STOPPED = "stopped"
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
PENDING = "pending" # Node is scheduled but not yet executing
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
STOPPED = "stopped"
|
||||
PAUSED = "paused"
|
||||
|
||||
# Legacy statuses - kept for backward compatibility
|
||||
RETRY = "retry" # Legacy: replaced by retry mechanism in error handling
|
||||
|
||||
@ -1,8 +1,16 @@
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node: BaseNode, err_msg: str):
|
||||
def __init__(self, node: Node, err_msg: str):
|
||||
self._node = node
|
||||
self._error = err_msg
|
||||
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
||||
|
||||
@property
|
||||
def node(self) -> Node:
|
||||
return self._node
|
||||
|
||||
@property
|
||||
def error(self) -> str:
|
||||
return self._error
|
||||
|
||||
16
api/core/workflow/graph/__init__.py
Normal file
16
api/core/workflow/graph/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
from .edge import Edge
|
||||
from .graph import Graph, NodeFactory
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .graph_template import GraphTemplate
|
||||
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
|
||||
__all__ = [
|
||||
"Edge",
|
||||
"Graph",
|
||||
"GraphTemplate",
|
||||
"NodeFactory",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
"ReadOnlyVariablePool",
|
||||
"ReadOnlyVariablePoolWrapper",
|
||||
]
|
||||
15
api/core/workflow/graph/edge.py
Normal file
15
api/core/workflow/graph/edge.py
Normal file
@ -0,0 +1,15 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
|
||||
@dataclass
|
||||
class Edge:
|
||||
"""Edge connecting two nodes in a workflow graph."""
|
||||
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
tail: str = "" # tail node id (source)
|
||||
head: str = "" # head node id (target)
|
||||
source_handle: str = "source" # source handle for conditional branching
|
||||
state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state
|
||||
346
api/core/workflow/graph/graph.py
Normal file
346
api/core/workflow/graph/graph.py
Normal file
@ -0,0 +1,346 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .edge import Edge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeFactory(Protocol):
|
||||
"""
|
||||
Protocol for creating Node instances from node data dictionaries.
|
||||
|
||||
This protocol decouples the Graph class from specific node mapping implementations,
|
||||
allowing for different node creation strategies while maintaining type safety.
|
||||
"""
|
||||
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@final
|
||||
class Graph:
|
||||
"""Graph representation with nodes and edges for workflow execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
nodes: dict[str, Node] | None = None,
|
||||
edges: dict[str, Edge] | None = None,
|
||||
in_edges: dict[str, list[str]] | None = None,
|
||||
out_edges: dict[str, list[str]] | None = None,
|
||||
root_node: Node,
|
||||
):
|
||||
"""
|
||||
Initialize Graph instance.
|
||||
|
||||
:param nodes: graph nodes mapping (node id: node object)
|
||||
:param edges: graph edges mapping (edge id: edge object)
|
||||
:param in_edges: incoming edges mapping (node id: list of edge ids)
|
||||
:param out_edges: outgoing edges mapping (node id: list of edge ids)
|
||||
:param root_node: root node object
|
||||
"""
|
||||
self.nodes = nodes or {}
|
||||
self.edges = edges or {}
|
||||
self.in_edges = in_edges or {}
|
||||
self.out_edges = out_edges or {}
|
||||
self.root_node = root_node
|
||||
|
||||
@classmethod
|
||||
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
|
||||
"""
|
||||
Parse node configurations and build a mapping of node IDs to configs.
|
||||
|
||||
:param node_configs: list of node configuration dictionaries
|
||||
:return: mapping of node ID to node config
|
||||
"""
|
||||
node_configs_map: dict[str, dict[str, object]] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
continue
|
||||
|
||||
node_configs_map[node_id] = node_config
|
||||
|
||||
return node_configs_map
|
||||
|
||||
@classmethod
|
||||
def _find_root_node_id(
|
||||
cls,
|
||||
node_configs_map: Mapping[str, Mapping[str, object]],
|
||||
edge_configs: Sequence[Mapping[str, object]],
|
||||
root_node_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find the root node ID if not specified.
|
||||
|
||||
:param node_configs_map: mapping of node ID to node config
|
||||
:param edge_configs: list of edge configurations
|
||||
:param root_node_id: explicitly specified root node ID
|
||||
:return: determined root node ID
|
||||
"""
|
||||
if root_node_id:
|
||||
if root_node_id not in node_configs_map:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
return root_node_id
|
||||
|
||||
# Find nodes with no incoming edges
|
||||
nodes_with_incoming: set[str] = set()
|
||||
for edge_config in edge_configs:
|
||||
target = edge_config.get("target")
|
||||
if isinstance(target, str):
|
||||
nodes_with_incoming.add(target)
|
||||
|
||||
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
|
||||
|
||||
# Prefer START node if available
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid].get("data")
|
||||
if not is_str_dict(node_data):
|
||||
continue
|
||||
node_type = node_data.get("type")
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if node_type in [NodeType.START, NodeType.DATASOURCE]:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
|
||||
|
||||
if not root_node_id:
|
||||
raise ValueError("Unable to determine root node ID")
|
||||
|
||||
return root_node_id
|
||||
|
||||
@classmethod
|
||||
def _build_edges(
|
||||
cls, edge_configs: list[dict[str, object]]
|
||||
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
|
||||
"""
|
||||
Build edge objects and mappings from edge configurations.
|
||||
|
||||
:param edge_configs: list of edge configurations
|
||||
:return: tuple of (edges dict, in_edges dict, out_edges dict)
|
||||
"""
|
||||
edges: dict[str, Edge] = {}
|
||||
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
edge_counter = 0
|
||||
for edge_config in edge_configs:
|
||||
source = edge_config.get("source")
|
||||
target = edge_config.get("target")
|
||||
|
||||
if not is_str(source) or not is_str(target):
|
||||
continue
|
||||
|
||||
# Create edge
|
||||
edge_id = f"edge_{edge_counter}"
|
||||
edge_counter += 1
|
||||
|
||||
source_handle = edge_config.get("sourceHandle", "source")
|
||||
if not is_str(source_handle):
|
||||
continue
|
||||
|
||||
edge = Edge(
|
||||
id=edge_id,
|
||||
tail=source,
|
||||
head=target,
|
||||
source_handle=source_handle,
|
||||
)
|
||||
|
||||
edges[edge_id] = edge
|
||||
out_edges[source].append(edge_id)
|
||||
in_edges[target].append(edge_id)
|
||||
|
||||
return edges, dict(in_edges), dict(out_edges)
|
||||
|
||||
@classmethod
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, object]],
|
||||
node_factory: "NodeFactory",
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
Create node instances from configurations using the node factory.
|
||||
|
||||
:param node_configs_map: mapping of node ID to node config
|
||||
:param node_factory: factory for creating node instances
|
||||
:return: mapping of node ID to node instance
|
||||
"""
|
||||
nodes: dict[str, Node] = {}
|
||||
|
||||
for node_id, node_config in node_configs_map.items():
|
||||
try:
|
||||
node_instance = node_factory.create_node(node_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to create node instance for node_id %s", node_id)
|
||||
raise
|
||||
nodes[node_id] = node_instance
|
||||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
nodes: dict[str, Node],
|
||||
edges: dict[str, Edge],
|
||||
in_edges: dict[str, list[str]],
|
||||
out_edges: dict[str, list[str]],
|
||||
active_root_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Mark nodes and edges from inactive root branches as skipped.
|
||||
|
||||
Algorithm:
|
||||
1. Mark inactive root nodes as skipped
|
||||
2. For skipped nodes, mark all their outgoing edges as skipped
|
||||
3. For each edge marked as skipped, check its target node:
|
||||
- If ALL incoming edges are skipped, mark the node as skipped
|
||||
- Otherwise, leave the node state unchanged
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
:param edges: mapping of edge ID to edge instance
|
||||
:param in_edges: mapping of node ID to incoming edge IDs
|
||||
:param out_edges: mapping of node ID to outgoing edge IDs
|
||||
:param active_root_id: ID of the active root node
|
||||
"""
|
||||
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
|
||||
top_level_roots: list[str] = [
|
||||
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
|
||||
]
|
||||
|
||||
# If there's only one root or the active root is not a top-level root, no marking needed
|
||||
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
|
||||
return
|
||||
|
||||
# Mark inactive root nodes as skipped
|
||||
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
|
||||
for root_id in inactive_roots:
|
||||
if root_id in nodes:
|
||||
nodes[root_id].state = NodeState.SKIPPED
|
||||
|
||||
# Recursively mark downstream nodes and edges
|
||||
def mark_downstream(node_id: str) -> None:
|
||||
"""Recursively mark downstream nodes and edges as skipped."""
|
||||
if nodes[node_id].state != NodeState.SKIPPED:
|
||||
return
|
||||
# If this node is skipped, mark all its outgoing edges as skipped
|
||||
out_edge_ids = out_edges.get(node_id, [])
|
||||
for edge_id in out_edge_ids:
|
||||
edge = edges[edge_id]
|
||||
edge.state = NodeState.SKIPPED
|
||||
|
||||
# Check the target node of this edge
|
||||
target_node = nodes[edge.head]
|
||||
in_edge_ids = in_edges.get(target_node.id, [])
|
||||
in_edge_states = [edges[eid].state for eid in in_edge_ids]
|
||||
|
||||
# If all incoming edges are skipped, mark the node as skipped
|
||||
if all(state == NodeState.SKIPPED for state in in_edge_states):
|
||||
target_node.state = NodeState.SKIPPED
|
||||
# Recursively process downstream nodes
|
||||
mark_downstream(target_node.id)
|
||||
|
||||
# Process each inactive root and its downstream nodes
|
||||
for root_id in inactive_roots:
|
||||
mark_downstream(root_id)
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, object],
|
||||
node_factory: "NodeFactory",
|
||||
root_node_id: str | None = None,
|
||||
) -> "Graph":
|
||||
"""
|
||||
Initialize graph
|
||||
|
||||
:param graph_config: graph config containing nodes and edges
|
||||
:param node_factory: factory for creating node instances from config data
|
||||
:param root_node_id: root node id
|
||||
:return: graph instance
|
||||
"""
|
||||
# Parse configs
|
||||
edge_configs = graph_config.get("edges", [])
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
# Find root node
|
||||
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
|
||||
|
||||
# Build edges
|
||||
edges, in_edges, out_edges = cls._build_edges(edge_configs)
|
||||
|
||||
# Create node instances
|
||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||
|
||||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
# Mark inactive root branches as skipped
|
||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
return cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=in_edges,
|
||||
out_edges=out_edges,
|
||||
root_node=root_node,
|
||||
)
|
||||
|
||||
@property
|
||||
def node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get list of node IDs (compatibility property for existing code)
|
||||
|
||||
:return: list of node IDs
|
||||
"""
|
||||
return list(self.nodes.keys())
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
|
||||
"""
|
||||
Get all outgoing edges from a node (V2 method)
|
||||
|
||||
:param node_id: node id
|
||||
:return: list of outgoing edges
|
||||
"""
|
||||
edge_ids = self.out_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
||||
def get_incoming_edges(self, node_id: str) -> list[Edge]:
|
||||
"""
|
||||
Get all incoming edges to a node (V2 method)
|
||||
|
||||
:param node_id: node id
|
||||
:return: list of incoming edges
|
||||
"""
|
||||
edge_ids = self.in_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
61
api/core/workflow/graph/graph_runtime_state_protocol.py
Normal file
61
api/core/workflow/graph/graph_runtime_state_protocol.py
Normal file
@ -0,0 +1,61 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
|
||||
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
"""Read-only interface for VariablePool."""
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (read-only)."""
|
||||
...
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""
|
||||
Read-only view of GraphRuntimeState for layers.
|
||||
|
||||
This protocol defines a read-only interface that prevents layers from
|
||||
modifying the graph runtime state while still allowing observation.
|
||||
All methods return defensive copies to ensure immutability.
|
||||
"""
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
"""Get read-only access to the variable pool."""
|
||||
...
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
...
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
...
|
||||
20
api/core/workflow/graph/graph_template.py
Normal file
20
api/core/workflow/graph/graph_template.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GraphTemplate(BaseModel):
|
||||
"""
|
||||
Graph Template for container nodes and subgraph expansion
|
||||
|
||||
According to GraphEngine V2 spec, GraphTemplate contains:
|
||||
- nodes: mapping of node definitions
|
||||
- edges: mapping of edge definitions
|
||||
- root_ids: list of root node IDs
|
||||
- output_selectors: list of output selectors for the template
|
||||
"""
|
||||
|
||||
nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping")
|
||||
edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping")
|
||||
root_ids: list[str] = Field(default_factory=list, description="root node IDs")
|
||||
output_selectors: list[str] = Field(default_factory=list, description="output selectors")
|
||||
77
api/core/workflow/graph/read_only_state_wrapper.py
Normal file
77
api/core/workflow/graph/read_only_state_wrapper.py
Normal file
@ -0,0 +1,77 @@
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class ReadOnlyVariablePoolWrapper:
|
||||
"""Wrapper that provides read-only access to VariablePool."""
|
||||
|
||||
def __init__(self, variable_pool: VariablePool):
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (returns a defensive copy)."""
|
||||
value = self._variable_pool.get([node_id, variable_key])
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (returns defensive copies)."""
|
||||
variables: dict[str, object] = {}
|
||||
if node_id in self._variable_pool.variable_dictionary:
|
||||
for key, var in self._variable_pool.variable_dictionary[node_id].items():
|
||||
# Variables have a value property that contains the actual data
|
||||
variables[key] = deepcopy(var.value)
|
||||
return variables
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeStateWrapper:
|
||||
"""
|
||||
Wrapper that provides read-only access to GraphRuntimeState.
|
||||
|
||||
This wrapper ensures that layers can observe the state without
|
||||
modifying it. All returned values are defensive copies.
|
||||
"""
|
||||
|
||||
def __init__(self, state: GraphRuntimeState):
|
||||
self._state = state
|
||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||
"""Get read-only access to the variable pool."""
|
||||
return self._variable_pool_wrapper
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
return self._state.start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
return self._state.total_tokens
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
# Return a copy to prevent modification
|
||||
return self._state.llm_usage.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
return deepcopy(self._state.outputs)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
return self._state.node_run_steps
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
return self._state.get_output(key, default)
|
||||
@ -1,4 +1,3 @@
|
||||
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
|
||||
from .graph_engine import GraphEngine
|
||||
|
||||
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
__all__ = ["GraphEngine"]
|
||||
|
||||
33
api/core/workflow/graph_engine/command_channels/README.md
Normal file
33
api/core/workflow/graph_engine/command_channels/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# Command Channels
|
||||
|
||||
Channel implementations for external workflow control.
|
||||
|
||||
## Components
|
||||
|
||||
### InMemoryChannel
|
||||
|
||||
Thread-safe in-memory queue for single-process deployments.
|
||||
|
||||
- `fetch_commands()` - Get pending commands
|
||||
- `send_command()` - Add command to queue
|
||||
|
||||
### RedisChannel
|
||||
|
||||
Redis-based queue for distributed deployments.
|
||||
|
||||
- `fetch_commands()` - Get commands with JSON deserialization
|
||||
- `send_command()` - Store commands with TTL
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# Local execution
|
||||
channel = InMemoryChannel()
|
||||
channel.send_command(AbortCommand(graph_id="workflow-123"))
|
||||
|
||||
# Distributed execution
|
||||
redis_channel = RedisChannel(
|
||||
redis_client=redis_client,
|
||||
channel_key="workflow:123:commands"
|
||||
)
|
||||
```
|
||||
@ -0,0 +1,6 @@
|
||||
"""Command channel implementations for GraphEngine."""
|
||||
|
||||
from .in_memory_channel import InMemoryChannel
|
||||
from .redis_channel import RedisChannel
|
||||
|
||||
__all__ = ["InMemoryChannel", "RedisChannel"]
|
||||
@ -0,0 +1,53 @@
|
||||
"""
|
||||
In-memory implementation of CommandChannel for local/testing scenarios.
|
||||
|
||||
This implementation uses a thread-safe queue for command communication
|
||||
within a single process. Each instance handles commands for one workflow execution.
|
||||
"""
|
||||
|
||||
from queue import Queue
|
||||
from typing import final
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryChannel:
|
||||
"""
|
||||
In-memory command channel implementation using a thread-safe queue.
|
||||
|
||||
Each instance is dedicated to a single GraphEngine/workflow execution.
|
||||
Suitable for local development, testing, and single-instance deployments.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in-memory channel with a single queue."""
|
||||
self._queue: Queue[GraphEngineCommand] = Queue()
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from the queue.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the queue)
|
||||
"""
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Drain all available commands from the queue
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
command = self._queue.get_nowait()
|
||||
commands.append(command)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to this channel's queue.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
self._queue.put(command)
|
||||
114
api/core/workflow/graph_engine/command_channels/redis_channel.py
Normal file
114
api/core/workflow/graph_engine/command_channels/redis_channel.py
Normal file
@ -0,0 +1,114 @@
|
||||
"""
|
||||
Redis-based implementation of CommandChannel for distributed scenarios.
|
||||
|
||||
This implementation uses Redis lists for command queuing, supporting
|
||||
multi-instance deployments and cross-server communication.
|
||||
Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
|
||||
@final
|
||||
class RedisChannel:
|
||||
"""
|
||||
Redis-based command channel implementation for distributed systems.
|
||||
|
||||
Each instance uses a unique Redis key for its command queue.
|
||||
Commands are JSON-serialized for transport.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: "RedisClientWrapper",
|
||||
channel_key: str,
|
||||
command_ttl: int = 3600,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Redis channel.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
channel_key: Unique key for this channel's command queue
|
||||
command_ttl: TTL for command keys in seconds (default: 3600)
|
||||
"""
|
||||
self._redis = redis_client
|
||||
self._key = channel_key
|
||||
self._command_ttl = command_ttl
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from Redis.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the Redis list)
|
||||
"""
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
with self._redis.pipeline() as pipe:
|
||||
# Get all commands and clear the list atomically
|
||||
pipe.lrange(self._key, 0, -1)
|
||||
pipe.delete(self._key)
|
||||
results = pipe.execute()
|
||||
|
||||
# Parse commands from JSON
|
||||
if results[0]:
|
||||
for command_json in results[0]:
|
||||
try:
|
||||
command_data = json.loads(command_json)
|
||||
command = self._deserialize_command(command_data)
|
||||
if command:
|
||||
commands.append(command)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Skip invalid commands
|
||||
continue
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to Redis.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
command_json = json.dumps(command.model_dump())
|
||||
|
||||
# Push to list and set expiry
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.rpush(self._key, command_json)
|
||||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
||||
"""
|
||||
Deserialize a command from dictionary data.
|
||||
|
||||
Args:
|
||||
data: Command data dictionary
|
||||
|
||||
Returns:
|
||||
Deserialized command or None if invalid
|
||||
"""
|
||||
command_type_value = data.get("command_type")
|
||||
if not isinstance(command_type_value, str):
|
||||
return None
|
||||
|
||||
try:
|
||||
command_type = CommandType(command_type_value)
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand(**data)
|
||||
else:
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand(**data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Command processing subsystem for graph engine.
|
||||
|
||||
This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
]
|
||||
@ -0,0 +1,32 @@
|
||||
"""
|
||||
Command handler implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
"""
|
||||
Handle an abort command.
|
||||
|
||||
Args:
|
||||
command: The abort command
|
||||
execution: Graph execution to abort
|
||||
"""
|
||||
assert isinstance(command, AbortCommand)
|
||||
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.abort(command.reason or "User requested abort")
|
||||
@ -0,0 +1,79 @@
|
||||
"""
|
||||
Main command processor for handling external commands.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Protocol, final
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
from ..protocols.command_channel import CommandChannel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommandHandler(Protocol):
|
||||
"""Protocol for command handlers."""
|
||||
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
|
||||
|
||||
|
||||
@final
|
||||
class CommandProcessor:
|
||||
"""
|
||||
Processes external commands sent to the engine.
|
||||
|
||||
This polls the command channel and dispatches commands to
|
||||
appropriate handlers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command_channel: CommandChannel,
|
||||
graph_execution: GraphExecution,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the command processor.
|
||||
|
||||
Args:
|
||||
command_channel: Channel for receiving commands
|
||||
graph_execution: Graph execution aggregate
|
||||
"""
|
||||
self._command_channel = command_channel
|
||||
self._graph_execution = graph_execution
|
||||
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
|
||||
|
||||
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
|
||||
"""
|
||||
Register a handler for a command type.
|
||||
|
||||
Args:
|
||||
command_type: Type of command to handle
|
||||
handler: Handler for the command
|
||||
"""
|
||||
self._handlers[command_type] = handler
|
||||
|
||||
def process_commands(self) -> None:
|
||||
"""Check for and process any pending commands."""
|
||||
try:
|
||||
commands = self._command_channel.fetch_commands()
|
||||
for command in commands:
|
||||
self._handle_command(command)
|
||||
except Exception as e:
|
||||
logger.warning("Error processing commands: %s", e)
|
||||
|
||||
def _handle_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Handle a single command.
|
||||
|
||||
Args:
|
||||
command: The command to handle
|
||||
"""
|
||||
handler = self._handlers.get(type(command))
|
||||
if handler:
|
||||
try:
|
||||
handler.handle(command, self._graph_execution)
|
||||
except Exception:
|
||||
logger.exception("Error handling command %s", command.__class__.__name__)
|
||||
else:
|
||||
logger.warning("No handler registered for command: %s", command.__class__.__name__)
|
||||
@ -1,25 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition):
|
||||
self.init_params = init_params
|
||||
self.graph = graph
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -1,25 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
run_result = previous_route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return False
|
||||
|
||||
if not run_result.edge_source_handle:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == run_result.edge_source_handle
|
||||
@ -1,27 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
return True
|
||||
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
_, _, final_result = condition_processor.process_conditions(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
conditions=self.condition.conditions,
|
||||
operator="and",
|
||||
)
|
||||
|
||||
return final_result
|
||||
@ -1,25 +0,0 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class ConditionManager:
|
||||
@staticmethod
|
||||
def get_condition_handler(
|
||||
init_params: GraphInitParams, graph: Graph, run_condition: RunCondition
|
||||
) -> RunConditionHandler:
|
||||
"""
|
||||
Get condition handler
|
||||
|
||||
:param init_params: init params
|
||||
:param graph: graph
|
||||
:param run_condition: run condition
|
||||
:return: condition handler
|
||||
"""
|
||||
if run_condition.type == "branch_identify":
|
||||
return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition)
|
||||
else:
|
||||
return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition)
|
||||
14
api/core/workflow/graph_engine/domain/__init__.py
Normal file
14
api/core/workflow/graph_engine/domain/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Domain models for graph engine.
|
||||
|
||||
This package contains the core domain entities, value objects, and aggregates
|
||||
that represent the business concepts of workflow graph execution.
|
||||
"""
|
||||
|
||||
from .graph_execution import GraphExecution
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
__all__ = [
|
||||
"GraphExecution",
|
||||
"NodeExecution",
|
||||
]
|
||||
215
api/core/workflow/graph_engine/domain/graph_execution.py
Normal file
215
api/core/workflow/graph_engine/domain/graph_execution.py
Normal file
@ -0,0 +1,215 @@
|
||||
"""GraphExecution aggregate root managing the overall graph execution state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import import_module
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
|
||||
class GraphExecutionErrorState(BaseModel):
|
||||
"""Serializable representation of an execution error."""
|
||||
|
||||
module: str = Field(description="Module containing the exception class")
|
||||
qualname: str = Field(description="Qualified name of the exception class")
|
||||
message: str | None = Field(default=None, description="Exception message string")
|
||||
|
||||
|
||||
class NodeExecutionState(BaseModel):
|
||||
"""Serializable representation of a node execution entity."""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = Field(default=NodeState.UNKNOWN)
|
||||
retry_count: int = Field(default=0)
|
||||
execution_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
|
||||
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Pydantic model describing serialized GraphExecution state."""
|
||||
|
||||
type: Literal["GraphExecution"] = Field(default="GraphExecution")
|
||||
version: str = Field(default="1.0")
|
||||
workflow_id: str
|
||||
started: bool = Field(default=False)
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
"""Convert an exception into its serializable representation."""
|
||||
|
||||
if error is None:
|
||||
return None
|
||||
|
||||
return GraphExecutionErrorState(
|
||||
module=error.__class__.__module__,
|
||||
qualname=error.__class__.__qualname__,
|
||||
message=str(error),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
|
||||
"""Locate an exception class from its module and qualified name."""
|
||||
|
||||
module = import_module(module_name)
|
||||
attr: object = module
|
||||
for part in qualname.split("."):
|
||||
attr = getattr(attr, part)
|
||||
|
||||
if isinstance(attr, type) and issubclass(attr, Exception):
|
||||
return attr
|
||||
|
||||
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
|
||||
|
||||
|
||||
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
|
||||
"""Reconstruct an exception instance from serialized data."""
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
exception_class = _resolve_exception_class(state.module, state.qualname)
|
||||
if state.message is None:
|
||||
return exception_class()
|
||||
return exception_class(state.message)
|
||||
except Exception:
|
||||
# Fallback to RuntimeError when reconstruction fails
|
||||
if state.message is None:
|
||||
return RuntimeError(state.qualname)
|
||||
return RuntimeError(state.message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExecution:
|
||||
"""
|
||||
Aggregate root for graph execution.
|
||||
|
||||
This manages the overall execution state of a workflow graph,
|
||||
coordinating between multiple node executions.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
started: bool = False
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark the graph execution as started."""
|
||||
if self.started:
|
||||
raise RuntimeError("Graph execution already started")
|
||||
self.started = True
|
||||
|
||||
def complete(self) -> None:
|
||||
"""Mark the graph execution as completed."""
|
||||
if not self.started:
|
||||
raise RuntimeError("Cannot complete execution that hasn't started")
|
||||
if self.completed:
|
||||
raise RuntimeError("Graph execution already completed")
|
||||
self.completed = True
|
||||
|
||||
def abort(self, reason: str) -> None:
|
||||
"""Abort the graph execution."""
|
||||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
self.error = error
|
||||
self.completed = True
|
||||
|
||||
def get_or_create_node_execution(self, node_id: str) -> NodeExecution:
|
||||
"""Get or create a node execution entity."""
|
||||
if node_id not in self.node_executions:
|
||||
self.node_executions[node_id] = NodeExecution(node_id=node_id)
|
||||
return self.node_executions[node_id]
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the execution is currently running."""
|
||||
return self.started and not self.completed and not self.aborted
|
||||
|
||||
@property
|
||||
def has_error(self) -> bool:
|
||||
"""Check if the execution has encountered an error."""
|
||||
return self.error is not None
|
||||
|
||||
@property
|
||||
def error_message(self) -> str | None:
|
||||
"""Get the error message if an error exists."""
|
||||
if not self.error:
|
||||
return None
|
||||
return str(self.error)
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the aggregate state into a JSON string."""
|
||||
|
||||
node_states = [
|
||||
NodeExecutionState(
|
||||
node_id=node_id,
|
||||
state=node_execution.state,
|
||||
retry_count=node_execution.retry_count,
|
||||
execution_id=node_execution.execution_id,
|
||||
error=node_execution.error,
|
||||
)
|
||||
for node_id, node_execution in sorted(self.node_executions.items())
|
||||
]
|
||||
|
||||
state = GraphExecutionState(
|
||||
workflow_id=self.workflow_id,
|
||||
started=self.started,
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore aggregate state from a serialized JSON string."""
|
||||
|
||||
state = GraphExecutionState.model_validate_json(data)
|
||||
|
||||
if state.type != "GraphExecution":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
if self.workflow_id != state.workflow_id:
|
||||
raise ValueError("Serialized workflow_id does not match aggregate identity")
|
||||
|
||||
self.started = state.started
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
state=item.state,
|
||||
retry_count=item.retry_count,
|
||||
execution_id=item.execution_id,
|
||||
error=item.error,
|
||||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
45
api/core/workflow/graph_engine/domain/node_execution.py
Normal file
45
api/core/workflow/graph_engine/domain/node_execution.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""
|
||||
NodeExecution entity representing a node's execution state.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeExecution:
|
||||
"""
|
||||
Entity representing the execution state of a single node.
|
||||
|
||||
This is a mutable entity that tracks the runtime state of a node
|
||||
during graph execution.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = NodeState.UNKNOWN
|
||||
retry_count: int = 0
|
||||
execution_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
def mark_started(self, execution_id: str) -> None:
|
||||
"""Mark the node as started with an execution ID."""
|
||||
self.state = NodeState.TAKEN
|
||||
self.execution_id = execution_id
|
||||
|
||||
def mark_taken(self) -> None:
|
||||
"""Mark the node as successfully completed."""
|
||||
self.state = NodeState.TAKEN
|
||||
self.error = None
|
||||
|
||||
def mark_failed(self, error: str) -> None:
|
||||
"""Mark the node as failed with an error."""
|
||||
self.error = error
|
||||
|
||||
def mark_skipped(self) -> None:
|
||||
"""Mark the node as skipped."""
|
||||
self.state = NodeState.SKIPPED
|
||||
|
||||
def increment_retry(self) -> None:
|
||||
"""Increment the retry count for this node."""
|
||||
self.retry_count += 1
|
||||
@ -1,6 +0,0 @@
|
||||
from .graph import Graph
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .runtime_route_state import RuntimeRouteState
|
||||
|
||||
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
|
||||
33
api/core/workflow/graph_engine/entities/commands.py
Normal file
33
api/core/workflow/graph_engine/entities/commands.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""
|
||||
GraphEngine command entities for external control.
|
||||
|
||||
This module defines command types that can be sent to a running GraphEngine
|
||||
instance to control its execution flow.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
"""Types of commands that can be sent to GraphEngine."""
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
RESUME = "resume"
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
"""Base class for all GraphEngine commands."""
|
||||
|
||||
command_type: CommandType = Field(..., description="Type of command")
|
||||
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")
|
||||
|
||||
|
||||
class AbortCommand(GraphEngineCommand):
|
||||
"""Command to abort a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for abort")
|
||||
@ -1,277 +0,0 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
###########################################
|
||||
# Graph Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
"""outputs"""
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: int = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
exceptions_count: int = Field(..., description="exception count")
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
###########################################
|
||||
# Node Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
id: str = Field(..., description="node execution id")
|
||||
node_id: str = Field(..., description="node id")
|
||||
node_type: NodeType = Field(..., description="node type")
|
||||
node_data: BaseNodeData = Field(..., description="node data")
|
||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
predecessor_node_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteration node parallel mode run id"""
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
pass
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeInLoopFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
start_at: datetime = Field(..., description="retry start time")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseParallelBranchEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
"""parallel id"""
|
||||
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
||||
"""parallel start node id"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Iteration Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseIterationEvent(GraphEngineEvent):
|
||||
iteration_id: str = Field(..., description="iteration node execution id")
|
||||
iteration_node_id: str = Field(..., description="iteration node id")
|
||||
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
|
||||
iteration_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class IterationRunNextEvent(BaseIterationEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Loop Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseLoopEvent(GraphEngineEvent):
|
||||
loop_id: str = Field(..., description="loop node execution id")
|
||||
loop_node_id: str = Field(..., description="loop node id")
|
||||
loop_node_type: NodeType = Field(..., description="node type, loop or loop")
|
||||
loop_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""loop run in parallel mode run id"""
|
||||
|
||||
|
||||
class LoopRunStartedEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class LoopRunNextEvent(BaseLoopEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Optional[Any] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class LoopRunSucceededEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
loop_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class LoopRunFailedEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Agent Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseAgentEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class AgentLogEvent(BaseAgentEvent):
|
||||
id: str = Field(..., description="id")
|
||||
label: str = Field(..., description="label")
|
||||
node_execution_id: str = Field(..., description="node execution id")
|
||||
parent_id: str | None = Field(..., description="parent id")
|
||||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
||||
node_id: str = Field(..., description="agent node id")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent
|
||||
@ -1,719 +0,0 @@
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
|
||||
from core.workflow.nodes.end.entities import EndStreamParam
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""run condition"""
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id"""
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=dict, description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict, description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict, description="reverse graph edge mapping (target node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict, description="graph parallel mapping (parallel id: parallel)"
|
||||
)
|
||||
node_parallel_mapping: dict[str, str] = Field(
|
||||
default_factory=dict, description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes")
|
||||
end_stream_param: EndStreamParam = Field(..., description="end stream param")
|
||||
|
||||
@classmethod
|
||||
def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
||||
:param graph_config: graph config
|
||||
:param root_node_id: root node id
|
||||
:return: graph
|
||||
"""
|
||||
# edge configs
|
||||
edge_configs = graph_config.get("edges")
|
||||
if edge_configs is None:
|
||||
edge_configs = []
|
||||
# node configs
|
||||
node_configs = graph_config.get("nodes")
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
node_configs = cast(list, node_configs)
|
||||
|
||||
# reorganize edges mapping
|
||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
fail_branch_source_node_id = [
|
||||
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
|
||||
]
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get("source")
|
||||
if not source_node_id:
|
||||
continue
|
||||
|
||||
if source_node_id not in edge_mapping:
|
||||
edge_mapping[source_node_id] = []
|
||||
|
||||
target_node_id = edge_config.get("target")
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get("sourceHandle"):
|
||||
if (
|
||||
edge_config.get("source") in fail_branch_source_node_id
|
||||
and edge_config.get("sourceHandle") != "fail-branch"
|
||||
):
|
||||
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
|
||||
elif edge_config.get("sourceHandle") != "source":
|
||||
run_condition = RunCondition(
|
||||
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
|
||||
)
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
|
||||
)
|
||||
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
if node_id not in target_edge_ids:
|
||||
root_node_configs.append(node_config)
|
||||
|
||||
all_node_id_config_mapping[node_id] = node_config
|
||||
|
||||
root_node_ids = [node_config.get("id") for node_config in root_node_configs]
|
||||
|
||||
# fetch root node
|
||||
if not root_node_id:
|
||||
# if no root node id, use any start node (START or trigger types) as root node
|
||||
root_node_id = next(
|
||||
(
|
||||
node_config.get("id")
|
||||
for node_config in root_node_configs
|
||||
if NodeType(node_config.get("data", {}).get("type", "")).is_start_node
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not root_node_id or root_node_id not in root_node_ids:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
|
||||
# Check whether it is connected to the previous node
|
||||
cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping)
|
||||
|
||||
# fetch all node ids from root node
|
||||
node_ids = [root_node_id]
|
||||
cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id)
|
||||
|
||||
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
|
||||
|
||||
# init parallel mapping
|
||||
parallel_mapping: dict[str, GraphParallel] = {}
|
||||
node_parallel_mapping: dict[str, str] = {}
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=root_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
)
|
||||
|
||||
# Check if it exceeds N layers of parallel
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
parent_parallel_id=parallel.parent_parallel_id,
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping
|
||||
)
|
||||
|
||||
# init end stream param
|
||||
end_stream_param = EndStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = cls(
|
||||
root_node_id=root_node_id,
|
||||
node_ids=node_ids,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
answer_stream_generate_routes=answer_stream_generate_routes,
|
||||
end_stream_param=end_stream_param,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
def add_extra_edge(
|
||||
self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None
|
||||
) -> None:
|
||||
"""
|
||||
Add extra edge to the graph
|
||||
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param run_condition: run condition
|
||||
"""
|
||||
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
|
||||
return
|
||||
|
||||
if source_node_id not in self.edge_mapping:
|
||||
self.edge_mapping[source_node_id] = []
|
||||
|
||||
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
|
||||
return
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
|
||||
)
|
||||
|
||||
self.edge_mapping[source_node_id].append(graph_edge)
|
||||
|
||||
def get_leaf_node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get leaf node ids of the graph
|
||||
|
||||
:return: leaf node ids
|
||||
"""
|
||||
leaf_node_ids = []
|
||||
for node_id in self.node_ids:
|
||||
if node_id not in self.edge_mapping or (
|
||||
len(self.edge_mapping[node_id]) == 1
|
||||
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id
|
||||
):
|
||||
leaf_node_ids.append(node_id)
|
||||
|
||||
return leaf_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(
|
||||
cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param node_ids: node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param node_id: node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(node_id, []):
|
||||
if graph_edge.target_node_id in node_ids:
|
||||
continue
|
||||
|
||||
node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None:
|
||||
"""
|
||||
Check whether it is connected to the previous node
|
||||
"""
|
||||
last_node_id = route[-1]
|
||||
|
||||
for graph_edge in edge_mapping.get(last_node_id, []):
|
||||
if not graph_edge.target_node_id:
|
||||
continue
|
||||
|
||||
if graph_edge.target_node_id in route:
|
||||
raise ValueError(
|
||||
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
|
||||
)
|
||||
|
||||
new_route = route.copy()
|
||||
new_route.append(graph_edge.target_node_id)
|
||||
cls._check_connected_to_previous_node(
|
||||
route=new_route,
|
||||
edge_mapping=edge_mapping,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
node_parallel_mapping: dict[str, str],
|
||||
parent_parallel: Optional[GraphParallel] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add parallel ids
|
||||
|
||||
:param edge_mapping: edge mapping
|
||||
:param start_node_id: start from node id
|
||||
:param parallel_mapping: parallel mapping
|
||||
:param node_parallel_mapping: node parallel mapping
|
||||
:param parent_parallel: parent parallel
|
||||
"""
|
||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||
parallel = None
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_branch_node_ids = defaultdict(list)
|
||||
condition_edge_mappings = defaultdict(list)
|
||||
for graph_edge in target_node_edges:
|
||||
if graph_edge.run_condition is None:
|
||||
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
|
||||
else:
|
||||
condition_hash = graph_edge.run_condition.hash
|
||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||
|
||||
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||
if len(graph_edges) > 1:
|
||||
for graph_edge in graph_edges:
|
||||
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
|
||||
|
||||
condition_parallels = {}
|
||||
for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
|
||||
# any target node id in node_parallel_mapping
|
||||
parallel = None
|
||||
if condition_parallel_branch_node_ids:
|
||||
parent_parallel_id = parent_parallel.id if parent_parallel else None
|
||||
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
condition_parallels[condition_hash] = parallel
|
||||
|
||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_branch_node_ids=condition_parallel_branch_node_ids,
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
parallel_node_ids = []
|
||||
for _, node_ids in in_branch_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
in_parent_parallel = True
|
||||
if parent_parallel_id:
|
||||
in_parent_parallel = False
|
||||
for parallel_node_id, parallel_id in node_parallel_mapping.items():
|
||||
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
|
||||
in_parent_parallel = True
|
||||
break
|
||||
|
||||
if in_parent_parallel:
|
||||
parallel_node_ids.append(node_id)
|
||||
node_parallel_mapping[node_id] = parallel.id
|
||||
|
||||
outside_parallel_target_node_ids = set()
|
||||
for node_id in parallel_node_ids:
|
||||
if node_id == parallel.start_from_node_id:
|
||||
continue
|
||||
|
||||
node_edges = edge_mapping.get(node_id)
|
||||
if not node_edges:
|
||||
continue
|
||||
|
||||
if len(node_edges) > 1:
|
||||
continue
|
||||
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if target_node_id in parallel_node_ids:
|
||||
continue
|
||||
|
||||
if parent_parallel_id:
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
continue
|
||||
|
||||
if (
|
||||
(
|
||||
node_parallel_mapping.get(target_node_id)
|
||||
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
|
||||
)
|
||||
or (
|
||||
parent_parallel
|
||||
and parent_parallel.end_to_node_id
|
||||
and target_node_id == parent_parallel.end_to_node_id
|
||||
)
|
||||
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
||||
):
|
||||
outside_parallel_target_node_ids.add(target_node_id)
|
||||
|
||||
if len(outside_parallel_target_node_ids) == 1:
|
||||
if (
|
||||
parent_parallel
|
||||
and parent_parallel.end_to_node_id
|
||||
and parallel.end_to_node_id == parent_parallel.end_to_node_id
|
||||
):
|
||||
parallel.end_to_node_id = None
|
||||
else:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
|
||||
if condition_edge_mappings:
|
||||
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||
for graph_edge in graph_edges:
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=condition_parallels.get(condition_hash),
|
||||
parent_parallel=parent_parallel,
|
||||
)
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel,
|
||||
)
|
||||
else:
|
||||
for graph_edge in target_node_edges:
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=parallel,
|
||||
parent_parallel=parent_parallel,
|
||||
)
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel,
|
||||
)
|
||||
else:
|
||||
for graph_edge in target_node_edges:
|
||||
current_parallel = cls._get_current_parallel(
|
||||
parallel_mapping=parallel_mapping,
|
||||
graph_edge=graph_edge,
|
||||
parallel=parallel,
|
||||
parent_parallel=parent_parallel,
|
||||
)
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_current_parallel(
|
||||
cls,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
graph_edge: GraphEdge,
|
||||
parallel: Optional[GraphParallel] = None,
|
||||
parent_parallel: Optional[GraphParallel] = None,
|
||||
) -> Optional[GraphParallel]:
|
||||
"""
|
||||
Get current parallel
|
||||
"""
|
||||
current_parallel = None
|
||||
if parallel:
|
||||
current_parallel = parallel
|
||||
elif parent_parallel:
|
||||
if not parent_parallel.end_to_node_id or (
|
||||
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
|
||||
):
|
||||
current_parallel = parent_parallel
|
||||
else:
|
||||
# fetch parent parallel's parent parallel
|
||||
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
|
||||
if parent_parallel_parent_parallel_id:
|
||||
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
|
||||
if parent_parallel_parent_parallel and (
|
||||
not parent_parallel_parent_parallel.end_to_node_id
|
||||
or (
|
||||
parent_parallel_parent_parallel.end_to_node_id
|
||||
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
|
||||
)
|
||||
):
|
||||
current_parallel = parent_parallel_parent_parallel
|
||||
|
||||
return current_parallel
|
||||
|
||||
@classmethod
|
||||
def _check_exceed_parallel_limit(
|
||||
cls,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
level_limit: int,
|
||||
parent_parallel_id: str,
|
||||
current_level: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Check if it exceeds N layers of parallel
|
||||
"""
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
return
|
||||
|
||||
current_level += 1
|
||||
if current_level > level_limit:
|
||||
raise ValueError(f"Exceeds {level_limit} layers of parallel")
|
||||
|
||||
if parent_parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=level_limit,
|
||||
parent_parallel_id=parent_parallel.parent_parallel_id,
|
||||
current_level=current_level,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallel_node_ids(
|
||||
cls,
|
||||
branch_node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param branch_node_ids: in branch node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param merge_node_id: merge node id
|
||||
:param start_node_id: start node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(start_node_id, []):
|
||||
if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids:
|
||||
branch_node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=branch_node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch_all_node_ids_in_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
parallel_branch_node_ids: list[str],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch all node ids in parallels
|
||||
"""
|
||||
routes_node_ids: dict[str, list[str]] = {}
|
||||
for parallel_branch_node_id in parallel_branch_node_ids:
|
||||
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
|
||||
|
||||
# fetch routes node ids
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=parallel_branch_node_id,
|
||||
routes_node_ids=routes_node_ids[parallel_branch_node_id],
|
||||
)
|
||||
|
||||
# fetch leaf node ids from routes node ids
|
||||
leaf_node_ids: dict[str, list[str]] = {}
|
||||
merge_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
|
||||
if branch_node_id not in leaf_node_ids:
|
||||
leaf_node_ids[branch_node_id] = []
|
||||
|
||||
leaf_node_ids[branch_node_id].append(node_id)
|
||||
|
||||
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||
if (
|
||||
branch_node_id != branch_node_id2
|
||||
and node_id in inner_route2
|
||||
and len(reverse_edge_mapping.get(node_id, [])) > 1
|
||||
and cls._is_node_in_routes(
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=node_id,
|
||||
routes_node_ids=routes_node_ids,
|
||||
)
|
||||
):
|
||||
if node_id not in merge_branch_node_ids:
|
||||
merge_branch_node_ids[node_id] = []
|
||||
|
||||
if branch_node_id2 not in merge_branch_node_ids[node_id]:
|
||||
merge_branch_node_ids[node_id].append(branch_node_id2)
|
||||
|
||||
# sorted merge_branch_node_ids by branch_node_ids length desc
|
||||
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
|
||||
|
||||
duplicate_end_node_ids = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
|
||||
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
|
||||
if (node_id, node_id2) not in duplicate_end_node_ids and (
|
||||
node_id2,
|
||||
node_id,
|
||||
) not in duplicate_end_node_ids:
|
||||
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
|
||||
|
||||
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
|
||||
# check which node is after
|
||||
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
|
||||
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id2]
|
||||
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
|
||||
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id]
|
||||
|
||||
branches_merge_node_ids: dict[str, str] = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
if len(branch_node_ids) <= 1:
|
||||
continue
|
||||
|
||||
for branch_node_id in branch_node_ids:
|
||||
if branch_node_id in branches_merge_node_ids:
|
||||
continue
|
||||
|
||||
branches_merge_node_ids[branch_node_id] = node_id
|
||||
|
||||
in_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
in_branch_node_ids[branch_node_id] = []
|
||||
if branch_node_id not in branches_merge_node_ids:
|
||||
# all node ids in current branch is in this thread
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
in_branch_node_ids[branch_node_id].extend(node_ids)
|
||||
else:
|
||||
merge_node_id = branches_merge_node_ids[branch_node_id]
|
||||
if merge_node_id != branch_node_id:
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
|
||||
# fetch all node ids from branch_node_id and merge_node_id
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=branch_node_id,
|
||||
)
|
||||
|
||||
return in_branch_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_fetch_routes(
|
||||
cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
Recursively fetch route
|
||||
"""
|
||||
if start_node_id not in edge_mapping:
|
||||
return
|
||||
|
||||
for graph_edge in edge_mapping[start_node_id]:
|
||||
# find next node ids
|
||||
if graph_edge.target_node_id not in routes_node_ids:
|
||||
routes_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_node_in_routes(
|
||||
cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]]
|
||||
) -> bool:
|
||||
"""
|
||||
Recursively check if the node is in the routes
|
||||
"""
|
||||
if start_node_id not in reverse_edge_mapping:
|
||||
return False
|
||||
|
||||
all_routes_node_ids = set()
|
||||
parallel_start_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
all_routes_node_ids.update(node_ids)
|
||||
|
||||
if branch_node_id in reverse_edge_mapping:
|
||||
for graph_edge in reverse_edge_mapping[branch_node_id]:
|
||||
if graph_edge.source_node_id not in parallel_start_node_ids:
|
||||
parallel_start_node_ids[graph_edge.source_node_id] = []
|
||||
|
||||
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
|
||||
|
||||
for _, branch_node_ids in parallel_start_node_ids.items():
|
||||
if set(branch_node_ids) == set(routes_node_ids.keys()):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
|
||||
"""
|
||||
is node2 after node1
|
||||
"""
|
||||
if node1_id not in edge_mapping:
|
||||
return False
|
||||
|
||||
for graph_edge in edge_mapping[node1_id]:
|
||||
if graph_edge.target_node_id == node2_id:
|
||||
return True
|
||||
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -1,31 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
"""variable pool"""
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
"""start time"""
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
|
||||
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
|
||||
#
|
||||
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
|
||||
# after a serialization and deserialization round trip.
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
||||
"""node run state"""
|
||||
@ -1,118 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
class Status(Enum):
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
EXCEPTION = "exception"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""node state id"""
|
||||
|
||||
node_id: str
|
||||
"""node id"""
|
||||
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
"""node run result"""
|
||||
|
||||
status: Status = Status.RUNNING
|
||||
"""node status"""
|
||||
|
||||
start_at: datetime
|
||||
"""start time"""
|
||||
|
||||
paused_at: Optional[datetime] = None
|
||||
"""paused time"""
|
||||
|
||||
finished_at: Optional[datetime] = None
|
||||
"""finished time"""
|
||||
|
||||
failed_reason: Optional[str] = None
|
||||
"""failed reason"""
|
||||
|
||||
paused_by: Optional[str] = None
|
||||
"""paused by"""
|
||||
|
||||
index: int = 1
|
||||
|
||||
def set_finished(self, run_result: NodeRunResult) -> None:
|
||||
"""
|
||||
Node finished
|
||||
|
||||
:param run_result: run result
|
||||
"""
|
||||
if self.status in {
|
||||
RouteNodeState.Status.SUCCESS,
|
||||
RouteNodeState.Status.FAILED,
|
||||
RouteNodeState.Status.EXCEPTION,
|
||||
}:
|
||||
raise Exception(f"Route state {self.id} already finished")
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
self.status = RouteNodeState.Status.SUCCESS
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
self.status = RouteNodeState.Status.FAILED
|
||||
self.failed_reason = run_result.error
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||
self.status = RouteNodeState.Status.EXCEPTION
|
||||
self.failed_reason = run_result.error
|
||||
else:
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
self.node_run_result = run_result
|
||||
self.finished_at = naive_utc_now()
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
routes: dict[str, list[str]] = Field(
|
||||
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
|
||||
)
|
||||
|
||||
node_state_mapping: dict[str, RouteNodeState] = Field(
|
||||
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
|
||||
)
|
||||
|
||||
def create_node_state(self, node_id: str) -> RouteNodeState:
|
||||
"""
|
||||
Create node state
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
state = RouteNodeState(node_id=node_id, start_at=naive_utc_now())
|
||||
self.node_state_mapping[state.id] = state
|
||||
return state
|
||||
|
||||
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
|
||||
"""
|
||||
Add route to the graph state
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:param target_node_state_id: target node state id
|
||||
"""
|
||||
if source_node_state_id not in self.routes:
|
||||
self.routes[source_node_state_id] = []
|
||||
|
||||
self.routes[source_node_state_id].append(target_node_state_id)
|
||||
|
||||
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]:
|
||||
"""
|
||||
Get routes with node state by source node id
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:return: routes with node state
|
||||
"""
|
||||
return [
|
||||
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
|
||||
]
|
||||
211
api/core/workflow/graph_engine/error_handler.py
Normal file
211
api/core/workflow/graph_engine/error_handler.py
Normal file
@ -0,0 +1,211 @@
|
||||
"""
|
||||
Main error handler that coordinates error strategies.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy as ErrorStrategyEnum,
|
||||
)
|
||||
from core.workflow.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetryEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .domain import GraphExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class ErrorHandler:
|
||||
"""
|
||||
Coordinates error handling strategies for node failures.
|
||||
|
||||
This acts as a facade for the various error strategies,
|
||||
selecting and applying the appropriate strategy based on
|
||||
node configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
|
||||
"""
|
||||
Initialize the error handler.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
graph_execution: The graph execution state
|
||||
"""
|
||||
self._graph = graph
|
||||
self._graph_execution = graph_execution
|
||||
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
Selects and applies the appropriate error strategy based on
|
||||
the node's configuration.
|
||||
|
||||
Args:
|
||||
event: The node failure event
|
||||
|
||||
Returns:
|
||||
Optional new event to process, or None to abort
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
# Get retry count from NodeExecution
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
retry_count = node_execution.retry_count
|
||||
|
||||
# First check if retry is configured and not exhausted
|
||||
if node.retry and retry_count < node.retry_config.max_retries:
|
||||
result = self._handle_retry(event, retry_count)
|
||||
if result:
|
||||
# Retry count will be incremented when NodeRunRetryEvent is handled
|
||||
return result
|
||||
|
||||
# Apply configured error strategy
|
||||
strategy = node.error_strategy
|
||||
|
||||
match strategy:
|
||||
case None:
|
||||
return self._handle_abort(event)
|
||||
case ErrorStrategyEnum.FAIL_BRANCH:
|
||||
return self._handle_fail_branch(event)
|
||||
case ErrorStrategyEnum.DEFAULT_VALUE:
|
||||
return self._handle_default_value(event)
|
||||
|
||||
def _handle_abort(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by aborting execution.
|
||||
|
||||
This is the default strategy when no other strategy is specified.
|
||||
It stops the entire graph execution when a node fails.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
None - signals abortion
|
||||
"""
|
||||
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
|
||||
# Return None to signal that execution should stop
|
||||
|
||||
def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int):
|
||||
"""
|
||||
Handle error by retrying the node.
|
||||
|
||||
This strategy re-attempts node execution up to a configured
|
||||
maximum number of retries with configurable intervals.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
retry_count: Current retry attempt count
|
||||
|
||||
Returns:
|
||||
NodeRunRetryEvent if retry should occur, None otherwise
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
# Check if we've exceeded max retries
|
||||
if not node.retry or retry_count >= node.retry_config.max_retries:
|
||||
return None
|
||||
|
||||
# Wait for retry interval
|
||||
time.sleep(node.retry_config.retry_interval_seconds)
|
||||
|
||||
# Create retry event
|
||||
return NodeRunRetryEvent(
|
||||
id=event.id,
|
||||
node_title=node.title,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_run_result=event.node_run_result,
|
||||
start_at=event.start_at,
|
||||
error=event.error,
|
||||
retry_index=retry_count + 1,
|
||||
)
|
||||
|
||||
def _handle_fail_branch(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by taking the fail branch.
|
||||
|
||||
This strategy converts failures to exceptions and routes execution
|
||||
through a designated fail-branch edge.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
NodeRunExceptionEvent to continue via fail branch
|
||||
"""
|
||||
outputs = {
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
}
|
||||
|
||||
return NodeRunExceptionEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
edge_source_handle="fail-branch",
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH,
|
||||
},
|
||||
),
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_default_value(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by using default values.
|
||||
|
||||
This strategy allows nodes to fail gracefully by providing
|
||||
predefined default output values.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
NodeRunExceptionEvent with default values
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
outputs = {
|
||||
**node.default_value_dict,
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
}
|
||||
|
||||
return NodeRunExceptionEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE,
|
||||
},
|
||||
),
|
||||
error=event.error,
|
||||
)
|
||||
14
api/core/workflow/graph_engine/event_management/__init__.py
Normal file
14
api/core/workflow/graph_engine/event_management/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Event management subsystem for graph engine.
|
||||
|
||||
This package handles event routing, collection, and emission for
|
||||
workflow graph execution events.
|
||||
"""
|
||||
|
||||
from .event_handlers import EventHandler
|
||||
from .event_manager import EventManager
|
||||
|
||||
__all__ = [
|
||||
"EventHandler",
|
||||
"EventManager",
|
||||
]
|
||||
@ -0,0 +1,311 @@
|
||||
"""
|
||||
Event handler implementations for different event types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..error_handler import ErrorHandler
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..graph_traversal import EdgeProcessor
|
||||
from .event_manager import EventManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class EventHandler:
|
||||
"""
|
||||
Registry of event handlers for different event types.
|
||||
|
||||
This centralizes the business logic for handling specific events,
|
||||
keeping it separate from the routing and collection infrastructure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_execution: GraphExecution,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
event_collector: "EventManager",
|
||||
edge_processor: "EdgeProcessor",
|
||||
state_manager: "GraphStateManager",
|
||||
error_handler: "ErrorHandler",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the event handler registry.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
graph_runtime_state: Runtime state with variable pool
|
||||
graph_execution: Graph execution aggregate
|
||||
response_coordinator: Response stream coordinator
|
||||
event_collector: Event manager for collecting events
|
||||
edge_processor: Edge processor for edge traversal
|
||||
state_manager: Unified state manager
|
||||
error_handler: Error handler
|
||||
"""
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._event_collector = event_collector
|
||||
self._edge_processor = edge_processor
|
||||
self._state_manager = state_manager
|
||||
self._error_handler = error_handler
|
||||
|
||||
def dispatch(self, event: GraphNodeEventBase) -> None:
|
||||
"""
|
||||
Handle any node event by dispatching to the appropriate handler.
|
||||
|
||||
Args:
|
||||
event: The event to handle
|
||||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
return self._dispatch(event)
|
||||
|
||||
@singledispatchmethod
|
||||
def _dispatch(self, event: GraphNodeEventBase) -> None:
|
||||
self._event_collector.collect(event)
|
||||
logger.warning("Unhandled event type: %s", type(event).__name__)
|
||||
|
||||
@_dispatch.register(NodeRunIterationStartedEvent)
|
||||
@_dispatch.register(NodeRunIterationNextEvent)
|
||||
@_dispatch.register(NodeRunIterationSucceededEvent)
|
||||
@_dispatch.register(NodeRunIterationFailedEvent)
|
||||
@_dispatch.register(NodeRunLoopStartedEvent)
|
||||
@_dispatch.register(NodeRunLoopNextEvent)
|
||||
@_dispatch.register(NodeRunLoopSucceededEvent)
|
||||
@_dispatch.register(NodeRunLoopFailedEvent)
|
||||
@_dispatch.register(NodeRunAgentLogEvent)
|
||||
def _(self, event: GraphNodeEventBase) -> None:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Handle node started event.
|
||||
|
||||
Args:
|
||||
event: The node started event
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event only for the first attempt; retries remain silent
|
||||
if is_initial_attempt:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Handle stream chunk event with full processing.
|
||||
|
||||
Args:
|
||||
event: The stream chunk event
|
||||
"""
|
||||
# Process with response coordinator
|
||||
streaming_events = list(self._response_coordinator.intercept_event(event))
|
||||
|
||||
# Collect all events
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Handle node success by coordinating subsystems.
|
||||
|
||||
This method coordinates between different subsystems to process
|
||||
node completion, handle edges, and trigger downstream execution.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
# Process edges and get ready nodes
|
||||
node = self._graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
|
||||
# Collect streaming events from edge processing
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Handle node failure using error handler.
|
||||
|
||||
Args:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.dispatch(result)
|
||||
else:
|
||||
# Abort execution
|
||||
self._graph_execution.fail(RuntimeError(event.error))
|
||||
self._event_collector.collect(event)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
Handle node exception event (fail-branch strategy).
|
||||
|
||||
Args:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch/default-value, treat as completion
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
|
||||
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update response outputs if applicable
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Collect the exception event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunRetryEvent) -> None:
|
||||
"""
|
||||
Handle node retry event.
|
||||
|
||||
Args:
|
||||
event: The node retry event
|
||||
"""
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
# Finish the previous attempt before re-queuing the node
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Emit retry event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
# Re-queue node for execution
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
self._graph_runtime_state.set_output("answer", f"{existing}{value}")
|
||||
else:
|
||||
self._graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
174
api/core/workflow/graph_engine/event_management/event_manager.py
Normal file
174
api/core/workflow/graph_engine/event_management/event_manager.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""
|
||||
Unified event manager for collecting and emitting events.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from ..layers.base import GraphEngineLayer
|
||||
|
||||
|
||||
@final
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock implementation that allows multiple concurrent readers
|
||||
but only one writer at a time.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._read_ready = threading.Condition(threading.RLock())
|
||||
self._readers = 0
|
||||
|
||||
def acquire_read(self) -> None:
|
||||
"""Acquire a read lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers += 1
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
self._read_ready.notify_all()
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def acquire_write(self) -> None:
|
||||
"""Acquire a write lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
while self._readers > 0:
|
||||
_ = self._read_ready.wait()
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release a write lock."""
|
||||
self._read_ready.release()
|
||||
|
||||
@contextmanager
|
||||
def read_lock(self):
|
||||
"""Return a context manager for read locking."""
|
||||
self.acquire_read()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_read()
|
||||
|
||||
@contextmanager
|
||||
def write_lock(self):
|
||||
"""Return a context manager for write locking."""
|
||||
self.acquire_write()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_write()
|
||||
|
||||
|
||||
@final
|
||||
class EventManager:
|
||||
"""
|
||||
Unified event manager that collects, buffers, and emits events.
|
||||
|
||||
This class combines event collection with event emission, providing
|
||||
thread-safe event management with support for notifying layers and
|
||||
streaming events to external consumers.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the event manager."""
|
||||
self._events: list[GraphEngineEvent] = []
|
||||
self._lock = ReadWriteLock()
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
self._execution_complete = threading.Event()
|
||||
|
||||
def set_layers(self, layers: list[GraphEngineLayer]) -> None:
|
||||
"""
|
||||
Set the layers to notify on event collection.
|
||||
|
||||
Args:
|
||||
layers: List of layers to notify
|
||||
"""
|
||||
self._layers = layers
|
||||
|
||||
def collect(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Thread-safe method to collect an event.
|
||||
|
||||
Args:
|
||||
event: The event to collect
|
||||
"""
|
||||
with self._lock.write_lock():
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
Get new events starting from a specific index.
|
||||
|
||||
Args:
|
||||
start_index: The index to start from
|
||||
|
||||
Returns:
|
||||
List of new events
|
||||
"""
|
||||
with self._lock.read_lock():
|
||||
return list(self._events[start_index:])
|
||||
|
||||
def _event_count(self) -> int:
|
||||
"""
|
||||
Get the current count of collected events.
|
||||
|
||||
Returns:
|
||||
Number of collected events
|
||||
"""
|
||||
with self._lock.read_lock():
|
||||
return len(self._events)
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete to stop the event emission generator."""
|
||||
self._execution_complete.set()
|
||||
|
||||
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generator that yields events as they're collected.
|
||||
|
||||
Yields:
|
||||
GraphEngineEvent instances as they're processed
|
||||
"""
|
||||
yielded_count = 0
|
||||
|
||||
while not self._execution_complete.is_set() or yielded_count < self._event_count():
|
||||
# Get new events since last yield
|
||||
new_events = self._get_new_events(yielded_count)
|
||||
|
||||
# Yield any new events
|
||||
for event in new_events:
|
||||
yield event
|
||||
yielded_count += 1
|
||||
|
||||
# Small sleep to avoid busy waiting
|
||||
if not self._execution_complete.is_set() and not new_events:
|
||||
time.sleep(0.001)
|
||||
|
||||
def _notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Notify all layers of an event.
|
||||
|
||||
Layer exceptions are caught and logged to prevent disrupting collection.
|
||||
|
||||
Args:
|
||||
event: The event to send to layers
|
||||
"""
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_event(event)
|
||||
except Exception:
|
||||
# Silently ignore layer errors during collection
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
288
api/core/workflow/graph_engine/graph_state_manager.py
Normal file
288
api/core/workflow/graph_engine/graph_state_manager.py
Normal file
@ -0,0 +1,288 @@
|
||||
"""
|
||||
Graph state manager that combines node, edge, and execution tracking.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
|
||||
class EdgeStateAnalysis(TypedDict):
|
||||
"""Analysis result for edge states."""
|
||||
|
||||
has_unknown: bool
|
||||
has_taken: bool
|
||||
all_skipped: bool
|
||||
|
||||
|
||||
@final
|
||||
class GraphStateManager:
|
||||
def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None:
|
||||
"""
|
||||
Initialize the state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
ready_queue: Queue for nodes ready to execute
|
||||
"""
|
||||
self._graph = graph
|
||||
self._ready_queue = ready_queue
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Execution tracking state
|
||||
self._executing_nodes: set[str] = set()
|
||||
|
||||
# ============= Node State Operations =============
|
||||
|
||||
def enqueue_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as TAKEN and add it to the ready queue.
|
||||
|
||||
This combines the state transition and enqueueing operations
|
||||
that always occur together when preparing a node for execution.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to enqueue
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self._ready_queue.put(node_id)
|
||||
|
||||
def mark_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as SKIPPED.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when all its incoming edges from taken branches
|
||||
have been satisfied.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
with self._lock:
|
||||
# Get all incoming edges to this node
|
||||
incoming_edges = self._graph.get_incoming_edges(node_id)
|
||||
|
||||
# If no incoming edges, node is always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# If any edge is UNKNOWN, node is not ready
|
||||
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
|
||||
return False
|
||||
|
||||
# Node is ready if at least one edge is TAKEN
|
||||
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
|
||||
|
||||
def get_node_state(self, node_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The current node state
|
||||
"""
|
||||
with self._lock:
|
||||
return self._graph.nodes[node_id].state
|
||||
|
||||
# ============= Edge State Operations =============
|
||||
|
||||
def mark_edge_taken(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as TAKEN.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.edges[edge_id].state = NodeState.TAKEN
|
||||
|
||||
def mark_edge_skipped(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as SKIPPED.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
|
||||
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
||||
"""
|
||||
Analyze the states of edges and return summary flags.
|
||||
|
||||
Args:
|
||||
edges: List of edges to analyze
|
||||
|
||||
Returns:
|
||||
Analysis result with state flags
|
||||
"""
|
||||
with self._lock:
|
||||
states = {edge.state for edge in edges}
|
||||
|
||||
return EdgeStateAnalysis(
|
||||
has_unknown=NodeState.UNKNOWN in states,
|
||||
has_taken=NodeState.TAKEN in states,
|
||||
all_skipped=states == {NodeState.SKIPPED} if states else True,
|
||||
)
|
||||
|
||||
def get_edge_state(self, edge_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of an edge.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge
|
||||
|
||||
Returns:
|
||||
The current edge state
|
||||
"""
|
||||
with self._lock:
|
||||
return self._graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
A tuple of (selected_edges, unselected_edges)
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
selected_edges.append(edge)
|
||||
else:
|
||||
unselected_edges.append(edge)
|
||||
|
||||
return selected_edges, unselected_edges
|
||||
|
||||
# ============= Execution Tracking Operations =============
|
||||
|
||||
def start_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node starting execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.add(node_id)
|
||||
|
||||
def finish_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as no longer executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node finishing execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.discard(node_id)
|
||||
|
||||
def is_executing(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is currently executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is executing
|
||||
"""
|
||||
with self._lock:
|
||||
return node_id in self._executing_nodes
|
||||
|
||||
def get_executing_count(self) -> int:
|
||||
"""
|
||||
Get the count of currently executing nodes.
|
||||
|
||||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
def get_executing_nodes(self) -> set[str]:
|
||||
"""
|
||||
Get a copy of the set of executing node IDs.
|
||||
|
||||
Returns:
|
||||
Set of node IDs currently executing
|
||||
"""
|
||||
with self._lock:
|
||||
return self._executing_nodes.copy()
|
||||
|
||||
def clear_executing(self) -> None:
|
||||
"""Clear all executing nodes."""
|
||||
with self._lock:
|
||||
self._executing_nodes.clear()
|
||||
|
||||
# ============= Composite Operations =============
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
Check if graph execution is complete.
|
||||
|
||||
Execution is complete when:
|
||||
- Ready queue is empty
|
||||
- No nodes are executing
|
||||
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
with self._lock:
|
||||
return self._ready_queue.empty() and len(self._executing_nodes) == 0
|
||||
|
||||
def get_queue_depth(self) -> int:
|
||||
"""
|
||||
Get the current depth of the ready queue.
|
||||
|
||||
Returns:
|
||||
Number of nodes in the ready queue
|
||||
"""
|
||||
return self._ready_queue.qsize()
|
||||
|
||||
def get_execution_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get execution statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with execution statistics
|
||||
"""
|
||||
with self._lock:
|
||||
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
|
||||
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
|
||||
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
|
||||
|
||||
return {
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"executing": len(self._executing_nodes),
|
||||
"taken_nodes": taken_nodes,
|
||||
"skipped_nodes": skipped_nodes,
|
||||
"unknown_nodes": unknown_nodes,
|
||||
}
|
||||
14
api/core/workflow/graph_engine/graph_traversal/__init__.py
Normal file
14
api/core/workflow/graph_engine/graph_traversal/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Graph traversal subsystem for graph engine.
|
||||
|
||||
This package handles graph navigation, edge processing,
|
||||
and skip propagation logic.
|
||||
"""
|
||||
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
__all__ = [
|
||||
"EdgeProcessor",
|
||||
"SkipPropagator",
|
||||
]
|
||||
201
api/core/workflow/graph_engine/graph_traversal/edge_processor.py
Normal file
201
api/core/workflow/graph_engine/graph_traversal/edge_processor.py
Normal file
@ -0,0 +1,201 @@
|
||||
"""
|
||||
Edge processing logic for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Edge, Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
class EdgeProcessor:
|
||||
"""
|
||||
Processes edges during graph execution.
|
||||
|
||||
This handles marking edges as taken or skipped, notifying
|
||||
the response coordinator, triggering downstream node execution,
|
||||
and managing branch node logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
state_manager: GraphStateManager,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
skip_propagator: "SkipPropagator",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the edge processor.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
response_coordinator: Response stream coordinator
|
||||
skip_propagator: Propagator for skip states
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
self._response_coordinator = response_coordinator
|
||||
self._skip_propagator = skip_propagator
|
||||
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges after a node succeeds.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
selected_handle: For branch nodes, the selected edge handle
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
|
||||
"""
|
||||
node = self._graph.nodes[node_id]
|
||||
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
return self._process_branch_node_edges(node_id, selected_handle)
|
||||
else:
|
||||
return self._process_non_branch_node_edges(node_id)
|
||||
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for non-branch nodes (mark all as TAKEN).
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
"""
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_branch_node_edges(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for branch nodes.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no edge was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} did not select any edge")
|
||||
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Process unselected edges first (mark as skipped)
|
||||
for edge in unselected_edges:
|
||||
self._process_skipped_edge(edge)
|
||||
|
||||
# Process selected edges
|
||||
for edge in selected_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Mark edge as taken and check downstream node.
|
||||
|
||||
Args:
|
||||
edge: The edge to process
|
||||
|
||||
Returns:
|
||||
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
|
||||
"""
|
||||
# Mark edge as taken
|
||||
self._state_manager.mark_edge_taken(edge.id)
|
||||
|
||||
# Notify response coordinator and get streaming events
|
||||
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes: list[str] = []
|
||||
if self._state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, streaming_events
|
||||
|
||||
def _process_skipped_edge(self, edge: Edge) -> None:
|
||||
"""
|
||||
Mark edge as skipped.
|
||||
|
||||
Args:
|
||||
edge: The edge to skip
|
||||
"""
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected branch
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no branch was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self._skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
Validate that a branch selection is valid.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle to validate
|
||||
|
||||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
@ -0,0 +1,95 @@
|
||||
"""
|
||||
Skip state propagation through the graph.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
|
||||
|
||||
@final
|
||||
class SkipPropagator:
|
||||
"""
|
||||
Propagates skip states through the graph.
|
||||
|
||||
When a node is skipped, this ensures all downstream nodes
|
||||
that depend solely on it are also skipped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
state_manager: GraphStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the skip propagator.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
|
||||
def propagate_skip_from_edge(self, edge_id: str) -> None:
|
||||
"""
|
||||
Recursively propagate skip state from a skipped edge.
|
||||
|
||||
Rules:
|
||||
- If a node has any UNKNOWN incoming edges, stop processing
|
||||
- If all incoming edges are SKIPPED, skip the node and its edges
|
||||
- If any incoming edge is TAKEN, the node may still execute
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the skipped edge to start from
|
||||
"""
|
||||
downstream_node_id = self._graph.edges[edge_id].head
|
||||
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
|
||||
|
||||
# Analyze edge states
|
||||
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
|
||||
|
||||
# Stop if there are unknown edges (not yet processed)
|
||||
if edge_states["has_unknown"]:
|
||||
return
|
||||
|
||||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self._state_manager.enqueue_node(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
if edge_states["all_skipped"]:
|
||||
self._propagate_skip_to_node(downstream_node_id)
|
||||
|
||||
def _propagate_skip_to_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node and all its outgoing edges as skipped.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
# Mark node as skipped
|
||||
self._state_manager.mark_node_skipped(node_id)
|
||||
|
||||
# Mark all outgoing edges as skipped and propagate
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
for edge in outgoing_edges:
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
|
||||
"""
|
||||
Skip all paths from unselected branch edges.
|
||||
|
||||
Args:
|
||||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
52
api/core/workflow/graph_engine/layers/README.md
Normal file
52
api/core/workflow/graph_engine/layers/README.md
Normal file
@ -0,0 +1,52 @@
|
||||
# Layers
|
||||
|
||||
Pluggable middleware for engine extensions.
|
||||
|
||||
## Components
|
||||
|
||||
### Layer (base)
|
||||
|
||||
Abstract base class for layers.
|
||||
|
||||
- `initialize()` - Receive runtime context
|
||||
- `on_graph_start()` - Execution start hook
|
||||
- `on_event()` - Process all events
|
||||
- `on_graph_end()` - Execution end hook
|
||||
|
||||
### DebugLoggingLayer
|
||||
|
||||
Comprehensive execution logging.
|
||||
|
||||
- Configurable detail levels
|
||||
- Tracks execution statistics
|
||||
- Truncates long values
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
debug_layer = DebugLoggingLayer(
|
||||
level="INFO",
|
||||
include_outputs=True
|
||||
)
|
||||
|
||||
engine = GraphEngine(graph)
|
||||
engine.add_layer(debug_layer)
|
||||
engine.run()
|
||||
```
|
||||
|
||||
## Custom Layers
|
||||
|
||||
```python
|
||||
class MetricsLayer(Layer):
|
||||
def on_event(self, event):
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self.metrics[event.node_id] = event.elapsed_time
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
**DebugLoggingLayer Options:**
|
||||
|
||||
- `level` - Log level (INFO, DEBUG, ERROR)
|
||||
- `include_inputs/outputs` - Log data values
|
||||
- `max_value_length` - Truncate long values
|
||||
16
api/core/workflow/graph_engine/layers/__init__.py
Normal file
16
api/core/workflow/graph_engine/layers/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Layer system for GraphEngine extensibility.
|
||||
|
||||
This module provides the layer infrastructure for extending GraphEngine functionality
|
||||
with middleware-like components that can observe events and interact with execution.
|
||||
"""
|
||||
|
||||
from .base import GraphEngineLayer
|
||||
from .debug_logging import DebugLoggingLayer
|
||||
from .execution_limits import ExecutionLimitsLayer
|
||||
|
||||
__all__ = [
|
||||
"DebugLoggingLayer",
|
||||
"ExecutionLimitsLayer",
|
||||
"GraphEngineLayer",
|
||||
]
|
||||
85
api/core/workflow/graph_engine/layers/base.py
Normal file
85
api/core/workflow/graph_engine/layers/base.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""
|
||||
Base layer class for GraphEngine extensions.
|
||||
|
||||
This module provides the abstract base class for implementing layers that can
|
||||
intercept and respond to GraphEngine events.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
|
||||
class GraphEngineLayer(ABC):
|
||||
"""
|
||||
Abstract base class for GraphEngine layers.
|
||||
|
||||
Layers are middleware-like components that can:
|
||||
- Observe all events emitted by the GraphEngine
|
||||
- Access the graph runtime state
|
||||
- Send commands to control execution
|
||||
|
||||
Subclasses should override the constructor to accept configuration parameters,
|
||||
then implement the three lifecycle methods.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self.command_channel: CommandChannel | None = None
|
||||
|
||||
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
Initialize the layer with engine dependencies.
|
||||
|
||||
Called by GraphEngine before execution starts to inject the read-only runtime state
|
||||
and command channel. This allows layers to observe engine context and send
|
||||
commands, but prevents direct state modification.
|
||||
|
||||
Args:
|
||||
graph_runtime_state: Read-only view of the runtime state
|
||||
command_channel: Channel for sending commands to the engine
|
||||
"""
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.command_channel = command_channel
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
|
||||
This is called after the engine has been initialized but before any nodes
|
||||
are executed. Layers can use this to set up resources or log start information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
This method receives all events generated during graph execution, including:
|
||||
- Graph lifecycle events (start, success, failure)
|
||||
- Node execution events (start, success, failure, retry)
|
||||
- Stream events for response nodes
|
||||
- Container events (iteration, loop)
|
||||
|
||||
Args:
|
||||
event: The event emitted by the engine
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
This is called after all nodes have been executed or when execution is
|
||||
aborted. Layers can use this to clean up resources or log final state.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
250
api/core/workflow/graph_engine/layers/debug_logging.py
Normal file
250
api/core/workflow/graph_engine/layers/debug_logging.py
Normal file
@ -0,0 +1,250 @@
|
||||
"""
|
||||
Debug logging layer for GraphEngine.
|
||||
|
||||
This module provides a layer that logs all events and state changes during
|
||||
graph execution for debugging purposes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .base import GraphEngineLayer
|
||||
|
||||
|
||||
@final
|
||||
class DebugLoggingLayer(GraphEngineLayer):
|
||||
"""
|
||||
A layer that provides comprehensive logging of GraphEngine execution.
|
||||
|
||||
This layer logs all events with configurable detail levels, helping developers
|
||||
debug workflow execution and understand the flow of events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
level: str = "INFO",
|
||||
include_inputs: bool = False,
|
||||
include_outputs: bool = True,
|
||||
include_process_data: bool = False,
|
||||
logger_name: str = "GraphEngine.Debug",
|
||||
max_value_length: int = 500,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the debug logging layer.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
include_inputs: Whether to log node input values
|
||||
include_outputs: Whether to log node output values
|
||||
include_process_data: Whether to log node process data
|
||||
logger_name: Name of the logger to use
|
||||
max_value_length: Maximum length of logged values (truncated if longer)
|
||||
"""
|
||||
super().__init__()
|
||||
self.level = level
|
||||
self.include_inputs = include_inputs
|
||||
self.include_outputs = include_outputs
|
||||
self.include_process_data = include_process_data
|
||||
self.max_value_length = max_value_length
|
||||
|
||||
# Set up logger
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
log_level = getattr(logging, level.upper(), logging.INFO)
|
||||
self.logger.setLevel(log_level)
|
||||
|
||||
# Track execution stats
|
||||
self.node_count = 0
|
||||
self.success_count = 0
|
||||
self.failure_count = 0
|
||||
self.retry_count = 0
|
||||
|
||||
def _truncate_value(self, value: Any) -> str:
|
||||
"""Truncate long values for logging."""
|
||||
str_value = str(value)
|
||||
if len(str_value) > self.max_value_length:
|
||||
return str_value[: self.max_value_length] + "... (truncated)"
|
||||
return str_value
|
||||
|
||||
def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str:
|
||||
"""Format a dictionary or mapping for logging with truncation."""
|
||||
if not data:
|
||||
return "{}"
|
||||
|
||||
formatted_items: list[str] = []
|
||||
for key, value in data.items():
|
||||
formatted_value = self._truncate_value(value)
|
||||
formatted_items.append(f" {key}: {formatted_value}")
|
||||
|
||||
return "{\n" + ",\n".join(formatted_items) + "\n}"
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Log graph execution start."""
|
||||
self.logger.info("=" * 80)
|
||||
self.logger.info("🚀 GRAPH EXECUTION STARTED")
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if self.graph_runtime_state:
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""Log individual events based on their type."""
|
||||
event_class = event.__class__.__name__
|
||||
|
||||
# Graph-level events
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.logger.debug("Graph run started event")
|
||||
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.logger.info("✅ Graph run succeeded")
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.logger.warning("⚠️ Graph run partially succeeded")
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.error(" Total exceptions: %s", event.exceptions_count)
|
||||
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self.logger.warning("⚠️ Graph run aborted: %s", event.reason)
|
||||
if event.outputs:
|
||||
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
# Node-level events
|
||||
# Retry before Started because Retry subclasses Started;
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.node_count += 1
|
||||
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
|
||||
|
||||
if self.include_inputs and event.node_run_result.inputs:
|
||||
self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs))
|
||||
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.success_count += 1
|
||||
self.logger.info("✅ Node succeeded: %s", event.node_id)
|
||||
|
||||
if self.include_outputs and event.node_run_result.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs))
|
||||
|
||||
if self.include_process_data and event.node_run_result.process_data:
|
||||
self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data))
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.failure_count += 1
|
||||
self.logger.error("❌ Node failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
if event.node_run_result.error:
|
||||
self.logger.error(" Details: %s", event.node_run_result.error)
|
||||
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
|
||||
self.logger.warning(" Error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
# Log stream chunks at debug level to avoid spam
|
||||
final_indicator = " (FINAL)" if event.is_final else ""
|
||||
self.logger.debug(
|
||||
"📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk)
|
||||
)
|
||||
|
||||
# Iteration events
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
self.logger.info("🔁 Iteration started: %s", event.node_id)
|
||||
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index)
|
||||
|
||||
elif isinstance(event, NodeRunIterationSucceededEvent):
|
||||
self.logger.info("✅ Iteration succeeded: %s", event.node_id)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, NodeRunIterationFailedEvent):
|
||||
self.logger.error("❌ Iteration failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
# Loop events
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
self.logger.info("🔄 Loop started: %s", event.node_id)
|
||||
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index)
|
||||
|
||||
elif isinstance(event, NodeRunLoopSucceededEvent):
|
||||
self.logger.info("✅ Loop succeeded: %s", event.node_id)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, NodeRunLoopFailedEvent):
|
||||
self.logger.error("❌ Loop failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
else:
|
||||
# Log unknown events at debug level
|
||||
self.logger.debug("Event: %s", event_class)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Log graph execution end with summary statistics."""
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if error:
|
||||
self.logger.error("🔴 GRAPH EXECUTION FAILED")
|
||||
self.logger.error(" Error: %s", error)
|
||||
else:
|
||||
self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY")
|
||||
|
||||
# Log execution statistics
|
||||
self.logger.info("Execution Statistics:")
|
||||
self.logger.info(" Total nodes executed: %s", self.node_count)
|
||||
self.logger.info(" Successful nodes: %s", self.success_count)
|
||||
self.logger.info(" Failed nodes: %s", self.failure_count)
|
||||
self.logger.info(" Node retries: %s", self.retry_count)
|
||||
|
||||
# Log final state if available
|
||||
if self.graph_runtime_state and self.include_outputs:
|
||||
if self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
|
||||
self.logger.info("=" * 80)
|
||||
150
api/core/workflow/graph_engine/layers/execution_limits.py
Normal file
150
api/core/workflow/graph_engine/layers/execution_limits.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""
|
||||
Execution limits layer for GraphEngine.
|
||||
|
||||
This layer monitors workflow execution to enforce limits on:
|
||||
- Maximum execution steps
|
||||
- Maximum execution time
|
||||
|
||||
When limits are exceeded, the layer automatically aborts execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
||||
|
||||
class LimitType(Enum):
|
||||
"""Types of execution limits that can be exceeded."""
|
||||
|
||||
STEP_LIMIT = "step_limit"
|
||||
TIME_LIMIT = "time_limit"
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
"""
|
||||
Layer that enforces execution limits for workflows.
|
||||
|
||||
Monitors:
|
||||
- Step count: Tracks number of node executions
|
||||
- Time limit: Monitors total execution time
|
||||
|
||||
Automatically aborts execution when limits are exceeded.
|
||||
"""
|
||||
|
||||
def __init__(self, max_steps: int, max_time: int) -> None:
|
||||
"""
|
||||
Initialize the execution limits layer.
|
||||
|
||||
Args:
|
||||
max_steps: Maximum number of execution steps allowed
|
||||
max_time: Maximum execution time in seconds allowed
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
self.max_time = max_time
|
||||
|
||||
# Runtime tracking
|
||||
self.start_time: float | None = None
|
||||
self.step_count = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# State tracking
|
||||
self._execution_started = False
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False # Track if abort command has been sent
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Called when graph execution starts."""
|
||||
self.start_time = time.time()
|
||||
self.step_count = 0
|
||||
self._execution_started = True
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False
|
||||
|
||||
self.logger.debug("Execution limits monitoring started")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
Monitors execution progress and enforces limits.
|
||||
"""
|
||||
if not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Track step count for node execution events
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self.step_count += 1
|
||||
self.logger.debug("Step %d started: %s", self.step_count, event.node_id)
|
||||
|
||||
# Check step limit when node execution completes
|
||||
if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent):
|
||||
if self._reached_step_limitation():
|
||||
self._send_abort_command(LimitType.STEP_LIMIT)
|
||||
|
||||
if self._reached_time_limitation():
|
||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._execution_started and not self._execution_ended:
|
||||
self._execution_ended = True
|
||||
|
||||
if self.start_time:
|
||||
total_time = time.time() - self.start_time
|
||||
self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time)
|
||||
|
||||
def _reached_step_limitation(self) -> bool:
|
||||
"""Check if step count limit has been exceeded."""
|
||||
return self.step_count > self.max_steps
|
||||
|
||||
def _reached_time_limitation(self) -> bool:
|
||||
"""Check if time limit has been exceeded."""
|
||||
return self.start_time is not None and (time.time() - self.start_time) > self.max_time
|
||||
|
||||
def _send_abort_command(self, limit_type: LimitType) -> None:
|
||||
"""
|
||||
Send abort command due to limit violation.
|
||||
|
||||
Args:
|
||||
limit_type: Type of limit exceeded
|
||||
"""
|
||||
if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Format detailed reason message
|
||||
if limit_type == LimitType.STEP_LIMIT:
|
||||
reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}"
|
||||
elif limit_type == LimitType.TIME_LIMIT:
|
||||
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
||||
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
||||
|
||||
self.logger.warning("Execution limit exceeded: %s", reason)
|
||||
|
||||
try:
|
||||
# Send abort command to the engine
|
||||
abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason)
|
||||
self.command_channel.send_command(abort_command)
|
||||
|
||||
# Mark that abort has been sent to prevent duplicate commands
|
||||
self._abort_sent = True
|
||||
|
||||
self.logger.debug("Abort command sent to engine")
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Failed to send abort command")
|
||||
50
api/core/workflow/graph_engine/manager.py
Normal file
50
api/core/workflow/graph_engine/manager.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""
|
||||
GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngineManager:
|
||||
"""
|
||||
Manager for sending control commands to GraphEngine instances.
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def send_stop_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
Args:
|
||||
task_id: The task ID of the workflow to stop
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
# Create Redis channel for this task
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
# Create and send abort command
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
|
||||
try:
|
||||
channel.send_command(abort_command)
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
# The legacy stop flag mechanism will still work
|
||||
pass
|
||||
14
api/core/workflow/graph_engine/orchestration/__init__.py
Normal file
14
api/core/workflow/graph_engine/orchestration/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Orchestration subsystem for graph engine.
|
||||
|
||||
This package coordinates the overall execution flow between
|
||||
different subsystems.
|
||||
"""
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
__all__ = [
|
||||
"Dispatcher",
|
||||
"ExecutionCoordinator",
|
||||
]
|
||||
104
api/core/workflow/graph_engine/orchestration/dispatcher.py
Normal file
104
api/core/workflow/graph_engine/orchestration/dispatcher.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""
|
||||
Main dispatcher for processing events from workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.graph_events.base import GraphNodeEventBase
|
||||
|
||||
from ..event_management import EventManager
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class Dispatcher:
|
||||
"""
|
||||
Main dispatcher that processes events from the event queue.
|
||||
|
||||
This runs in a separate thread and coordinates event processing
|
||||
with timeout and completion detection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dispatcher.
|
||||
|
||||
Args:
|
||||
event_queue: Queue of events from workers
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event manager for collecting unhandled events
|
||||
execution_coordinator: Coordinator for execution flow
|
||||
event_emitter: Optional event manager to signal completion
|
||||
"""
|
||||
self._event_queue = event_queue
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._execution_coordinator = execution_coordinator
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the dispatcher thread."""
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
# Check for commands
|
||||
self._execution_coordinator.check_commands()
|
||||
|
||||
# Check for scaling
|
||||
self._execution_coordinator.check_scaling()
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Check if execution is complete
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
self._execution_coordinator.mark_failed(e)
|
||||
|
||||
finally:
|
||||
self._execution_coordinator.mark_complete()
|
||||
# Signal the event emitter that execution is complete
|
||||
if self._event_emitter:
|
||||
self._event_emitter.mark_complete()
|
||||
@ -0,0 +1,87 @@
|
||||
"""
|
||||
Execution coordinator for managing overall workflow execution.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..event_management import EventManager
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandler
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionCoordinator:
|
||||
"""
|
||||
Coordinates overall execution flow between subsystems.
|
||||
|
||||
This provides high-level coordination methods used by the
|
||||
dispatcher to manage execution state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
state_manager: GraphStateManager,
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
command_processor: CommandProcessor,
|
||||
worker_pool: WorkerPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the execution coordinator.
|
||||
|
||||
Args:
|
||||
graph_execution: Graph execution aggregate
|
||||
state_manager: Unified state manager
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event manager for collecting events
|
||||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self._graph_execution = graph_execution
|
||||
self._state_manager = state_manager
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._command_processor = command_processor
|
||||
self._worker_pool = worker_pool
|
||||
|
||||
def check_commands(self) -> None:
|
||||
"""Process any pending commands."""
|
||||
self._command_processor.process_commands()
|
||||
|
||||
def check_scaling(self) -> None:
|
||||
"""Check and perform worker scaling if needed."""
|
||||
self._worker_pool.check_and_scale()
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
Check if execution is complete.
|
||||
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
# Check if aborted or failed
|
||||
if self._graph_execution.aborted or self._graph_execution.has_error:
|
||||
return True
|
||||
|
||||
# Complete if no work remains
|
||||
return self._state_manager.is_execution_complete()
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
if not self._graph_execution.completed:
|
||||
self._graph_execution.complete()
|
||||
|
||||
def mark_failed(self, error: Exception) -> None:
|
||||
"""
|
||||
Mark execution as failed.
|
||||
|
||||
Args:
|
||||
error: The error that caused failure
|
||||
"""
|
||||
self._graph_execution.fail(error)
|
||||
41
api/core/workflow/graph_engine/protocols/command_channel.py
Normal file
41
api/core/workflow/graph_engine/protocols/command_channel.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""
|
||||
CommandChannel protocol for GraphEngine command communication.
|
||||
|
||||
This protocol defines the interface for sending and receiving commands
|
||||
to/from a GraphEngine instance, supporting both local and distributed scenarios.
|
||||
"""
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
class CommandChannel(Protocol):
|
||||
"""
|
||||
Protocol for bidirectional command communication with GraphEngine.
|
||||
|
||||
Since each GraphEngine instance processes only one workflow execution,
|
||||
this channel is dedicated to that single execution.
|
||||
"""
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch pending commands for this GraphEngine instance.
|
||||
|
||||
Called by GraphEngine to poll for commands that need to be processed.
|
||||
|
||||
Returns:
|
||||
List of pending commands (may be empty)
|
||||
"""
|
||||
...
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to be processed by this GraphEngine instance.
|
||||
|
||||
Called by external systems to send control commands to the running workflow.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
...
|
||||
12
api/core/workflow/graph_engine/ready_queue/__init__.py
Normal file
12
api/core/workflow/graph_engine/ready_queue/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""
|
||||
Ready queue implementations for GraphEngine.
|
||||
|
||||
This package contains the protocol and implementations for managing
|
||||
the queue of nodes ready for execution.
|
||||
"""
|
||||
|
||||
from .factory import create_ready_queue_from_state
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
from .protocol import ReadyQueue, ReadyQueueState
|
||||
|
||||
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"]
|
||||
35
api/core/workflow/graph_engine/ready_queue/factory.py
Normal file
35
api/core/workflow/graph_engine/ready_queue/factory.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
Factory for creating ReadyQueue instances from serialized state.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
from .protocol import ReadyQueueState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocol import ReadyQueue
|
||||
|
||||
|
||||
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
|
||||
"""
|
||||
Create a ReadyQueue instance from a serialized state.
|
||||
|
||||
Args:
|
||||
state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue
|
||||
|
||||
Returns:
|
||||
A ReadyQueue instance initialized with the given state
|
||||
|
||||
Raises:
|
||||
ValueError: If the queue type is unknown or version is unsupported
|
||||
"""
|
||||
if state.type == "InMemoryReadyQueue":
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}")
|
||||
queue = InMemoryReadyQueue()
|
||||
# Always pass as JSON string to loads()
|
||||
queue.loads(state.model_dump_json())
|
||||
return queue
|
||||
else:
|
||||
raise ValueError(f"Unknown ready queue type: {state.type}")
|
||||
140
api/core/workflow/graph_engine/ready_queue/in_memory.py
Normal file
140
api/core/workflow/graph_engine/ready_queue/in_memory.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""
|
||||
In-memory implementation of the ReadyQueue protocol.
|
||||
|
||||
This implementation wraps Python's standard queue.Queue and adds
|
||||
serialization capabilities for state storage.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import final
|
||||
|
||||
from .protocol import ReadyQueue, ReadyQueueState
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryReadyQueue(ReadyQueue):
|
||||
"""
|
||||
In-memory ready queue implementation with serialization support.
|
||||
|
||||
This implementation uses Python's queue.Queue internally and provides
|
||||
methods to serialize and restore the queue state.
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int = 0) -> None:
|
||||
"""
|
||||
Initialize the in-memory ready queue.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum size of the queue (0 for unlimited)
|
||||
"""
|
||||
self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize)
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""
|
||||
Add a node ID to the ready queue.
|
||||
|
||||
Args:
|
||||
item: The node ID to add to the queue
|
||||
"""
|
||||
self._queue.put(item)
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""
|
||||
Retrieve and remove a node ID from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for an item (None for blocking)
|
||||
|
||||
Returns:
|
||||
The node ID retrieved from the queue
|
||||
|
||||
Raises:
|
||||
queue.Empty: If timeout expires and no item is available
|
||||
"""
|
||||
if timeout is None:
|
||||
return self._queue.get(block=True)
|
||||
return self._queue.get(timeout=timeout)
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""
|
||||
Indicate that a previously retrieved task is complete.
|
||||
|
||||
Used by worker threads to signal task completion for
|
||||
join() synchronization.
|
||||
"""
|
||||
self._queue.task_done()
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""
|
||||
Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
True if the queue has no items, False otherwise
|
||||
"""
|
||||
return self._queue.empty()
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""
|
||||
Get the approximate size of the queue.
|
||||
|
||||
Returns:
|
||||
The approximate number of items in the queue
|
||||
"""
|
||||
return self._queue.qsize()
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""
|
||||
Serialize the queue state to a JSON string for storage.
|
||||
|
||||
Returns:
|
||||
A JSON string containing the serialized queue state
|
||||
"""
|
||||
# Extract all items from the queue without removing them
|
||||
items: list[str] = []
|
||||
temp_items: list[str] = []
|
||||
|
||||
# Drain the queue temporarily to get all items
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
temp_items.append(item)
|
||||
items.append(item)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Put items back in the same order
|
||||
for item in temp_items:
|
||||
self._queue.put(item)
|
||||
|
||||
state = ReadyQueueState(
|
||||
type="InMemoryReadyQueue",
|
||||
version="1.0",
|
||||
items=items,
|
||||
)
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""
|
||||
Restore the queue state from a JSON string.
|
||||
|
||||
Args:
|
||||
data: The JSON string containing the serialized queue state to restore
|
||||
"""
|
||||
state = ReadyQueueState.model_validate_json(data)
|
||||
|
||||
if state.type != "InMemoryReadyQueue":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported version: {state.version}")
|
||||
|
||||
# Clear the current queue
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Restore items
|
||||
for item in state.items:
|
||||
self._queue.put(item)
|
||||
104
api/core/workflow/graph_engine/ready_queue/protocol.py
Normal file
104
api/core/workflow/graph_engine/ready_queue/protocol.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""
|
||||
ReadyQueue protocol for GraphEngine node execution queue.
|
||||
|
||||
This protocol defines the interface for managing the queue of nodes ready
|
||||
for execution, supporting both in-memory and persistent storage scenarios.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ReadyQueueState(BaseModel):
|
||||
"""
|
||||
Pydantic model for serialized ready queue state.
|
||||
|
||||
This defines the structure of the data returned by dumps()
|
||||
and expected by loads() for ready queue serialization.
|
||||
"""
|
||||
|
||||
type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')")
|
||||
version: str = Field(description="Serialization format version")
|
||||
items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue")
|
||||
|
||||
|
||||
class ReadyQueue(Protocol):
|
||||
"""
|
||||
Protocol for managing nodes ready for execution in GraphEngine.
|
||||
|
||||
This protocol defines the interface that any ready queue implementation
|
||||
must provide, enabling both in-memory queues and persistent queues
|
||||
that can be serialized for state storage.
|
||||
"""
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""
|
||||
Add a node ID to the ready queue.
|
||||
|
||||
Args:
|
||||
item: The node ID to add to the queue
|
||||
"""
|
||||
...
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""
|
||||
Retrieve and remove a node ID from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for an item (None for blocking)
|
||||
|
||||
Returns:
|
||||
The node ID retrieved from the queue
|
||||
|
||||
Raises:
|
||||
queue.Empty: If timeout expires and no item is available
|
||||
"""
|
||||
...
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""
|
||||
Indicate that a previously retrieved task is complete.
|
||||
|
||||
Used by worker threads to signal task completion for
|
||||
join() synchronization.
|
||||
"""
|
||||
...
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""
|
||||
Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
True if the queue has no items, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""
|
||||
Get the approximate size of the queue.
|
||||
|
||||
Returns:
|
||||
The approximate number of items in the queue
|
||||
"""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""
|
||||
Serialize the queue state to a JSON string for storage.
|
||||
|
||||
Returns:
|
||||
A JSON string containing the serialized queue state
|
||||
that can be persisted and later restored
|
||||
"""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""
|
||||
Restore the queue state from a JSON string.
|
||||
|
||||
Args:
|
||||
data: The JSON string containing the serialized queue state to restore
|
||||
"""
|
||||
...
|
||||
@ -0,0 +1,10 @@
|
||||
"""
|
||||
ResponseStreamCoordinator - Coordinates streaming output from response nodes
|
||||
|
||||
This component manages response streaming sessions and ensures ordered streaming
|
||||
of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
from .coordinator import ResponseStreamCoordinator
|
||||
|
||||
__all__ = ["ResponseStreamCoordinator"]
|
||||
@ -0,0 +1,696 @@
|
||||
"""
|
||||
Main ResponseStreamCoordinator implementation.
|
||||
|
||||
This module contains the public ResponseStreamCoordinator class that manages
|
||||
response streaming sessions and ensures ordered streaming of responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import Literal, TypeAlias, final
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type definitions
|
||||
NodeID: TypeAlias = str
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
class ResponseSessionState(BaseModel):
|
||||
"""Serializable representation of a response session."""
|
||||
|
||||
node_id: str
|
||||
index: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class StreamBufferState(BaseModel):
|
||||
"""Serializable representation of buffered stream chunks."""
|
||||
|
||||
selector: tuple[str, ...]
|
||||
events: list[NodeRunStreamChunkEvent] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StreamPositionState(BaseModel):
|
||||
"""Serializable representation for stream read positions."""
|
||||
|
||||
selector: tuple[str, ...]
|
||||
position: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorState(BaseModel):
|
||||
"""Serialized snapshot of ResponseStreamCoordinator."""
|
||||
|
||||
type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator")
|
||||
version: str = Field(default="1.0")
|
||||
response_nodes: Sequence[str] = Field(default_factory=list)
|
||||
active_session: ResponseSessionState | None = None
|
||||
waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||
pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||
node_execution_ids: dict[str, str] = Field(default_factory=dict)
|
||||
paths_map: dict[str, list[list[str]]] = Field(default_factory=dict)
|
||||
stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list)
|
||||
stream_positions: Sequence[StreamPositionState] = Field(default_factory=list)
|
||||
closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list)
|
||||
|
||||
|
||||
@final
|
||||
class ResponseStreamCoordinator:
|
||||
"""
|
||||
Manages response streaming sessions without relying on global state.
|
||||
|
||||
Ensures ordered streaming of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
|
||||
"""
|
||||
Initialize coordinator with variable pool.
|
||||
|
||||
Args:
|
||||
variable_pool: VariablePool instance for accessing node variables
|
||||
graph: Graph instance for looking up node information
|
||||
"""
|
||||
self._variable_pool = variable_pool
|
||||
self._graph = graph
|
||||
self._active_session: ResponseSession | None = None
|
||||
self._waiting_sessions: deque[ResponseSession] = deque()
|
||||
self._lock = RLock()
|
||||
|
||||
# Internal stream management (replacing OutputRegistry)
|
||||
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
|
||||
self._stream_positions: dict[tuple[str, ...], int] = {}
|
||||
self._closed_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Track response nodes
|
||||
self._response_nodes: set[NodeID] = set()
|
||||
|
||||
# Store paths for each response node
|
||||
self._paths_maps: dict[NodeID, list[Path]] = {}
|
||||
|
||||
# Track node execution IDs and types for proper event forwarding
|
||||
self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id
|
||||
|
||||
# Track response sessions to ensure only one per node
|
||||
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
|
||||
|
||||
def register(self, response_node_id: NodeID) -> None:
|
||||
with self._lock:
|
||||
if response_node_id in self._response_nodes:
|
||||
return
|
||||
self._response_nodes.add(response_node_id)
|
||||
|
||||
# Build and save paths map for this response node
|
||||
paths_map = self._build_paths_map(response_node_id)
|
||||
self._paths_maps[response_node_id] = paths_map
|
||||
|
||||
# Create and store response session for this node
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
session = ResponseSession.from_node(response_node)
|
||||
self._response_sessions[response_node_id] = session
|
||||
|
||||
def track_node_execution(self, node_id: NodeID, execution_id: str) -> None:
|
||||
"""Track the execution ID for a node when it starts executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
execution_id: The execution ID from NodeRunStartedEvent
|
||||
"""
|
||||
with self._lock:
|
||||
self._node_execution_ids[node_id] = execution_id
|
||||
|
||||
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
|
||||
"""Get the execution ID for a node, creating one if it doesn't exist.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The execution ID for the node
|
||||
"""
|
||||
with self._lock:
|
||||
if node_id not in self._node_execution_ids:
|
||||
self._node_execution_ids[node_id] = str(uuid4())
|
||||
return self._node_execution_ids[node_id]
|
||||
|
||||
def _build_paths_map(self, response_node_id: NodeID) -> list[Path]:
|
||||
"""
|
||||
Build a paths map for a response node by finding all paths from root node
|
||||
to the response node, recording branch edges along each path.
|
||||
|
||||
Args:
|
||||
response_node_id: ID of the response node to analyze
|
||||
|
||||
Returns:
|
||||
List of Path objects, where each path contains branch edge IDs
|
||||
"""
|
||||
# Get root node ID
|
||||
root_node_id = self._graph.root_node.id
|
||||
|
||||
# If root is the response node, return empty path
|
||||
if root_node_id == response_node_id:
|
||||
return [Path()]
|
||||
|
||||
# Extract variable selectors from the response node's template
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
response_session = ResponseSession.from_node(response_node)
|
||||
template = response_session.template
|
||||
|
||||
# Collect all variable selectors from the template
|
||||
variable_selectors: set[tuple[str, ...]] = set()
|
||||
for segment in template.segments:
|
||||
if isinstance(segment, VariableSegment):
|
||||
variable_selectors.add(tuple(segment.selector[:2]))
|
||||
|
||||
# Step 1: Find all complete paths from root to response node
|
||||
all_complete_paths: list[list[EdgeID]] = []
|
||||
|
||||
def find_paths(
|
||||
current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID]
|
||||
) -> None:
|
||||
"""Recursively find all paths from current node to target node."""
|
||||
if current_node_id == target_node_id:
|
||||
# Found a complete path, store it
|
||||
all_complete_paths.append(current_path.copy())
|
||||
return
|
||||
|
||||
# Mark as visited to avoid cycles
|
||||
visited.add(current_node_id)
|
||||
|
||||
# Explore outgoing edges
|
||||
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
|
||||
for edge in outgoing_edges:
|
||||
edge_id = edge.id
|
||||
next_node_id = edge.head
|
||||
|
||||
# Skip if already visited in this path
|
||||
if next_node_id not in visited:
|
||||
# Add edge to path and recurse
|
||||
new_path = current_path + [edge_id]
|
||||
find_paths(next_node_id, target_node_id, new_path, visited.copy())
|
||||
|
||||
# Start searching from root node
|
||||
find_paths(root_node_id, response_node_id, [], set())
|
||||
|
||||
# Step 2: For each complete path, filter edges based on node blocking behavior
|
||||
filtered_paths: list[Path] = []
|
||||
for path in all_complete_paths:
|
||||
blocking_edges: list[str] = []
|
||||
for edge_id in path:
|
||||
edge = self._graph.edges[edge_id]
|
||||
source_node = self._graph.nodes[edge.tail]
|
||||
|
||||
# Check if node is a branch/container (original behavior)
|
||||
if source_node.execution_type in {
|
||||
NodeExecutionType.BRANCH,
|
||||
NodeExecutionType.CONTAINER,
|
||||
} or source_node.blocks_variable_output(variable_selectors):
|
||||
blocking_edges.append(edge_id)
|
||||
|
||||
# Keep the path even if it's empty
|
||||
filtered_paths.append(Path(edges=blocking_edges))
|
||||
|
||||
return filtered_paths
|
||||
|
||||
def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
Handle when an edge is taken (selected by a branch node).
|
||||
|
||||
This method updates the paths for all response nodes by removing
|
||||
the taken edge. If any response node has an empty path after removal,
|
||||
it means the node is now deterministically reachable and should start.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge that was taken
|
||||
|
||||
Returns:
|
||||
List of events to emit from starting new sessions
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
with self._lock:
|
||||
# Check each response node in order
|
||||
for response_node_id in self._response_nodes:
|
||||
if response_node_id not in self._paths_maps:
|
||||
continue
|
||||
|
||||
paths = self._paths_maps[response_node_id]
|
||||
has_reachable_path = False
|
||||
|
||||
# Update each path by removing the taken edge
|
||||
for path in paths:
|
||||
# Remove the taken edge from this path
|
||||
path.remove_edge(edge_id)
|
||||
|
||||
# Check if this path is now empty (node is reachable)
|
||||
if path.is_empty():
|
||||
has_reachable_path = True
|
||||
|
||||
# If node is now reachable (has empty path), start/queue session
|
||||
if has_reachable_path:
|
||||
# Pass the node_id to the activation method
|
||||
# The method will handle checking and removing from map
|
||||
events.extend(self._active_or_queue_session(response_node_id))
|
||||
return events
|
||||
|
||||
def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
Start a session immediately if no active session, otherwise queue it.
|
||||
Only activates sessions that exist in the _response_sessions map.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the response node to activate
|
||||
|
||||
Returns:
|
||||
List of events from flush attempt if session started immediately
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Get the session from our map (only activate if it exists)
|
||||
session = self._response_sessions.get(node_id)
|
||||
if not session:
|
||||
return events
|
||||
|
||||
# Remove from map to ensure it won't be activated again
|
||||
del self._response_sessions[node_id]
|
||||
|
||||
if self._active_session is None:
|
||||
self._active_session = session
|
||||
|
||||
# Try to flush immediately
|
||||
events.extend(self.try_flush())
|
||||
else:
|
||||
# Queue the session if another is active
|
||||
self._waiting_sessions.append(session)
|
||||
|
||||
return events
|
||||
|
||||
def intercept_event(
|
||||
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
|
||||
) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
with self._lock:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._append_stream_chunk(event.selector, event)
|
||||
if event.is_final:
|
||||
self._close_stream(event.selector)
|
||||
return self.try_flush()
|
||||
else:
|
||||
# Skip cause we share the same variable pool.
|
||||
#
|
||||
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
# self._variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
return self.try_flush()
|
||||
|
||||
def _create_stream_chunk_event(
|
||||
self,
|
||||
node_id: str,
|
||||
execution_id: str,
|
||||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
For selectors with special prefixes (sys, env, conversation), we use the
|
||||
active response node's information since these are not actual node IDs.
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
# Use the active response node for special selectors
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
node = self._graph.nodes[node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
"""Process a variable segment. Returns (events, is_complete).
|
||||
|
||||
Handles both regular node selectors and special system selectors (sys, env, conversation).
|
||||
For special selectors, we attribute the output to the active response node.
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
is_complete = False
|
||||
|
||||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector, # Keep original selector
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
# Regular node selector - use event as is
|
||||
events.append(event)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
|
||||
return events, is_complete
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self._active_session is not None
|
||||
current_response_node = self._graph.nodes[self._active_session.node_id]
|
||||
|
||||
# Use get_or_create_execution_id to ensure we have a consistent ID
|
||||
execution_id = self._get_or_create_execution_id(current_response_node.id)
|
||||
|
||||
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
event = self._create_stream_chunk_event(
|
||||
node_id=current_response_node.id,
|
||||
execution_id=execution_id,
|
||||
selector=[current_response_node.id, "answer"], # FIXME(-LAN-)
|
||||
chunk=segment.text,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
return [event]
|
||||
|
||||
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
|
||||
with self._lock:
|
||||
if not self._active_session:
|
||||
return []
|
||||
|
||||
template = self._active_session.template
|
||||
response_node_id = self._active_session.node_id
|
||||
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Process segments sequentially from current index
|
||||
while self._active_session.index < len(template.segments):
|
||||
segment = template.segments[self._active_session.index]
|
||||
|
||||
if isinstance(segment, VariableSegment):
|
||||
# Check if the source node for this variable is skipped
|
||||
# Only check for actual nodes, not special selectors (sys, env, conversation)
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
if source_selector_prefix in self._graph.nodes:
|
||||
source_node = self._graph.nodes[source_selector_prefix]
|
||||
|
||||
if source_node.state == NodeState.SKIPPED:
|
||||
# Skip this variable segment if the source node is skipped
|
||||
self._active_session.index += 1
|
||||
continue
|
||||
|
||||
segment_events, is_complete = self._process_variable_segment(segment)
|
||||
events.extend(segment_events)
|
||||
|
||||
# Only advance index if this variable segment is complete
|
||||
if is_complete:
|
||||
self._active_session.index += 1
|
||||
else:
|
||||
# Wait for more data
|
||||
break
|
||||
|
||||
else:
|
||||
segment_events = self._process_text_segment(segment)
|
||||
events.extend(segment_events)
|
||||
self._active_session.index += 1
|
||||
|
||||
if self._active_session.is_complete():
|
||||
# End current session and get events from starting next session
|
||||
next_session_events = self.end_session(response_node_id)
|
||||
events.extend(next_session_events)
|
||||
|
||||
return events
|
||||
|
||||
def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
End the active session for a response node.
|
||||
Automatically starts the next waiting session if available.
|
||||
|
||||
Args:
|
||||
node_id: ID of the response node ending its session
|
||||
|
||||
Returns:
|
||||
List of events from starting the next session
|
||||
"""
|
||||
with self._lock:
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
if self._active_session and self._active_session.node_id == node_id:
|
||||
self._active_session = None
|
||||
|
||||
# Try to start next waiting session
|
||||
if self._waiting_sessions:
|
||||
next_session = self._waiting_sessions.popleft()
|
||||
self._active_session = next_session
|
||||
|
||||
# Immediately try to flush any available segments
|
||||
events = self.try_flush()
|
||||
|
||||
return events
|
||||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key in self._closed_streams:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
self._stream_buffers[key] = []
|
||||
self._stream_positions[key] = 0
|
||||
|
||||
self._stream_buffers[key].append(event)
|
||||
|
||||
def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
|
||||
"""
|
||||
Pop the next unread stream chunk from the buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return None
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
buffer = self._stream_buffers[key]
|
||||
|
||||
if position >= len(buffer):
|
||||
return None
|
||||
|
||||
event = buffer[position]
|
||||
self._stream_positions[key] = position + 1
|
||||
return event
|
||||
|
||||
def _has_unread_stream(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return False
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
return position < len(self._stream_buffers[key])
|
||||
|
||||
def _close_stream(self, selector: Sequence[str]) -> None:
|
||||
"""
|
||||
Mark a stream as closed (no more chunks can be appended).
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
"""
|
||||
key = tuple(selector)
|
||||
self._closed_streams.add(key)
|
||||
|
||||
def _is_stream_closed(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if a stream is closed.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if the stream is closed, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
return key in self._closed_streams
|
||||
|
||||
def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None:
|
||||
"""Convert an in-memory session into its serializable form."""
|
||||
|
||||
if session is None:
|
||||
return None
|
||||
return ResponseSessionState(node_id=session.node_id, index=session.index)
|
||||
|
||||
def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession:
|
||||
"""Rebuild a response session from serialized data."""
|
||||
|
||||
node = self._graph.nodes.get(session_state.node_id)
|
||||
if node is None:
|
||||
raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state")
|
||||
|
||||
session = ResponseSession.from_node(node)
|
||||
session.index = session_state.index
|
||||
return session
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize coordinator state to JSON."""
|
||||
|
||||
with self._lock:
|
||||
state = ResponseStreamCoordinatorState(
|
||||
response_nodes=sorted(self._response_nodes),
|
||||
active_session=self._serialize_session(self._active_session),
|
||||
waiting_sessions=[
|
||||
session_state
|
||||
for session in list(self._waiting_sessions)
|
||||
if (session_state := self._serialize_session(session)) is not None
|
||||
],
|
||||
pending_sessions=[
|
||||
session_state
|
||||
for _, session in sorted(self._response_sessions.items())
|
||||
if (session_state := self._serialize_session(session)) is not None
|
||||
],
|
||||
node_execution_ids=dict(sorted(self._node_execution_ids.items())),
|
||||
paths_map={
|
||||
node_id: [path.edges.copy() for path in paths]
|
||||
for node_id, paths in sorted(self._paths_maps.items())
|
||||
},
|
||||
stream_buffers=[
|
||||
StreamBufferState(
|
||||
selector=selector,
|
||||
events=[event.model_copy(deep=True) for event in events],
|
||||
)
|
||||
for selector, events in sorted(self._stream_buffers.items())
|
||||
],
|
||||
stream_positions=[
|
||||
StreamPositionState(selector=selector, position=position)
|
||||
for selector, position in sorted(self._stream_positions.items())
|
||||
],
|
||||
closed_streams=sorted(self._closed_streams),
|
||||
)
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore coordinator state from JSON."""
|
||||
|
||||
state = ResponseStreamCoordinatorState.model_validate_json(data)
|
||||
|
||||
if state.type != "ResponseStreamCoordinator":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
with self._lock:
|
||||
self._response_nodes = set(state.response_nodes)
|
||||
self._paths_maps = {
|
||||
node_id: [Path(edges=list(path_edges)) for path_edges in paths]
|
||||
for node_id, paths in state.paths_map.items()
|
||||
}
|
||||
self._node_execution_ids = dict(state.node_execution_ids)
|
||||
|
||||
self._stream_buffers = {
|
||||
tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events]
|
||||
for buffer in state.stream_buffers
|
||||
}
|
||||
self._stream_positions = {
|
||||
tuple(position.selector): position.position for position in state.stream_positions
|
||||
}
|
||||
for selector in self._stream_buffers:
|
||||
self._stream_positions.setdefault(selector, 0)
|
||||
|
||||
self._closed_streams = {tuple(selector) for selector in state.closed_streams}
|
||||
|
||||
self._waiting_sessions = deque(
|
||||
self._session_from_state(session_state) for session_state in state.waiting_sessions
|
||||
)
|
||||
self._response_sessions = {
|
||||
session_state.node_id: self._session_from_state(session_state)
|
||||
for session_state in state.pending_sessions
|
||||
}
|
||||
self._active_session = self._session_from_state(state.active_session) if state.active_session else None
|
||||
35
api/core/workflow/graph_engine/response_coordinator/path.py
Normal file
35
api/core/workflow/graph_engine/response_coordinator/path.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
Internal path representation for response coordinator.
|
||||
|
||||
This module contains the private Path class used internally by ResponseStreamCoordinator
|
||||
to track execution paths to response nodes.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeAlias
|
||||
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Path:
|
||||
"""
|
||||
Represents a path of branch edges that must be taken to reach a response node.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
edges: list[EdgeID] = field(default_factory=list[EdgeID])
|
||||
|
||||
def contains_edge(self, edge_id: EdgeID) -> bool:
|
||||
"""Check if this path contains the given edge."""
|
||||
return edge_id in self.edges
|
||||
|
||||
def remove_edge(self, edge_id: EdgeID) -> None:
|
||||
"""Remove the given edge from this path in place."""
|
||||
if self.contains_edge(edge_id):
|
||||
self.edges.remove(edge_id)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the path has no edges (node is reachable)."""
|
||||
return len(self.edges) == 0
|
||||
@ -0,0 +1,52 @@
|
||||
"""
|
||||
Internal response session management for response coordinator.
|
||||
|
||||
This module contains the private ResponseSession class used internally
|
||||
by ResponseStreamCoordinator to manage streaming sessions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseSession:
|
||||
"""
|
||||
Represents an active response streaming session.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
template: Template # Template object from the response node
|
||||
index: int = 0 # Current position in the template segments
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: Node) -> "ResponseSession":
|
||||
"""
|
||||
Create a ResponseSession from an AnswerNode or EndNode.
|
||||
|
||||
Args:
|
||||
node: Must be either an AnswerNode or EndNode instance
|
||||
|
||||
Returns:
|
||||
ResponseSession configured with the node's streaming template
|
||||
|
||||
Raises:
|
||||
TypeError: If node is not an AnswerNode or EndNode
|
||||
"""
|
||||
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
|
||||
raise TypeError
|
||||
return cls(
|
||||
node_id=node.id,
|
||||
template=node.get_streaming_template(),
|
||||
)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if all segments in the template have been processed."""
|
||||
return self.index >= len(self.template.segments)
|
||||
142
api/core/workflow/graph_engine/worker.py
Normal file
142
api/core/workflow/graph_engine/worker.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
Worker - Thread implementation for queue-based node execution
|
||||
|
||||
Workers pull node IDs from the ready_queue, execute nodes, and push events
|
||||
to the event_queue for the dispatcher to process.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
|
||||
@final
|
||||
class Worker(threading.Thread):
|
||||
"""
|
||||
Worker thread that executes nodes from the ready queue.
|
||||
|
||||
Workers continuously pull node IDs from the ready_queue, execute the
|
||||
corresponding nodes, and push the resulting events to the event_queue
|
||||
for the dispatcher to process.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: ReadyQueue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
|
||||
Args:
|
||||
ready_queue: Ready queue containing node IDs ready for execution
|
||||
event_queue: Queue for pushing execution events
|
||||
graph: Graph containing nodes to execute
|
||||
worker_id: Unique identifier for this worker
|
||||
flask_app: Optional Flask application for context preservation
|
||||
context_vars: Optional context variables to preserve in worker thread
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
"""Check if the worker is currently idle."""
|
||||
# Worker is idle if it hasn't processed a task recently (within 0.2 seconds)
|
||||
return (time.time() - self._last_task_time) > 0.2
|
||||
|
||||
@property
|
||||
def idle_duration(self) -> float:
|
||||
"""Get the duration in seconds since the worker last processed a task."""
|
||||
return time.time() - self._last_task_time
|
||||
|
||||
@property
|
||||
def worker_id(self) -> int:
|
||||
"""Get the worker's ID."""
|
||||
return self._worker_id
|
||||
|
||||
@override
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Main worker loop.
|
||||
|
||||
Continuously pulls node IDs from ready_queue, executes them,
|
||||
and pushes events to event_queue until stopped.
|
||||
"""
|
||||
while not self._stop_event.is_set():
|
||||
# Try to get a node ID from the ready queue (with timeout)
|
||||
try:
|
||||
node_id = self._ready_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
self._execute_node(node)
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=str(uuid4()),
|
||||
node_id="unknown",
|
||||
node_type=NodeType.CODE,
|
||||
in_iteration_id=None,
|
||||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
)
|
||||
self._event_queue.put(error_event)
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
Execute a single node and handle its events.
|
||||
|
||||
Args:
|
||||
node: The node instance to execute
|
||||
"""
|
||||
# Execute the node with preserved context if Flask app is provided
|
||||
if self._flask_app and self._context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
):
|
||||
# Execute the node
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self._event_queue.put(event)
|
||||
else:
|
||||
# Execute without context preservation
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self._event_queue.put(event)
|
||||
12
api/core/workflow/graph_engine/worker_management/__init__.py
Normal file
12
api/core/workflow/graph_engine/worker_management/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""
|
||||
Worker management subsystem for graph engine.
|
||||
|
||||
This package manages the worker pool, including creation,
|
||||
scaling, and activity tracking.
|
||||
"""
|
||||
|
||||
from .worker_pool import WorkerPool
|
||||
|
||||
__all__ = [
|
||||
"WorkerPool",
|
||||
]
|
||||
291
api/core/workflow/graph_engine/worker_management/worker_pool.py
Normal file
291
api/core/workflow/graph_engine/worker_management/worker_pool.py
Normal file
@ -0,0 +1,291 @@
|
||||
"""
|
||||
Simple worker pool that consolidates functionality.
|
||||
|
||||
This is a simpler implementation that merges WorkerPool, ActivityTracker,
|
||||
DynamicScaler, and WorkerFactory into a single class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
from ..ready_queue import ReadyQueue
|
||||
from ..worker import Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextvars import Context
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
"""
|
||||
Simple worker pool with integrated management.
|
||||
|
||||
This class consolidates all worker management functionality into
|
||||
a single, simpler implementation without excessive abstraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: ReadyQueue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the simple worker pool.
|
||||
|
||||
Args:
|
||||
ready_queue: Ready queue for nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
flask_app: Optional Flask app for context preservation
|
||||
context_vars: Optional context variables
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
scale_down_idle_time: Seconds before scaling down idle workers
|
||||
"""
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
|
||||
# Scaling parameters with defaults
|
||||
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
||||
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
|
||||
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
|
||||
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
||||
|
||||
# Worker management
|
||||
self._workers: list[Worker] = []
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
def start(self, initial_count: int | None = None) -> None:
|
||||
"""
|
||||
Start the worker pool.
|
||||
|
||||
Args:
|
||||
initial_count: Number of workers to start with (auto-calculated if None)
|
||||
"""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Calculate initial worker count
|
||||
if initial_count is None:
|
||||
node_count = len(self._graph.nodes)
|
||||
if node_count < 10:
|
||||
initial_count = self._min_workers
|
||||
elif node_count < 50:
|
||||
initial_count = min(self._min_workers + 1, self._max_workers)
|
||||
else:
|
||||
initial_count = min(self._min_workers + 2, self._max_workers)
|
||||
|
||||
logger.debug(
|
||||
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
|
||||
initial_count,
|
||||
node_count,
|
||||
self._min_workers,
|
||||
self._max_workers,
|
||||
)
|
||||
|
||||
# Create initial workers
|
||||
for _ in range(initial_count):
|
||||
self._create_worker()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop all workers in the pool."""
|
||||
with self._lock:
|
||||
self._running = False
|
||||
worker_count = len(self._workers)
|
||||
|
||||
if worker_count > 0:
|
||||
logger.debug("Stopping worker pool: %d workers", worker_count)
|
||||
|
||||
# Stop all workers
|
||||
for worker in self._workers:
|
||||
worker.stop()
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
|
||||
self._workers.clear()
|
||||
|
||||
def _create_worker(self) -> None:
|
||||
"""Create and start a new worker."""
|
||||
worker_id = self._worker_counter
|
||||
self._worker_counter += 1
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
self._workers.append(worker)
|
||||
|
||||
def _remove_worker(self, worker: Worker, worker_id: int) -> None:
|
||||
"""Remove a specific worker from the pool."""
|
||||
# Stop the worker
|
||||
worker.stop()
|
||||
|
||||
# Wait for it to finish
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
# Remove from list
|
||||
if worker in self._workers:
|
||||
self._workers.remove(worker)
|
||||
|
||||
def _try_scale_up(self, queue_depth: int, current_count: int) -> bool:
|
||||
"""
|
||||
Try to scale up workers if needed.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
current_count: Current number of workers
|
||||
|
||||
Returns:
|
||||
True if scaled up, False otherwise
|
||||
"""
|
||||
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
|
||||
old_count = current_count
|
||||
self._create_worker()
|
||||
|
||||
logger.debug(
|
||||
"Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)",
|
||||
old_count,
|
||||
len(self._workers),
|
||||
queue_depth,
|
||||
self._scale_up_threshold,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool:
|
||||
"""
|
||||
Try to scale down workers if we have excess capacity.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
current_count: Current number of workers
|
||||
active_count: Number of active workers
|
||||
idle_count: Number of idle workers
|
||||
|
||||
Returns:
|
||||
True if scaled down, False otherwise
|
||||
"""
|
||||
# Skip if we're at minimum or have no idle workers
|
||||
if current_count <= self._min_workers or idle_count == 0:
|
||||
return False
|
||||
|
||||
# Check if we have excess capacity
|
||||
has_excess_capacity = (
|
||||
queue_depth <= active_count # Active workers can handle current queue
|
||||
or idle_count > active_count # More idle than active workers
|
||||
or (queue_depth == 0 and idle_count > 0) # No work and have idle workers
|
||||
)
|
||||
|
||||
if not has_excess_capacity:
|
||||
return False
|
||||
|
||||
# Find and remove idle workers that have been idle long enough
|
||||
workers_to_remove: list[tuple[Worker, int]] = []
|
||||
|
||||
for worker in self._workers:
|
||||
# Check if worker is idle and has exceeded idle time threshold
|
||||
if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time:
|
||||
# Don't remove if it would leave us unable to handle the queue
|
||||
remaining_workers = current_count - len(workers_to_remove) - 1
|
||||
if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2):
|
||||
workers_to_remove.append((worker, worker.worker_id))
|
||||
# Only remove one worker per check to avoid aggressive scaling
|
||||
break
|
||||
|
||||
# Remove idle workers if any found
|
||||
if workers_to_remove:
|
||||
old_count = current_count
|
||||
for worker, worker_id in workers_to_remove:
|
||||
self._remove_worker(worker, worker_id)
|
||||
|
||||
logger.debug(
|
||||
"Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, "
|
||||
"queue_depth=%d, active=%d, idle=%d)",
|
||||
old_count,
|
||||
len(self._workers),
|
||||
len(workers_to_remove),
|
||||
self._scale_down_idle_time,
|
||||
queue_depth,
|
||||
active_count,
|
||||
idle_count - len(workers_to_remove),
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_and_scale(self) -> None:
|
||||
"""Check and perform scaling if needed."""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
current_count = len(self._workers)
|
||||
queue_depth = self._ready_queue.qsize()
|
||||
|
||||
# Count active vs idle workers by querying their state directly
|
||||
idle_count = sum(1 for worker in self._workers if worker.is_idle)
|
||||
active_count = current_count - idle_count
|
||||
|
||||
# Try to scale up if queue is backing up
|
||||
self._try_scale_up(queue_depth, current_count)
|
||||
|
||||
# Try to scale down if we have excess capacity
|
||||
self._try_scale_down(queue_depth, current_count, active_count, idle_count)
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
"""Get current number of workers."""
|
||||
with self._lock:
|
||||
return len(self._workers)
|
||||
|
||||
def get_status(self) -> dict[str, int]:
|
||||
"""
|
||||
Get pool status information.
|
||||
|
||||
Returns:
|
||||
Dictionary with status information
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_workers": len(self._workers),
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"min_workers": self._min_workers,
|
||||
"max_workers": self._max_workers,
|
||||
}
|
||||
72
api/core/workflow/graph_events/__init__.py
Normal file
72
api/core/workflow/graph_events/__init__.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Agent events
|
||||
from .agent import NodeRunAgentLogEvent
|
||||
|
||||
# Base events
|
||||
from .base import (
|
||||
BaseGraphEvent,
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
)
|
||||
|
||||
# Graph events
|
||||
from .graph import (
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
||||
# Iteration events
|
||||
from .iteration import (
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
)
|
||||
|
||||
# Loop events
|
||||
from .loop import (
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
)
|
||||
|
||||
# Node events
|
||||
from .node import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseGraphEvent",
|
||||
"GraphEngineEvent",
|
||||
"GraphNodeEventBase",
|
||||
"GraphRunAbortedEvent",
|
||||
"GraphRunFailedEvent",
|
||||
"GraphRunPartialSucceededEvent",
|
||||
"GraphRunStartedEvent",
|
||||
"GraphRunSucceededEvent",
|
||||
"NodeRunAgentLogEvent",
|
||||
"NodeRunExceptionEvent",
|
||||
"NodeRunFailedEvent",
|
||||
"NodeRunIterationFailedEvent",
|
||||
"NodeRunIterationNextEvent",
|
||||
"NodeRunIterationStartedEvent",
|
||||
"NodeRunIterationSucceededEvent",
|
||||
"NodeRunLoopFailedEvent",
|
||||
"NodeRunLoopNextEvent",
|
||||
"NodeRunLoopStartedEvent",
|
||||
"NodeRunLoopSucceededEvent",
|
||||
"NodeRunRetrieverResourceEvent",
|
||||
"NodeRunRetryEvent",
|
||||
"NodeRunStartedEvent",
|
||||
"NodeRunStreamChunkEvent",
|
||||
"NodeRunSucceededEvent",
|
||||
]
|
||||
17
api/core/workflow/graph_events/agent.py
Normal file
17
api/core/workflow/graph_events/agent.py
Normal file
@ -0,0 +1,17 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import GraphAgentNodeEventBase
|
||||
|
||||
|
||||
class NodeRunAgentLogEvent(GraphAgentNodeEventBase):
|
||||
message_id: str = Field(..., description="message id")
|
||||
label: str = Field(..., description="label")
|
||||
node_execution_id: str = Field(..., description="node execution id")
|
||||
parent_id: str | None = Field(..., description="parent id")
|
||||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
31
api/core/workflow/graph_events/base.py
Normal file
31
api/core/workflow/graph_events/base.py
Normal file
@ -0,0 +1,31 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphNodeEventBase(GraphEngineEvent):
|
||||
id: str = Field(..., description="node execution id")
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
node_run_result: NodeRunResult = Field(default_factory=NodeRunResult)
|
||||
|
||||
|
||||
class GraphAgentNodeEventBase(GraphNodeEventBase):
|
||||
pass
|
||||
28
api/core/workflow/graph_events/graph.py
Normal file
28
api/core/workflow/graph_events/graph.py
Normal file
@ -0,0 +1,28 @@
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.graph_events import BaseGraphEvent
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: int = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
exceptions_count: int = Field(..., description="exception count")
|
||||
outputs: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GraphRunAbortedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is aborted by user command."""
|
||||
|
||||
reason: str | None = Field(default=None, description="reason for abort")
|
||||
outputs: dict[str, object] = Field(default_factory=dict, description="partial outputs if any")
|
||||
40
api/core/workflow/graph_events/iteration.py
Normal file
40
api/core/workflow/graph_events/iteration.py
Normal file
@ -0,0 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
|
||||
class NodeRunIterationStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class NodeRunIterationNextEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Any = None
|
||||
|
||||
|
||||
class NodeRunIterationSucceededEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class NodeRunIterationFailedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
40
api/core/workflow/graph_events/loop.py
Normal file
40
api/core/workflow/graph_events/loop.py
Normal file
@ -0,0 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
|
||||
class NodeRunLoopStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class NodeRunLoopNextEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Any = None
|
||||
|
||||
|
||||
class NodeRunLoopSucceededEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class NodeRunLoopFailedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
53
api/core/workflow/graph_events/node.py
Normal file
53
api/core/workflow/graph_events/node.py
Normal file
@ -0,0 +1,53 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
|
||||
class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
predecessor_node_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode
|
||||
provider_type: str = ""
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
# Spec-compliant fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(GraphNodeEventBase):
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
|
||||
|
||||
class NodeRunFailedEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
40
api/core/workflow/node_events/__init__.py
Normal file
40
api/core/workflow/node_events/__init__.py
Normal file
@ -0,0 +1,40 @@
|
||||
from .agent import AgentLogEvent
|
||||
from .base import NodeEventBase, NodeRunResult
|
||||
from .iteration import (
|
||||
IterationFailedEvent,
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
IterationSucceededEvent,
|
||||
)
|
||||
from .loop import (
|
||||
LoopFailedEvent,
|
||||
LoopNextEvent,
|
||||
LoopStartedEvent,
|
||||
LoopSucceededEvent,
|
||||
)
|
||||
from .node import (
|
||||
ModelInvokeCompletedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentLogEvent",
|
||||
"IterationFailedEvent",
|
||||
"IterationNextEvent",
|
||||
"IterationStartedEvent",
|
||||
"IterationSucceededEvent",
|
||||
"LoopFailedEvent",
|
||||
"LoopNextEvent",
|
||||
"LoopStartedEvent",
|
||||
"LoopSucceededEvent",
|
||||
"ModelInvokeCompletedEvent",
|
||||
"NodeEventBase",
|
||||
"NodeRunResult",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunRetryEvent",
|
||||
"StreamChunkEvent",
|
||||
"StreamCompletedEvent",
|
||||
]
|
||||
18
api/core/workflow/node_events/agent.py
Normal file
18
api/core/workflow/node_events/agent.py
Normal file
@ -0,0 +1,18 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import NodeEventBase
|
||||
|
||||
|
||||
class AgentLogEvent(NodeEventBase):
|
||||
message_id: str = Field(..., description="id")
|
||||
label: str = Field(..., description="label")
|
||||
node_execution_id: str = Field(..., description="node execution id")
|
||||
parent_id: str | None = Field(..., description="parent id")
|
||||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Mapping[str, Any] = Field(default_factory=dict, description="metadata")
|
||||
node_id: str = Field(..., description="node id")
|
||||
40
api/core/workflow/node_events/base.py
Normal file
40
api/core/workflow/node_events/base.py
Normal file
@ -0,0 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeEventBase(BaseModel):
|
||||
"""Base class for all node events"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _default_metadata():
|
||||
v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
return v
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
"""
|
||||
Node Run Result.
|
||||
"""
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.PENDING
|
||||
|
||||
inputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, Any] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata)
|
||||
llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
|
||||
edge_source_handle: str = "source" # source handle id of node with multiple branches
|
||||
|
||||
error: str = ""
|
||||
error_type: str = ""
|
||||
|
||||
# single step node run retry
|
||||
retry_index: int = 0
|
||||
36
api/core/workflow/node_events/iteration.py
Normal file
36
api/core/workflow/node_events/iteration.py
Normal file
@ -0,0 +1,36 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import NodeEventBase
|
||||
|
||||
|
||||
class IterationStartedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class IterationNextEvent(NodeEventBase):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Any = None
|
||||
|
||||
|
||||
class IterationSucceededEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class IterationFailedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
36
api/core/workflow/node_events/loop.py
Normal file
36
api/core/workflow/node_events/loop.py
Normal file
@ -0,0 +1,36 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import NodeEventBase
|
||||
|
||||
|
||||
class LoopStartedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class LoopNextEvent(NodeEventBase):
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Any = None
|
||||
|
||||
|
||||
class LoopSucceededEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class LoopFailedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
41
api/core/workflow/node_events/node.py
Normal file
41
api/core/workflow/node_events/node.py
Normal file
@ -0,0 +1,41 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
from .base import NodeEventBase
|
||||
|
||||
|
||||
class RunRetrieverResourceEvent(NodeEventBase):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class ModelInvokeCompletedEvent(NodeEventBase):
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(NodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="Retry attempt number")
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class StreamChunkEvent(NodeEventBase):
|
||||
# Spec-compliant fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
|
||||
|
||||
class StreamCompletedEvent(NodeEventBase):
|
||||
node_run_result: NodeRunResult = Field(..., description="run result")
|
||||
@ -1,3 +1,3 @@
|
||||
from .enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
__all__ = ["NodeType"]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
@ -9,16 +9,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
@ -29,17 +25,25 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
@ -57,19 +61,23 @@ from .exc import (
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
class AgentNode(BaseNode):
|
||||
|
||||
class AgentNode(Node):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
_node_type = NodeType.AGENT
|
||||
node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -78,7 +86,7 @@ class AgentNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -91,7 +99,9 @@ class AgentNode(BaseNode):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -99,12 +109,12 @@ class AgentNode(BaseNode):
|
||||
agent_strategy_name=self._node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
)
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
@ -139,8 +149,8 @@ class AgentNode(BaseNode):
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
@ -158,16 +168,16 @@ class AgentNode(BaseNode):
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_type=self.type_,
|
||||
node_id=self.node_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
@ -181,7 +191,7 @@ class AgentNode(BaseNode):
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: PluginAgentStrategy,
|
||||
strategy: "PluginAgentStrategy",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
@ -320,7 +330,7 @@ class AgentNode(BaseNode):
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
@ -339,10 +349,11 @@ class AgentNode(BaseNode):
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> InvokeCredentials:
|
||||
) -> "InvokeCredentials":
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
credentials = InvokeCredentials()
|
||||
|
||||
@ -388,6 +399,8 @@ class AgentNode(BaseNode):
|
||||
Get agent strategy icon
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
try:
|
||||
@ -401,7 +414,7 @@ class AgentNode(BaseNode):
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
@ -450,7 +463,9 @@ class AgentNode(BaseNode):
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
def _filter_mcp_type_tool(
|
||||
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
@ -473,11 +488,13 @@ class AgentNode(BaseNode):
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator:
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
@ -491,7 +508,7 @@ class AgentNode(BaseNode):
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage: LLMUsage | None = None
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
@ -553,7 +570,11 @@ class AgentNode(BaseNode):
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
@ -564,13 +585,17 @@ class AgentNode(BaseNode):
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
if message.message.json_object is not None:
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
@ -587,8 +612,10 @@ class AgentNode(BaseNode):
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
@ -639,7 +666,7 @@ class AgentNode(BaseNode):
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
id=message.message.id,
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
@ -652,7 +679,7 @@ class AgentNode(BaseNode):
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.id == agent_log.id:
|
||||
if log.message_id == agent_log.message_id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
@ -673,7 +700,7 @@ class AgentNode(BaseNode):
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.id,
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
@ -689,8 +716,24 @@ class AgentNode(BaseNode):
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from enum import Enum, StrEnum
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
class ParamsAutoGenerated(Enum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
class ParamsAutoGenerated(IntEnum):
|
||||
CLOSE = auto()
|
||||
OPEN = auto()
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AgentNodeError(Exception):
|
||||
"""Base exception for all agent node errors."""
|
||||
|
||||
@ -12,7 +9,7 @@ class AgentNodeError(Exception):
|
||||
class AgentStrategyError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent strategy."""
|
||||
|
||||
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
|
||||
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
|
||||
self.strategy_name = strategy_name
|
||||
self.provider_name = provider_name
|
||||
super().__init__(message)
|
||||
@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError):
|
||||
class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
"""Exception raised when the specified agent strategy is not found."""
|
||||
|
||||
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
|
||||
def __init__(self, strategy_name: str, provider_name: str | None = None):
|
||||
super().__init__(
|
||||
f"Agent strategy '{strategy_name}' not found"
|
||||
+ (f" for provider '{provider_name}'" if provider_name else ""),
|
||||
@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
class AgentInvocationError(AgentNodeError):
|
||||
"""Exception raised when there's an error invoking the agent."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError):
|
||||
class AgentParameterError(AgentNodeError):
|
||||
"""Exception raised when there's an error with agent parameters."""
|
||||
|
||||
def __init__(self, message: str, parameter_name: Optional[str] = None):
|
||||
def __init__(self, message: str, parameter_name: str | None = None):
|
||||
self.parameter_name = parameter_name
|
||||
super().__init__(message)
|
||||
|
||||
@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError):
|
||||
class AgentVariableError(AgentNodeError):
|
||||
"""Exception raised when there's an error with variables in the agent node."""
|
||||
|
||||
def __init__(self, message: str, variable_name: Optional[str] = None):
|
||||
def __init__(self, message: str, variable_name: str | None = None):
|
||||
self.variable_name = variable_name
|
||||
super().__init__(message)
|
||||
|
||||
@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError):
|
||||
class ToolFileError(AgentNodeError):
|
||||
"""Exception raised when there's an error with a tool file."""
|
||||
|
||||
def __init__(self, message: str, file_id: Optional[str] = None):
|
||||
def __init__(self, message: str, file_id: str | None = None):
|
||||
self.file_id = file_id
|
||||
super().__init__(message)
|
||||
|
||||
@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError):
|
||||
class AgentMessageTransformError(AgentNodeError):
|
||||
"""Exception raised when there's an error transforming agent messages."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError):
|
||||
class AgentModelError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the model used by the agent."""
|
||||
|
||||
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
|
||||
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
super().__init__(message)
|
||||
@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError):
|
||||
class AgentMemoryError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent's memory."""
|
||||
|
||||
def __init__(self, message: str, conversation_id: Optional[str] = None):
|
||||
def __init__(self, message: str, conversation_id: str | None = None):
|
||||
self.conversation_id = conversation_id
|
||||
super().__init__(message)
|
||||
|
||||
@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
variable_name: Optional[str] = None,
|
||||
expected_type: Optional[str] = None,
|
||||
actual_type: Optional[str] = None,
|
||||
variable_name: str | None = None,
|
||||
expected_type: str | None = None,
|
||||
actual_type: str | None = None,
|
||||
):
|
||||
self.variable_name = variable_name
|
||||
self.expected_type = expected_type
|
||||
|
||||
@ -1,4 +0,0 @@
|
||||
from .answer_node import AnswerNode
|
||||
from .entities import AnswerStreamGenerateRoute
|
||||
|
||||
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]
|
||||
|
||||
@ -1,31 +1,26 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerNode(BaseNode):
|
||||
_node_type = NodeType.ANSWER
|
||||
class AnswerNode(Node):
|
||||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
@ -34,7 +29,7 @@ class AnswerNode(BaseNode):
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
@ -48,35 +43,29 @@ class AnswerNode(BaseNode):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
# generate routes
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
|
||||
|
||||
answer = ""
|
||||
files = []
|
||||
for part in generate_routes:
|
||||
if part.type == GenerateRouteChunk.ChunkType.VAR:
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(value_selector)
|
||||
if variable:
|
||||
if isinstance(variable, FileSegment):
|
||||
files.append(variable.value)
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
files.extend(variable.value)
|
||||
answer += variable.markdown
|
||||
else:
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer += part.text
|
||||
|
||||
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
|
||||
files = self._extract_files_from_segments(segments.value)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
|
||||
outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)},
|
||||
)
|
||||
|
||||
def _extract_files_from_segments(self, segments: Sequence[Segment]):
|
||||
"""Extract all files from segments containing FileSegment or ArrayFileSegment instances.
|
||||
|
||||
FileSegment contains a single file, while ArrayFileSegment contains multiple files.
|
||||
This method flattens all files into a single list.
|
||||
"""
|
||||
files = []
|
||||
for segment in segments:
|
||||
if isinstance(segment, FileSegment):
|
||||
# Single file - wrap in list for consistency
|
||||
files.append(segment.value)
|
||||
elif isinstance(segment, ArrayFileSegment):
|
||||
# Multiple files - extend the list
|
||||
files.extend(segment.value)
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
@ -96,3 +85,12 @@ class AnswerNode(BaseNode):
|
||||
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def get_streaming_template(self) -> Template:
|
||||
"""
|
||||
Get the template for streaming.
|
||||
|
||||
Returns:
|
||||
Template instance for this Answer node
|
||||
"""
|
||||
return Template.from_answer_template(self._node_data.answer)
|
||||
|
||||
@ -1,174 +0,0 @@
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
AnswerStreamGenerateRoute,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerStreamGeneratorRouter:
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
) -> AnswerStreamGenerateRoute:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selectors of answer nodes
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
|
||||
for answer_node_id, node_config in node_id_config_mapping.items():
|
||||
if node_config.get("data", {}).get("type") != NodeType.ANSWER.value:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
generate_route = cls._extract_generate_route_selectors(node_config)
|
||||
answer_generate_route[answer_node_id] = generate_route
|
||||
|
||||
# fetch answer dependencies
|
||||
answer_node_ids = list(answer_generate_route.keys())
|
||||
answer_dependencies = cls._fetch_answers_dependencies(
|
||||
answer_node_ids=answer_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
)
|
||||
|
||||
return AnswerStreamGenerateRoute(
|
||||
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω")
|
||||
|
||||
generate_routes: list[GenerateRouteChunk] = []
|
||||
for part in template.split("Ω"):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "")
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(text=part))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@classmethod
|
||||
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = AnswerNodeData(**config.get("data", {}))
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace("{{", "").replace("}}", "")
|
||||
return part.startswith("{{") and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _fetch_answers_dependencies(
|
||||
cls,
|
||||
answer_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch answer dependencies
|
||||
:param answer_node_ids: answer node ids
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:return:
|
||||
"""
|
||||
answer_dependencies: dict[str, list[str]] = {}
|
||||
for answer_node_id in answer_node_ids:
|
||||
if answer_dependencies.get(answer_node_id) is None:
|
||||
answer_dependencies[answer_node_id] = []
|
||||
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
current_node_id=answer_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies,
|
||||
)
|
||||
|
||||
return answer_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_answer_dependencies(
|
||||
cls,
|
||||
current_node_id: str,
|
||||
answer_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch answer dependencies
|
||||
:param current_node_id: current node id
|
||||
:param answer_node_id: answer node id
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param answer_dependencies: answer dependencies
|
||||
:return:
|
||||
"""
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
|
||||
if (
|
||||
source_node_type
|
||||
in {
|
||||
NodeType.ANSWER,
|
||||
NodeType.IF_ELSE,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ITERATION,
|
||||
NodeType.LOOP,
|
||||
NodeType.VARIABLE_ASSIGNER,
|
||||
}
|
||||
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
|
||||
):
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
current_node_id=source_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies,
|
||||
)
|
||||
@ -1,199 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerStreamProcessor(StreamProcessor):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.generate_routes = graph.answer_stream_generate_routes
|
||||
self.route_position = {}
|
||||
for answer_node_id in self.generate_routes.answer_generate_route:
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
self.reset()
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id or event.in_loop_id:
|
||||
yield event
|
||||
continue
|
||||
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
]
|
||||
else:
|
||||
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
|
||||
stream_out_answer_node_ids
|
||||
)
|
||||
|
||||
for _ in stream_out_answer_node_ids:
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(
|
||||
self, event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
:return:
|
||||
"""
|
||||
for answer_node_id in self.route_position:
|
||||
# all depends on answer node id not in rest node ids
|
||||
if event.route_node_state.node_id != answer_node_id and (
|
||||
answer_node_id not in self.rest_node_ids
|
||||
or not all(
|
||||
dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[answer_node_id]
|
||||
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=route_chunk.text,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
from_variable_selector=[answer_node_id, "answer"],
|
||||
node_version=event.node_version,
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
break
|
||||
|
||||
value = self.variable_pool.get(value_selector)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = value.markdown
|
||||
|
||||
if text:
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=list(value_selector),
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_version=event.node_version,
|
||||
)
|
||||
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.from_variable_selector:
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
stream_out_answer_node_ids = []
|
||||
for answer_node_id, route_position in self.route_position.items():
|
||||
if answer_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
# Remove current node id from answer dependencies to support stream output if it is a success branch
|
||||
answer_dependencies = self.generate_routes.answer_dependencies
|
||||
edge_mapping = self.graph.edge_mapping.get(event.node_id)
|
||||
success_edge = (
|
||||
next(
|
||||
(
|
||||
edge
|
||||
for edge in edge_mapping
|
||||
if edge.run_condition
|
||||
and edge.run_condition.type == "branch_identify"
|
||||
and edge.run_condition.branch_identify == "success-branch"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if edge_mapping
|
||||
else None
|
||||
)
|
||||
if (
|
||||
event.node_id in answer_dependencies[answer_node_id]
|
||||
and success_edge
|
||||
and success_edge.target_node_id == answer_node_id
|
||||
):
|
||||
answer_dependencies[answer_node_id].remove(event.node_id)
|
||||
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
|
||||
# all depends on answer node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids):
|
||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
||||
continue
|
||||
|
||||
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
|
||||
|
||||
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
|
||||
continue
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
continue
|
||||
|
||||
stream_out_answer_node_ids.append(answer_node_id)
|
||||
|
||||
return stream_out_answer_node_ids
|
||||
@ -1,109 +0,0 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
|
||||
@abstractmethod
|
||||
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
# remove finished node id
|
||||
self.rest_node_ids.remove(finished_node_id)
|
||||
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return
|
||||
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids: list[str] = []
|
||||
unreachable_first_node_ids: list[str] = []
|
||||
if finished_node_id not in self.graph.edge_mapping:
|
||||
logger.warning("node %s has no edge mapping", finished_node_id)
|
||||
return
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (
|
||||
edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify
|
||||
):
|
||||
# remove unreachable nodes
|
||||
# FIXME: because of the code branch can combine directly, so for answer node
|
||||
# we remove the node maybe shortcut the answer node, so comment this code for now
|
||||
# there is not effect on the answer node and the workflow, when we have a better solution
|
||||
# we can open this code. Issues: #11542 #9560 #10638 #10564
|
||||
# ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
|
||||
# if "answer" in ids:
|
||||
# continue
|
||||
# else:
|
||||
# reachable_node_ids.extend(ids)
|
||||
|
||||
# The branch_identify parameter is added to ensure that
|
||||
# only nodes in the correct logical branch are included.
|
||||
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
|
||||
reachable_node_ids.extend(ids)
|
||||
else:
|
||||
# if the condition edge in parallel, and the target node is not in parallel, we should not remove it
|
||||
# Issues: #13626
|
||||
if (
|
||||
finished_node_id in self.graph.node_parallel_mapping
|
||||
and edge.target_node_id not in self.graph.node_parallel_mapping
|
||||
):
|
||||
continue
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids))
|
||||
for node_id in unreachable_first_node_ids:
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
|
||||
if node_id not in self.rest_node_ids:
|
||||
self.rest_node_ids.append(node_id)
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
continue
|
||||
|
||||
# Only follow edges that match the branch_identify or have no run_condition
|
||||
if edge.run_condition and edge.run_condition.branch_identify:
|
||||
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
|
||||
continue
|
||||
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
if node_id in reachable_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
|
||||
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
|
||||
class ChunkType(Enum):
|
||||
VAR = "var"
|
||||
TEXT = "text"
|
||||
class ChunkType(StrEnum):
|
||||
VAR = auto()
|
||||
TEXT = auto()
|
||||
|
||||
type: ChunkType = Field(..., description="generate route chunk type")
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .node import BaseNode
|
||||
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
"BaseIterationState",
|
||||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNode",
|
||||
"BaseNodeData",
|
||||
]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user