Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

132
api/core/workflow/README.md Normal file
View File

@ -0,0 +1,132 @@
# Workflow
## Project Overview
This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control.
## Architecture
### Core Components
The graph engine follows a layered architecture with strict dependency rules:
1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution
- **Manager** - External control interface for stop/pause/resume commands
- **Worker** - Node execution runtime
- **Command Processing** - Handles control commands (abort, pause, resume)
- **Event Management** - Event propagation and layer notifications
- **Graph Traversal** - Edge processing and skip propagation
- **Response Coordinator** - Path tracking and session management
- **Layers** - Pluggable middleware (debug logging, execution limits)
- **Command Channels** - Communication channels (InMemory, Redis)
1. **Graph** (`graph/`) - Graph structure and runtime state
- **Graph Template** - Workflow definition
- **Edge** - Node connections with conditions
- **Runtime State Protocol** - State management interface
1. **Nodes** (`nodes/`) - Node implementations
- **Base** - Abstract node classes and variable parsing
- **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc.
1. **Events** (`node_events/`) - Event system
- **Base** - Event protocols
- **Node Events** - Node lifecycle events
1. **Entities** (`entities/`) - Domain models
- **Variable Pool** - Variable storage
- **Graph Init Params** - Initialization configuration
## Key Design Patterns
### Command Channel Pattern
External workflow control via Redis or in-memory channels:
```python
# Send stop command to running workflow
channel = RedisChannel(redis_client, f"workflow:{task_id}:commands")
channel.send_command(AbortCommand(reason="User requested"))
```
### Layer System
Extensible middleware for cross-cutting concerns:
```python
engine = GraphEngine(graph)
engine.add_layer(DebugLoggingLayer(level="INFO"))
engine.add_layer(ExecutionLimitsLayer(max_nodes=100))
```
### Event-Driven Architecture
All node executions emit events for monitoring and integration:
- `NodeRunStartedEvent` - Node execution begins
- `NodeRunSucceededEvent` - Node completes successfully
- `NodeRunFailedEvent` - Node encounters error
- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle
### Variable Pool
Centralized variable storage with namespace isolation:
```python
# Variables scoped by node_id
pool.add(["node1", "output"], value)
result = pool.get(["node1", "output"])
```
## Import Architecture Rules
The codebase enforces strict layering via import-linter:
1. **Workflow Layers** (top to bottom):
- graph_engine → graph_events → graph → nodes → node_events → entities
1. **Graph Engine Internal Layers**:
- orchestration → command_processing → event_management → graph_traversal → domain
1. **Domain Isolation**:
- Domain models cannot import from infrastructure layers
1. **Command Channel Independence**:
- InMemory and Redis channels must remain independent
## Common Tasks
### Adding a New Node Type
1. Create node class in `nodes/<node_type>/`
1. Inherit from `BaseNode` or appropriate base class
1. Implement `_run()` method
1. Register in `nodes/node_mapping.py`
1. Add tests in `tests/unit_tests/core/workflow/nodes/`
### Implementing a Custom Layer
1. Create class inheriting from `Layer` base
1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()`
1. Add to engine via `engine.add_layer()`
### Debugging Workflow Execution
Enable debug logging layer:
```python
debug_layer = DebugLoggingLayer(
level="DEBUG",
include_inputs=True,
include_outputs=True
)
```

View File

@ -1,7 +0,0 @@
from .base_workflow_callback import WorkflowCallback
from .workflow_logging_callback import WorkflowLoggingCallback
__all__ = [
"WorkflowCallback",
"WorkflowLoggingCallback",
]

View File

@ -1,12 +0,0 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.event import GraphEngineEvent
class WorkflowCallback(ABC):
@abstractmethod
def on_event(self, event: GraphEngineEvent) -> None:
"""
Published event
"""
raise NotImplementedError

View File

@ -1,263 +0,0 @@
from typing import Optional
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
LoopRunFailedEvent,
LoopRunNextEvent,
LoopRunStartedEvent,
LoopRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from .base_workflow_callback import WorkflowCallback
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id: Optional[str] = None
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunPartialSucceededEvent):
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(event=event)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(event=event)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(event=event)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(event=event)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(event=event)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(event=event)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(event=event)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(event=event)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(event=event)
elif isinstance(event, LoopRunStartedEvent):
self.on_workflow_loop_started(event=event)
elif isinstance(event, LoopRunNextEvent):
self.on_workflow_loop_next(event=event)
elif isinstance(event, LoopRunSucceededEvent | LoopRunFailedEvent):
self.on_workflow_loop_completed(event=event)
else:
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
"""
Workflow node execute started
"""
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
self.print_text(f"Node ID: {event.node_id}", color="yellow")
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
self.print_text(f"Type: {event.node_type.value}", color="yellow")
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
"""
Workflow node execute succeeded
"""
route_node_state = event.route_node_state
self.print_text("\n[NodeRunSucceededEvent]", color="green")
self.print_text(f"Node ID: {event.node_id}", color="green")
self.print_text(f"Node Title: {event.node_data.title}", color="green")
self.print_text(f"Type: {event.node_type.value}", color="green")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color="green",
)
self.print_text(
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="green",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="green",
)
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color="green",
)
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
"""
Workflow node execute failed
"""
route_node_state = event.route_node_state
self.print_text("\n[NodeRunFailedEvent]", color="red")
self.print_text(f"Node ID: {event.node_id}", color="red")
self.print_text(f"Node Title: {event.node_data.title}", color="red")
self.print_text(f"Type: {event.node_type.value}", color="red")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color="red")
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color="red",
)
self.print_text(
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="red",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="red",
)
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
"""
Publish text chunk
"""
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text("\n[NodeRunStreamChunkEvent]")
self.print_text(f"Node ID: {route_node_state.node_id}")
node_run_result = route_node_state.node_run_result
if node_run_result:
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
)
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
if event.in_loop_id:
self.print_text(f"Loop ID: {event.in_loop_id}", color="blue")
def on_workflow_parallel_completed(
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = "blue"
elif isinstance(event, ParallelBranchRunFailedEvent):
color = "red"
self.print_text(
"\n[ParallelBranchRunSucceededEvent]"
if isinstance(event, ParallelBranchRunSucceededEvent)
else "\n[ParallelBranchRunFailedEvent]",
color=color,
)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if event.in_loop_id:
self.print_text(f"Loop ID: {event.in_loop_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
"""
Publish iteration started
"""
self.print_text("\n[IterationRunStartedEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
"""
Publish iteration next
"""
self.print_text("\n[IterationRunNextEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
self.print_text(f"Iteration Index: {event.index}", color="blue")
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
"""
Publish iteration completed
"""
self.print_text(
"\n[IterationRunSucceededEvent]"
if isinstance(event, IterationRunSucceededEvent)
else "\n[IterationRunFailedEvent]",
color="blue",
)
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
def on_workflow_loop_started(self, event: LoopRunStartedEvent) -> None:
"""
Publish loop started
"""
self.print_text("\n[LoopRunStartedEvent]", color="blue")
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
"""
Publish loop next
"""
self.print_text("\n[LoopRunNextEvent]", color="blue")
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
self.print_text(f"Loop Index: {event.index}", color="blue")
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
"""
Publish loop completed
"""
self.print_text(
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
color="blue",
)
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(f"{text_to_print}", end=end)
def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"

View File

@ -1,3 +1,4 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"

View File

@ -20,7 +20,7 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable") -> None:
def update(self, conversation_id: str, variable: "Variable"):
"""
Updates the value of the specified conversation variable in the underlying storage.

View File

@ -0,0 +1,18 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .graph_runtime_state import GraphRuntimeState
from .run_condition import RunCondition
from .variable_pool import VariablePool, VariableValue
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"RunCondition",
"VariablePool",
"VariableValue",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -0,0 +1,8 @@
from pydantic import BaseModel
class AgentNodeStrategyInit(BaseModel):
"""Agent node strategy initialization data."""
name: str
icon: str | None = None

View File

@ -3,19 +3,18 @@ from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from models.enums import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
user_from: str = Field(
..., description="user from, account or end-user"
) # Should be UserFrom enum: 'account' | 'end-user'
invoke_from: str = Field(
..., description="invoke from, service-api, web-app, explore or debugger"
) # Should be InvokeFrom enum: 'service-api' | 'web-app' | 'explore' | 'debugger'
call_depth: int = Field(..., description="call depth")

View File

@ -0,0 +1,160 @@
from copy import deepcopy
from pydantic import BaseModel, PrivateAttr
from core.model_runtime.entities.llm_entities import LLMUsage
from .variable_pool import VariablePool
class GraphRuntimeState(BaseModel):
# Private attributes to prevent direct modification
_variable_pool: VariablePool = PrivateAttr()
_start_at: float = PrivateAttr()
_total_tokens: int = PrivateAttr(default=0)
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
_node_run_steps: int = PrivateAttr(default=0)
_ready_queue_json: str = PrivateAttr()
_graph_execution_json: str = PrivateAttr()
_response_coordinator_json: str = PrivateAttr()
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue_json: str = "",
graph_execution_json: str = "",
response_coordinator_json: str = "",
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
super().__init__(**kwargs)
# Initialize private attributes with validation
self._variable_pool = variable_pool
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
if llm_usage is None:
llm_usage = LLMUsage.empty_usage()
self._llm_usage = llm_usage
if outputs is None:
outputs = {}
self._outputs = deepcopy(outputs)
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
self._ready_queue_json = ready_queue_json
self._graph_execution_json = graph_execution_json
self._response_coordinator_json = response_coordinator_json
@property
def variable_pool(self) -> VariablePool:
"""Get the variable pool."""
return self._variable_pool
@property
def start_at(self) -> float:
"""Get the start time."""
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
"""Set the start time."""
self._start_at = value
@property
def total_tokens(self) -> int:
"""Get the total tokens count."""
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int):
"""Set the total tokens count."""
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
"""Get the LLM usage info."""
# Return a copy to prevent external modification
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage):
"""Set the LLM usage info."""
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, object]:
"""Get a copy of the outputs dictionary."""
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, object]) -> None:
"""Set the outputs dictionary."""
self._outputs = deepcopy(value)
def set_output(self, key: str, value: object) -> None:
"""Set a single output value."""
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: object = None) -> object:
"""Get a single output value."""
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, object]) -> None:
"""Update multiple output values."""
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count."""
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
"""Set the node run steps count."""
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
"""Increment the node run steps by 1."""
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
"""Add tokens to the total count."""
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
@property
def ready_queue_json(self) -> str:
"""Get a copy of the ready queue state."""
return self._ready_queue_json
@property
def graph_execution_json(self) -> str:
"""Get a copy of the serialized graph execution state."""
return self._graph_execution_json
@property
def response_coordinator_json(self) -> str:
"""Get a copy of the serialized response coordinator state."""
return self._response_coordinator_json

View File

@ -1,34 +0,0 @@
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[Mapping[str, Any]] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
error_type: Optional[str] = None # error type if status is failed
# single step node run retry
retry_index: int = 0
class AgentNodeStrategyInit(BaseModel):
name: str
icon: str | None = None

View File

@ -1,5 +1,5 @@
import hashlib
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel
@ -10,10 +10,10 @@ class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
branch_identify: str | None = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
conditions: list[Condition] | None = None
"""conditions to run the node, required when type is condition"""
@property

View File

@ -1,12 +0,0 @@
from collections.abc import Sequence
from pydantic import BaseModel
class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: Sequence[str]

View File

@ -9,12 +9,17 @@ from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
RAG_PIPELINE_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.system_variable import SystemVariable
from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, File]
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
@ -40,14 +45,18 @@ class VariablePool(BaseModel):
)
environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
default_factory=list,
default_factory=list[VariableUnion],
)
conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.",
default_factory=list[VariableUnion],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
default_factory=list,
)
def model_post_init(self, context: Any, /) -> None:
def model_post_init(self, context: Any, /):
# Create a mapping from field names to SystemVariableKey enum values
self._add_system_variables(self.system_variables)
# Add environment variables to the variable pool
@ -56,8 +65,18 @@ class VariablePool(BaseModel):
# Add conversation variables to the variable pool
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
# Add rag pipeline variables to the variable pool
if self.rag_pipeline_variables:
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
for rag_var in self.rag_pipeline_variables:
node_id = rag_var.variable.belong_to_node_id
key = rag_var.variable.variable
value = rag_var.value
rag_pipeline_variables_map[node_id][key] = value
for key, value in rag_pipeline_variables_map.items():
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
def add(self, selector: Sequence[str], value: Any, /) -> None:
def add(self, selector: Sequence[str], value: Any, /):
"""
Add a variable to the variable pool.
@ -161,11 +180,11 @@ class VariablePool(BaseModel):
# Return result as Segment
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
def _extract_value(self, obj: Any) -> Any:
def _extract_value(self, obj: Any):
"""Extract the actual value from an ObjectSegment."""
return obj.value if isinstance(obj, ObjectSegment) else obj
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str):
"""Get a nested attribute from a dictionary-like object."""
if not isinstance(obj, dict):
return None
@ -191,7 +210,7 @@ class VariablePool(BaseModel):
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
segments = []
segments: list[Segment] = []
for part in filter(lambda x: x, parts):
if "." in part and (variable := self.get(part.split("."))):
segments.append(variable)

View File

@ -7,31 +7,14 @@ implementation details like tenant_id, app_id, etc.
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from libs.datetime_utils import naive_utc_now
class WorkflowType(StrEnum):
"""
Workflow Type Enum for domain layer
"""
WORKFLOW = "workflow"
CHAT = "chat"
class WorkflowExecutionStatus(StrEnum):
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCEEDED = "partial-succeeded"
class WorkflowExecution(BaseModel):
"""
Domain model for workflow execution based on WorkflowRun but without
@ -45,7 +28,7 @@ class WorkflowExecution(BaseModel):
graph: Mapping[str, Any] = Field(...)
inputs: Mapping[str, Any] = Field(...)
outputs: Optional[Mapping[str, Any]] = None
outputs: Mapping[str, Any] | None = None
status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
error_message: str = Field(default="")
@ -54,7 +37,7 @@ class WorkflowExecution(BaseModel):
exceptions_count: int = Field(default=0)
started_at: datetime = Field(...)
finished_at: Optional[datetime] = None
finished_at: datetime | None = None
@property
def elapsed_time(self) -> float:

View File

@ -8,50 +8,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from core.workflow.nodes.enums import NodeType
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = "total_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
TRIGGER_INFO = "trigger_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
class WorkflowNodeExecutionStatus(StrEnum):
"""
Node Execution Status Enum.
"""
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
RETRY = "retry"
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class WorkflowNodeExecution(BaseModel):
@ -78,42 +39,95 @@ class WorkflowNodeExecution(BaseModel):
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
# In most scenarios, `id` should be used as the primary identifier.
node_execution_id: Optional[str] = None
node_execution_id: str | None = None
workflow_id: str # ID of the workflow this node belongs to
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging)
# --------- Core identification fields ends ---------
# Execution positioning and flow
index: int # Sequence number for ordering in trace visualization
predecessor_node_id: Optional[str] = None # ID of the node that executed before this one
predecessor_node_id: str | None = None # ID of the node that executed before this one
node_id: str # ID of the node being executed
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
title: str # Display title of the node
# Execution data
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
# The `inputs` and `outputs` fields hold the full content
inputs: Mapping[str, Any] | None = None # Input variables used by this node
process_data: Mapping[str, Any] | None = None # Intermediate processing data
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
# Execution state
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
error: Optional[str] = None # Error message if execution failed
error: str | None = None # Error message if execution failed
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
# Additional metadata
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.)
# Timing information
created_at: datetime # When execution started
finished_at: Optional[datetime] = None # When execution completed
finished_at: datetime | None = None # When execution completed
_truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None)
_truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None)
_truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None)
def get_truncated_inputs(self) -> Mapping[str, Any] | None:
return self._truncated_inputs
def get_truncated_outputs(self) -> Mapping[str, Any] | None:
return self._truncated_outputs
def get_truncated_process_data(self) -> Mapping[str, Any] | None:
return self._truncated_process_data
def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None):
self._truncated_inputs = truncated_inputs
def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None):
self._truncated_outputs = truncated_outputs
def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None):
self._truncated_process_data = truncated_process_data
def get_response_inputs(self) -> Mapping[str, Any] | None:
inputs = self.get_truncated_inputs()
if inputs:
return inputs
return self.inputs
@property
def inputs_truncated(self):
return self._truncated_inputs is not None
@property
def outputs_truncated(self):
return self._truncated_outputs is not None
@property
def process_data_truncated(self):
return self._truncated_process_data is not None
def get_response_outputs(self) -> Mapping[str, Any] | None:
outputs = self.get_truncated_outputs()
if outputs is not None:
return outputs
return self.outputs
def get_response_process_data(self) -> Mapping[str, Any] | None:
process_data = self.get_truncated_process_data()
if process_data is not None:
return process_data
return self.process_data
def update_from_mapping(
self,
inputs: Optional[Mapping[str, Any]] = None,
process_data: Optional[Mapping[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
) -> None:
inputs: Mapping[str, Any] | None = None,
process_data: Mapping[str, Any] | None = None,
outputs: Mapping[str, Any] | None = None,
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
):
"""
Update the model from mappings.

View File

@ -1,4 +1,12 @@
from enum import StrEnum
from enum import Enum, StrEnum
class NodeState(Enum):
"""State of a node or edge during workflow execution."""
UNKNOWN = "unknown"
TAKEN = "taken"
SKIPPED = "skipped"
class SystemVariableKey(StrEnum):
@ -14,3 +22,120 @@ class SystemVariableKey(StrEnum):
APP_ID = "app_id"
WORKFLOW_ID = "workflow_id"
WORKFLOW_EXECUTION_ID = "workflow_run_id"
# RAG Pipeline
DOCUMENT_ID = "document_id"
ORIGINAL_DOCUMENT_ID = "original_document_id"
BATCH = "batch"
DATASET_ID = "dataset_id"
DATASOURCE_TYPE = "datasource_type"
DATASOURCE_INFO = "datasource_info"
INVOKE_FROM = "invoke_from"
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
DATASOURCE = "datasource"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
LOOP_START = "loop-start"
LOOP_END = "loop-end"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
TRIGGER_WEBHOOK = "trigger-webhook"
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
class NodeExecutionType(StrEnum):
"""Node execution type classification."""
EXECUTABLE = "executable" # Regular nodes that execute and produce outputs
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
ROOT = "root" # Nodes that can serve as execution entry points
class ErrorStrategy(StrEnum):
FAIL_BRANCH = "fail-branch"
DEFAULT_VALUE = "default-value"
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"
class WorkflowType(StrEnum):
"""
Workflow Type Enum for domain layer
"""
WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"
class WorkflowExecutionStatus(StrEnum):
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCEEDED = "partial-succeeded"
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = "total_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
TRIGGER_INFO = "trigger_info"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
class WorkflowNodeExecutionStatus(StrEnum):
PENDING = "pending" # Node is scheduled but not yet executing
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
STOPPED = "stopped"
PAUSED = "paused"
# Legacy statuses - kept for backward compatibility
RETRY = "retry" # Legacy: replaced by retry mechanism in error handling

View File

@ -1,8 +1,16 @@
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.node import Node
class WorkflowNodeRunFailedError(Exception):
def __init__(self, node: BaseNode, err_msg: str):
def __init__(self, node: Node, err_msg: str):
self._node = node
self._error = err_msg
super().__init__(f"Node {node.title} run failed: {err_msg}")
@property
def node(self) -> Node:
return self._node
@property
def error(self) -> str:
return self._error

View File

@ -0,0 +1,16 @@
from .edge import Edge
from .graph import Graph, NodeFactory
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .graph_template import GraphTemplate
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
__all__ = [
"Edge",
"Graph",
"GraphTemplate",
"NodeFactory",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
]

View File

@ -0,0 +1,15 @@
import uuid
from dataclasses import dataclass, field
from core.workflow.enums import NodeState
@dataclass
class Edge:
"""Edge connecting two nodes in a workflow graph."""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
tail: str = "" # tail node id (source)
head: str = "" # head node id (target)
source_handle: str = "source" # source handle for conditional branching
state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state

View File

@ -0,0 +1,346 @@
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .edge import Edge
logger = logging.getLogger(__name__)
class NodeFactory(Protocol):
"""
Protocol for creating Node instances from node data dictionaries.
This protocol decouples the Graph class from specific node mapping implementations,
allowing for different node creation strategies while maintaining type safety.
"""
def create_node(self, node_config: dict[str, object]) -> Node:
"""
Create a Node instance from node configuration data.
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
"""
...
@final
class Graph:
"""Graph representation with nodes and edges for workflow execution."""
def __init__(
self,
*,
nodes: dict[str, Node] | None = None,
edges: dict[str, Edge] | None = None,
in_edges: dict[str, list[str]] | None = None,
out_edges: dict[str, list[str]] | None = None,
root_node: Node,
):
"""
Initialize Graph instance.
:param nodes: graph nodes mapping (node id: node object)
:param edges: graph edges mapping (edge id: edge object)
:param in_edges: incoming edges mapping (node id: list of edge ids)
:param out_edges: outgoing edges mapping (node id: list of edge ids)
:param root_node: root node object
"""
self.nodes = nodes or {}
self.edges = edges or {}
self.in_edges = in_edges or {}
self.out_edges = out_edges or {}
self.root_node = root_node
@classmethod
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
"""
Parse node configurations and build a mapping of node IDs to configs.
:param node_configs: list of node configuration dictionaries
:return: mapping of node ID to node config
"""
node_configs_map: dict[str, dict[str, object]] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id or not isinstance(node_id, str):
continue
node_configs_map[node_id] = node_config
return node_configs_map
@classmethod
def _find_root_node_id(
cls,
node_configs_map: Mapping[str, Mapping[str, object]],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
"""
Find the root node ID if not specified.
:param node_configs_map: mapping of node ID to node config
:param edge_configs: list of edge configurations
:param root_node_id: explicitly specified root node ID
:return: determined root node ID
"""
if root_node_id:
if root_node_id not in node_configs_map:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
return root_node_id
# Find nodes with no incoming edges
nodes_with_incoming: set[str] = set()
for edge_config in edge_configs:
target = edge_config.get("target")
if isinstance(target, str):
nodes_with_incoming.add(target)
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid].get("data")
if not is_str_dict(node_data):
continue
node_type = node_data.get("type")
if not isinstance(node_type, str):
continue
if node_type in [NodeType.START, NodeType.DATASOURCE]:
start_node_id = nid
break
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
if not root_node_id:
raise ValueError("Unable to determine root node ID")
return root_node_id
@classmethod
def _build_edges(
cls, edge_configs: list[dict[str, object]]
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
"""
Build edge objects and mappings from edge configurations.
:param edge_configs: list of edge configurations
:return: tuple of (edges dict, in_edges dict, out_edges dict)
"""
edges: dict[str, Edge] = {}
in_edges: dict[str, list[str]] = defaultdict(list)
out_edges: dict[str, list[str]] = defaultdict(list)
edge_counter = 0
for edge_config in edge_configs:
source = edge_config.get("source")
target = edge_config.get("target")
if not is_str(source) or not is_str(target):
continue
# Create edge
edge_id = f"edge_{edge_counter}"
edge_counter += 1
source_handle = edge_config.get("sourceHandle", "source")
if not is_str(source_handle):
continue
edge = Edge(
id=edge_id,
tail=source,
head=target,
source_handle=source_handle,
)
edges[edge_id] = edge
out_edges[source].append(edge_id)
in_edges[target].append(edge_id)
return edges, dict(in_edges), dict(out_edges)
@classmethod
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, object]],
node_factory: "NodeFactory",
) -> dict[str, Node]:
"""
Create node instances from configurations using the node factory.
:param node_configs_map: mapping of node ID to node config
:param node_factory: factory for creating node instances
:return: mapping of node ID to node instance
"""
nodes: dict[str, Node] = {}
for node_id, node_config in node_configs_map.items():
try:
node_instance = node_factory.create_node(node_config)
except Exception:
logger.exception("Failed to create node instance for node_id %s", node_id)
raise
nodes[node_id] = node_instance
return nodes
@classmethod
def _mark_inactive_root_branches(
cls,
nodes: dict[str, Node],
edges: dict[str, Edge],
in_edges: dict[str, list[str]],
out_edges: dict[str, list[str]],
active_root_id: str,
) -> None:
"""
Mark nodes and edges from inactive root branches as skipped.
Algorithm:
1. Mark inactive root nodes as skipped
2. For skipped nodes, mark all their outgoing edges as skipped
3. For each edge marked as skipped, check its target node:
- If ALL incoming edges are skipped, mark the node as skipped
- Otherwise, leave the node state unchanged
:param nodes: mapping of node ID to node instance
:param edges: mapping of edge ID to edge instance
:param in_edges: mapping of node ID to incoming edge IDs
:param out_edges: mapping of node ID to outgoing edge IDs
:param active_root_id: ID of the active root node
"""
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
top_level_roots: list[str] = [
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
]
# If there's only one root or the active root is not a top-level root, no marking needed
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
return
# Mark inactive root nodes as skipped
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
for root_id in inactive_roots:
if root_id in nodes:
nodes[root_id].state = NodeState.SKIPPED
# Recursively mark downstream nodes and edges
def mark_downstream(node_id: str) -> None:
"""Recursively mark downstream nodes and edges as skipped."""
if nodes[node_id].state != NodeState.SKIPPED:
return
# If this node is skipped, mark all its outgoing edges as skipped
out_edge_ids = out_edges.get(node_id, [])
for edge_id in out_edge_ids:
edge = edges[edge_id]
edge.state = NodeState.SKIPPED
# Check the target node of this edge
target_node = nodes[edge.head]
in_edge_ids = in_edges.get(target_node.id, [])
in_edge_states = [edges[eid].state for eid in in_edge_ids]
# If all incoming edges are skipped, mark the node as skipped
if all(state == NodeState.SKIPPED for state in in_edge_states):
target_node.state = NodeState.SKIPPED
# Recursively process downstream nodes
mark_downstream(target_node.id)
# Process each inactive root and its downstream nodes
for root_id in inactive_roots:
mark_downstream(root_id)
@classmethod
def init(
cls,
*,
graph_config: Mapping[str, object],
node_factory: "NodeFactory",
root_node_id: str | None = None,
) -> "Graph":
"""
Initialize graph
:param graph_config: graph config containing nodes and edges
:param node_factory: factory for creating node instances from config data
:param root_node_id: root node id
:return: graph instance
"""
# Parse configs
edge_configs = graph_config.get("edges", [])
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)
# Find root node
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
# Build edges
edges, in_edges, out_edges = cls._build_edges(edge_configs)
# Create node instances
nodes = cls._create_node_instances(node_configs_map, node_factory)
# Get root node instance
root_node = nodes[root_node_id]
# Mark inactive root branches as skipped
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
# Create and return the graph
return cls(
nodes=nodes,
edges=edges,
in_edges=in_edges,
out_edges=out_edges,
root_node=root_node,
)
@property
def node_ids(self) -> list[str]:
"""
Get list of node IDs (compatibility property for existing code)
:return: list of node IDs
"""
return list(self.nodes.keys())
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
"""
Get all outgoing edges from a node (V2 method)
:param node_id: node id
:return: list of outgoing edges
"""
edge_ids = self.out_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
def get_incoming_edges(self, node_id: str) -> list[Edge]:
"""
Get all incoming edges to a node (V2 method)
:param node_id: node id
:return: list of incoming edges
"""
edge_ids = self.in_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]

View File

@ -0,0 +1,61 @@
from collections.abc import Mapping
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (read-only)."""
...
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (read-only)."""
...
class ReadOnlyGraphRuntimeState(Protocol):
"""
Read-only view of GraphRuntimeState for layers.
This protocol defines a read-only interface that prevents layers from
modifying the graph runtime state while still allowing observation.
All methods return defensive copies to ensure immutability.
"""
@property
def variable_pool(self) -> ReadOnlyVariablePool:
"""Get read-only access to the variable pool."""
...
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
...
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
...
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
...
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
...
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
...
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
...

View File

@ -0,0 +1,20 @@
from typing import Any
from pydantic import BaseModel, Field
class GraphTemplate(BaseModel):
"""
Graph Template for container nodes and subgraph expansion
According to GraphEngine V2 spec, GraphTemplate contains:
- nodes: mapping of node definitions
- edges: mapping of edge definitions
- root_ids: list of root node IDs
- output_selectors: list of output selectors for the template
"""
nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping")
edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping")
root_ids: list[str] = Field(default_factory=list, description="root node IDs")
output_selectors: list[str] = Field(default_factory=list, description="output selectors")

View File

@ -0,0 +1,77 @@
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
class ReadOnlyVariablePoolWrapper:
"""Wrapper that provides read-only access to VariablePool."""
def __init__(self, variable_pool: VariablePool):
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (returns a defensive copy)."""
value = self._variable_pool.get([node_id, variable_key])
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (returns defensive copies)."""
variables: dict[str, object] = {}
if node_id in self._variable_pool.variable_dictionary:
for key, var in self._variable_pool.variable_dictionary[node_id].items():
# Variables have a value property that contains the actual data
variables[key] = deepcopy(var.value)
return variables
class ReadOnlyGraphRuntimeStateWrapper:
"""
Wrapper that provides read-only access to GraphRuntimeState.
This wrapper ensures that layers can observe the state without
modifying it. All returned values are defensive copies.
"""
def __init__(self, state: GraphRuntimeState):
self._state = state
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
@property
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
"""Get read-only access to the variable pool."""
return self._variable_pool_wrapper
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
return self._state.start_at
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
return self._state.total_tokens
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
# Return a copy to prevent modification
return self._state.llm_usage.model_copy()
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
return deepcopy(self._state.outputs)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
return self._state.node_run_steps
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
return self._state.get_output(key, default)

View File

@ -1,4 +1,3 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
from .graph_engine import GraphEngine
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
__all__ = ["GraphEngine"]

View File

@ -0,0 +1,33 @@
# Command Channels
Channel implementations for external workflow control.
## Components
### InMemoryChannel
Thread-safe in-memory queue for single-process deployments.
- `fetch_commands()` - Get pending commands
- `send_command()` - Add command to queue
### RedisChannel
Redis-based queue for distributed deployments.
- `fetch_commands()` - Get commands with JSON deserialization
- `send_command()` - Store commands with TTL
## Usage
```python
# Local execution
channel = InMemoryChannel()
channel.send_command(AbortCommand(graph_id="workflow-123"))
# Distributed execution
redis_channel = RedisChannel(
redis_client=redis_client,
channel_key="workflow:123:commands"
)
```

View File

@ -0,0 +1,6 @@
"""Command channel implementations for GraphEngine."""
from .in_memory_channel import InMemoryChannel
from .redis_channel import RedisChannel
__all__ = ["InMemoryChannel", "RedisChannel"]

View File

@ -0,0 +1,53 @@
"""
In-memory implementation of CommandChannel for local/testing scenarios.
This implementation uses a thread-safe queue for command communication
within a single process. Each instance handles commands for one workflow execution.
"""
from queue import Queue
from typing import final
from ..entities.commands import GraphEngineCommand
@final
class InMemoryChannel:
"""
In-memory command channel implementation using a thread-safe queue.
Each instance is dedicated to a single GraphEngine/workflow execution.
Suitable for local development, testing, and single-instance deployments.
"""
def __init__(self) -> None:
"""Initialize the in-memory channel with a single queue."""
self._queue: Queue[GraphEngineCommand] = Queue()
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch all pending commands from the queue.
Returns:
List of pending commands (drains the queue)
"""
commands: list[GraphEngineCommand] = []
# Drain all available commands from the queue
while not self._queue.empty():
try:
command = self._queue.get_nowait()
commands.append(command)
except Exception:
break
return commands
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to this channel's queue.
Args:
command: The command to send
"""
self._queue.put(command)

View File

@ -0,0 +1,114 @@
"""
Redis-based implementation of CommandChannel for distributed scenarios.
This implementation uses Redis lists for command queuing, supporting
multi-instance deployments and cross-server communication.
Each instance uses a unique key for its command queue.
"""
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@final
class RedisChannel:
"""
Redis-based command channel implementation for distributed systems.
Each instance uses a unique Redis key for its command queue.
Commands are JSON-serialized for transport.
"""
def __init__(
self,
redis_client: "RedisClientWrapper",
channel_key: str,
command_ttl: int = 3600,
) -> None:
"""
Initialize the Redis channel.
Args:
redis_client: Redis client instance
channel_key: Unique key for this channel's command queue
command_ttl: TTL for command keys in seconds (default: 3600)
"""
self._redis = redis_client
self._key = channel_key
self._command_ttl = command_ttl
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch all pending commands from Redis.
Returns:
List of pending commands (drains the Redis list)
"""
commands: list[GraphEngineCommand] = []
# Use pipeline for atomic operations
with self._redis.pipeline() as pipe:
# Get all commands and clear the list atomically
pipe.lrange(self._key, 0, -1)
pipe.delete(self._key)
results = pipe.execute()
# Parse commands from JSON
if results[0]:
for command_json in results[0]:
try:
command_data = json.loads(command_json)
command = self._deserialize_command(command_data)
if command:
commands.append(command)
except (json.JSONDecodeError, ValueError):
# Skip invalid commands
continue
return commands
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to Redis.
Args:
command: The command to send
"""
command_json = json.dumps(command.model_dump())
# Push to list and set expiry
with self._redis.pipeline() as pipe:
pipe.rpush(self._key, command_json)
pipe.expire(self._key, self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
"""
Deserialize a command from dictionary data.
Args:
data: Command data dictionary
Returns:
Deserialized command or None if invalid
"""
command_type_value = data.get("command_type")
if not isinstance(command_type_value, str):
return None
try:
command_type = CommandType(command_type_value)
if command_type == CommandType.ABORT:
return AbortCommand(**data)
else:
# For other command types, use base class
return GraphEngineCommand(**data)
except (ValueError, TypeError):
return None

View File

@ -0,0 +1,14 @@
"""
Command processing subsystem for graph engine.
This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
]

View File

@ -0,0 +1,32 @@
"""
Command handler implementations.
"""
import logging
from typing import final
from typing_extensions import override
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@final
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
"""
Handle an abort command.
Args:
command: The abort command
execution: Graph execution to abort
"""
assert isinstance(command, AbortCommand)
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
execution.abort(command.reason or "User requested abort")

View File

@ -0,0 +1,79 @@
"""
Main command processor for handling external commands.
"""
import logging
from typing import Protocol, final
from ..domain.graph_execution import GraphExecution
from ..entities.commands import GraphEngineCommand
from ..protocols.command_channel import CommandChannel
logger = logging.getLogger(__name__)
class CommandHandler(Protocol):
"""Protocol for command handlers."""
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
@final
class CommandProcessor:
"""
Processes external commands sent to the engine.
This polls the command channel and dispatches commands to
appropriate handlers.
"""
def __init__(
self,
command_channel: CommandChannel,
graph_execution: GraphExecution,
) -> None:
"""
Initialize the command processor.
Args:
command_channel: Channel for receiving commands
graph_execution: Graph execution aggregate
"""
self._command_channel = command_channel
self._graph_execution = graph_execution
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
"""
Register a handler for a command type.
Args:
command_type: Type of command to handle
handler: Handler for the command
"""
self._handlers[command_type] = handler
def process_commands(self) -> None:
"""Check for and process any pending commands."""
try:
commands = self._command_channel.fetch_commands()
for command in commands:
self._handle_command(command)
except Exception as e:
logger.warning("Error processing commands: %s", e)
def _handle_command(self, command: GraphEngineCommand) -> None:
"""
Handle a single command.
Args:
command: The command to handle
"""
handler = self._handlers.get(type(command))
if handler:
try:
handler.handle(command, self._graph_execution)
except Exception:
logger.exception("Error handling command %s", command.__class__.__name__)
else:
logger.warning("No handler registered for command: %s", command.__class__.__name__)

View File

@ -1,25 +0,0 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class RunConditionHandler(ABC):
def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition
@abstractmethod
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
raise NotImplementedError

View File

@ -1,25 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
run_result = previous_route_node_state.node_run_result
if not run_result:
return False
if not run_result.edge_source_handle:
return False
return self.condition.branch_identify == run_result.edge_source_handle

View File

@ -1,27 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.conditions:
return True
# process condition
condition_processor = ConditionProcessor()
_, _, final_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool,
conditions=self.condition.conditions,
operator="and",
)
return final_result

View File

@ -1,25 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(
init_params: GraphInitParams, graph: Graph, run_condition: RunCondition
) -> RunConditionHandler:
"""
Get condition handler
:param init_params: init params
:param graph: graph
:param run_condition: run condition
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition)
else:
return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition)

View File

@ -0,0 +1,14 @@
"""
Domain models for graph engine.
This package contains the core domain entities, value objects, and aggregates
that represent the business concepts of workflow graph execution.
"""
from .graph_execution import GraphExecution
from .node_execution import NodeExecution
__all__ = [
"GraphExecution",
"NodeExecution",
]

View File

@ -0,0 +1,215 @@
"""GraphExecution aggregate root managing the overall graph execution state."""
from __future__ import annotations
from dataclasses import dataclass, field
from importlib import import_module
from typing import Literal
from pydantic import BaseModel, Field
from core.workflow.enums import NodeState
from .node_execution import NodeExecution
class GraphExecutionErrorState(BaseModel):
"""Serializable representation of an execution error."""
module: str = Field(description="Module containing the exception class")
qualname: str = Field(description="Qualified name of the exception class")
message: str | None = Field(default=None, description="Exception message string")
class NodeExecutionState(BaseModel):
"""Serializable representation of a node execution entity."""
node_id: str
state: NodeState = Field(default=NodeState.UNKNOWN)
retry_count: int = Field(default=0)
execution_id: str | None = Field(default=None)
error: str | None = Field(default=None)
class GraphExecutionState(BaseModel):
"""Pydantic model describing serialized GraphExecution state."""
type: Literal["GraphExecution"] = Field(default="GraphExecution")
version: str = Field(default="1.0")
workflow_id: str
started: bool = Field(default=False)
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
"""Convert an exception into its serializable representation."""
if error is None:
return None
return GraphExecutionErrorState(
module=error.__class__.__module__,
qualname=error.__class__.__qualname__,
message=str(error),
)
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
"""Locate an exception class from its module and qualified name."""
module = import_module(module_name)
attr: object = module
for part in qualname.split("."):
attr = getattr(attr, part)
if isinstance(attr, type) and issubclass(attr, Exception):
return attr
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
"""Reconstruct an exception instance from serialized data."""
if state is None:
return None
try:
exception_class = _resolve_exception_class(state.module, state.qualname)
if state.message is None:
return exception_class()
return exception_class(state.message)
except Exception:
# Fallback to RuntimeError when reconstruction fails
if state.message is None:
return RuntimeError(state.qualname)
return RuntimeError(state.message)
@dataclass
class GraphExecution:
"""
Aggregate root for graph execution.
This manages the overall execution state of a workflow graph,
coordinating between multiple node executions.
"""
workflow_id: str
started: bool = False
completed: bool = False
aborted: bool = False
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
def start(self) -> None:
"""Mark the graph execution as started."""
if self.started:
raise RuntimeError("Graph execution already started")
self.started = True
def complete(self) -> None:
"""Mark the graph execution as completed."""
if not self.started:
raise RuntimeError("Cannot complete execution that hasn't started")
if self.completed:
raise RuntimeError("Graph execution already completed")
self.completed = True
def abort(self, reason: str) -> None:
"""Abort the graph execution."""
self.aborted = True
self.error = RuntimeError(f"Aborted: {reason}")
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
self.error = error
self.completed = True
def get_or_create_node_execution(self, node_id: str) -> NodeExecution:
"""Get or create a node execution entity."""
if node_id not in self.node_executions:
self.node_executions[node_id] = NodeExecution(node_id=node_id)
return self.node_executions[node_id]
@property
def is_running(self) -> bool:
"""Check if the execution is currently running."""
return self.started and not self.completed and not self.aborted
@property
def has_error(self) -> bool:
"""Check if the execution has encountered an error."""
return self.error is not None
@property
def error_message(self) -> str | None:
"""Get the error message if an error exists."""
if not self.error:
return None
return str(self.error)
def dumps(self) -> str:
"""Serialize the aggregate state into a JSON string."""
node_states = [
NodeExecutionState(
node_id=node_id,
state=node_execution.state,
retry_count=node_execution.retry_count,
execution_id=node_execution.execution_id,
error=node_execution.error,
)
for node_id, node_execution in sorted(self.node_executions.items())
]
state = GraphExecutionState(
workflow_id=self.workflow_id,
started=self.started,
completed=self.completed,
aborted=self.aborted,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""Restore aggregate state from a serialized JSON string."""
state = GraphExecutionState.model_validate_json(data)
if state.type != "GraphExecution":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported serialized version: {state.version}")
if self.workflow_id != state.workflow_id:
raise ValueError("Serialized workflow_id does not match aggregate identity")
self.started = state.started
self.completed = state.completed
self.aborted = state.aborted
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {
item.node_id: NodeExecution(
node_id=item.node_id,
state=item.state,
retry_count=item.retry_count,
execution_id=item.execution_id,
error=item.error,
)
for item in state.node_executions
}
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1

View File

@ -0,0 +1,45 @@
"""
NodeExecution entity representing a node's execution state.
"""
from dataclasses import dataclass
from core.workflow.enums import NodeState
@dataclass
class NodeExecution:
"""
Entity representing the execution state of a single node.
This is a mutable entity that tracks the runtime state of a node
during graph execution.
"""
node_id: str
state: NodeState = NodeState.UNKNOWN
retry_count: int = 0
execution_id: str | None = None
error: str | None = None
def mark_started(self, execution_id: str) -> None:
"""Mark the node as started with an execution ID."""
self.state = NodeState.TAKEN
self.execution_id = execution_id
def mark_taken(self) -> None:
"""Mark the node as successfully completed."""
self.state = NodeState.TAKEN
self.error = None
def mark_failed(self, error: str) -> None:
"""Mark the node as failed with an error."""
self.error = error
def mark_skipped(self) -> None:
"""Mark the node as skipped."""
self.state = NodeState.SKIPPED
def increment_retry(self) -> None:
"""Increment the retry count for this node."""
self.retry_count += 1

View File

@ -1,6 +0,0 @@
from .graph import Graph
from .graph_init_params import GraphInitParams
from .graph_runtime_state import GraphRuntimeState
from .runtime_route_state import RuntimeRouteState
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

View File

@ -0,0 +1,33 @@
"""
GraphEngine command entities for external control.
This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = "abort"
PAUSE = "pause"
RESUME = "resume"
class GraphEngineCommand(BaseModel):
"""Base class for all GraphEngine commands."""
command_type: CommandType = Field(..., description="Type of command")
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")
class AbortCommand(GraphEngineCommand):
"""Command to abort a running workflow execution."""
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
reason: str | None = Field(default=None, description="Optional reason for abort")

View File

@ -1,277 +0,0 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class GraphEngineEvent(BaseModel):
pass
###########################################
# Graph Events
###########################################
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphRunStartedEvent(BaseGraphEvent):
pass
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: Optional[dict[str, Any]] = None
"""outputs"""
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: int = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
exceptions_count: int = Field(..., description="exception count")
outputs: Optional[dict[str, Any]] = None
###########################################
# Node Events
###########################################
class BaseNodeEvent(GraphEngineEvent):
id: str = Field(..., description="node execution id")
node_id: str = Field(..., description="node id")
node_type: NodeType = Field(..., description="node type")
node_data: BaseNodeData = Field(..., description="node data")
route_node_state: RouteNodeState = Field(..., description="route node state")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
"""predecessor node id"""
parallel_mode_run_id: Optional[str] = None
"""iteration node parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(BaseNodeEvent):
pass
class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunExceptionEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInLoopFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
start_at: datetime = Field(..., description="retry start time")
###########################################
# Parallel Branch Events
###########################################
class BaseParallelBranchEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
"""parallel id"""
parallel_start_node_id: str = Field(..., description="parallel start node id")
"""parallel start node id"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
error: str = Field(..., description="failed reason")
###########################################
# Iteration Events
###########################################
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration node execution id")
iteration_node_id: str = Field(..., description="iteration node id")
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
iteration_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteration run in parallel mode run id"""
class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
class IterationRunNextEvent(BaseIterationEvent):
index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = None
duration: Optional[float] = None
class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
iteration_duration_map: Optional[dict[str, float]] = None
class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")
###########################################
# Loop Events
###########################################
class BaseLoopEvent(GraphEngineEvent):
loop_id: str = Field(..., description="loop node execution id")
loop_node_id: str = Field(..., description="loop node id")
loop_node_type: NodeType = Field(..., description="node type, loop or loop")
loop_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""loop run in parallel mode run id"""
class LoopRunStartedEvent(BaseLoopEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
class LoopRunNextEvent(BaseLoopEvent):
index: int = Field(..., description="index")
pre_loop_output: Optional[Any] = None
duration: Optional[float] = None
class LoopRunSucceededEvent(BaseLoopEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
loop_duration_map: Optional[dict[str, float]] = None
class LoopRunFailedEvent(BaseLoopEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")
###########################################
# Agent Events
###########################################
class BaseAgentEvent(GraphEngineEvent):
pass
class AgentLogEvent(BaseAgentEvent):
id: str = Field(..., description="id")
label: str = Field(..., description="label")
node_execution_id: str = Field(..., description="node execution id")
parent_id: str | None = Field(..., description="parent id")
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
node_id: str = Field(..., description="agent node id")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent

View File

@ -1,719 +0,0 @@
import uuid
from collections import defaultdict
from collections.abc import Mapping
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from configs import dify_config
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
from core.workflow.nodes.end.entities import EndStreamParam
class GraphEdge(BaseModel):
source_node_id: str = Field(..., description="source node id")
target_node_id: str = Field(..., description="target node id")
run_condition: Optional[RunCondition] = None
"""run condition"""
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = None
"""parent parallel id"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
class Graph(BaseModel):
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=dict, description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict, description="graph edge mapping (source node id: edges)"
)
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict, description="reverse graph edge mapping (target node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict, description="graph parallel mapping (parallel id: parallel)"
)
node_parallel_mapping: dict[str, str] = Field(
default_factory=dict, description="graph node parallel mapping (node id: parallel id)"
)
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes")
end_stream_param: EndStreamParam = Field(..., description="end stream param")
@classmethod
def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph":
"""
Init graph
:param graph_config: graph config
:param root_node_id: root node id
:return: graph
"""
# edge configs
edge_configs = graph_config.get("edges")
if edge_configs is None:
edge_configs = []
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")
edge_configs = cast(list, edge_configs)
node_configs = cast(list, node_configs)
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
fail_branch_source_node_id = [
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
]
for edge_config in edge_configs:
source_node_id = edge_config.get("source")
if not source_node_id:
continue
if source_node_id not in edge_mapping:
edge_mapping[source_node_id] = []
target_node_id = edge_config.get("target")
if not target_node_id:
continue
if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = []
target_edge_ids.add(target_node_id)
# parse run condition
run_condition = None
if edge_config.get("sourceHandle"):
if (
edge_config.get("source") in fail_branch_source_node_id
and edge_config.get("sourceHandle") != "fail-branch"
):
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
elif edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
)
graph_edge = GraphEdge(
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
)
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id:
continue
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
all_node_id_config_mapping[node_id] = node_config
root_node_ids = [node_config.get("id") for node_config in root_node_configs]
# fetch root node
if not root_node_id:
# if no root node id, use any start node (START or trigger types) as root node
root_node_id = next(
(
node_config.get("id")
for node_config in root_node_configs
if NodeType(node_config.get("data", {}).get("type", "")).is_start_node
),
None,
)
if not root_node_id or root_node_id not in root_node_ids:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Check whether it is connected to the previous node
cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping)
# fetch all node ids from root node
node_ids = [root_node_id]
cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id)
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
# init parallel mapping
parallel_mapping: dict[str, GraphParallel] = {}
node_parallel_mapping: dict[str, str] = {}
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=root_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
)
# Check if it exceeds N layers of parallel
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
parent_parallel_id=parallel.parent_parallel_id,
)
# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping
)
# init end stream param
end_stream_param = EndStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
node_parallel_mapping=node_parallel_mapping,
)
# init graph
graph = cls(
root_node_id=root_node_id,
node_ids=node_ids,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes,
end_stream_param=end_stream_param,
)
return graph
def add_extra_edge(
self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None
) -> None:
"""
Add extra edge to the graph
:param source_node_id: source node id
:param target_node_id: target node id
:param run_condition: run condition
"""
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
return
if source_node_id not in self.edge_mapping:
self.edge_mapping[source_node_id] = []
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
return
graph_edge = GraphEdge(
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
)
self.edge_mapping[source_node_id].append(graph_edge)
def get_leaf_node_ids(self) -> list[str]:
"""
Get leaf node ids of the graph
:return: leaf node ids
"""
leaf_node_ids = []
for node_id in self.node_ids:
if node_id not in self.edge_mapping or (
len(self.edge_mapping[node_id]) == 1
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id
):
leaf_node_ids.append(node_id)
return leaf_node_ids
@classmethod
def _recursively_add_node_ids(
cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str
) -> None:
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id
)
@classmethod
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None:
"""
Check whether it is connected to the previous node
"""
last_node_id = route[-1]
for graph_edge in edge_mapping.get(last_node_id, []):
if not graph_edge.target_node_id:
continue
if graph_edge.target_node_id in route:
raise ValueError(
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
)
new_route = route.copy()
new_route.append(graph_edge.target_node_id)
cls._check_connected_to_previous_node(
route=new_route,
edge_mapping=edge_mapping,
)
@classmethod
def _recursively_add_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str],
parent_parallel: Optional[GraphParallel] = None,
) -> None:
"""
Recursively add parallel ids
:param edge_mapping: edge mapping
:param start_node_id: start from node id
:param parallel_mapping: parallel mapping
:param node_parallel_mapping: node parallel mapping
:param parent_parallel: parent parallel
"""
target_node_edges = edge_mapping.get(start_node_id, [])
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = defaultdict(list)
condition_edge_mappings = defaultdict(list)
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
condition_edge_mappings[condition_hash].append(graph_edge)
for condition_hash, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
for graph_edge in graph_edges:
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
condition_parallels = {}
for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
# any target node id in node_parallel_mapping
parallel = None
if condition_parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
)
parallel_mapping[parallel.id] = parallel
condition_parallels[condition_hash] = parallel
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=condition_parallel_branch_node_ids,
)
# collect all branches node ids
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break
if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id
outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue
node_edges = edge_mapping.get(node_id)
if not node_edges:
continue
if len(node_edges) > 1:
continue
target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue
if (
(
node_parallel_mapping.get(target_node_id)
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
)
or (
parent_parallel
and parent_parallel.end_to_node_id
and target_node_id == parent_parallel.end_to_node_id
)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
if len(outside_parallel_target_node_ids) == 1:
if (
parent_parallel
and parent_parallel.end_to_node_id
and parallel.end_to_node_id == parent_parallel.end_to_node_id
):
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
if condition_edge_mappings:
for condition_hash, graph_edges in condition_edge_mappings.items():
for graph_edge in graph_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=condition_parallels.get(condition_hash),
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
@classmethod
def _get_current_parallel(
cls,
parallel_mapping: dict[str, GraphParallel],
graph_edge: GraphEdge,
parallel: Optional[GraphParallel] = None,
parent_parallel: Optional[GraphParallel] = None,
) -> Optional[GraphParallel]:
"""
Get current parallel
"""
current_parallel = None
if parallel:
current_parallel = parallel
elif parent_parallel:
if not parent_parallel.end_to_node_id or (
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
):
current_parallel = parent_parallel
else:
# fetch parent parallel's parent parallel
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
if parent_parallel_parent_parallel_id:
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
if parent_parallel_parent_parallel and (
not parent_parallel_parent_parallel.end_to_node_id
or (
parent_parallel_parent_parallel.end_to_node_id
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
)
):
current_parallel = parent_parallel_parent_parallel
return current_parallel
@classmethod
def _check_exceed_parallel_limit(
cls,
parallel_mapping: dict[str, GraphParallel],
level_limit: int,
parent_parallel_id: str,
current_level: int = 1,
) -> None:
"""
Check if it exceeds N layers of parallel
"""
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
return
current_level += 1
if current_level > level_limit:
raise ValueError(f"Exceeds {level_limit} layers of parallel")
if parent_parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=level_limit,
parent_parallel_id=parent_parallel.parent_parallel_id,
current_level=current_level,
)
@classmethod
def _recursively_add_parallel_node_ids(
cls,
branch_node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
merge_node_id: str,
start_node_id: str,
) -> None:
"""
Recursively add node ids
:param branch_node_ids: in branch node ids
:param edge_mapping: edge mapping
:param merge_node_id: merge node id
:param start_node_id: start node id
"""
for graph_edge in edge_mapping.get(start_node_id, []):
if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids:
branch_node_ids.append(graph_edge.target_node_id)
cls._recursively_add_parallel_node_ids(
branch_node_ids=branch_node_ids,
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=graph_edge.target_node_id,
)
@classmethod
def _fetch_all_node_ids_in_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
parallel_branch_node_ids: list[str],
) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
"""
routes_node_ids: dict[str, list[str]] = {}
for parallel_branch_node_id in parallel_branch_node_ids:
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
# fetch routes node ids
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=parallel_branch_node_id,
routes_node_ids=routes_node_ids[parallel_branch_node_id],
)
# fetch leaf node ids from routes node ids
leaf_node_ids: dict[str, list[str]] = {}
merge_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
if branch_node_id not in leaf_node_ids:
leaf_node_ids[branch_node_id] = []
leaf_node_ids[branch_node_id].append(node_id)
for branch_node_id2, inner_route2 in routes_node_ids.items():
if (
branch_node_id != branch_node_id2
and node_id in inner_route2
and len(reverse_edge_mapping.get(node_id, [])) > 1
and cls._is_node_in_routes(
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=node_id,
routes_node_ids=routes_node_ids,
)
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []
if branch_node_id2 not in merge_branch_node_ids[node_id]:
merge_branch_node_ids[node_id].append(branch_node_id2)
# sorted merge_branch_node_ids by branch_node_ids length desc
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
duplicate_end_node_ids = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
if (node_id, node_id2) not in duplicate_end_node_ids and (
node_id2,
node_id,
) not in duplicate_end_node_ids:
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
if len(branch_node_ids) <= 1:
continue
for branch_node_id in branch_node_ids:
if branch_node_id in branches_merge_node_ids:
continue
branches_merge_node_ids[branch_node_id] = node_id
in_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
in_branch_node_ids[branch_node_id] = []
if branch_node_id not in branches_merge_node_ids:
# all node ids in current branch is in this thread
in_branch_node_ids[branch_node_id].append(branch_node_id)
in_branch_node_ids[branch_node_id].extend(node_ids)
else:
merge_node_id = branches_merge_node_ids[branch_node_id]
if merge_node_id != branch_node_id:
in_branch_node_ids[branch_node_id].append(branch_node_id)
# fetch all node ids from branch_node_id and merge_node_id
cls._recursively_add_parallel_node_ids(
branch_node_ids=in_branch_node_ids[branch_node_id],
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=branch_node_id,
)
return in_branch_node_ids
@classmethod
def _recursively_fetch_routes(
cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str]
) -> None:
"""
Recursively fetch route
"""
if start_node_id not in edge_mapping:
return
for graph_edge in edge_mapping[start_node_id]:
# find next node ids
if graph_edge.target_node_id not in routes_node_ids:
routes_node_ids.append(graph_edge.target_node_id)
cls._recursively_fetch_routes(
edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids
)
@classmethod
def _is_node_in_routes(
cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]]
) -> bool:
"""
Recursively check if the node is in the routes
"""
if start_node_id not in reverse_edge_mapping:
return False
all_routes_node_ids = set()
parallel_start_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
all_routes_node_ids.update(node_ids)
if branch_node_id in reverse_edge_mapping:
for graph_edge in reverse_edge_mapping[branch_node_id]:
if graph_edge.source_node_id not in parallel_start_node_ids:
parallel_start_node_ids[graph_edge.source_node_id] = []
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
for _, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == set(routes_node_ids.keys()):
return True
return False
@classmethod
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
"""
is node2 after node1
"""
if node1_id not in edge_mapping:
return False
for graph_edge in edge_mapping[node1_id]:
if graph_edge.target_node_id == node2_id:
return True
if cls._is_node2_after_node1(
node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping
):
return True
return False

View File

@ -1,31 +0,0 @@
from typing import Any
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
start_at: float = Field(..., description="start time")
"""start time"""
total_tokens: int = 0
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
#
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
# after a serialization and deserialization round trip.
outputs: dict[str, Any] = Field(default_factory=dict)
node_run_steps: int = 0
"""node run steps"""
node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state"""

View File

@ -1,118 +0,0 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from libs.datetime_utils import naive_utc_now
class RouteNodeState(BaseModel):
class Status(Enum):
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
EXCEPTION = "exception"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
node_id: str
"""node id"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.RUNNING
"""node status"""
start_at: datetime
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
index: int = 1
def set_finished(self, run_result: NodeRunResult) -> None:
"""
Node finished
:param run_result: run result
"""
if self.status in {
RouteNodeState.Status.SUCCESS,
RouteNodeState.Status.FAILED,
RouteNodeState.Status.EXCEPTION,
}:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
self.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
self.status = RouteNodeState.Status.EXCEPTION
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result
self.finished_at = naive_utc_now()
class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field(
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
)
node_state_mapping: dict[str, RouteNodeState] = Field(
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
)
def create_node_state(self, node_id: str) -> RouteNodeState:
"""
Create node state
:param node_id: node id
"""
state = RouteNodeState(node_id=node_id, start_at=naive_utc_now())
self.node_state_mapping[state.id] = state
return state
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
"""
Add route to the graph state
:param source_node_state_id: source node state id
:param target_node_state_id: target node state id
"""
if source_node_state_id not in self.routes:
self.routes[source_node_state_id] = []
self.routes[source_node_state_id].append(target_node_state_id)
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]:
"""
Get routes with node state by source node id
:param source_node_state_id: source node state id
:return: routes with node state
"""
return [
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
]

View File

@ -0,0 +1,211 @@
"""
Main error handler that coordinates error strategies.
"""
import logging
import time
from typing import TYPE_CHECKING, final
from core.workflow.enums import (
ErrorStrategy as ErrorStrategyEnum,
)
from core.workflow.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetryEvent,
)
from core.workflow.node_events import NodeRunResult
if TYPE_CHECKING:
from .domain import GraphExecution
logger = logging.getLogger(__name__)
@final
class ErrorHandler:
"""
Coordinates error handling strategies for node failures.
This acts as a facade for the various error strategies,
selecting and applying the appropriate strategy based on
node configuration.
"""
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
"""
Initialize the error handler.
Args:
graph: The workflow graph
graph_execution: The graph execution state
"""
self._graph = graph
self._graph_execution = graph_execution
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
"""
Handle a node failure event.
Selects and applies the appropriate error strategy based on
the node's configuration.
Args:
event: The node failure event
Returns:
Optional new event to process, or None to abort
"""
node = self._graph.nodes[event.node_id]
# Get retry count from NodeExecution
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
retry_count = node_execution.retry_count
# First check if retry is configured and not exhausted
if node.retry and retry_count < node.retry_config.max_retries:
result = self._handle_retry(event, retry_count)
if result:
# Retry count will be incremented when NodeRunRetryEvent is handled
return result
# Apply configured error strategy
strategy = node.error_strategy
match strategy:
case None:
return self._handle_abort(event)
case ErrorStrategyEnum.FAIL_BRANCH:
return self._handle_fail_branch(event)
case ErrorStrategyEnum.DEFAULT_VALUE:
return self._handle_default_value(event)
def _handle_abort(self, event: NodeRunFailedEvent):
"""
Handle error by aborting execution.
This is the default strategy when no other strategy is specified.
It stops the entire graph execution when a node fails.
Args:
event: The failure event
Returns:
None - signals abortion
"""
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
# Return None to signal that execution should stop
def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int):
"""
Handle error by retrying the node.
This strategy re-attempts node execution up to a configured
maximum number of retries with configurable intervals.
Args:
event: The failure event
retry_count: Current retry attempt count
Returns:
NodeRunRetryEvent if retry should occur, None otherwise
"""
node = self._graph.nodes[event.node_id]
# Check if we've exceeded max retries
if not node.retry or retry_count >= node.retry_config.max_retries:
return None
# Wait for retry interval
time.sleep(node.retry_config.retry_interval_seconds)
# Create retry event
return NodeRunRetryEvent(
id=event.id,
node_title=node.title,
node_id=event.node_id,
node_type=event.node_type,
node_run_result=event.node_run_result,
start_at=event.start_at,
error=event.error,
retry_index=retry_count + 1,
)
def _handle_fail_branch(self, event: NodeRunFailedEvent):
"""
Handle error by taking the fail branch.
This strategy converts failures to exceptions and routes execution
through a designated fail-branch edge.
Args:
event: The failure event
Returns:
NodeRunExceptionEvent to continue via fail branch
"""
outputs = {
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,
}
return NodeRunExceptionEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
edge_source_handle="fail-branch",
metadata={
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH,
},
),
error=event.error,
)
def _handle_default_value(self, event: NodeRunFailedEvent):
"""
Handle error by using default values.
This strategy allows nodes to fail gracefully by providing
predefined default output values.
Args:
event: The failure event
Returns:
NodeRunExceptionEvent with default values
"""
node = self._graph.nodes[event.node_id]
outputs = {
**node.default_value_dict,
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,
}
return NodeRunExceptionEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
metadata={
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE,
},
),
error=event.error,
)

View File

@ -0,0 +1,14 @@
"""
Event management subsystem for graph engine.
This package handles event routing, collection, and emission for
workflow graph execution events.
"""
from .event_handlers import EventHandler
from .event_manager import EventManager
__all__ = [
"EventHandler",
"EventManager",
]

View File

@ -0,0 +1,311 @@
"""
Event handler implementations for different event types.
"""
import logging
from collections.abc import Mapping
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from ..domain.graph_execution import GraphExecution
from ..response_coordinator import ResponseStreamCoordinator
if TYPE_CHECKING:
from ..error_handler import ErrorHandler
from ..graph_state_manager import GraphStateManager
from ..graph_traversal import EdgeProcessor
from .event_manager import EventManager
logger = logging.getLogger(__name__)
@final
class EventHandler:
"""
Registry of event handlers for different event types.
This centralizes the business logic for handling specific events,
keeping it separate from the routing and collection infrastructure.
"""
def __init__(
self,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
graph_execution: GraphExecution,
response_coordinator: ResponseStreamCoordinator,
event_collector: "EventManager",
edge_processor: "EdgeProcessor",
state_manager: "GraphStateManager",
error_handler: "ErrorHandler",
) -> None:
"""
Initialize the event handler registry.
Args:
graph: The workflow graph
graph_runtime_state: Runtime state with variable pool
graph_execution: Graph execution aggregate
response_coordinator: Response stream coordinator
event_collector: Event manager for collecting events
edge_processor: Edge processor for edge traversal
state_manager: Unified state manager
error_handler: Error handler
"""
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._event_collector = event_collector
self._edge_processor = edge_processor
self._state_manager = state_manager
self._error_handler = error_handler
def dispatch(self, event: GraphNodeEventBase) -> None:
"""
Handle any node event by dispatching to the appropriate handler.
Args:
event: The event to handle
"""
# Events in loops or iterations are always collected
if event.in_loop_id or event.in_iteration_id:
self._event_collector.collect(event)
return
return self._dispatch(event)
@singledispatchmethod
def _dispatch(self, event: GraphNodeEventBase) -> None:
self._event_collector.collect(event)
logger.warning("Unhandled event type: %s", type(event).__name__)
@_dispatch.register(NodeRunIterationStartedEvent)
@_dispatch.register(NodeRunIterationNextEvent)
@_dispatch.register(NodeRunIterationSucceededEvent)
@_dispatch.register(NodeRunIterationFailedEvent)
@_dispatch.register(NodeRunLoopStartedEvent)
@_dispatch.register(NodeRunLoopNextEvent)
@_dispatch.register(NodeRunLoopSucceededEvent)
@_dispatch.register(NodeRunLoopFailedEvent)
@_dispatch.register(NodeRunAgentLogEvent)
def _(self, event: GraphNodeEventBase) -> None:
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunStartedEvent) -> None:
"""
Handle node started event.
Args:
event: The node started event
"""
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id)
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event only for the first attempt; retries remain silent
if is_initial_attempt:
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunStreamChunkEvent) -> None:
"""
Handle stream chunk event with full processing.
Args:
event: The stream chunk event
"""
# Process with response coordinator
streaming_events = list(self._response_coordinator.intercept_event(event))
# Collect all events
for stream_event in streaming_events:
self._event_collector.collect(stream_event)
@_dispatch.register
def _(self, event: NodeRunSucceededEvent) -> None:
"""
Handle node success by coordinating subsystems.
This method coordinates between different subsystems to process
node completion, handle edges, and trigger downstream execution.
Args:
event: The node succeeded event
"""
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
# Store outputs in variable pool
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
# Forward to response coordinator and emit streaming events
streaming_events = self._response_coordinator.intercept_event(event)
for stream_event in streaming_events:
self._event_collector.collect(stream_event)
# Process edges and get ready nodes
node = self._graph.nodes[event.node_id]
if node.execution_type == NodeExecutionType.BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
# Collect streaming events from edge processing
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
# Enqueue ready nodes
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update execution tracking
self._state_manager.finish_execution(event.node_id)
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)
# Collect the event
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunFailedEvent) -> None:
"""
Handle node failure using error handler.
Args:
event: The node failed event
"""
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()
result = self._error_handler.handle_node_failure(event)
if result:
# Process the resulting event (retry, exception, etc.)
self.dispatch(result)
else:
# Abort execution
self._graph_execution.fail(RuntimeError(event.error))
self._event_collector.collect(event)
self._state_manager.finish_execution(event.node_id)
@_dispatch.register
def _(self, event: NodeRunExceptionEvent) -> None:
"""
Handle node exception event (fail-branch strategy).
Args:
event: The node exception event
"""
# Node continues via fail-branch/default-value, treat as completion
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
node = self._graph.nodes[event.node_id]
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update response outputs if applicable
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)
self._state_manager.finish_execution(event.node_id)
# Collect the exception event for observers
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunRetryEvent) -> None:
"""
Handle node retry event.
Args:
event: The node retry event
"""
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()
# Finish the previous attempt before re-queuing the node
self._state_manager.finish_execution(event.node_id)
# Emit retry event for observers
self._event_collector.collect(event)
# Re-queue node for execution
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
"""
Store node outputs in the variable pool.
Args:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in outputs.items():
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
"""Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in outputs.items():
if key == "answer":
existing = self._graph_runtime_state.get_output("answer", "")
if existing:
self._graph_runtime_state.set_output("answer", f"{existing}{value}")
else:
self._graph_runtime_state.set_output("answer", value)
else:
self._graph_runtime_state.set_output(key, value)

View File

@ -0,0 +1,174 @@
"""
Unified event manager for collecting and emitting events.
"""
import threading
import time
from collections.abc import Generator
from contextlib import contextmanager
from typing import final
from core.workflow.graph_events import GraphEngineEvent
from ..layers.base import GraphEngineLayer
@final
class ReadWriteLock:
"""
A read-write lock implementation that allows multiple concurrent readers
but only one writer at a time.
"""
def __init__(self) -> None:
self._read_ready = threading.Condition(threading.RLock())
self._readers = 0
def acquire_read(self) -> None:
"""Acquire a read lock."""
_ = self._read_ready.acquire()
try:
self._readers += 1
finally:
self._read_ready.release()
def release_read(self) -> None:
"""Release a read lock."""
_ = self._read_ready.acquire()
try:
self._readers -= 1
if self._readers == 0:
self._read_ready.notify_all()
finally:
self._read_ready.release()
def acquire_write(self) -> None:
"""Acquire a write lock."""
_ = self._read_ready.acquire()
while self._readers > 0:
_ = self._read_ready.wait()
def release_write(self) -> None:
"""Release a write lock."""
self._read_ready.release()
@contextmanager
def read_lock(self):
"""Return a context manager for read locking."""
self.acquire_read()
try:
yield
finally:
self.release_read()
@contextmanager
def write_lock(self):
"""Return a context manager for write locking."""
self.acquire_write()
try:
yield
finally:
self.release_write()
@final
class EventManager:
"""
Unified event manager that collects, buffers, and emits events.
This class combines event collection with event emission, providing
thread-safe event management with support for notifying layers and
streaming events to external consumers.
"""
def __init__(self) -> None:
"""Initialize the event manager."""
self._events: list[GraphEngineEvent] = []
self._lock = ReadWriteLock()
self._layers: list[GraphEngineLayer] = []
self._execution_complete = threading.Event()
def set_layers(self, layers: list[GraphEngineLayer]) -> None:
"""
Set the layers to notify on event collection.
Args:
layers: List of layers to notify
"""
self._layers = layers
def collect(self, event: GraphEngineEvent) -> None:
"""
Thread-safe method to collect an event.
Args:
event: The event to collect
"""
with self._lock.write_lock():
self._events.append(event)
self._notify_layers(event)
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
"""
Get new events starting from a specific index.
Args:
start_index: The index to start from
Returns:
List of new events
"""
with self._lock.read_lock():
return list(self._events[start_index:])
def _event_count(self) -> int:
"""
Get the current count of collected events.
Returns:
Number of collected events
"""
with self._lock.read_lock():
return len(self._events)
def mark_complete(self) -> None:
"""Mark execution as complete to stop the event emission generator."""
self._execution_complete.set()
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
"""
Generator that yields events as they're collected.
Yields:
GraphEngineEvent instances as they're processed
"""
yielded_count = 0
while not self._execution_complete.is_set() or yielded_count < self._event_count():
# Get new events since last yield
new_events = self._get_new_events(yielded_count)
# Yield any new events
for event in new_events:
yield event
yielded_count += 1
# Small sleep to avoid busy waiting
if not self._execution_complete.is_set() and not new_events:
time.sleep(0.001)
def _notify_layers(self, event: GraphEngineEvent) -> None:
"""
Notify all layers of an event.
Layer exceptions are caught and logged to prevent disrupting collection.
Args:
event: The event to send to layers
"""
for layer in self._layers:
try:
layer.on_event(event)
except Exception:
# Silently ignore layer errors during collection
pass

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,288 @@
"""
Graph state manager that combines node, edge, and execution tracking.
"""
import threading
from collections.abc import Sequence
from typing import TypedDict, final
from core.workflow.enums import NodeState
from core.workflow.graph import Edge, Graph
from .ready_queue import ReadyQueue
class EdgeStateAnalysis(TypedDict):
"""Analysis result for edge states."""
has_unknown: bool
has_taken: bool
all_skipped: bool
@final
class GraphStateManager:
def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None:
"""
Initialize the state manager.
Args:
graph: The workflow graph
ready_queue: Queue for nodes ready to execute
"""
self._graph = graph
self._ready_queue = ready_queue
self._lock = threading.RLock()
# Execution tracking state
self._executing_nodes: set[str] = set()
# ============= Node State Operations =============
def enqueue_node(self, node_id: str) -> None:
"""
Mark a node as TAKEN and add it to the ready queue.
This combines the state transition and enqueueing operations
that always occur together when preparing a node for execution.
Args:
node_id: The ID of the node to enqueue
"""
with self._lock:
self._graph.nodes[node_id].state = NodeState.TAKEN
self._ready_queue.put(node_id)
def mark_node_skipped(self, node_id: str) -> None:
"""
Mark a node as SKIPPED.
Args:
node_id: The ID of the node to skip
"""
with self._lock:
self._graph.nodes[node_id].state = NodeState.SKIPPED
def is_node_ready(self, node_id: str) -> bool:
"""
Check if a node is ready to be executed.
A node is ready when all its incoming edges from taken branches
have been satisfied.
Args:
node_id: The ID of the node to check
Returns:
True if the node is ready for execution
"""
with self._lock:
# Get all incoming edges to this node
incoming_edges = self._graph.get_incoming_edges(node_id)
# If no incoming edges, node is always ready
if not incoming_edges:
return True
# If any edge is UNKNOWN, node is not ready
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
return False
# Node is ready if at least one edge is TAKEN
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
def get_node_state(self, node_id: str) -> NodeState:
"""
Get the current state of a node.
Args:
node_id: The ID of the node
Returns:
The current node state
"""
with self._lock:
return self._graph.nodes[node_id].state
# ============= Edge State Operations =============
def mark_edge_taken(self, edge_id: str) -> None:
"""
Mark an edge as TAKEN.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self._graph.edges[edge_id].state = NodeState.TAKEN
def mark_edge_skipped(self, edge_id: str) -> None:
"""
Mark an edge as SKIPPED.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self._graph.edges[edge_id].state = NodeState.SKIPPED
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
"""
Analyze the states of edges and return summary flags.
Args:
edges: List of edges to analyze
Returns:
Analysis result with state flags
"""
with self._lock:
states = {edge.state for edge in edges}
return EdgeStateAnalysis(
has_unknown=NodeState.UNKNOWN in states,
has_taken=NodeState.TAKEN in states,
all_skipped=states == {NodeState.SKIPPED} if states else True,
)
def get_edge_state(self, edge_id: str) -> NodeState:
"""
Get the current state of an edge.
Args:
edge_id: The ID of the edge
Returns:
The current edge state
"""
with self._lock:
return self._graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
Categorize branch edges into selected and unselected.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
A tuple of (selected_edges, unselected_edges)
"""
with self._lock:
outgoing_edges = self._graph.get_outgoing_edges(node_id)
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []
for edge in outgoing_edges:
if edge.source_handle == selected_handle:
selected_edges.append(edge)
else:
unselected_edges.append(edge)
return selected_edges, unselected_edges
# ============= Execution Tracking Operations =============
def start_execution(self, node_id: str) -> None:
"""
Mark a node as executing.
Args:
node_id: The ID of the node starting execution
"""
with self._lock:
self._executing_nodes.add(node_id)
def finish_execution(self, node_id: str) -> None:
"""
Mark a node as no longer executing.
Args:
node_id: The ID of the node finishing execution
"""
with self._lock:
self._executing_nodes.discard(node_id)
def is_executing(self, node_id: str) -> bool:
"""
Check if a node is currently executing.
Args:
node_id: The ID of the node to check
Returns:
True if the node is executing
"""
with self._lock:
return node_id in self._executing_nodes
def get_executing_count(self) -> int:
"""
Get the count of currently executing nodes.
Returns:
Number of executing nodes
"""
with self._lock:
return len(self._executing_nodes)
def get_executing_nodes(self) -> set[str]:
"""
Get a copy of the set of executing node IDs.
Returns:
Set of node IDs currently executing
"""
with self._lock:
return self._executing_nodes.copy()
def clear_executing(self) -> None:
"""Clear all executing nodes."""
with self._lock:
self._executing_nodes.clear()
# ============= Composite Operations =============
def is_execution_complete(self) -> bool:
"""
Check if graph execution is complete.
Execution is complete when:
- Ready queue is empty
- No nodes are executing
Returns:
True if execution is complete
"""
with self._lock:
return self._ready_queue.empty() and len(self._executing_nodes) == 0
def get_queue_depth(self) -> int:
"""
Get the current depth of the ready queue.
Returns:
Number of nodes in the ready queue
"""
return self._ready_queue.qsize()
def get_execution_stats(self) -> dict[str, int]:
"""
Get execution statistics.
Returns:
Dictionary with execution statistics
"""
with self._lock:
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
return {
"queue_depth": self._ready_queue.qsize(),
"executing": len(self._executing_nodes),
"taken_nodes": taken_nodes,
"skipped_nodes": skipped_nodes,
"unknown_nodes": unknown_nodes,
}

View File

@ -0,0 +1,14 @@
"""
Graph traversal subsystem for graph engine.
This package handles graph navigation, edge processing,
and skip propagation logic.
"""
from .edge_processor import EdgeProcessor
from .skip_propagator import SkipPropagator
__all__ = [
"EdgeProcessor",
"SkipPropagator",
]

View File

@ -0,0 +1,201 @@
"""
Edge processing logic for graph traversal.
"""
from collections.abc import Sequence
from typing import TYPE_CHECKING, final
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Edge, Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent
from ..graph_state_manager import GraphStateManager
from ..response_coordinator import ResponseStreamCoordinator
if TYPE_CHECKING:
from .skip_propagator import SkipPropagator
@final
class EdgeProcessor:
"""
Processes edges during graph execution.
This handles marking edges as taken or skipped, notifying
the response coordinator, triggering downstream node execution,
and managing branch node logic.
"""
def __init__(
self,
graph: Graph,
state_manager: GraphStateManager,
response_coordinator: ResponseStreamCoordinator,
skip_propagator: "SkipPropagator",
) -> None:
"""
Initialize the edge processor.
Args:
graph: The workflow graph
state_manager: Unified state manager
response_coordinator: Response stream coordinator
skip_propagator: Propagator for skip states
"""
self._graph = graph
self._state_manager = state_manager
self._response_coordinator = response_coordinator
self._skip_propagator = skip_propagator
def process_node_success(
self, node_id: str, selected_handle: str | None = None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges after a node succeeds.
Args:
node_id: The ID of the succeeded node
selected_handle: For branch nodes, the selected edge handle
Returns:
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
"""
node = self._graph.nodes[node_id]
if node.execution_type == NodeExecutionType.BRANCH:
return self._process_branch_node_edges(node_id, selected_handle)
else:
return self._process_non_branch_node_edges(node_id)
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for non-branch nodes (mark all as TAKEN).
Args:
node_id: The ID of the succeeded node
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
"""
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
nodes, events = self._process_taken_edge(edge)
ready_nodes.extend(nodes)
all_streaming_events.extend(events)
return ready_nodes, all_streaming_events
def _process_branch_node_edges(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for branch nodes.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no edge was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} did not select any edge")
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
# Categorize edges
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Process unselected edges first (mark as skipped)
for edge in unselected_edges:
self._process_skipped_edge(edge)
# Process selected edges
for edge in selected_edges:
nodes, events = self._process_taken_edge(edge)
ready_nodes.extend(nodes)
all_streaming_events.extend(events)
return ready_nodes, all_streaming_events
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Mark edge as taken and check downstream node.
Args:
edge: The edge to process
Returns:
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
"""
# Mark edge as taken
self._state_manager.mark_edge_taken(edge.id)
# Notify response coordinator and get streaming events
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready
ready_nodes: list[str] = []
if self._state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)
return ready_nodes, streaming_events
def _process_skipped_edge(self, edge: Edge) -> None:
"""
Mark edge as skipped.
Args:
edge: The edge to skip
"""
self._state_manager.mark_edge_skipped(edge.id)
def handle_branch_completion(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Handle completion of a branch node.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected branch
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no branch was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
# Categorize edges into selected and unselected
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Skip all unselected paths
self._skip_propagator.skip_branch_paths(unselected_edges)
# Process selected edges and get ready nodes and streaming events
return self.process_node_success(node_id, selected_handle)
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
"""
Validate that a branch selection is valid.
Args:
node_id: The ID of the branch node
selected_handle: The handle to validate
Returns:
True if the selection is valid
"""
outgoing_edges = self._graph.get_outgoing_edges(node_id)
valid_handles = {edge.source_handle for edge in outgoing_edges}
return selected_handle in valid_handles

View File

@ -0,0 +1,95 @@
"""
Skip state propagation through the graph.
"""
from collections.abc import Sequence
from typing import final
from core.workflow.graph import Edge, Graph
from ..graph_state_manager import GraphStateManager
@final
class SkipPropagator:
"""
Propagates skip states through the graph.
When a node is skipped, this ensures all downstream nodes
that depend solely on it are also skipped.
"""
def __init__(
self,
graph: Graph,
state_manager: GraphStateManager,
) -> None:
"""
Initialize the skip propagator.
Args:
graph: The workflow graph
state_manager: Unified state manager
"""
self._graph = graph
self._state_manager = state_manager
def propagate_skip_from_edge(self, edge_id: str) -> None:
"""
Recursively propagate skip state from a skipped edge.
Rules:
- If a node has any UNKNOWN incoming edges, stop processing
- If all incoming edges are SKIPPED, skip the node and its edges
- If any incoming edge is TAKEN, the node may still execute
Args:
edge_id: The ID of the skipped edge to start from
"""
downstream_node_id = self._graph.edges[edge_id].head
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
# Analyze edge states
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
# Stop if there are unknown edges (not yet processed)
if edge_states["has_unknown"]:
return
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
if edge_states["all_skipped"]:
self._propagate_skip_to_node(downstream_node_id)
def _propagate_skip_to_node(self, node_id: str) -> None:
"""
Mark a node and all its outgoing edges as skipped.
Args:
node_id: The ID of the node to skip
"""
# Mark node as skipped
self._state_manager.mark_node_skipped(node_id)
# Mark all outgoing edges as skipped and propagate
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
self._state_manager.mark_edge_skipped(edge.id)
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
"""
Skip all paths from unselected branch edges.
Args:
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:
self._state_manager.mark_edge_skipped(edge.id)
self.propagate_skip_from_edge(edge.id)

View File

@ -0,0 +1,52 @@
# Layers
Pluggable middleware for engine extensions.
## Components
### Layer (base)
Abstract base class for layers.
- `initialize()` - Receive runtime context
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
### DebugLoggingLayer
Comprehensive execution logging.
- Configurable detail levels
- Tracks execution statistics
- Truncates long values
## Usage
```python
debug_layer = DebugLoggingLayer(
level="INFO",
include_outputs=True
)
engine = GraphEngine(graph)
engine.add_layer(debug_layer)
engine.run()
```
## Custom Layers
```python
class MetricsLayer(Layer):
def on_event(self, event):
if isinstance(event, NodeRunSucceededEvent):
self.metrics[event.node_id] = event.elapsed_time
```
## Configuration
**DebugLoggingLayer Options:**
- `level` - Log level (INFO, DEBUG, ERROR)
- `include_inputs/outputs` - Log data values
- `max_value_length` - Truncate long values

View File

@ -0,0 +1,16 @@
"""
Layer system for GraphEngine extensibility.
This module provides the layer infrastructure for extending GraphEngine functionality
with middleware-like components that can observe events and interact with execution.
"""
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"GraphEngineLayer",
]

View File

@ -0,0 +1,85 @@
"""
Base layer class for GraphEngine extensions.
This module provides the abstract base class for implementing layers that can
intercept and respond to GraphEngine events.
"""
from abc import ABC, abstractmethod
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
Layers are middleware-like components that can:
- Observe all events emitted by the GraphEngine
- Access the graph runtime state
- Send commands to control execution
Subclasses should override the constructor to accept configuration parameters,
then implement the three lifecycle methods.
"""
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine before execution starts to inject the read-only runtime state
and command channel. This allows layers to observe engine context and send
commands, but prevents direct state modification.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod
def on_graph_start(self) -> None:
"""
Called when graph execution starts.
This is called after the engine has been initialized but before any nodes
are executed. Layers can use this to set up resources or log start information.
"""
pass
@abstractmethod
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
This method receives all events generated during graph execution, including:
- Graph lifecycle events (start, success, failure)
- Node execution events (start, success, failure, retry)
- Stream events for response nodes
- Container events (iteration, loop)
Args:
event: The event emitted by the engine
"""
pass
@abstractmethod
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.
This is called after all nodes have been executed or when execution is
aborted. Layers can use this to clean up resources or log final state.
Args:
error: The exception that caused execution to fail, or None if successful
"""
pass

View File

@ -0,0 +1,250 @@
"""
Debug logging layer for GraphEngine.
This module provides a layer that logs all events and state changes during
graph execution for debugging purposes.
"""
import logging
from collections.abc import Mapping
from typing import Any, final
from typing_extensions import override
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from .base import GraphEngineLayer
@final
class DebugLoggingLayer(GraphEngineLayer):
"""
A layer that provides comprehensive logging of GraphEngine execution.
This layer logs all events with configurable detail levels, helping developers
debug workflow execution and understand the flow of events.
"""
def __init__(
self,
level: str = "INFO",
include_inputs: bool = False,
include_outputs: bool = True,
include_process_data: bool = False,
logger_name: str = "GraphEngine.Debug",
max_value_length: int = 500,
) -> None:
"""
Initialize the debug logging layer.
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR)
include_inputs: Whether to log node input values
include_outputs: Whether to log node output values
include_process_data: Whether to log node process data
logger_name: Name of the logger to use
max_value_length: Maximum length of logged values (truncated if longer)
"""
super().__init__()
self.level = level
self.include_inputs = include_inputs
self.include_outputs = include_outputs
self.include_process_data = include_process_data
self.max_value_length = max_value_length
# Set up logger
self.logger = logging.getLogger(logger_name)
log_level = getattr(logging, level.upper(), logging.INFO)
self.logger.setLevel(log_level)
# Track execution stats
self.node_count = 0
self.success_count = 0
self.failure_count = 0
self.retry_count = 0
def _truncate_value(self, value: Any) -> str:
"""Truncate long values for logging."""
str_value = str(value)
if len(str_value) > self.max_value_length:
return str_value[: self.max_value_length] + "... (truncated)"
return str_value
def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str:
"""Format a dictionary or mapping for logging with truncation."""
if not data:
return "{}"
formatted_items: list[str] = []
for key, value in data.items():
formatted_value = self._truncate_value(value)
formatted_items.append(f" {key}: {formatted_value}")
return "{\n" + ",\n".join(formatted_items) + "\n}"
@override
def on_graph_start(self) -> None:
"""Log graph execution start."""
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
if self.graph_runtime_state:
# Log initial state
self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""Log individual events based on their type."""
event_class = event.__class__.__name__
# Graph-level events
if isinstance(event, GraphRunStartedEvent):
self.logger.debug("Graph run started event")
elif isinstance(event, GraphRunSucceededEvent):
self.logger.info("✅ Graph run succeeded")
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self.logger.warning("⚠️ Graph run partially succeeded")
if event.exceptions_count > 0:
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0:
self.logger.error(" Total exceptions: %s", event.exceptions_count)
elif isinstance(event, GraphRunAbortedEvent):
self.logger.warning("⚠️ Graph run aborted: %s", event.reason)
if event.outputs:
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
# Node-level events
# Retry before Started because Retry subclasses Started;
elif isinstance(event, NodeRunRetryEvent):
self.retry_count += 1
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
self.logger.warning(" Previous error: %s", event.error)
elif isinstance(event, NodeRunStartedEvent):
self.node_count += 1
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
if self.include_inputs and event.node_run_result.inputs:
self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs))
elif isinstance(event, NodeRunSucceededEvent):
self.success_count += 1
self.logger.info("✅ Node succeeded: %s", event.node_id)
if self.include_outputs and event.node_run_result.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs))
if self.include_process_data and event.node_run_result.process_data:
self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data))
elif isinstance(event, NodeRunFailedEvent):
self.failure_count += 1
self.logger.error("❌ Node failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
if event.node_run_result.error:
self.logger.error(" Details: %s", event.node_run_result.error)
elif isinstance(event, NodeRunExceptionEvent):
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
self.logger.warning(" Error: %s", event.error)
elif isinstance(event, NodeRunStreamChunkEvent):
# Log stream chunks at debug level to avoid spam
final_indicator = " (FINAL)" if event.is_final else ""
self.logger.debug(
"📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk)
)
# Iteration events
elif isinstance(event, NodeRunIterationStartedEvent):
self.logger.info("🔁 Iteration started: %s", event.node_id)
elif isinstance(event, NodeRunIterationNextEvent):
self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index)
elif isinstance(event, NodeRunIterationSucceededEvent):
self.logger.info("✅ Iteration succeeded: %s", event.node_id)
if self.include_outputs and event.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, NodeRunIterationFailedEvent):
self.logger.error("❌ Iteration failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
# Loop events
elif isinstance(event, NodeRunLoopStartedEvent):
self.logger.info("🔄 Loop started: %s", event.node_id)
elif isinstance(event, NodeRunLoopNextEvent):
self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index)
elif isinstance(event, NodeRunLoopSucceededEvent):
self.logger.info("✅ Loop succeeded: %s", event.node_id)
if self.include_outputs and event.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, NodeRunLoopFailedEvent):
self.logger.error("❌ Loop failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
else:
# Log unknown events at debug level
self.logger.debug("Event: %s", event_class)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Log graph execution end with summary statistics."""
self.logger.info("=" * 80)
if error:
self.logger.error("🔴 GRAPH EXECUTION FAILED")
self.logger.error(" Error: %s", error)
else:
self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY")
# Log execution statistics
self.logger.info("Execution Statistics:")
self.logger.info(" Total nodes executed: %s", self.node_count)
self.logger.info(" Successful nodes: %s", self.success_count)
self.logger.info(" Failed nodes: %s", self.failure_count)
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
if self.graph_runtime_state and self.include_outputs:
if self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)

View File

@ -0,0 +1,150 @@
"""
Execution limits layer for GraphEngine.
This layer monitors workflow execution to enforce limits on:
- Maximum execution steps
- Maximum execution time
When limits are exceeded, the layer automatically aborts execution.
"""
import logging
import time
from enum import Enum
from typing import final
from typing_extensions import override
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
NodeRunStartedEvent,
)
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
class LimitType(Enum):
"""Types of execution limits that can be exceeded."""
STEP_LIMIT = "step_limit"
TIME_LIMIT = "time_limit"
@final
class ExecutionLimitsLayer(GraphEngineLayer):
"""
Layer that enforces execution limits for workflows.
Monitors:
- Step count: Tracks number of node executions
- Time limit: Monitors total execution time
Automatically aborts execution when limits are exceeded.
"""
def __init__(self, max_steps: int, max_time: int) -> None:
"""
Initialize the execution limits layer.
Args:
max_steps: Maximum number of execution steps allowed
max_time: Maximum execution time in seconds allowed
"""
super().__init__()
self.max_steps = max_steps
self.max_time = max_time
# Runtime tracking
self.start_time: float | None = None
self.step_count = 0
self.logger = logging.getLogger(__name__)
# State tracking
self._execution_started = False
self._execution_ended = False
self._abort_sent = False # Track if abort command has been sent
@override
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self.start_time = time.time()
self.step_count = 0
self._execution_started = True
self._execution_ended = False
self._abort_sent = False
self.logger.debug("Execution limits monitoring started")
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
Monitors execution progress and enforces limits.
"""
if not self._execution_started or self._execution_ended or self._abort_sent:
return
# Track step count for node execution events
if isinstance(event, NodeRunStartedEvent):
self.step_count += 1
self.logger.debug("Step %d started: %s", self.step_count, event.node_id)
# Check step limit when node execution completes
if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent):
if self._reached_step_limitation():
self._send_abort_command(LimitType.STEP_LIMIT)
if self._reached_time_limitation():
self._send_abort_command(LimitType.TIME_LIMIT)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._execution_started and not self._execution_ended:
self._execution_ended = True
if self.start_time:
total_time = time.time() - self.start_time
self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time)
def _reached_step_limitation(self) -> bool:
"""Check if step count limit has been exceeded."""
return self.step_count > self.max_steps
def _reached_time_limitation(self) -> bool:
"""Check if time limit has been exceeded."""
return self.start_time is not None and (time.time() - self.start_time) > self.max_time
def _send_abort_command(self, limit_type: LimitType) -> None:
"""
Send abort command due to limit violation.
Args:
limit_type: Type of limit exceeded
"""
if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent:
return
# Format detailed reason message
if limit_type == LimitType.STEP_LIMIT:
reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}"
elif limit_type == LimitType.TIME_LIMIT:
elapsed_time = time.time() - self.start_time if self.start_time else 0
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
self.logger.warning("Execution limit exceeded: %s", reason)
try:
# Send abort command to the engine
abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason)
self.command_channel.send_command(abort_command)
# Mark that abort has been sent to prevent duplicate commands
self._abort_sent = True
self.logger.debug("Abort command sent to engine")
except Exception:
self.logger.exception("Failed to send abort command")

View File

@ -0,0 +1,50 @@
"""
GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from extensions.ext_redis import redis_client
@final
class GraphEngineManager:
"""
Manager for sending control commands to GraphEngine instances.
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop, pause, and resume operations.
"""
@staticmethod
def send_stop_command(task_id: str, reason: str | None = None) -> None:
"""
Send a stop command to a running workflow.
Args:
task_id: The task ID of the workflow to stop
reason: Optional reason for stopping (defaults to "User requested stop")
"""
if not task_id:
return
# Create Redis channel for this task
channel_key = f"workflow:{task_id}:commands"
channel = RedisChannel(redis_client, channel_key)
# Create and send abort command
abort_command = AbortCommand(reason=reason or "User requested stop")
try:
channel.send_command(abort_command)
except Exception:
# Silently fail if Redis is unavailable
# The legacy stop flag mechanism will still work
pass

View File

@ -0,0 +1,14 @@
"""
Orchestration subsystem for graph engine.
This package coordinates the overall execution flow between
different subsystems.
"""
from .dispatcher import Dispatcher
from .execution_coordinator import ExecutionCoordinator
__all__ = [
"Dispatcher",
"ExecutionCoordinator",
]

View File

@ -0,0 +1,104 @@
"""
Main dispatcher for processing events from workers.
"""
import logging
import queue
import threading
import time
from typing import TYPE_CHECKING, final
from core.workflow.graph_events.base import GraphNodeEventBase
from ..event_management import EventManager
from .execution_coordinator import ExecutionCoordinator
if TYPE_CHECKING:
from ..event_management import EventHandler
logger = logging.getLogger(__name__)
@final
class Dispatcher:
"""
Main dispatcher that processes events from the event queue.
This runs in a separate thread and coordinates event processing
with timeout and completion detection.
"""
def __init__(
self,
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
event_collector: EventManager,
execution_coordinator: ExecutionCoordinator,
event_emitter: EventManager | None = None,
) -> None:
"""
Initialize the dispatcher.
Args:
event_queue: Queue of events from workers
event_handler: Event handler registry for processing events
event_collector: Event manager for collecting unhandled events
execution_coordinator: Coordinator for execution flow
event_emitter: Optional event manager to signal completion
"""
self._event_queue = event_queue
self._event_handler = event_handler
self._event_collector = event_collector
self._execution_coordinator = execution_coordinator
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._start_time: float | None = None
def start(self) -> None:
"""Start the dispatcher thread."""
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
while not self._stop_event.is_set():
# Check for commands
self._execution_coordinator.check_commands()
# Check for scaling
self._execution_coordinator.check_scaling()
# Process events
try:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
# Check if execution is complete
if self._execution_coordinator.is_execution_complete():
break
except Exception as e:
logger.exception("Dispatcher error")
self._execution_coordinator.mark_failed(e)
finally:
self._execution_coordinator.mark_complete()
# Signal the event emitter that execution is complete
if self._event_emitter:
self._event_emitter.mark_complete()

View File

@ -0,0 +1,87 @@
"""
Execution coordinator for managing overall workflow execution.
"""
from typing import TYPE_CHECKING, final
from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..event_management import EventManager
from ..graph_state_manager import GraphStateManager
from ..worker_management import WorkerPool
if TYPE_CHECKING:
from ..event_management import EventHandler
@final
class ExecutionCoordinator:
"""
Coordinates overall execution flow between subsystems.
This provides high-level coordination methods used by the
dispatcher to manage execution state.
"""
def __init__(
self,
graph_execution: GraphExecution,
state_manager: GraphStateManager,
event_handler: "EventHandler",
event_collector: EventManager,
command_processor: CommandProcessor,
worker_pool: WorkerPool,
) -> None:
"""
Initialize the execution coordinator.
Args:
graph_execution: Graph execution aggregate
state_manager: Unified state manager
event_handler: Event handler registry for processing events
event_collector: Event manager for collecting events
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self._graph_execution = graph_execution
self._state_manager = state_manager
self._event_handler = event_handler
self._event_collector = event_collector
self._command_processor = command_processor
self._worker_pool = worker_pool
def check_commands(self) -> None:
"""Process any pending commands."""
self._command_processor.process_commands()
def check_scaling(self) -> None:
"""Check and perform worker scaling if needed."""
self._worker_pool.check_and_scale()
def is_execution_complete(self) -> bool:
"""
Check if execution is complete.
Returns:
True if execution is complete
"""
# Check if aborted or failed
if self._graph_execution.aborted or self._graph_execution.has_error:
return True
# Complete if no work remains
return self._state_manager.is_execution_complete()
def mark_complete(self) -> None:
"""Mark execution as complete."""
if not self._graph_execution.completed:
self._graph_execution.complete()
def mark_failed(self, error: Exception) -> None:
"""
Mark execution as failed.
Args:
error: The error that caused failure
"""
self._graph_execution.fail(error)

View File

@ -0,0 +1,41 @@
"""
CommandChannel protocol for GraphEngine command communication.
This protocol defines the interface for sending and receiving commands
to/from a GraphEngine instance, supporting both local and distributed scenarios.
"""
from typing import Protocol
from ..entities.commands import GraphEngineCommand
class CommandChannel(Protocol):
"""
Protocol for bidirectional command communication with GraphEngine.
Since each GraphEngine instance processes only one workflow execution,
this channel is dedicated to that single execution.
"""
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch pending commands for this GraphEngine instance.
Called by GraphEngine to poll for commands that need to be processed.
Returns:
List of pending commands (may be empty)
"""
...
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to be processed by this GraphEngine instance.
Called by external systems to send control commands to the running workflow.
Args:
command: The command to send
"""
...

View File

@ -0,0 +1,12 @@
"""
Ready queue implementations for GraphEngine.
This package contains the protocol and implementations for managing
the queue of nodes ready for execution.
"""
from .factory import create_ready_queue_from_state
from .in_memory import InMemoryReadyQueue
from .protocol import ReadyQueue, ReadyQueueState
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"]

View File

@ -0,0 +1,35 @@
"""
Factory for creating ReadyQueue instances from serialized state.
"""
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
from .protocol import ReadyQueueState
if TYPE_CHECKING:
from .protocol import ReadyQueue
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
"""
Create a ReadyQueue instance from a serialized state.
Args:
state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue
Returns:
A ReadyQueue instance initialized with the given state
Raises:
ValueError: If the queue type is unknown or version is unsupported
"""
if state.type == "InMemoryReadyQueue":
if state.version != "1.0":
raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}")
queue = InMemoryReadyQueue()
# Always pass as JSON string to loads()
queue.loads(state.model_dump_json())
return queue
else:
raise ValueError(f"Unknown ready queue type: {state.type}")

View File

@ -0,0 +1,140 @@
"""
In-memory implementation of the ReadyQueue protocol.
This implementation wraps Python's standard queue.Queue and adds
serialization capabilities for state storage.
"""
import queue
from typing import final
from .protocol import ReadyQueue, ReadyQueueState
@final
class InMemoryReadyQueue(ReadyQueue):
"""
In-memory ready queue implementation with serialization support.
This implementation uses Python's queue.Queue internally and provides
methods to serialize and restore the queue state.
"""
def __init__(self, maxsize: int = 0) -> None:
"""
Initialize the in-memory ready queue.
Args:
maxsize: Maximum size of the queue (0 for unlimited)
"""
self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize)
def put(self, item: str) -> None:
"""
Add a node ID to the ready queue.
Args:
item: The node ID to add to the queue
"""
self._queue.put(item)
def get(self, timeout: float | None = None) -> str:
"""
Retrieve and remove a node ID from the queue.
Args:
timeout: Maximum time to wait for an item (None for blocking)
Returns:
The node ID retrieved from the queue
Raises:
queue.Empty: If timeout expires and no item is available
"""
if timeout is None:
return self._queue.get(block=True)
return self._queue.get(timeout=timeout)
def task_done(self) -> None:
"""
Indicate that a previously retrieved task is complete.
Used by worker threads to signal task completion for
join() synchronization.
"""
self._queue.task_done()
def empty(self) -> bool:
"""
Check if the queue is empty.
Returns:
True if the queue has no items, False otherwise
"""
return self._queue.empty()
def qsize(self) -> int:
"""
Get the approximate size of the queue.
Returns:
The approximate number of items in the queue
"""
return self._queue.qsize()
def dumps(self) -> str:
"""
Serialize the queue state to a JSON string for storage.
Returns:
A JSON string containing the serialized queue state
"""
# Extract all items from the queue without removing them
items: list[str] = []
temp_items: list[str] = []
# Drain the queue temporarily to get all items
while not self._queue.empty():
try:
item = self._queue.get_nowait()
temp_items.append(item)
items.append(item)
except queue.Empty:
break
# Put items back in the same order
for item in temp_items:
self._queue.put(item)
state = ReadyQueueState(
type="InMemoryReadyQueue",
version="1.0",
items=items,
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""
Restore the queue state from a JSON string.
Args:
data: The JSON string containing the serialized queue state to restore
"""
state = ReadyQueueState.model_validate_json(data)
if state.type != "InMemoryReadyQueue":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported version: {state.version}")
# Clear the current queue
while not self._queue.empty():
try:
self._queue.get_nowait()
except queue.Empty:
break
# Restore items
for item in state.items:
self._queue.put(item)

View File

@ -0,0 +1,104 @@
"""
ReadyQueue protocol for GraphEngine node execution queue.
This protocol defines the interface for managing the queue of nodes ready
for execution, supporting both in-memory and persistent storage scenarios.
"""
from collections.abc import Sequence
from typing import Protocol
from pydantic import BaseModel, Field
class ReadyQueueState(BaseModel):
"""
Pydantic model for serialized ready queue state.
This defines the structure of the data returned by dumps()
and expected by loads() for ready queue serialization.
"""
type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')")
version: str = Field(description="Serialization format version")
items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue")
class ReadyQueue(Protocol):
"""
Protocol for managing nodes ready for execution in GraphEngine.
This protocol defines the interface that any ready queue implementation
must provide, enabling both in-memory queues and persistent queues
that can be serialized for state storage.
"""
def put(self, item: str) -> None:
"""
Add a node ID to the ready queue.
Args:
item: The node ID to add to the queue
"""
...
def get(self, timeout: float | None = None) -> str:
"""
Retrieve and remove a node ID from the queue.
Args:
timeout: Maximum time to wait for an item (None for blocking)
Returns:
The node ID retrieved from the queue
Raises:
queue.Empty: If timeout expires and no item is available
"""
...
def task_done(self) -> None:
"""
Indicate that a previously retrieved task is complete.
Used by worker threads to signal task completion for
join() synchronization.
"""
...
def empty(self) -> bool:
"""
Check if the queue is empty.
Returns:
True if the queue has no items, False otherwise
"""
...
def qsize(self) -> int:
"""
Get the approximate size of the queue.
Returns:
The approximate number of items in the queue
"""
...
def dumps(self) -> str:
"""
Serialize the queue state to a JSON string for storage.
Returns:
A JSON string containing the serialized queue state
that can be persisted and later restored
"""
...
def loads(self, data: str) -> None:
"""
Restore the queue state from a JSON string.
Args:
data: The JSON string containing the serialized queue state to restore
"""
...

View File

@ -0,0 +1,10 @@
"""
ResponseStreamCoordinator - Coordinates streaming output from response nodes
This component manages response streaming sessions and ensures ordered streaming
of responses based on upstream node outputs and constants.
"""
from .coordinator import ResponseStreamCoordinator
__all__ = ["ResponseStreamCoordinator"]

View File

@ -0,0 +1,696 @@
"""
Main ResponseStreamCoordinator implementation.
This module contains the public ResponseStreamCoordinator class that manages
response streaming sessions and ensures ordered streaming of responses.
"""
import logging
from collections import deque
from collections.abc import Sequence
from threading import RLock
from typing import Literal, TypeAlias, final
from uuid import uuid4
from pydantic import BaseModel, Field
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from .path import Path
from .session import ResponseSession
logger = logging.getLogger(__name__)
# Type definitions
NodeID: TypeAlias = str
EdgeID: TypeAlias = str
class ResponseSessionState(BaseModel):
"""Serializable representation of a response session."""
node_id: str
index: int = Field(default=0, ge=0)
class StreamBufferState(BaseModel):
"""Serializable representation of buffered stream chunks."""
selector: tuple[str, ...]
events: list[NodeRunStreamChunkEvent] = Field(default_factory=list)
class StreamPositionState(BaseModel):
"""Serializable representation for stream read positions."""
selector: tuple[str, ...]
position: int = Field(default=0, ge=0)
class ResponseStreamCoordinatorState(BaseModel):
"""Serialized snapshot of ResponseStreamCoordinator."""
type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator")
version: str = Field(default="1.0")
response_nodes: Sequence[str] = Field(default_factory=list)
active_session: ResponseSessionState | None = None
waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
node_execution_ids: dict[str, str] = Field(default_factory=dict)
paths_map: dict[str, list[list[str]]] = Field(default_factory=dict)
stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list)
stream_positions: Sequence[StreamPositionState] = Field(default_factory=list)
closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list)
@final
class ResponseStreamCoordinator:
"""
Manages response streaming sessions without relying on global state.
Ensures ordered streaming of responses based on upstream node outputs and constants.
"""
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
"""
Initialize coordinator with variable pool.
Args:
variable_pool: VariablePool instance for accessing node variables
graph: Graph instance for looking up node information
"""
self._variable_pool = variable_pool
self._graph = graph
self._active_session: ResponseSession | None = None
self._waiting_sessions: deque[ResponseSession] = deque()
self._lock = RLock()
# Internal stream management (replacing OutputRegistry)
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
self._stream_positions: dict[tuple[str, ...], int] = {}
self._closed_streams: set[tuple[str, ...]] = set()
# Track response nodes
self._response_nodes: set[NodeID] = set()
# Store paths for each response node
self._paths_maps: dict[NodeID, list[Path]] = {}
# Track node execution IDs and types for proper event forwarding
self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id
# Track response sessions to ensure only one per node
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
def register(self, response_node_id: NodeID) -> None:
with self._lock:
if response_node_id in self._response_nodes:
return
self._response_nodes.add(response_node_id)
# Build and save paths map for this response node
paths_map = self._build_paths_map(response_node_id)
self._paths_maps[response_node_id] = paths_map
# Create and store response session for this node
response_node = self._graph.nodes[response_node_id]
session = ResponseSession.from_node(response_node)
self._response_sessions[response_node_id] = session
def track_node_execution(self, node_id: NodeID, execution_id: str) -> None:
"""Track the execution ID for a node when it starts executing.
Args:
node_id: The ID of the node
execution_id: The execution ID from NodeRunStartedEvent
"""
with self._lock:
self._node_execution_ids[node_id] = execution_id
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
"""Get the execution ID for a node, creating one if it doesn't exist.
Args:
node_id: The ID of the node
Returns:
The execution ID for the node
"""
with self._lock:
if node_id not in self._node_execution_ids:
self._node_execution_ids[node_id] = str(uuid4())
return self._node_execution_ids[node_id]
def _build_paths_map(self, response_node_id: NodeID) -> list[Path]:
"""
Build a paths map for a response node by finding all paths from root node
to the response node, recording branch edges along each path.
Args:
response_node_id: ID of the response node to analyze
Returns:
List of Path objects, where each path contains branch edge IDs
"""
# Get root node ID
root_node_id = self._graph.root_node.id
# If root is the response node, return empty path
if root_node_id == response_node_id:
return [Path()]
# Extract variable selectors from the response node's template
response_node = self._graph.nodes[response_node_id]
response_session = ResponseSession.from_node(response_node)
template = response_session.template
# Collect all variable selectors from the template
variable_selectors: set[tuple[str, ...]] = set()
for segment in template.segments:
if isinstance(segment, VariableSegment):
variable_selectors.add(tuple(segment.selector[:2]))
# Step 1: Find all complete paths from root to response node
all_complete_paths: list[list[EdgeID]] = []
def find_paths(
current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID]
) -> None:
"""Recursively find all paths from current node to target node."""
if current_node_id == target_node_id:
# Found a complete path, store it
all_complete_paths.append(current_path.copy())
return
# Mark as visited to avoid cycles
visited.add(current_node_id)
# Explore outgoing edges
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
for edge in outgoing_edges:
edge_id = edge.id
next_node_id = edge.head
# Skip if already visited in this path
if next_node_id not in visited:
# Add edge to path and recurse
new_path = current_path + [edge_id]
find_paths(next_node_id, target_node_id, new_path, visited.copy())
# Start searching from root node
find_paths(root_node_id, response_node_id, [], set())
# Step 2: For each complete path, filter edges based on node blocking behavior
filtered_paths: list[Path] = []
for path in all_complete_paths:
blocking_edges: list[str] = []
for edge_id in path:
edge = self._graph.edges[edge_id]
source_node = self._graph.nodes[edge.tail]
# Check if node is a branch/container (original behavior)
if source_node.execution_type in {
NodeExecutionType.BRANCH,
NodeExecutionType.CONTAINER,
} or source_node.blocks_variable_output(variable_selectors):
blocking_edges.append(edge_id)
# Keep the path even if it's empty
filtered_paths.append(Path(edges=blocking_edges))
return filtered_paths
def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]:
"""
Handle when an edge is taken (selected by a branch node).
This method updates the paths for all response nodes by removing
the taken edge. If any response node has an empty path after removal,
it means the node is now deterministically reachable and should start.
Args:
edge_id: The ID of the edge that was taken
Returns:
List of events to emit from starting new sessions
"""
events: list[NodeRunStreamChunkEvent] = []
with self._lock:
# Check each response node in order
for response_node_id in self._response_nodes:
if response_node_id not in self._paths_maps:
continue
paths = self._paths_maps[response_node_id]
has_reachable_path = False
# Update each path by removing the taken edge
for path in paths:
# Remove the taken edge from this path
path.remove_edge(edge_id)
# Check if this path is now empty (node is reachable)
if path.is_empty():
has_reachable_path = True
# If node is now reachable (has empty path), start/queue session
if has_reachable_path:
# Pass the node_id to the activation method
# The method will handle checking and removing from map
events.extend(self._active_or_queue_session(response_node_id))
return events
def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]:
"""
Start a session immediately if no active session, otherwise queue it.
Only activates sessions that exist in the _response_sessions map.
Args:
node_id: The ID of the response node to activate
Returns:
List of events from flush attempt if session started immediately
"""
events: list[NodeRunStreamChunkEvent] = []
# Get the session from our map (only activate if it exists)
session = self._response_sessions.get(node_id)
if not session:
return events
# Remove from map to ensure it won't be activated again
del self._response_sessions[node_id]
if self._active_session is None:
self._active_session = session
# Try to flush immediately
events.extend(self.try_flush())
else:
# Queue the session if another is active
self._waiting_sessions.append(session)
return events
def intercept_event(
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
) -> Sequence[NodeRunStreamChunkEvent]:
with self._lock:
if isinstance(event, NodeRunStreamChunkEvent):
self._append_stream_chunk(event.selector, event)
if event.is_final:
self._close_stream(event.selector)
return self.try_flush()
else:
# Skip cause we share the same variable pool.
#
# for variable_name, variable_value in event.node_run_result.outputs.items():
# self._variable_pool.add((event.node_id, variable_name), variable_value)
return self.try_flush()
def _create_stream_chunk_event(
self,
node_id: str,
execution_id: str,
selector: Sequence[str],
chunk: str,
is_final: bool = False,
) -> NodeRunStreamChunkEvent:
"""Create a stream chunk event with consistent structure.
For selectors with special prefixes (sys, env, conversation), we use the
active response node's information since these are not actual node IDs.
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self._graph.nodes and self._active_session:
# Use the active response node for special selectors
response_node = self._graph.nodes[self._active_session.node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=selector,
chunk=chunk,
is_final=is_final,
)
# Standard case: selector refers to an actual node
node = self._graph.nodes[node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=node.id,
node_type=node.node_type,
selector=selector,
chunk=chunk,
is_final=is_final,
)
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
"""Process a variable segment. Returns (events, is_complete).
Handles both regular node selectors and special system selectors (sys, env, conversation).
For special selectors, we attribute the output to the active response node.
"""
events: list[NodeRunStreamChunkEvent] = []
source_selector_prefix = segment.selector[0] if segment.selector else ""
is_complete = False
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
if self._active_session and source_selector_prefix not in self._graph.nodes:
# Special selector - use active response node
output_node_id = self._active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
execution_id = self._get_or_create_execution_id(output_node_id)
# Stream all available chunks
while self._has_unread_stream(segment.selector):
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self._active_session and source_selector_prefix not in self._graph.nodes:
response_node = self._graph.nodes[self._active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=event.selector, # Keep original selector
chunk=event.chunk,
is_final=event.is_final,
)
events.append(updated_event)
else:
# Regular node selector - use event as is
events.append(event)
# Check if this is the last chunk by looking ahead
stream_closed = self._is_stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete
if stream_closed:
is_complete = True
elif value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
)
)
is_complete = True
return events, is_complete
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self._active_session is not None
current_response_node = self._graph.nodes[self._active_session.node_id]
# Use get_or_create_execution_id to ensure we have a consistent ID
execution_id = self._get_or_create_execution_id(current_response_node.id)
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
event = self._create_stream_chunk_event(
node_id=current_response_node.id,
execution_id=execution_id,
selector=[current_response_node.id, "answer"], # FIXME(-LAN-)
chunk=segment.text,
is_final=is_last_segment,
)
return [event]
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
with self._lock:
if not self._active_session:
return []
template = self._active_session.template
response_node_id = self._active_session.node_id
events: list[NodeRunStreamChunkEvent] = []
# Process segments sequentially from current index
while self._active_session.index < len(template.segments):
segment = template.segments[self._active_session.index]
if isinstance(segment, VariableSegment):
# Check if the source node for this variable is skipped
# Only check for actual nodes, not special selectors (sys, env, conversation)
source_selector_prefix = segment.selector[0] if segment.selector else ""
if source_selector_prefix in self._graph.nodes:
source_node = self._graph.nodes[source_selector_prefix]
if source_node.state == NodeState.SKIPPED:
# Skip this variable segment if the source node is skipped
self._active_session.index += 1
continue
segment_events, is_complete = self._process_variable_segment(segment)
events.extend(segment_events)
# Only advance index if this variable segment is complete
if is_complete:
self._active_session.index += 1
else:
# Wait for more data
break
else:
segment_events = self._process_text_segment(segment)
events.extend(segment_events)
self._active_session.index += 1
if self._active_session.is_complete():
# End current session and get events from starting next session
next_session_events = self.end_session(response_node_id)
events.extend(next_session_events)
return events
def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]:
"""
End the active session for a response node.
Automatically starts the next waiting session if available.
Args:
node_id: ID of the response node ending its session
Returns:
List of events from starting the next session
"""
with self._lock:
events: list[NodeRunStreamChunkEvent] = []
if self._active_session and self._active_session.node_id == node_id:
self._active_session = None
# Try to start next waiting session
if self._waiting_sessions:
next_session = self._waiting_sessions.popleft()
self._active_session = next_session
# Immediately try to flush any available segments
events = self.try_flush()
return events
# ============= Internal Stream Management Methods =============
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
"""
Append a stream chunk to the internal buffer.
Args:
selector: List of strings identifying the stream location
event: The NodeRunStreamChunkEvent to append
Raises:
ValueError: If the stream is already closed
"""
key = tuple(selector)
if key in self._closed_streams:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
if key not in self._stream_buffers:
self._stream_buffers[key] = []
self._stream_positions[key] = 0
self._stream_buffers[key].append(event)
def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
"""
Pop the next unread stream chunk from the buffer.
Args:
selector: List of strings identifying the stream location
Returns:
The next event, or None if no unread events available
"""
key = tuple(selector)
if key not in self._stream_buffers:
return None
position = self._stream_positions.get(key, 0)
buffer = self._stream_buffers[key]
if position >= len(buffer):
return None
event = buffer[position]
self._stream_positions[key] = position + 1
return event
def _has_unread_stream(self, selector: Sequence[str]) -> bool:
"""
Check if the stream has unread events.
Args:
selector: List of strings identifying the stream location
Returns:
True if there are unread events, False otherwise
"""
key = tuple(selector)
if key not in self._stream_buffers:
return False
position = self._stream_positions.get(key, 0)
return position < len(self._stream_buffers[key])
def _close_stream(self, selector: Sequence[str]) -> None:
"""
Mark a stream as closed (no more chunks can be appended).
Args:
selector: List of strings identifying the stream location
"""
key = tuple(selector)
self._closed_streams.add(key)
def _is_stream_closed(self, selector: Sequence[str]) -> bool:
"""
Check if a stream is closed.
Args:
selector: List of strings identifying the stream location
Returns:
True if the stream is closed, False otherwise
"""
key = tuple(selector)
return key in self._closed_streams
def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None:
"""Convert an in-memory session into its serializable form."""
if session is None:
return None
return ResponseSessionState(node_id=session.node_id, index=session.index)
def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession:
"""Rebuild a response session from serialized data."""
node = self._graph.nodes.get(session_state.node_id)
if node is None:
raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state")
session = ResponseSession.from_node(node)
session.index = session_state.index
return session
def dumps(self) -> str:
"""Serialize coordinator state to JSON."""
with self._lock:
state = ResponseStreamCoordinatorState(
response_nodes=sorted(self._response_nodes),
active_session=self._serialize_session(self._active_session),
waiting_sessions=[
session_state
for session in list(self._waiting_sessions)
if (session_state := self._serialize_session(session)) is not None
],
pending_sessions=[
session_state
for _, session in sorted(self._response_sessions.items())
if (session_state := self._serialize_session(session)) is not None
],
node_execution_ids=dict(sorted(self._node_execution_ids.items())),
paths_map={
node_id: [path.edges.copy() for path in paths]
for node_id, paths in sorted(self._paths_maps.items())
},
stream_buffers=[
StreamBufferState(
selector=selector,
events=[event.model_copy(deep=True) for event in events],
)
for selector, events in sorted(self._stream_buffers.items())
],
stream_positions=[
StreamPositionState(selector=selector, position=position)
for selector, position in sorted(self._stream_positions.items())
],
closed_streams=sorted(self._closed_streams),
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""Restore coordinator state from JSON."""
state = ResponseStreamCoordinatorState.model_validate_json(data)
if state.type != "ResponseStreamCoordinator":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported serialized version: {state.version}")
with self._lock:
self._response_nodes = set(state.response_nodes)
self._paths_maps = {
node_id: [Path(edges=list(path_edges)) for path_edges in paths]
for node_id, paths in state.paths_map.items()
}
self._node_execution_ids = dict(state.node_execution_ids)
self._stream_buffers = {
tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events]
for buffer in state.stream_buffers
}
self._stream_positions = {
tuple(position.selector): position.position for position in state.stream_positions
}
for selector in self._stream_buffers:
self._stream_positions.setdefault(selector, 0)
self._closed_streams = {tuple(selector) for selector in state.closed_streams}
self._waiting_sessions = deque(
self._session_from_state(session_state) for session_state in state.waiting_sessions
)
self._response_sessions = {
session_state.node_id: self._session_from_state(session_state)
for session_state in state.pending_sessions
}
self._active_session = self._session_from_state(state.active_session) if state.active_session else None

View File

@ -0,0 +1,35 @@
"""
Internal path representation for response coordinator.
This module contains the private Path class used internally by ResponseStreamCoordinator
to track execution paths to response nodes.
"""
from dataclasses import dataclass, field
from typing import TypeAlias
EdgeID: TypeAlias = str
@dataclass
class Path:
"""
Represents a path of branch edges that must be taken to reach a response node.
Note: This is an internal class not exposed in the public API.
"""
edges: list[EdgeID] = field(default_factory=list[EdgeID])
def contains_edge(self, edge_id: EdgeID) -> bool:
"""Check if this path contains the given edge."""
return edge_id in self.edges
def remove_edge(self, edge_id: EdgeID) -> None:
"""Remove the given edge from this path in place."""
if self.contains_edge(edge_id):
self.edges.remove(edge_id)
def is_empty(self) -> bool:
"""Check if the path has no edges (node is reachable)."""
return len(self.edges) == 0

View File

@ -0,0 +1,52 @@
"""
Internal response session management for response coordinator.
This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
@dataclass
class ResponseSession:
"""
Represents an active response streaming session.
Note: This is an internal class not exposed in the public API.
"""
node_id: str
template: Template # Template object from the response node
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: Node) -> "ResponseSession":
"""
Create a ResponseSession from an AnswerNode or EndNode.
Args:
node: Must be either an AnswerNode or EndNode instance
Returns:
ResponseSession configured with the node's streaming template
Raises:
TypeError: If node is not an AnswerNode or EndNode
"""
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
raise TypeError
return cls(
node_id=node.id,
template=node.get_streaming_template(),
)
def is_complete(self) -> bool:
"""Check if all segments in the template have been processed."""
return self.index >= len(self.template.segments)

View File

@ -0,0 +1,142 @@
"""
Worker - Thread implementation for queue-based node execution
Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
import contextvars
import queue
import threading
import time
from datetime import datetime
from typing import final
from uuid import uuid4
from flask import Flask
from typing_extensions import override
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
from .ready_queue import ReadyQueue
@final
class Worker(threading.Thread):
"""
Worker thread that executes nodes from the ready queue.
Workers continuously pull node IDs from the ready_queue, execute the
corresponding nodes, and push the resulting events to the event_queue
for the dispatcher to process.
"""
def __init__(
self,
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
) -> None:
"""
Initialize worker thread.
Args:
ready_queue: Ready queue containing node IDs ready for execution
event_queue: Queue for pushing execution events
graph: Graph containing nodes to execute
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self._last_task_time = time.time()
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
@property
def is_idle(self) -> bool:
"""Check if the worker is currently idle."""
# Worker is idle if it hasn't processed a task recently (within 0.2 seconds)
return (time.time() - self._last_task_time) > 0.2
@property
def idle_duration(self) -> float:
"""Get the duration in seconds since the worker last processed a task."""
return time.time() - self._last_task_time
@property
def worker_id(self) -> int:
"""Get the worker's ID."""
return self._worker_id
@override
def run(self) -> None:
"""
Main worker loop.
Continuously pulls node IDs from ready_queue, executes them,
and pushes events to event_queue until stopped.
"""
while not self._stop_event.is_set():
# Try to get a node ID from the ready queue (with timeout)
try:
node_id = self._ready_queue.get(timeout=0.1)
except queue.Empty:
continue
self._last_task_time = time.time()
node = self._graph.nodes[node_id]
try:
self._execute_node(node)
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
node_id="unknown",
node_type=NodeType.CODE,
in_iteration_id=None,
error=str(e),
start_at=datetime.now(),
)
self._event_queue.put(error_event)
def _execute_node(self, node: Node) -> None:
"""
Execute a single node and handle its events.
Args:
node: The node instance to execute
"""
# Execute the node with preserved context if Flask app is provided
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
else:
# Execute without context preservation
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)

View File

@ -0,0 +1,12 @@
"""
Worker management subsystem for graph engine.
This package manages the worker pool, including creation,
scaling, and activity tracking.
"""
from .worker_pool import WorkerPool
__all__ = [
"WorkerPool",
]

View File

@ -0,0 +1,291 @@
"""
Simple worker pool that consolidates functionality.
This is a simpler implementation that merges WorkerPool, ActivityTracker,
DynamicScaler, and WorkerFactory into a single class.
"""
import logging
import queue
import threading
from typing import TYPE_CHECKING, final
from configs import dify_config
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..ready_queue import ReadyQueue
from ..worker import Worker
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from contextvars import Context
from flask import Flask
@final
class WorkerPool:
"""
Simple worker pool with integrated management.
This class consolidates all worker management functionality into
a single, simpler implementation without excessive abstraction.
"""
def __init__(
self,
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
scale_down_idle_time: float | None = None,
) -> None:
"""
Initialize the simple worker pool.
Args:
ready_queue: Ready queue for nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Seconds before scaling down idle workers
"""
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
# Scaling parameters with defaults
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
# Worker management
self._workers: list[Worker] = []
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
# No longer tracking worker states with callbacks to avoid lock contention
def start(self, initial_count: int | None = None) -> None:
"""
Start the worker pool.
Args:
initial_count: Number of workers to start with (auto-calculated if None)
"""
with self._lock:
if self._running:
return
self._running = True
# Calculate initial worker count
if initial_count is None:
node_count = len(self._graph.nodes)
if node_count < 10:
initial_count = self._min_workers
elif node_count < 50:
initial_count = min(self._min_workers + 1, self._max_workers)
else:
initial_count = min(self._min_workers + 2, self._max_workers)
logger.debug(
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
initial_count,
node_count,
self._min_workers,
self._max_workers,
)
# Create initial workers
for _ in range(initial_count):
self._create_worker()
def stop(self) -> None:
"""Stop all workers in the pool."""
with self._lock:
self._running = False
worker_count = len(self._workers)
if worker_count > 0:
logger.debug("Stopping worker pool: %d workers", worker_count)
# Stop all workers
for worker in self._workers:
worker.stop()
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=10.0)
self._workers.clear()
def _create_worker(self) -> None:
"""Create and start a new worker."""
worker_id = self._worker_counter
self._worker_counter += 1
worker = Worker(
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,
)
worker.start()
self._workers.append(worker)
def _remove_worker(self, worker: Worker, worker_id: int) -> None:
"""Remove a specific worker from the pool."""
# Stop the worker
worker.stop()
# Wait for it to finish
if worker.is_alive():
worker.join(timeout=2.0)
# Remove from list
if worker in self._workers:
self._workers.remove(worker)
def _try_scale_up(self, queue_depth: int, current_count: int) -> bool:
"""
Try to scale up workers if needed.
Args:
queue_depth: Current queue depth
current_count: Current number of workers
Returns:
True if scaled up, False otherwise
"""
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
old_count = current_count
self._create_worker()
logger.debug(
"Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)",
old_count,
len(self._workers),
queue_depth,
self._scale_up_threshold,
)
return True
return False
def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool:
"""
Try to scale down workers if we have excess capacity.
Args:
queue_depth: Current queue depth
current_count: Current number of workers
active_count: Number of active workers
idle_count: Number of idle workers
Returns:
True if scaled down, False otherwise
"""
# Skip if we're at minimum or have no idle workers
if current_count <= self._min_workers or idle_count == 0:
return False
# Check if we have excess capacity
has_excess_capacity = (
queue_depth <= active_count # Active workers can handle current queue
or idle_count > active_count # More idle than active workers
or (queue_depth == 0 and idle_count > 0) # No work and have idle workers
)
if not has_excess_capacity:
return False
# Find and remove idle workers that have been idle long enough
workers_to_remove: list[tuple[Worker, int]] = []
for worker in self._workers:
# Check if worker is idle and has exceeded idle time threshold
if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time:
# Don't remove if it would leave us unable to handle the queue
remaining_workers = current_count - len(workers_to_remove) - 1
if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2):
workers_to_remove.append((worker, worker.worker_id))
# Only remove one worker per check to avoid aggressive scaling
break
# Remove idle workers if any found
if workers_to_remove:
old_count = current_count
for worker, worker_id in workers_to_remove:
self._remove_worker(worker, worker_id)
logger.debug(
"Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, "
"queue_depth=%d, active=%d, idle=%d)",
old_count,
len(self._workers),
len(workers_to_remove),
self._scale_down_idle_time,
queue_depth,
active_count,
idle_count - len(workers_to_remove),
)
return True
return False
def check_and_scale(self) -> None:
"""Check and perform scaling if needed."""
with self._lock:
if not self._running:
return
current_count = len(self._workers)
queue_depth = self._ready_queue.qsize()
# Count active vs idle workers by querying their state directly
idle_count = sum(1 for worker in self._workers if worker.is_idle)
active_count = current_count - idle_count
# Try to scale up if queue is backing up
self._try_scale_up(queue_depth, current_count)
# Try to scale down if we have excess capacity
self._try_scale_down(queue_depth, current_count, active_count, idle_count)
def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self._workers)
def get_status(self) -> dict[str, int]:
"""
Get pool status information.
Returns:
Dictionary with status information
"""
with self._lock:
return {
"total_workers": len(self._workers),
"queue_depth": self._ready_queue.qsize(),
"min_workers": self._min_workers,
"max_workers": self._max_workers,
}

View File

@ -0,0 +1,72 @@
# Agent events
from .agent import NodeRunAgentLogEvent
# Base events
from .base import (
BaseGraphEvent,
GraphEngineEvent,
GraphNodeEventBase,
)
# Graph events
from .graph import (
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
# Iteration events
from .iteration import (
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
)
# Loop events
from .loop import (
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
)
# Node events
from .node import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
__all__ = [
"BaseGraphEvent",
"GraphEngineEvent",
"GraphNodeEventBase",
"GraphRunAbortedEvent",
"GraphRunFailedEvent",
"GraphRunPartialSucceededEvent",
"GraphRunStartedEvent",
"GraphRunSucceededEvent",
"NodeRunAgentLogEvent",
"NodeRunExceptionEvent",
"NodeRunFailedEvent",
"NodeRunIterationFailedEvent",
"NodeRunIterationNextEvent",
"NodeRunIterationStartedEvent",
"NodeRunIterationSucceededEvent",
"NodeRunLoopFailedEvent",
"NodeRunLoopNextEvent",
"NodeRunLoopStartedEvent",
"NodeRunLoopSucceededEvent",
"NodeRunRetrieverResourceEvent",
"NodeRunRetryEvent",
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
]

View File

@ -0,0 +1,17 @@
from collections.abc import Mapping
from typing import Any
from pydantic import Field
from .base import GraphAgentNodeEventBase
class NodeRunAgentLogEvent(GraphAgentNodeEventBase):
message_id: str = Field(..., description="message id")
label: str = Field(..., description="label")
node_execution_id: str = Field(..., description="node execution id")
parent_id: str | None = Field(..., description="parent id")
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
metadata: Mapping[str, object] = Field(default_factory=dict)

View File

@ -0,0 +1,31 @@
from pydantic import BaseModel, Field
from core.workflow.enums import NodeType
from core.workflow.node_events import NodeRunResult
class GraphEngineEvent(BaseModel):
pass
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphNodeEventBase(GraphEngineEvent):
id: str = Field(..., description="node execution id")
node_id: str
node_type: NodeType
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
node_run_result: NodeRunResult = Field(default_factory=NodeRunResult)
class GraphAgentNodeEventBase(GraphNodeEventBase):
pass

View File

@ -0,0 +1,28 @@
from pydantic import Field
from core.workflow.graph_events import BaseGraphEvent
class GraphRunStartedEvent(BaseGraphEvent):
pass
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: dict[str, object] = Field(default_factory=dict)
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: int = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
exceptions_count: int = Field(..., description="exception count")
outputs: dict[str, object] = Field(default_factory=dict)
class GraphRunAbortedEvent(BaseGraphEvent):
"""Event emitted when a graph run is aborted by user command."""
reason: str | None = Field(default=None, description="reason for abort")
outputs: dict[str, object] = Field(default_factory=dict, description="partial outputs if any")

View File

@ -0,0 +1,40 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from pydantic import Field
from .base import GraphNodeEventBase
class NodeRunIterationStartedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
class NodeRunIterationNextEvent(GraphNodeEventBase):
node_title: str
index: int = Field(..., description="index")
pre_iteration_output: Any = None
class NodeRunIterationSucceededEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
class NodeRunIterationFailedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -0,0 +1,40 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from pydantic import Field
from .base import GraphNodeEventBase
class NodeRunLoopStartedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
class NodeRunLoopNextEvent(GraphNodeEventBase):
node_title: str
index: int = Field(..., description="index")
pre_loop_output: Any = None
class NodeRunLoopSucceededEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
class NodeRunLoopFailedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -0,0 +1,53 @@
from collections.abc import Sequence
from datetime import datetime
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from .base import GraphNodeEventBase
class NodeRunStartedEvent(GraphNodeEventBase):
node_title: str
predecessor_node_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
start_at: datetime = Field(..., description="node start time")
# FIXME(-LAN-): only for ToolNode
provider_type: str = ""
provider_id: str = ""
class NodeRunStreamChunkEvent(GraphNodeEventBase):
# Spec-compliant fields
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(GraphNodeEventBase):
start_at: datetime = Field(..., description="node start time")
class NodeRunFailedEvent(GraphNodeEventBase):
error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time")
class NodeRunExceptionEvent(GraphNodeEventBase):
error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time")
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")

View File

@ -0,0 +1,40 @@
from .agent import AgentLogEvent
from .base import NodeEventBase, NodeRunResult
from .iteration import (
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
IterationSucceededEvent,
)
from .loop import (
LoopFailedEvent,
LoopNextEvent,
LoopStartedEvent,
LoopSucceededEvent,
)
from .node import (
ModelInvokeCompletedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
StreamChunkEvent,
StreamCompletedEvent,
)
__all__ = [
"AgentLogEvent",
"IterationFailedEvent",
"IterationNextEvent",
"IterationStartedEvent",
"IterationSucceededEvent",
"LoopFailedEvent",
"LoopNextEvent",
"LoopStartedEvent",
"LoopSucceededEvent",
"ModelInvokeCompletedEvent",
"NodeEventBase",
"NodeRunResult",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"StreamChunkEvent",
"StreamCompletedEvent",
]

View File

@ -0,0 +1,18 @@
from collections.abc import Mapping
from typing import Any
from pydantic import Field
from .base import NodeEventBase
class AgentLogEvent(NodeEventBase):
message_id: str = Field(..., description="id")
label: str = Field(..., description="label")
node_execution_id: str = Field(..., description="node execution id")
parent_id: str | None = Field(..., description="parent id")
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
metadata: Mapping[str, Any] = Field(default_factory=dict, description="metadata")
node_id: str = Field(..., description="node id")

View File

@ -0,0 +1,40 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class NodeEventBase(BaseModel):
"""Base class for all node events"""
pass
def _default_metadata():
v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
return v
class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.PENDING
inputs: Mapping[str, Any] = Field(default_factory=dict)
process_data: Mapping[str, Any] = Field(default_factory=dict)
outputs: Mapping[str, Any] = Field(default_factory=dict)
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata)
llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
edge_source_handle: str = "source" # source handle id of node with multiple branches
error: str = ""
error_type: str = ""
# single step node run retry
retry_index: int = 0

View File

@ -0,0 +1,36 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from pydantic import Field
from .base import NodeEventBase
class IterationStartedEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
class IterationNextEvent(NodeEventBase):
index: int = Field(..., description="index")
pre_iteration_output: Any = None
class IterationSucceededEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
class IterationFailedEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -0,0 +1,36 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from pydantic import Field
from .base import NodeEventBase
class LoopStartedEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
class LoopNextEvent(NodeEventBase):
index: int = Field(..., description="index")
pre_loop_output: Any = None
class LoopSucceededEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
class LoopFailedEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -0,0 +1,41 @@
from collections.abc import Sequence
from datetime import datetime
from pydantic import Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_events import NodeRunResult
from .base import NodeEventBase
class RunRetrieverResourceEvent(NodeEventBase):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class ModelInvokeCompletedEvent(NodeEventBase):
text: str
usage: LLMUsage
finish_reason: str | None = None
reasoning_content: str | None = None
class RunRetryEvent(NodeEventBase):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="Retry attempt number")
start_at: datetime = Field(..., description="Retry start time")
class StreamChunkEvent(NodeEventBase):
# Spec-compliant fields
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
class StreamCompletedEvent(NodeEventBase):
node_run_result: NodeRunResult = Field(..., description="run result")

View File

@ -1,3 +1,3 @@
from .enums import NodeType
from core.workflow.enums import NodeType
__all__ = ["NodeType"]

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
from packaging.version import Version
from pydantic import ValidationError
@ -9,16 +9,12 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.agent.strategy.plugin import PluginAgentStrategy
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
@ -29,17 +25,25 @@ from core.tools.entities.tool_entities import (
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
@ -57,19 +61,23 @@ from .exc import (
ToolFileNotFoundError,
)
if TYPE_CHECKING:
from core.agent.strategy.plugin import PluginAgentStrategy
from core.plugin.entities.request import InvokeCredentials
class AgentNode(BaseNode):
class AgentNode(Node):
"""
Agent Node
"""
_node_type = NodeType.AGENT
node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -78,7 +86,7 @@ class AgentNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -91,7 +99,9 @@ class AgentNode(BaseNode):
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
@ -99,12 +109,12 @@ class AgentNode(BaseNode):
agent_strategy_name=self._node_data.agent_strategy_name,
)
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
)
),
)
return
@ -139,8 +149,8 @@ class AgentNode(BaseNode):
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
@ -158,16 +168,16 @@ class AgentNode(BaseNode):
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_type=self.type_,
node_id=self.node_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield RunCompletedEvent(
run_result=NodeRunResult(
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
@ -181,7 +191,7 @@ class AgentNode(BaseNode):
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: PluginAgentStrategy,
strategy: "PluginAgentStrategy",
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@ -320,7 +330,7 @@ class AgentNode(BaseNode):
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
@ -339,10 +349,11 @@ class AgentNode(BaseNode):
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
) -> "InvokeCredentials":
"""
Generate credentials based on the given agent parameters.
"""
from core.plugin.entities.request import InvokeCredentials
credentials = InvokeCredentials()
@ -388,6 +399,8 @@ class AgentNode(BaseNode):
Get agent strategy icon
:return:
"""
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id)
try:
@ -401,7 +414,7 @@ class AgentNode(BaseNode):
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
@ -450,7 +463,9 @@ class AgentNode(BaseNode):
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
def _filter_mcp_type_tool(
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy
@ -473,11 +488,13 @@ class AgentNode(BaseNode):
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator:
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
@ -491,7 +508,7 @@ class AgentNode(BaseNode):
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
@ -553,7 +570,11 @@ class AgentNode(BaseNode):
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
@ -564,13 +585,17 @@ class AgentNode(BaseNode):
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if message.message.json_object is not None:
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
@ -587,8 +612,10 @@ class AgentNode(BaseNode):
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
@ -639,7 +666,7 @@ class AgentNode(BaseNode):
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
@ -652,7 +679,7 @@ class AgentNode(BaseNode):
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
if log.message_id == agent_log.message_id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
@ -673,7 +700,7 @@ class AgentNode(BaseNode):
for log in agent_logs:
json_output.append(
{
"id": log.id,
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
@ -689,8 +716,24 @@ class AgentNode(BaseNode):
else:
json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,

View File

@ -1,4 +1,4 @@
from enum import Enum, StrEnum
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from pydantic import BaseModel
@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData):
agent_parameters: dict[str, AgentInput]
class ParamsAutoGenerated(Enum):
CLOSE = 0
OPEN = 1
class ParamsAutoGenerated(IntEnum):
CLOSE = auto()
OPEN = auto()
class AgentOldVersionModelFeatures(StrEnum):
@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum):
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()

View File

@ -1,6 +1,3 @@
from typing import Optional
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
@ -12,7 +9,7 @@ class AgentNodeError(Exception):
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError):
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
def __init__(self, strategy_name: str, provider_name: str | None = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError):
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError):
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: Optional[str] = None):
def __init__(self, message: str, parameter_name: str | None = None):
self.parameter_name = parameter_name
super().__init__(message)
@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError):
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: Optional[str] = None):
def __init__(self, message: str, variable_name: str | None = None):
self.variable_name = variable_name
super().__init__(message)
@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError):
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: Optional[str] = None):
def __init__(self, message: str, file_id: str | None = None):
self.file_id = file_id
super().__init__(message)
@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError):
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError):
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError):
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: Optional[str] = None):
def __init__(self, message: str, conversation_id: str | None = None):
self.conversation_id = conversation_id
super().__init__(message)
@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError):
def __init__(
self,
message: str,
variable_name: Optional[str] = None,
expected_type: Optional[str] = None,
actual_type: Optional[str] = None,
variable_name: str | None = None,
expected_type: str | None = None,
actual_type: str | None = None,
):
self.variable_name = variable_name
self.expected_type = expected_type

View File

@ -1,4 +0,0 @@
from .answer_node import AnswerNode
from .entities import AnswerStreamGenerateRoute
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]

View File

@ -1,31 +1,26 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base import BaseNode
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode):
_node_type = NodeType.ANSWER
class AnswerNode(Node):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.RESPONSE
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -34,7 +29,7 @@ class AnswerNode(BaseNode):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -48,35 +43,29 @@ class AnswerNode(BaseNode):
return "1"
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
answer = ""
files = []
for part in generate_routes:
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
variable = self.graph_runtime_state.variable_pool.get(value_selector)
if variable:
if isinstance(variable, FileSegment):
files.append(variable.value)
elif isinstance(variable, ArrayFileSegment):
files.extend(variable.value)
answer += variable.markdown
else:
part = cast(TextGenerateRouteChunk, part)
answer += part.text
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
files = self._extract_files_from_segments(segments.value)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)},
)
def _extract_files_from_segments(self, segments: Sequence[Segment]):
"""Extract all files from segments containing FileSegment or ArrayFileSegment instances.
FileSegment contains a single file, while ArrayFileSegment contains multiple files.
This method flattens all files into a single list.
"""
files = []
for segment in segments:
if isinstance(segment, FileSegment):
# Single file - wrap in list for consistency
files.append(segment.value)
elif isinstance(segment, ArrayFileSegment):
# Multiple files - extend the list
files.extend(segment.value)
return files
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -96,3 +85,12 @@ class AnswerNode(BaseNode):
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return variable_mapping
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this Answer node
"""
return Template.from_answer_template(self._node_data.answer)

View File

@ -1,174 +0,0 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamGeneratorRouter:
@classmethod
def init(
cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
for answer_node_id, node_config in node_id_config_mapping.items():
if node_config.get("data", {}).get("type") != NodeType.ANSWER.value:
continue
# get generate route for stream output
generate_route = cls._extract_generate_route_selectors(node_config)
answer_generate_route[answer_node_id] = generate_route
# fetch answer dependencies
answer_node_ids = list(answer_generate_route.keys())
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping,
)
return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω")
generate_routes: list[GenerateRouteChunk] = []
for part in template.split("Ω"):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "")
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector))
else:
generate_routes.append(TextGenerateRouteChunk(text=part))
return generate_routes
@classmethod
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = AnswerNodeData(**config.get("data", {}))
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace("{{", "").replace("}}", "")
return part.startswith("{{") and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(
cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
answer_dependencies: dict[str, list[str]] = {}
for answer_node_id in answer_node_ids:
if answer_dependencies.get(answer_node_id) is None:
answer_dependencies[answer_node_id] = []
cls._recursive_fetch_answer_dependencies(
current_node_id=answer_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(
cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (
source_node_type
in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.LOOP,
NodeType.VARIABLE_ASSIGNER,
}
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)

View File

@ -1,199 +0,0 @@
import logging
from collections.abc import Generator
from typing import cast
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunExceptionEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
for answer_node_id in self.generate_routes.answer_generate_route:
self.route_position[answer_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id or event.in_loop_id:
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = (
stream_out_answer_node_ids
)
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event))
else:
yield event
def reset(self) -> None:
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for answer_node_id in self.route_position:
# all depends on answer node id not in rest node ids
if event.route_node_state.node_id != answer_node_id and (
answer_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
)
):
continue
route_position = self.route_position[answer_node_id]
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
for route_chunk in route_chunks:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
break
value = self.variable_pool.get(value_selector)
if value is None:
break
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=list(value_selector),
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[answer_node_id] += 1
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
stream_out_answer_node_ids = []
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies
edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids):
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
continue
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids

View File

@ -1,109 +0,0 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
logger = logging.getLogger(__name__)
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids: list[str] = []
unreachable_first_node_ids: list[str] = []
if finished_node_id not in self.graph.edge_mapping:
logger.warning("node %s has no edge mapping", finished_node_id)
return
for edge in self.graph.edge_mapping[finished_node_id]:
if (
edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify
):
# remove unreachable nodes
# FIXME: because of the code branch can combine directly, so for answer node
# we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564
# ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
# if "answer" in ids:
# continue
# else:
# reachable_node_ids.extend(ids)
# The branch_identify parameter is added to ensure that
# only nodes in the correct logical branch are included.
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
reachable_node_ids.extend(ids)
else:
# if the condition edge in parallel, and the target node is not in parallel, we should not remove it
# Issues: #13626
if (
finished_node_id in self.graph.node_parallel_mapping
and edge.target_node_id not in self.graph.node_parallel_mapping
):
continue
unreachable_first_node_ids.append(edge.target_node_id)
unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids))
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
if node_id not in self.rest_node_ids:
self.rest_node_ids.append(node_id)
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
# Only follow edges that match the branch_identify or have no run_condition
if edge.run_condition and edge.run_condition.branch_identify:
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
if node_id in reachable_node_ids:
return
self.rest_node_ids.remove(node_id)
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum, auto
from pydantic import BaseModel, Field
@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel):
Generate Route Chunk.
"""
class ChunkType(Enum):
VAR = "var"
TEXT = "text"
class ChunkType(StrEnum):
VAR = auto()
TEXT = auto()
type: ChunkType = Field(..., description="generate route chunk type")

View File

@ -1,11 +1,9 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .node import BaseNode
__all__ = [
"BaseIterationNodeData",
"BaseIterationState",
"BaseLoopNodeData",
"BaseLoopState",
"BaseNode",
"BaseNodeData",
]

Some files were not shown because too many files have changed in this diff Show More