refactor(api): rename dify_graph to graphon (#34095)

This commit is contained in:
99
2026-03-25 21:58:56 +08:00
committed by GitHub
parent 7e9d00a5a6
commit 52e7492cbc
898 changed files with 2687 additions and 2687 deletions

135
api/graphon/README.md Normal file
View File

@ -0,0 +1,135 @@
# Workflow
## Project Overview
This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control.
## Architecture
### Core Components
The graph engine follows a layered architecture with strict dependency rules:
1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution
- **Manager** - External control interface for stop/pause/resume commands
- **Worker** - Node execution runtime
- **Command Processing** - Handles control commands (abort, pause, resume)
- **Event Management** - Event propagation and layer notifications
- **Graph Traversal** - Edge processing and skip propagation
- **Response Coordinator** - Path tracking and session management
- **Layers** - Pluggable middleware (debug logging, execution limits)
- **Command Channels** - Communication channels (InMemory, Redis)
1. **Graph** (`graph/`) - Graph structure and runtime state
- **Graph Template** - Workflow definition
- **Edge** - Node connections with conditions
- **Runtime State Protocol** - State management interface
1. **Nodes** (`nodes/`) - Node implementations
- **Base** - Abstract node classes and variable parsing
- **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc.
1. **Events** (`node_events/`) - Event system
- **Base** - Event protocols
- **Node Events** - Node lifecycle events
1. **Entities** (`entities/`) - Domain models
- **Variable Pool** - Variable storage
- **Graph Init Params** - Initialization configuration
## Key Design Patterns
### Command Channel Pattern
External workflow control via Redis or in-memory channels:
```python
# Send stop command to running workflow
channel = RedisChannel(redis_client, f"workflow:{task_id}:commands")
channel.send_command(AbortCommand(reason="User requested"))
```
### Layer System
Extensible middleware for cross-cutting concerns:
```python
engine = GraphEngine(graph)
engine.layer(DebugLoggingLayer(level="INFO"))
engine.layer(ExecutionLimitsLayer(max_nodes=100))
```
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
can assume `graph_runtime_state` is available.
### Event-Driven Architecture
All node executions emit events for monitoring and integration:
- `NodeRunStartedEvent` - Node execution begins
- `NodeRunSucceededEvent` - Node completes successfully
- `NodeRunFailedEvent` - Node encounters error
- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle
### Variable Pool
Centralized variable storage with namespace isolation:
```python
# Variables scoped by node_id
pool.add(["node1", "output"], value)
result = pool.get(["node1", "output"])
```
## Import Architecture Rules
The codebase enforces strict layering via import-linter:
1. **Workflow Layers** (top to bottom):
- graph_engine → graph_events → graph → nodes → node_events → entities
1. **Graph Engine Internal Layers**:
- orchestration → command_processing → event_management → graph_traversal → domain
1. **Domain Isolation**:
- Domain models cannot import from infrastructure layers
1. **Command Channel Independence**:
- InMemory and Redis channels must remain independent
## Common Tasks
### Adding a New Node Type
1. Create node class in `nodes/<node_type>/`
1. Inherit from `BaseNode` or appropriate base class
1. Implement `_run()` method
1. Ensure the node module is importable under `nodes/<node_type>/`
1. Add tests in `tests/unit_tests/graphon/nodes/`
### Implementing a Custom Layer
1. Create class inheriting from `Layer` base
1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()`
1. Add to engine via `engine.layer()`
### Debugging Workflow Execution
Enable debug logging layer:
```python
debug_layer = DebugLoggingLayer(
level="DEBUG",
include_inputs=True,
include_outputs=True
)
```

0
api/graphon/__init__.py Normal file
View File

View File

@ -0,0 +1,11 @@
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_start_reason import WorkflowStartReason
__all__ = [
"GraphInitParams",
"WorkflowExecution",
"WorkflowNodeExecution",
"WorkflowStartReason",
]

View File

@ -0,0 +1,178 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, model_validator
from graphon.entities.exc import DefaultValueTypeError
from graphon.enums import ErrorStrategy, NodeType
# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`.
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
class DefaultValue(BaseModel):
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
# `type` therefore accepts downstream string node kinds; unknown node implementations
# are rejected later when the node factory resolves the node registry.
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
# and persisted templates/workflows also carry undeclared compatibility keys such as
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive
# here until graph parsing becomes discriminated by node type or those legacy payloads
# are normalized.
model_config = ConfigDict(extra="allow")
type: NodeType
title: str = ""
desc: str | None = None
version: str = "1"
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = Field(default_factory=RetryConfig)
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
def __getitem__(self, key: str) -> Any:
"""
Dict-style access without calling model_dump() on every lookup.
Prefer using model fields and Pydantic's extra storage.
"""
# First, check declared model fields
if key in self.__class__.model_fields:
return getattr(self, key)
# Then, check undeclared compatibility fields stored in Pydantic's extra dict.
extras = getattr(self, "__pydantic_extra__", None)
if extras is None:
extras = getattr(self, "model_extra", None)
if extras is not None and key in extras:
return extras[key]
raise KeyError(key)
def get(self, key: str, default: Any = None) -> Any:
"""
Dict-style .get() without calling model_dump() on every lookup.
"""
if key in self.__class__.model_fields:
return getattr(self, key)
extras = getattr(self, "__pydantic_extra__", None)
if extras is None:
extras = getattr(self, "model_extra", None)
if extras is not None and key in extras:
return extras.get(key, default)
return default

View File

@ -0,0 +1,10 @@
class BaseNodeError(ValueError):
"""Base class for node errors."""
pass
class DefaultValueTypeError(BaseNodeError):
"""Raised when the default value type is invalid."""
pass

View File

@ -0,0 +1,23 @@
from __future__ import annotations
import sys
from pydantic import TypeAdapter, with_config
from graphon.entities.base_node_data import BaseNodeData
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
@with_config(extra="allow")
class NodeConfigDict(TypedDict):
id: str
# This is the permissive raw graph boundary. Node factories re-validate `data`
# with the concrete `NodeData` subtype after resolving the node implementation.
data: BaseNodeData
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)

View File

@ -0,0 +1,24 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
DIFY_RUN_CONTEXT_KEY = "_dify"
class GraphInitParams(BaseModel):
"""GraphInitParams encapsulates the configurations and contextual information
that remain constant throughout a single execution of the graph engine.
A single execution is defined as follows: as long as the execution has not reached
its conclusion, it is considered one execution. For instance, if a workflow is suspended
and later resumed, it is still regarded as a single execution, not two.
For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`.
"""
# init params
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
run_context: Mapping[str, Any] = Field(..., description="runtime context")
call_depth: int = Field(..., description="call depth")

View File

@ -0,0 +1,42 @@
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Annotated, Any, Literal, TypeAlias
from pydantic import BaseModel, Field
from graphon.nodes.human_input.entities import FormInput, UserAction
class PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
SCHEDULED_PAUSE = auto()
class HumanInputRequired(BaseModel):
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
form_content: str
inputs: list[FormInput] = Field(default_factory=list)
actions: list[UserAction] = Field(default_factory=list)
node_id: str
node_title: str
# The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from
# `output_variable_name` to their resolved values.
#
# For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its
# selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable
# `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The
# `resolved_default_values` is `{"name": "John"}`.
#
# Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`.
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
class SchedulingPause(BaseModel):
TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
message: str
PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]

View File

@ -0,0 +1,71 @@
"""
Domain entities for workflow execution.
Models describe graph runtime state and avoid infrastructure-specific details.
"""
from __future__ import annotations
from collections.abc import Mapping
from datetime import UTC, datetime
from typing import Any
from pydantic import BaseModel, Field
from graphon.enums import WorkflowExecutionStatus, WorkflowType
class WorkflowExecution(BaseModel):
"""
Domain model for a workflow execution within the graph runtime.
"""
id_: str = Field(...)
workflow_id: str = Field(...)
workflow_version: str = Field(...)
workflow_type: WorkflowType = Field(...)
graph: Mapping[str, Any] = Field(...)
inputs: Mapping[str, Any] = Field(...)
outputs: Mapping[str, Any] | None = None
status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
error_message: str = Field(default="")
total_tokens: int = Field(default=0)
total_steps: int = Field(default=0)
exceptions_count: int = Field(default=0)
started_at: datetime = Field(...)
finished_at: datetime | None = None
@property
def elapsed_time(self) -> float:
"""
Calculate elapsed time in seconds.
If workflow is not finished, use current time.
"""
end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None)
return (end_time - self.started_at).total_seconds()
@classmethod
def new(
cls,
*,
id_: str,
workflow_id: str,
workflow_type: WorkflowType,
workflow_version: str,
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
) -> WorkflowExecution:
return WorkflowExecution(
id_=id_,
workflow_id=workflow_id,
workflow_type=workflow_type,
workflow_version=workflow_version,
graph=graph,
inputs=inputs,
status=WorkflowExecutionStatus.RUNNING,
started_at=started_at,
)

View File

@ -0,0 +1,141 @@
"""
Domain entities for workflow node execution.
These models capture node-level execution state for the graph runtime without
describing storage or application-layer concerns.
"""
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field, PrivateAttr
from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class WorkflowNodeExecution(BaseModel):
"""
Domain model for workflow node execution.
This model represents the graph-level record of a node execution and
contains only execution state relevant to the runtime.
"""
# --------- Core identification fields ---------
# Unique identifier for this execution record, used when persisting to storage.
# Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382').
id: str
# Optional secondary ID for cross-referencing purposes.
#
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
# In most scenarios, `id` should be used as the primary identifier.
node_execution_id: str | None = None
workflow_id: str # ID of the workflow this node belongs to
workflow_execution_id: str | None = None # ID of the workflow execution (null for single-step debugging)
# --------- Core identification fields ends ---------
# Execution positioning and flow
index: int # Sequence number for ordering in trace visualization
predecessor_node_id: str | None = None # ID of the node that executed before this one
node_id: str # ID of the node being executed
node_type: NodeType # Type of node (e.g., start, llm, downstream response node)
title: str # Display title of the node
# Execution data
# The `inputs` and `outputs` fields hold the full content
inputs: Mapping[str, Any] | None = None # Input variables used by this node
process_data: Mapping[str, Any] | None = None # Intermediate processing data
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
# Execution state
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
error: str | None = None # Error message if execution failed
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
# Additional metadata
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.)
# Timing information
created_at: datetime # When execution started
finished_at: datetime | None = None # When execution completed
_truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None)
_truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None)
_truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None)
def get_truncated_inputs(self) -> Mapping[str, Any] | None:
return self._truncated_inputs
def get_truncated_outputs(self) -> Mapping[str, Any] | None:
return self._truncated_outputs
def get_truncated_process_data(self) -> Mapping[str, Any] | None:
return self._truncated_process_data
def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None):
self._truncated_inputs = truncated_inputs
def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None):
self._truncated_outputs = truncated_outputs
def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None):
self._truncated_process_data = truncated_process_data
def get_response_inputs(self) -> Mapping[str, Any] | None:
inputs = self.get_truncated_inputs()
if inputs:
return inputs
return self.inputs
@property
def inputs_truncated(self):
return self._truncated_inputs is not None
@property
def outputs_truncated(self):
return self._truncated_outputs is not None
@property
def process_data_truncated(self):
return self._truncated_process_data is not None
def get_response_outputs(self) -> Mapping[str, Any] | None:
outputs = self.get_truncated_outputs()
if outputs is not None:
return outputs
return self.outputs
def get_response_process_data(self) -> Mapping[str, Any] | None:
process_data = self.get_truncated_process_data()
if process_data is not None:
return process_data
return self.process_data
def update_from_mapping(
self,
inputs: Mapping[str, Any] | None = None,
process_data: Mapping[str, Any] | None = None,
outputs: Mapping[str, Any] | None = None,
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
):
"""
Update the model from mappings.
Args:
inputs: The inputs to update
process_data: The process data to update
outputs: The outputs to update
metadata: The metadata to update
"""
if inputs is not None:
self.inputs = dict(inputs)
if process_data is not None:
self.process_data = dict(process_data)
if outputs is not None:
self.outputs = dict(outputs)
if metadata is not None:
self.metadata = dict(metadata)

View File

@ -0,0 +1,8 @@
from enum import StrEnum
class WorkflowStartReason(StrEnum):
"""Reason for workflow start events across graph/queue/SSE layers."""
INITIAL = "initial" # First start of a workflow run.
RESUMPTION = "resumption" # Start triggered after resuming a paused run.

262
api/graphon/enums.py Normal file
View File

@ -0,0 +1,262 @@
from enum import StrEnum
from typing import ClassVar, TypeAlias
class NodeState(StrEnum):
"""State of a node or edge during workflow execution."""
UNKNOWN = "unknown"
TAKEN = "taken"
SKIPPED = "skipped"
NodeType: TypeAlias = str
class BuiltinNodeTypes:
"""Built-in node type string constants.
`node_type` values are plain strings throughout the graph runtime. This namespace
only exposes the built-in values shipped by `graphon`; downstream packages can
use additional strings without extending this class.
"""
START: ClassVar[NodeType] = "start"
END: ClassVar[NodeType] = "end"
ANSWER: ClassVar[NodeType] = "answer"
LLM: ClassVar[NodeType] = "llm"
KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval"
IF_ELSE: ClassVar[NodeType] = "if-else"
CODE: ClassVar[NodeType] = "code"
TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform"
QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier"
HTTP_REQUEST: ClassVar[NodeType] = "http-request"
TOOL: ClassVar[NodeType] = "tool"
DATASOURCE: ClassVar[NodeType] = "datasource"
VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner"
LOOP: ClassVar[NodeType] = "loop"
LOOP_START: ClassVar[NodeType] = "loop-start"
LOOP_END: ClassVar[NodeType] = "loop-end"
ITERATION: ClassVar[NodeType] = "iteration"
ITERATION_START: ClassVar[NodeType] = "iteration-start"
PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor"
VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner"
DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor"
LIST_OPERATOR: ClassVar[NodeType] = "list-operator"
AGENT: ClassVar[NodeType] = "agent"
HUMAN_INPUT: ClassVar[NodeType] = "human-input"
BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = (
BuiltinNodeTypes.START,
BuiltinNodeTypes.END,
BuiltinNodeTypes.ANSWER,
BuiltinNodeTypes.LLM,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL,
BuiltinNodeTypes.IF_ELSE,
BuiltinNodeTypes.CODE,
BuiltinNodeTypes.TEMPLATE_TRANSFORM,
BuiltinNodeTypes.QUESTION_CLASSIFIER,
BuiltinNodeTypes.HTTP_REQUEST,
BuiltinNodeTypes.TOOL,
BuiltinNodeTypes.DATASOURCE,
BuiltinNodeTypes.VARIABLE_AGGREGATOR,
BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR,
BuiltinNodeTypes.LOOP,
BuiltinNodeTypes.LOOP_START,
BuiltinNodeTypes.LOOP_END,
BuiltinNodeTypes.ITERATION,
BuiltinNodeTypes.ITERATION_START,
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
BuiltinNodeTypes.VARIABLE_ASSIGNER,
BuiltinNodeTypes.DOCUMENT_EXTRACTOR,
BuiltinNodeTypes.LIST_OPERATOR,
BuiltinNodeTypes.AGENT,
BuiltinNodeTypes.HUMAN_INPUT,
)
class NodeExecutionType(StrEnum):
"""Node execution type classification."""
EXECUTABLE = "executable" # Regular nodes that execute and produce outputs
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
ROOT = "root" # Nodes that can serve as execution entry points
class ErrorStrategy(StrEnum):
FAIL_BRANCH = "fail-branch"
DEFAULT_VALUE = "default-value"
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"
class WorkflowType(StrEnum):
"""
Workflow Type Enum for domain layer
"""
WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"
class WorkflowExecutionStatus(StrEnum):
# State diagram for the workflw status:
# (@) means start, (*) means end
#
# ┌------------------>------------------------->------------------->--------------┐
# | |
# | ┌-----------------------<--------------------┐ |
# ^ | | |
# | | ^ |
# | V | |
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
# | Scheduled |------->| Running |---------------------->| paused | |
# └-----------┘ └-----------------------┘ └-----------┘ |
# | | | | | | |
# | | | | | | |
# ^ | | | V V |
# | | | | | ┌---------┐ |
# (@) | | | └------------------------>| Stopped |<----┘
# | | | └---------┘
# | | | |
# | | V V
# | | ┌-----------┐ |
# | | | Succeeded |------------->--------------┤
# | | └-----------┘ |
# | V V
# | +--------┐ |
# | | Failed |---------------------->----------------┤
# | └--------┘ |
# V V
# ┌---------------------┐ |
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
# └---------------------┘
#
# Mermaid diagram:
#
# ---
# title: State diagram for Workflow run state
# ---
# stateDiagram-v2
# scheduled: Scheduled
# running: Running
# succeeded: Succeeded
# failed: Failed
# partial_succeeded: Partial Succeeded
# paused: Paused
# stopped: Stopped
#
# [*] --> scheduled:
# scheduled --> running: Start Execution
# running --> paused: Human input required
# paused --> running: human input added
# paused --> stopped: User stops execution
# running --> succeeded: Execution finishes without any error
# running --> failed: Execution finishes with errors
# running --> stopped: User stops execution
# running --> partial_succeeded: some execution occurred and handled during execution
#
# scheduled --> stopped: User stops execution
#
# succeeded --> [*]
# failed --> [*]
# partial_succeeded --> [*]
# stopped --> [*]
# `SCHEDULED` means that the workflow is scheduled to run, but has not
# started running yet. (maybe due to possible worker saturation.)
#
# This enum value is currently unused.
SCHEDULED = "scheduled"
# `RUNNING` means the workflow is exeuting.
RUNNING = "running"
# `SUCCEEDED` means the execution of workflow succeed without any error.
SUCCEEDED = "succeeded"
# `FAILED` means the execution of workflow failed without some errors.
FAILED = "failed"
# `STOPPED` means the execution of workflow was stopped, either manually
# by the user, or automatically by the Dify application (E.G. the moderation
# mechanism.)
STOPPED = "stopped"
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
# execution, but they were successfully handled (e.g., by using an error
# strategy such as "fail branch" or "default value").
PARTIAL_SUCCEEDED = "partial-succeeded"
# `PAUSED` indicates that the workflow execution is temporarily paused
# (e.g., awaiting human input) and is expected to resume later.
PAUSED = "paused"
def is_ended(self) -> bool:
return self in _END_STATE
@classmethod
def ended_values(cls) -> list[str]:
return [status.value for status in _END_STATE]
_END_STATE = frozenset(
[
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
WorkflowExecutionStatus.STOPPED,
]
)
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
Values in this enum are persisted as execution metadata and must stay in sync
with every node that writes `NodeRunResult.metadata`.
"""
TOTAL_TOKENS = "total_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
TRIGGER_INFO = "trigger_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node
class WorkflowNodeExecutionStatus(StrEnum):
PENDING = "pending" # Node is scheduled but not yet executing
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
STOPPED = "stopped"
PAUSED = "paused"
# Legacy statuses - kept for backward compatibility
RETRY = "retry" # Legacy: replaced by retry mechanism in error handling

16
api/graphon/errors.py Normal file
View File

@ -0,0 +1,16 @@
from graphon.nodes.base.node import Node
class WorkflowNodeRunFailedError(Exception):
def __init__(self, node: Node, err_msg: str):
self._node = node
self._error = err_msg
super().__init__(f"Node {node.title} run failed: {err_msg}")
@property
def node(self) -> Node:
return self._node
@property
def error(self) -> str:
return self._error

View File

@ -0,0 +1,22 @@
from .constants import FILE_MODEL_IDENTITY
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
from .file_factory import get_file_type_by_mime_type, standardize_file_type
from .models import (
File,
FileUploadConfig,
ImageConfig,
)
__all__ = [
"FILE_MODEL_IDENTITY",
"ArrayFileAttribute",
"File",
"FileAttribute",
"FileBelongsTo",
"FileTransferMethod",
"FileType",
"FileUploadConfig",
"ImageConfig",
"get_file_type_by_mime_type",
"standardize_file_type",
]

View File

@ -0,0 +1,48 @@
from collections.abc import Iterable
from typing import Any
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
# comparing `dify_model_identity` with constants throughout the codebase, extract
# this logic into a dedicated function. This would encapsulate the implementation
# details of how different variable types are identified.
FILE_MODEL_IDENTITY = "__dify__file__"
DEFAULT_MIME_TYPE = "application/octet-stream"
DEFAULT_EXTENSION = ".bin"
def _with_case_variants(extensions: Iterable[str]) -> frozenset[str]:
normalized = {extension.lower() for extension in extensions}
return frozenset(normalized | {extension.upper() for extension in normalized})
IMAGE_EXTENSIONS = _with_case_variants({"jpg", "jpeg", "png", "webp", "gif", "svg"})
VIDEO_EXTENSIONS = _with_case_variants({"mp4", "mov", "mpeg", "webm"})
AUDIO_EXTENSIONS = _with_case_variants({"mp3", "m4a", "wav", "amr", "mpga"})
DOCUMENT_EXTENSIONS = _with_case_variants(
{
"txt",
"markdown",
"md",
"mdx",
"pdf",
"html",
"htm",
"xlsx",
"xls",
"vtt",
"properties",
"doc",
"docx",
"csv",
"eml",
"msg",
"ppt",
"pptx",
"xml",
"epub",
}
)
def maybe_file_object(o: Any) -> bool:
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY

57
api/graphon/file/enums.py Normal file
View File

@ -0,0 +1,57 @@
from enum import StrEnum
class FileType(StrEnum):
IMAGE = "image"
DOCUMENT = "document"
AUDIO = "audio"
VIDEO = "video"
CUSTOM = "custom"
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(StrEnum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
DATASOURCE_FILE = "datasource_file"
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(StrEnum):
USER = "user"
ASSISTANT = "assistant"
@staticmethod
def value_of(value):
for member in FileBelongsTo:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileAttribute(StrEnum):
TYPE = "type"
SIZE = "size"
NAME = "name"
MIME_TYPE = "mime_type"
TRANSFER_METHOD = "transfer_method"
URL = "url"
EXTENSION = "extension"
RELATED_ID = "related_id"
class ArrayFileAttribute(StrEnum):
LENGTH = "length"

View File

@ -0,0 +1,39 @@
from .constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from .enums import FileType
def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType:
"""
Infer the actual file type from extension and mime type.
"""
guessed_type = None
if extension:
guessed_type = _get_file_type_by_extension(extension)
if guessed_type is None and mime_type:
guessed_type = get_file_type_by_mime_type(mime_type)
return guessed_type or FileType.CUSTOM
def _get_file_type_by_extension(extension: str) -> FileType | None:
normalized_extension = extension.lstrip(".")
if normalized_extension in IMAGE_EXTENSIONS:
return FileType.IMAGE
if normalized_extension in VIDEO_EXTENSIONS:
return FileType.VIDEO
if normalized_extension in AUDIO_EXTENSIONS:
return FileType.AUDIO
if normalized_extension in DOCUMENT_EXTENSIONS:
return FileType.DOCUMENT
return None
def get_file_type_by_mime_type(mime_type: str) -> FileType:
if "image" in mime_type:
return FileType.IMAGE
if "video" in mime_type:
return FileType.VIDEO
if "audio" in mime_type:
return FileType.AUDIO
if "text" in mime_type or "pdf" in mime_type:
return FileType.DOCUMENT
return FileType.CUSTOM

View File

@ -0,0 +1,129 @@
from __future__ import annotations
import base64
from collections.abc import Mapping
from graphon.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
TextPromptMessageContent,
VideoPromptMessageContent,
)
from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
from .runtime import get_workflow_file_runtime
def get_attr(*, file: File, attr: FileAttribute):
match attr:
case FileAttribute.TYPE:
return file.type.value
case FileAttribute.SIZE:
return file.size
case FileAttribute.NAME:
return file.filename
case FileAttribute.MIME_TYPE:
return file.mime_type
case FileAttribute.TRANSFER_METHOD:
return file.transfer_method.value
case FileAttribute.URL:
return _to_url(file)
case FileAttribute.EXTENSION:
return file.extension
case FileAttribute.RELATED_ID:
return file.related_id
def to_prompt_message_content(
f: File,
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> PromptMessageContentUnionTypes:
"""Convert a file to prompt message content."""
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}
if f.type not in prompt_class_map:
return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
send_format = get_workflow_file_runtime().multimodal_send_format
params = {
"base64_data": _get_encoded_string(f) if send_format == "base64" else "",
"url": _to_url(f) if send_format == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
"filename": f.filename or "",
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
return prompt_class_map[f.type].model_validate(params)
def download(f: File, /) -> bytes:
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.DATASOURCE_FILE,
):
return _download_file_content(f)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
def _download_file_content(file: File, /) -> bytes:
"""Download and return a file from storage as bytes."""
return get_workflow_file_runtime().load_file_bytes(file=file)
def _get_encoded_string(f: File, /) -> str:
match f.transfer_method:
case FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
raise ValueError("Missing file remote_url")
response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True)
response.raise_for_status()
data = response.content
case FileTransferMethod.LOCAL_FILE:
data = _download_file_content(f)
case FileTransferMethod.TOOL_FILE:
data = _download_file_content(f)
case FileTransferMethod.DATASOURCE_FILE:
data = _download_file_content(f)
return base64.b64encode(data).decode("utf-8")
def _to_url(f: File, /):
url = f.generate_url()
if url is None:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
return url
class FileManager:
"""Adapter exposing file manager helpers behind FileManagerProtocol."""
def download(self, f: File, /) -> bytes:
return download(f)
file_manager = FileManager()

View File

@ -0,0 +1,48 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from .runtime import get_workflow_file_runtime
if TYPE_CHECKING:
from .models import File
def resolve_file_url(file: File, /, *, for_external: bool = True) -> str | None:
return get_workflow_file_runtime().resolve_file_url(file=file, for_external=for_external)
def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str:
return get_workflow_file_runtime().resolve_upload_file_url(
upload_file_id=upload_file_id,
as_attachment=as_attachment,
for_external=for_external,
)
def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str:
return get_workflow_file_runtime().resolve_tool_file_url(
tool_file_id=tool_file_id,
extension=extension,
for_external=for_external,
)
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
return get_workflow_file_runtime().verify_preview_signature(
preview_kind="image",
file_id=upload_file_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
return get_workflow_file_runtime().verify_preview_signature(
preview_kind="file",
file_id=upload_file_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)

215
api/graphon/file/models.py Normal file
View File

@ -0,0 +1,215 @@
from __future__ import annotations
import base64
import json
from collections.abc import Mapping, Sequence
from typing import Any
from pydantic import BaseModel, Field, model_validator
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent
from . import helpers
from .constants import FILE_MODEL_IDENTITY
from .enums import FileTransferMethod, FileType
_FILE_REFERENCE_PREFIX = "dify-file-ref:"
def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str:
"""Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``."""
return helpers.get_signed_tool_file_url(
tool_file_id=tool_file_id,
extension=extension,
for_external=for_external,
)
class ImageConfig(BaseModel):
"""
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
"""
number_limits: int = 0
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
detail: ImagePromptMessageContent.DETAIL | None = None
class FileUploadConfig(BaseModel):
"""
File Upload Entity.
"""
image_config: ImageConfig | None = None
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
number_limits: int = 0
def _parse_reference(reference: str | None) -> tuple[str | None, str | None]:
"""Best-effort parser for record references and historical storage-key payloads."""
if not reference:
return None, None
if not reference.startswith(_FILE_REFERENCE_PREFIX):
return reference, None
encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX)
try:
payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode()))
except (ValueError, json.JSONDecodeError):
return reference, None
record_id = payload.get("record_id")
if not isinstance(record_id, str) or not record_id:
return reference, None
storage_key = payload.get("storage_key")
if not isinstance(storage_key, str):
storage_key = None
return record_id, storage_key
class File(BaseModel):
"""Graph-owned file reference.
The graph layer deliberately keeps only the metadata required to route,
serialize, and render files. Application ownership concerns such as
tenant/user/conversation identity stay in the workflow/storage layer.
"""
# NOTE: dify_model_identity is a special identifier used to distinguish between
# new and old data formats during serialization and deserialization.
dify_model_identity: str = FILE_MODEL_IDENTITY
id: str | None = None # message file id
type: FileType
transfer_method: FileTransferMethod
# If `transfer_method` is `FileTransferMethod.remote_url`, the
# `remote_url` attribute must not be `None`.
remote_url: str | None = None # remote url
# Opaque workflow-layer reference for files resolved outside ``graphon``.
# New payloads only carry the backing record id; historical payloads may
# still include storage_key and must remain readable.
reference: str | None = None
filename: str | None = None
extension: str | None = Field(default=None, description="File extension, should contain dot")
mime_type: str | None = None
size: int = -1
_storage_key: str
def __init__(
self,
*,
id: str | None = None,
tenant_id: str | None = None,
type: FileType,
transfer_method: FileTransferMethod,
remote_url: str | None = None,
reference: str | None = None,
related_id: str | None = None,
filename: str | None = None,
extension: str | None = None,
mime_type: str | None = None,
size: int = -1,
storage_key: str | None = None,
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
url: str | None = None,
# Legacy compatibility fields - explicitly accept known extra fields
tool_file_id: str | None = None,
upload_file_id: str | None = None,
datasource_file_id: str | None = None,
):
legacy_record_id = related_id or tool_file_id or upload_file_id or datasource_file_id
normalized_reference = reference
if normalized_reference is None and legacy_record_id is not None:
normalized_reference = str(legacy_record_id)
_, parsed_storage_key = _parse_reference(normalized_reference)
super().__init__(
id=id,
type=type,
transfer_method=transfer_method,
remote_url=remote_url,
reference=normalized_reference,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
dify_model_identity=dify_model_identity,
url=url,
)
# Accept legacy constructor fields without promoting them back into the graph model.
_ = tenant_id
self._storage_key = storage_key or parsed_storage_key or ""
def to_dict(self) -> Mapping[str, str | int | None]:
data = self.model_dump(mode="json")
return {
**data,
"related_id": self.related_id,
"url": self.generate_url(),
}
@property
def markdown(self) -> str:
url = self.generate_url()
if self.type == FileType.IMAGE:
text = f"![{self.filename or ''}]({url})"
else:
text = f"[{self.filename or url}]({url})"
return text
def generate_url(self, for_external: bool = True) -> str | None:
return helpers.resolve_file_url(self, for_external=for_external)
def to_plugin_parameter(self) -> dict[str, Any]:
return {
"dify_model_identity": FILE_MODEL_IDENTITY,
"mime_type": self.mime_type,
"filename": self.filename,
"extension": self.extension,
"size": self.size,
"type": self.type,
"url": self.generate_url(for_external=False),
}
@model_validator(mode="after")
def validate_after(self) -> File:
match self.transfer_method:
case FileTransferMethod.REMOTE_URL:
if not self.remote_url:
raise ValueError("Missing file url")
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
raise ValueError("Invalid file url")
case FileTransferMethod.LOCAL_FILE:
if not self.reference:
raise ValueError("Missing file reference")
case FileTransferMethod.TOOL_FILE:
if not self.reference:
raise ValueError("Missing file reference")
case FileTransferMethod.DATASOURCE_FILE:
if not self.reference:
raise ValueError("Missing file reference")
return self
@property
def related_id(self) -> str | None:
record_id, _ = _parse_reference(self.reference)
return record_id
@related_id.setter
def related_id(self, value: str | None) -> None:
self.reference = value
@property
def storage_key(self) -> str:
_, storage_key = _parse_reference(self.reference)
return storage_key or self._storage_key
@storage_key.setter
def storage_key(self, value: str) -> None:
self._storage_key = value

View File

@ -0,0 +1,56 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal, Protocol
if TYPE_CHECKING:
from .models import File
class HttpResponseProtocol(Protocol):
"""Subset of response behavior needed by workflow file helpers."""
@property
def content(self) -> bytes: ...
def raise_for_status(self) -> object: ...
class WorkflowFileRuntimeProtocol(Protocol):
"""Runtime dependencies required by ``graphon.file``.
Implementations are expected to be provided by integration layers (for example,
``core.app.workflow.file_runtime``) so the workflow package avoids importing
application infrastructure modules directly.
"""
@property
def multimodal_send_format(self) -> str: ...
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ...
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ...
def load_file_bytes(self, *, file: File) -> bytes: ...
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: ...
def resolve_upload_file_url(
self,
*,
upload_file_id: str,
as_attachment: bool = False,
for_external: bool = True,
) -> str: ...
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ...
def verify_preview_signature(
self,
*,
preview_kind: Literal["image", "file"],
file_id: str,
timestamp: str,
nonce: str,
sign: str,
) -> bool: ...

View File

@ -0,0 +1,71 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal, NoReturn
from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
if TYPE_CHECKING:
from .models import File
class WorkflowFileRuntimeNotConfiguredError(RuntimeError):
"""Raised when workflow file runtime dependencies were not configured."""
class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
def _raise(self) -> NoReturn:
raise WorkflowFileRuntimeNotConfiguredError(
"workflow file runtime is not configured, call set_workflow_file_runtime(...) first"
)
@property
def multimodal_send_format(self) -> str:
self._raise()
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
self._raise()
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
self._raise()
def load_file_bytes(self, *, file: File) -> bytes:
self._raise()
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
self._raise()
def resolve_upload_file_url(
self,
*,
upload_file_id: str,
as_attachment: bool = False,
for_external: bool = True,
) -> str:
self._raise()
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
self._raise()
def verify_preview_signature(
self,
*,
preview_kind: Literal["image", "file"],
file_id: str,
timestamp: str,
nonce: str,
sign: str,
) -> bool:
self._raise()
_runtime: WorkflowFileRuntimeProtocol = _UnconfiguredWorkflowFileRuntime()
def set_workflow_file_runtime(runtime: WorkflowFileRuntimeProtocol) -> None:
global _runtime
_runtime = runtime
def get_workflow_file_runtime() -> WorkflowFileRuntimeProtocol:
return _runtime

View File

@ -0,0 +1,9 @@
from collections.abc import Callable
from typing import Any
_tool_file_manager_factory: Callable[[], Any] | None = None
def set_tool_file_manager_factory(factory: Callable[[], Any]):
global _tool_file_manager_factory
_tool_file_manager_factory = factory

View File

@ -0,0 +1,11 @@
from .edge import Edge
from .graph import Graph, GraphBuilder, NodeFactory
from .graph_template import GraphTemplate
__all__ = [
"Edge",
"Graph",
"GraphBuilder",
"GraphTemplate",
"NodeFactory",
]

15
api/graphon/graph/edge.py Normal file
View File

@ -0,0 +1,15 @@
import uuid
from dataclasses import dataclass, field
from graphon.enums import NodeState
@dataclass
class Edge:
"""Edge connecting two nodes in a workflow graph."""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
tail: str = "" # tail node id (source)
head: str = "" # head node id (target)
source_handle: str = "source" # source handle for conditional branching
state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state

438
api/graphon/graph/graph.py Normal file
View File

@ -0,0 +1,438 @@
from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from pydantic import TypeAdapter
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState
from graphon.nodes.base.node import Node
from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__)
_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict])
class NodeFactory(Protocol):
"""
Protocol for creating Node instances from node data dictionaries.
This protocol decouples the Graph class from specific node mapping implementations,
allowing for different node creation strategies while maintaining type safety.
"""
def create_node(self, node_config: NodeConfigDict) -> Node:
"""
Create a Node instance from node configuration data.
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or no implementation exists for the resolved version
:raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation
"""
...
@final
class Graph:
"""Graph representation with nodes and edges for workflow execution."""
def __init__(
self,
*,
nodes: dict[str, Node] | None = None,
edges: dict[str, Edge] | None = None,
in_edges: dict[str, list[str]] | None = None,
out_edges: dict[str, list[str]] | None = None,
root_node: Node,
):
"""
Initialize Graph instance.
:param nodes: graph nodes mapping (node id: node object)
:param edges: graph edges mapping (edge id: edge object)
:param in_edges: incoming edges mapping (node id: list of edge ids)
:param out_edges: outgoing edges mapping (node id: list of edge ids)
:param root_node: root node object
"""
self.nodes = nodes or {}
self.edges = edges or {}
self.in_edges = in_edges or {}
self.out_edges = out_edges or {}
self.root_node = root_node
@classmethod
def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]:
"""
Parse node configurations and build a mapping of node IDs to configs.
:param node_configs: list of node configuration dictionaries
:return: mapping of node ID to node config
"""
node_configs_map: dict[str, NodeConfigDict] = {}
for node_config in node_configs:
node_configs_map[node_config["id"]] = node_config
return node_configs_map
@classmethod
def _build_edges(
cls, edge_configs: list[dict[str, object]]
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
"""
Build edge objects and mappings from edge configurations.
:param edge_configs: list of edge configurations
:return: tuple of (edges dict, in_edges dict, out_edges dict)
"""
edges: dict[str, Edge] = {}
in_edges: dict[str, list[str]] = defaultdict(list)
out_edges: dict[str, list[str]] = defaultdict(list)
edge_counter = 0
for edge_config in edge_configs:
source = edge_config.get("source")
target = edge_config.get("target")
if not isinstance(source, str) or not isinstance(target, str):
continue
# Create edge
edge_id = f"edge_{edge_counter}"
edge_counter += 1
source_handle = edge_config.get("sourceHandle", "source")
if not isinstance(source_handle, str):
continue
edge = Edge(
id=edge_id,
tail=source,
head=target,
source_handle=source_handle,
)
edges[edge_id] = edge
out_edges[source].append(edge_id)
in_edges[target].append(edge_id)
return edges, dict(in_edges), dict(out_edges)
@classmethod
def _create_node_instances(
cls,
node_configs_map: dict[str, NodeConfigDict],
node_factory: NodeFactory,
) -> dict[str, Node]:
"""
Create node instances from configurations using the node factory.
:param node_configs_map: mapping of node ID to node config
:param node_factory: factory for creating node instances
:return: mapping of node ID to node instance
"""
nodes: dict[str, Node] = {}
for node_id, node_config in node_configs_map.items():
try:
node_instance = node_factory.create_node(node_config)
except Exception:
logger.exception("Failed to create node instance for node_id %s", node_id)
raise
nodes[node_id] = node_instance
return nodes
@classmethod
def new(cls) -> GraphBuilder:
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@staticmethod
def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]:
"""
Remove editor-only nodes before `NodeConfigDict` validation.
Persisted note widgets use a top-level `type == "custom-note"` but leave
`data.type` empty because they are never executable graph nodes. Filter
them while configs are still raw dicts so Pydantic does not validate
their placeholder payloads against `BaseNodeData.type: NodeType`.
"""
filtered_node_configs: list[dict[str, object]] = []
for node_config in node_configs:
if node_config.get("type", "") == "custom-note":
continue
filtered_node_configs.append(dict(node_config))
return filtered_node_configs
@classmethod
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
"""
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
:param nodes: mapping of node ID to node instance
"""
for node in nodes.values():
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
node.execution_type = NodeExecutionType.BRANCH
@classmethod
def _mark_inactive_root_branches(
cls,
nodes: dict[str, Node],
edges: dict[str, Edge],
in_edges: dict[str, list[str]],
out_edges: dict[str, list[str]],
active_root_id: str,
) -> None:
"""
Mark nodes and edges from inactive root branches as skipped.
Algorithm:
1. Mark inactive root nodes as skipped
2. For skipped nodes, mark all their outgoing edges as skipped
3. For each edge marked as skipped, check its target node:
- If ALL incoming edges are skipped, mark the node as skipped
- Otherwise, leave the node state unchanged
:param nodes: mapping of node ID to node instance
:param edges: mapping of edge ID to edge instance
:param in_edges: mapping of node ID to incoming edge IDs
:param out_edges: mapping of node ID to outgoing edge IDs
:param active_root_id: ID of the active root node
"""
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
top_level_roots: list[str] = [
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
]
# If there's only one root or the active root is not a top-level root, no marking needed
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
return
# Mark inactive root nodes as skipped
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
for root_id in inactive_roots:
if root_id in nodes:
nodes[root_id].state = NodeState.SKIPPED
# Recursively mark downstream nodes and edges
def mark_downstream(node_id: str) -> None:
"""Recursively mark downstream nodes and edges as skipped."""
if nodes[node_id].state != NodeState.SKIPPED:
return
# If this node is skipped, mark all its outgoing edges as skipped
out_edge_ids = out_edges.get(node_id, [])
for edge_id in out_edge_ids:
edge = edges[edge_id]
edge.state = NodeState.SKIPPED
# Check the target node of this edge
target_node = nodes[edge.head]
in_edge_ids = in_edges.get(target_node.id, [])
in_edge_states = [edges[eid].state for eid in in_edge_ids]
# If all incoming edges are skipped, mark the node as skipped
if all(state == NodeState.SKIPPED for state in in_edge_states):
target_node.state = NodeState.SKIPPED
# Recursively process downstream nodes
mark_downstream(target_node.id)
# Process each inactive root and its downstream nodes
for root_id in inactive_roots:
mark_downstream(root_id)
@classmethod
def init(
cls,
*,
graph_config: Mapping[str, object],
node_factory: NodeFactory,
root_node_id: str,
skip_validation: bool = False,
) -> Graph:
"""
Initialize a graph with an explicit execution entry point.
:param graph_config: graph config containing nodes and edges
:param node_factory: factory for creating node instances from config data
:param root_node_id: active root node id
:return: graph instance
"""
# Parse configs
edge_configs = graph_config.get("edges", [])
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs)
node_configs = cls._filter_canvas_only_nodes(node_configs)
node_configs = _ListNodeConfigDict.validate_python(node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)
if root_node_id not in node_configs_map:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Build edges
edges, in_edges, out_edges = cls._build_edges(edge_configs)
# Create node instances
nodes = cls._create_node_instances(node_configs_map, node_factory)
# Promote fail-branch nodes to branch execution type at graph level
cls._promote_fail_branch_nodes(nodes)
# Get root node instance
root_node = nodes[root_node_id]
# Mark inactive root branches as skipped
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
# Create and return the graph
graph = cls(
nodes=nodes,
edges=edges,
in_edges=in_edges,
out_edges=out_edges,
root_node=root_node,
)
if not skip_validation:
# Validate the graph structure using built-in validators
get_graph_validator().validate(graph)
return graph
@property
def node_ids(self) -> list[str]:
"""
Get list of node IDs (compatibility property for existing code)
:return: list of node IDs
"""
return list(self.nodes.keys())
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
"""
Get all outgoing edges from a node (V2 method)
:param node_id: node id
:return: list of outgoing edges
"""
edge_ids = self.out_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
def get_incoming_edges(self, node_id: str) -> list[Edge]:
"""
Get all incoming edges to a node (V2 method)
:param node_id: node id
:return: list of incoming edges
"""
edge_ids = self.in_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
@final
class GraphBuilder:
"""Fluent helper for constructing simple graphs, primarily for tests."""
def __init__(self, *, graph_cls: type[Graph]):
self._graph_cls = graph_cls
self._nodes: list[Node] = []
self._nodes_by_id: dict[str, Node] = {}
self._edges: list[Edge] = []
self._edge_counter = 0
def add_root(self, node: Node) -> GraphBuilder:
"""Register the root node. Must be called exactly once."""
if self._nodes:
raise ValueError("Root node has already been added")
self._register_node(node)
self._nodes.append(node)
return self
def add_node(
self,
node: Node,
*,
from_node_id: str | None = None,
source_handle: str = "source",
) -> GraphBuilder:
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
raise ValueError("Root node must be added before adding other nodes")
predecessor_id = from_node_id or self._nodes[-1].id
if predecessor_id not in self._nodes_by_id:
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
predecessor = self._nodes_by_id[predecessor_id]
self._register_node(node)
self._nodes.append(node)
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
self._edges.append(edge)
return self
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:
raise ValueError(f"Tail node '{tail}' not found")
if head not in self._nodes_by_id:
raise ValueError(f"Head node '{head}' not found")
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
self._edges.append(edge)
return self
def build(self) -> Graph:
"""Materialize the graph instance from the accumulated nodes and edges."""
if not self._nodes:
raise ValueError("Cannot build an empty graph")
nodes = {node.id: node for node in self._nodes}
edges = {edge.id: edge for edge in self._edges}
in_edges: dict[str, list[str]] = defaultdict(list)
out_edges: dict[str, list[str]] = defaultdict(list)
for edge in self._edges:
out_edges[edge.tail].append(edge.id)
in_edges[edge.head].append(edge.id)
return self._graph_cls(
nodes=nodes,
edges=edges,
in_edges=dict(in_edges),
out_edges=dict(out_edges),
root_node=self._nodes[0],
)
def _register_node(self, node: Node) -> None:
if not node.id:
raise ValueError("Node must have a non-empty id")
if node.id in self._nodes_by_id:
raise ValueError(f"Duplicate node id detected: {node.id}")
self._nodes_by_id[node.id] = node

View File

@ -0,0 +1,20 @@
from typing import Any
from pydantic import BaseModel, Field
class GraphTemplate(BaseModel):
"""
Graph Template for container nodes and subgraph expansion
According to GraphEngine V2 spec, GraphTemplate contains:
- nodes: mapping of node definitions
- edges: mapping of edge definitions
- root_ids: list of root node IDs
- output_selectors: list of output selectors for the template
"""
nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping")
edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping")
root_ids: list[str] = Field(default_factory=list, description="root node IDs")
output_selectors: list[str] = Field(default_factory=list, description="output selectors")

View File

@ -0,0 +1,125 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeType
if TYPE_CHECKING:
from .graph import Graph
@dataclass(frozen=True, slots=True)
class GraphValidationIssue:
"""Immutable value object describing a single validation issue."""
code: str
message: str
node_id: str | None = None
class GraphValidationError(ValueError):
"""Raised when graph validation fails."""
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
if not issues:
raise ValueError("GraphValidationError requires at least one issue.")
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
super().__init__(message)
class GraphValidationRule(Protocol):
"""Protocol that individual validation rules must satisfy."""
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
"""Validate the provided graph and return any discovered issues."""
...
@dataclass(frozen=True, slots=True)
class _EdgeEndpointValidator:
"""Ensures all edges reference existing nodes."""
missing_node_code: str = "MISSING_NODE"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
issues: list[GraphValidationIssue] = []
for edge in graph.edges.values():
if edge.tail not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
node_id=edge.tail,
)
)
if edge.head not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
node_id=edge.head,
)
)
return issues
@dataclass(frozen=True, slots=True)
class _RootNodeValidator:
"""Validates root node invariants."""
invalid_root_code: str = "INVALID_ROOT"
container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START)
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
root_node = graph.root_node
issues: list[GraphValidationIssue] = []
if root_node.id not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' is missing from the node registry.",
node_id=root_node.id,
)
)
return issues
node_type = root_node.node_type
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
node_id=root_node.id,
)
)
return issues
@dataclass(frozen=True, slots=True)
class GraphValidator:
"""Coordinates execution of graph validation rules."""
rules: tuple[GraphValidationRule, ...]
def validate(self, graph: Graph) -> None:
"""Validate the graph against all configured rules."""
issues: list[GraphValidationIssue] = []
for rule in self.rules:
issues.extend(rule.validate(graph))
if issues:
raise GraphValidationError(issues)
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
)
def get_graph_validator() -> GraphValidator:
"""Construct the validator composed of default rules."""
return GraphValidator(_DEFAULT_RULES)

View File

@ -0,0 +1,4 @@
from .config import GraphEngineConfig
from .graph_engine import GraphEngine
__all__ = ["GraphEngine", "GraphEngineConfig"]

View File

@ -0,0 +1,15 @@
import time
def get_timestamp() -> float:
"""Retrieve a timestamp as a float point numer representing the number of seconds
since the Unix epoch.
This function is primarily used to measure the execution time of the workflow engine.
Since workflow execution may be paused and resumed on a different machine,
`time.perf_counter` cannot be used as it is inconsistent across machines.
To address this, the function uses the wall clock as the time source.
However, it assumes that the clocks of all servers are properly synchronized.
"""
return round(time.time())

View File

@ -0,0 +1,33 @@
# Command Channels
Channel implementations for external workflow control.
## Components
### InMemoryChannel
Thread-safe in-memory queue for single-process deployments.
- `fetch_commands()` - Get pending commands
- `send_command()` - Add command to queue
### RedisChannel
Redis-based queue for distributed deployments.
- `fetch_commands()` - Get commands with JSON deserialization
- `send_command()` - Store commands with TTL
## Usage
```python
# Local execution
channel = InMemoryChannel()
channel.send_command(AbortCommand(graph_id="workflow-123"))
# Distributed execution
redis_channel = RedisChannel(
redis_client=redis_client,
channel_key="workflow:123:commands"
)
```

View File

@ -0,0 +1,6 @@
"""Command channel implementations for GraphEngine."""
from .in_memory_channel import InMemoryChannel
from .redis_channel import RedisChannel
__all__ = ["InMemoryChannel", "RedisChannel"]

View File

@ -0,0 +1,53 @@
"""
In-memory implementation of CommandChannel for local/testing scenarios.
This implementation uses a thread-safe queue for command communication
within a single process. Each instance handles commands for one workflow execution.
"""
from queue import Queue
from typing import final
from ..entities.commands import GraphEngineCommand
@final
class InMemoryChannel:
"""
In-memory command channel implementation using a thread-safe queue.
Each instance is dedicated to a single GraphEngine/workflow execution.
Suitable for local development, testing, and single-instance deployments.
"""
def __init__(self) -> None:
"""Initialize the in-memory channel with a single queue."""
self._queue: Queue[GraphEngineCommand] = Queue()
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch all pending commands from the queue.
Returns:
List of pending commands (drains the queue)
"""
commands: list[GraphEngineCommand] = []
# Drain all available commands from the queue
while not self._queue.empty():
try:
command = self._queue.get_nowait()
commands.append(command)
except Exception:
break
return commands
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to this channel's queue.
Args:
command: The command to send
"""
self._queue.put(command)

View File

@ -0,0 +1,153 @@
"""
Redis-based implementation of CommandChannel for distributed scenarios.
This implementation uses Redis lists for command queuing, supporting
multi-instance deployments and cross-server communication.
Each instance uses a unique key for its command queue.
"""
import json
from contextlib import AbstractContextManager
from typing import Any, Protocol, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
class RedisPipelineProtocol(Protocol):
"""Minimal Redis pipeline contract used by the command channel."""
def lrange(self, name: str, start: int, end: int) -> Any: ...
def delete(self, *names: str) -> Any: ...
def execute(self) -> list[Any]: ...
def rpush(self, name: str, *values: str) -> Any: ...
def expire(self, name: str, time: int) -> Any: ...
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
def get(self, name: str) -> Any: ...
class RedisClientProtocol(Protocol):
"""Redis client contract required by the command channel."""
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...
@final
class RedisChannel:
"""
Redis-based command channel implementation for distributed systems.
Each instance uses a unique Redis key for its command queue.
Commands are JSON-serialized for transport.
"""
def __init__(
self,
redis_client: RedisClientProtocol,
channel_key: str,
command_ttl: int = 3600,
) -> None:
"""
Initialize the Redis channel.
Args:
redis_client: Redis client instance
channel_key: Unique key for this channel's command queue
command_ttl: TTL for command keys in seconds (default: 3600)
"""
self._redis = redis_client
self._key = channel_key
self._command_ttl = command_ttl
self._pending_key = f"{channel_key}:pending"
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch all pending commands from Redis.
Returns:
List of pending commands (drains the Redis list)
"""
if not self._has_pending_commands():
return []
commands: list[GraphEngineCommand] = []
# Use pipeline for atomic operations
with self._redis.pipeline() as pipe:
# Get all commands and clear the list atomically
pipe.lrange(self._key, 0, -1)
pipe.delete(self._key)
results = pipe.execute()
# Parse commands from JSON
if results[0]:
for command_json in results[0]:
try:
command_data = json.loads(command_json)
command = self._deserialize_command(command_data)
if command:
commands.append(command)
except (json.JSONDecodeError, ValueError):
# Skip invalid commands
continue
return commands
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to Redis.
Args:
command: The command to send
"""
command_json = json.dumps(command.model_dump())
# Push to list and set expiry
with self._redis.pipeline() as pipe:
pipe.rpush(self._key, command_json)
pipe.expire(self._key, self._command_ttl)
pipe.set(self._pending_key, "1", ex=self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
"""
Deserialize a command from dictionary data.
Args:
data: Command data dictionary
Returns:
Deserialized command or None if invalid
"""
command_type_value = data.get("command_type")
if not isinstance(command_type_value, str):
return None
try:
command_type = CommandType(command_type_value)
if command_type == CommandType.ABORT:
return AbortCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
if command_type == CommandType.UPDATE_VARIABLES:
return UpdateVariablesCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)
except (ValueError, TypeError):
return None
def _has_pending_commands(self) -> bool:
"""
Check and consume the pending marker to avoid unnecessary list reads.
Returns:
True if commands should be fetched from Redis.
"""
with self._redis.pipeline() as pipe:
pipe.get(self._pending_key)
pipe.delete(self._pending_key)
pending_value, _ = pipe.execute()
return pending_value is not None

View File

@ -0,0 +1,16 @@
"""
Command processing subsystem for graph engine.
This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
"UpdateVariablesCommandHandler",
]

View File

@ -0,0 +1,56 @@
import logging
from typing import final
from typing_extensions import override
from graphon.entities.pause_reason import SchedulingPause
from graphon.runtime import VariablePool
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@final
class AbortCommandHandler(CommandHandler):
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, AbortCommand)
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
execution.abort(command.reason or "User requested abort")
@final
class PauseCommandHandler(CommandHandler):
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, PauseCommand)
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
# Convert string reason to PauseReason if needed
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)
@final
class UpdateVariablesCommandHandler(CommandHandler):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, UpdateVariablesCommand)
for update in command.updates:
try:
variable = update.value
self._variable_pool.add(variable.selector, variable)
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
except ValueError as exc:
logger.warning(
"Skipping invalid variable selector %s for workflow %s: %s",
getattr(update.value, "selector", None),
execution.workflow_id,
exc,
)

View File

@ -0,0 +1,79 @@
"""
Main command processor for handling external commands.
"""
import logging
from typing import Protocol, final
from ..domain.graph_execution import GraphExecution
from ..entities.commands import GraphEngineCommand
from ..protocols.command_channel import CommandChannel
logger = logging.getLogger(__name__)
class CommandHandler(Protocol):
"""Protocol for command handlers."""
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
@final
class CommandProcessor:
"""
Processes external commands sent to the engine.
This polls the command channel and dispatches commands to
appropriate handlers.
"""
def __init__(
self,
command_channel: CommandChannel,
graph_execution: GraphExecution,
) -> None:
"""
Initialize the command processor.
Args:
command_channel: Channel for receiving commands
graph_execution: Graph execution aggregate
"""
self._command_channel = command_channel
self._graph_execution = graph_execution
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
"""
Register a handler for a command type.
Args:
command_type: Type of command to handle
handler: Handler for the command
"""
self._handlers[command_type] = handler
def process_commands(self) -> None:
"""Check for and process any pending commands."""
try:
commands = self._command_channel.fetch_commands()
for command in commands:
self._handle_command(command)
except Exception as e:
logger.warning("Error processing commands: %s", e)
def _handle_command(self, command: GraphEngineCommand) -> None:
"""
Handle a single command.
Args:
command: The command to handle
"""
handler = self._handlers.get(type(command))
if handler:
try:
handler.handle(command, self._graph_execution)
except Exception:
logger.exception("Error handling command %s", command.__class__.__name__)
else:
logger.warning("No handler registered for command: %s", command.__class__.__name__)

View File

@ -0,0 +1,16 @@
"""
GraphEngine configuration models.
"""
from pydantic import BaseModel, ConfigDict
class GraphEngineConfig(BaseModel):
"""Configuration for GraphEngine worker pool scaling."""
model_config = ConfigDict(frozen=True)
min_workers: int = 1
max_workers: int = 5
scale_up_threshold: int = 3
scale_down_idle_time: float = 5.0

View File

@ -0,0 +1,14 @@
"""
Domain models for graph engine.
This package contains the core domain entities, value objects, and aggregates
that represent the business concepts of workflow graph execution.
"""
from .graph_execution import GraphExecution
from .node_execution import NodeExecution
__all__ = [
"GraphExecution",
"NodeExecution",
]

View File

@ -0,0 +1,242 @@
"""GraphExecution aggregate root managing the overall graph execution state."""
from __future__ import annotations
from dataclasses import dataclass, field
from importlib import import_module
from typing import Literal
from pydantic import BaseModel, Field
from graphon.entities.pause_reason import PauseReason
from graphon.enums import NodeState
from graphon.runtime.graph_runtime_state import GraphExecutionProtocol
from .node_execution import NodeExecution
class GraphExecutionErrorState(BaseModel):
"""Serializable representation of an execution error."""
module: str = Field(description="Module containing the exception class")
qualname: str = Field(description="Qualified name of the exception class")
message: str | None = Field(default=None, description="Exception message string")
class NodeExecutionState(BaseModel):
"""Serializable representation of a node execution entity."""
node_id: str
state: NodeState = Field(default=NodeState.UNKNOWN)
retry_count: int = Field(default=0)
execution_id: str | None = Field(default=None)
error: str | None = Field(default=None)
class GraphExecutionState(BaseModel):
"""Pydantic model describing serialized GraphExecution state."""
type: Literal["GraphExecution"] = Field(default="GraphExecution")
version: str = Field(default="1.0")
workflow_id: str
started: bool = Field(default=False)
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
paused: bool = Field(default=False)
pause_reasons: list[PauseReason] = Field(default_factory=list)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
"""Convert an exception into its serializable representation."""
if error is None:
return None
return GraphExecutionErrorState(
module=error.__class__.__module__,
qualname=error.__class__.__qualname__,
message=str(error),
)
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
"""Locate an exception class from its module and qualified name."""
module = import_module(module_name)
attr: object = module
for part in qualname.split("."):
attr = getattr(attr, part)
if isinstance(attr, type) and issubclass(attr, Exception):
return attr
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
"""Reconstruct an exception instance from serialized data."""
if state is None:
return None
try:
exception_class = _resolve_exception_class(state.module, state.qualname)
if state.message is None:
return exception_class()
return exception_class(state.message)
except Exception:
# Fallback to RuntimeError when reconstruction fails
if state.message is None:
return RuntimeError(state.qualname)
return RuntimeError(state.message)
@dataclass
class GraphExecution:
"""
Aggregate root for graph execution.
This manages the overall execution state of a workflow graph,
coordinating between multiple node executions.
"""
workflow_id: str
started: bool = False
completed: bool = False
aborted: bool = False
paused: bool = False
pause_reasons: list[PauseReason] = field(default_factory=list)
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
def start(self) -> None:
"""Mark the graph execution as started."""
if self.started:
raise RuntimeError("Graph execution already started")
self.started = True
def complete(self) -> None:
"""Mark the graph execution as completed."""
if not self.started:
raise RuntimeError("Cannot complete execution that hasn't started")
if self.completed:
raise RuntimeError("Graph execution already completed")
self.completed = True
def abort(self, reason: str) -> None:
"""Abort the graph execution."""
self.aborted = True
self.error = RuntimeError(f"Aborted: {reason}")
def pause(self, reason: PauseReason) -> None:
"""Pause the graph execution without marking it complete."""
if self.completed:
raise RuntimeError("Cannot pause execution that has completed")
if self.aborted:
raise RuntimeError("Cannot pause execution that has been aborted")
self.paused = True
self.pause_reasons.append(reason)
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
self.error = error
self.completed = True
def get_or_create_node_execution(self, node_id: str) -> NodeExecution:
"""Get or create a node execution entity."""
if node_id not in self.node_executions:
self.node_executions[node_id] = NodeExecution(node_id=node_id)
return self.node_executions[node_id]
@property
def is_running(self) -> bool:
"""Check if the execution is currently running."""
return self.started and not self.completed and not self.aborted and not self.paused
@property
def is_paused(self) -> bool:
"""Check if the execution is currently paused."""
return self.paused
@property
def has_error(self) -> bool:
"""Check if the execution has encountered an error."""
return self.error is not None
@property
def error_message(self) -> str | None:
"""Get the error message if an error exists."""
if not self.error:
return None
return str(self.error)
def dumps(self) -> str:
"""Serialize the aggregate state into a JSON string."""
node_states = [
NodeExecutionState(
node_id=node_id,
state=node_execution.state,
retry_count=node_execution.retry_count,
execution_id=node_execution.execution_id,
error=node_execution.error,
)
for node_id, node_execution in sorted(self.node_executions.items())
]
state = GraphExecutionState(
workflow_id=self.workflow_id,
started=self.started,
completed=self.completed,
aborted=self.aborted,
paused=self.paused,
pause_reasons=self.pause_reasons,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""Restore aggregate state from a serialized JSON string."""
state = GraphExecutionState.model_validate_json(data)
if state.type != "GraphExecution":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported serialized version: {state.version}")
if self.workflow_id != state.workflow_id:
raise ValueError("Serialized workflow_id does not match aggregate identity")
self.started = state.started
self.completed = state.completed
self.aborted = state.aborted
self.paused = state.paused
self.pause_reasons = state.pause_reasons
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {
item.node_id: NodeExecution(
node_id=item.node_id,
state=item.state,
retry_count=item.retry_count,
execution_id=item.execution_id,
error=item.error,
)
for item in state.node_executions
}
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1
_: GraphExecutionProtocol = GraphExecution(workflow_id="")

View File

@ -0,0 +1,45 @@
"""
NodeExecution entity representing a node's execution state.
"""
from dataclasses import dataclass
from graphon.enums import NodeState
@dataclass
class NodeExecution:
"""
Entity representing the execution state of a single node.
This is a mutable entity that tracks the runtime state of a node
during graph execution.
"""
node_id: str
state: NodeState = NodeState.UNKNOWN
retry_count: int = 0
execution_id: str | None = None
error: str | None = None
def mark_started(self, execution_id: str) -> None:
"""Mark the node as started with an execution ID."""
self.state = NodeState.TAKEN
self.execution_id = execution_id
def mark_taken(self) -> None:
"""Mark the node as successfully completed."""
self.state = NodeState.TAKEN
self.error = None
def mark_failed(self, error: str) -> None:
"""Mark the node as failed with an error."""
self.error = error
def mark_skipped(self) -> None:
"""Mark the node as skipped."""
self.state = NodeState.SKIPPED
def increment_retry(self) -> None:
"""Increment the retry count for this node."""
self.retry_count += 1

View File

@ -0,0 +1,56 @@
"""
GraphEngine command entities for external control.
This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from graphon.variables.variables import Variable
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = auto()
PAUSE = auto()
UPDATE_VARIABLES = auto()
class GraphEngineCommand(BaseModel):
"""Base class for all GraphEngine commands."""
command_type: CommandType = Field(..., description="Type of command")
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")
class AbortCommand(GraphEngineCommand):
"""Command to abort a running workflow execution."""
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
reason: str | None = Field(default=None, description="Optional reason for abort")
class PauseCommand(GraphEngineCommand):
"""Command to pause a running workflow execution."""
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str = Field(default="unknown reason", description="reason for pause")
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
"""Command to update a group of variables in the variable pool."""
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")

View File

@ -0,0 +1,213 @@
"""
Main error handler that coordinates error strategies.
"""
import logging
import time
from typing import TYPE_CHECKING, final
from graphon.enums import (
ErrorStrategy as ErrorStrategyEnum,
)
from graphon.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from graphon.graph import Graph
from graphon.graph_events import (
GraphNodeEventBase,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetryEvent,
)
from graphon.node_events import NodeRunResult
if TYPE_CHECKING:
from .domain import GraphExecution
logger = logging.getLogger(__name__)
@final
class ErrorHandler:
"""
Coordinates error handling strategies for node failures.
This acts as a facade for the various error strategies,
selecting and applying the appropriate strategy based on
node configuration.
"""
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
"""
Initialize the error handler.
Args:
graph: The workflow graph
graph_execution: The graph execution state
"""
self._graph = graph
self._graph_execution = graph_execution
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
"""
Handle a node failure event.
Selects and applies the appropriate error strategy based on
the node's configuration.
Args:
event: The node failure event
Returns:
Optional new event to process, or None to abort
"""
node = self._graph.nodes[event.node_id]
# Get retry count from NodeExecution
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
retry_count = node_execution.retry_count
# First check if retry is configured and not exhausted
if node.retry and retry_count < node.retry_config.max_retries:
result = self._handle_retry(event, retry_count)
if result:
# Retry count will be incremented when NodeRunRetryEvent is handled
return result
# Apply configured error strategy
strategy = node.error_strategy
match strategy:
case None:
return self._handle_abort(event)
case ErrorStrategyEnum.FAIL_BRANCH:
return self._handle_fail_branch(event)
case ErrorStrategyEnum.DEFAULT_VALUE:
return self._handle_default_value(event)
def _handle_abort(self, event: NodeRunFailedEvent):
"""
Handle error by aborting execution.
This is the default strategy when no other strategy is specified.
It stops the entire graph execution when a node fails.
Args:
event: The failure event
Returns:
None - signals abortion
"""
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
# Return None to signal that execution should stop
def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int):
"""
Handle error by retrying the node.
This strategy re-attempts node execution up to a configured
maximum number of retries with configurable intervals.
Args:
event: The failure event
retry_count: Current retry attempt count
Returns:
NodeRunRetryEvent if retry should occur, None otherwise
"""
node = self._graph.nodes[event.node_id]
# Check if we've exceeded max retries
if not node.retry or retry_count >= node.retry_config.max_retries:
return None
# Wait for retry interval
time.sleep(node.retry_config.retry_interval_seconds)
# Create retry event
return NodeRunRetryEvent(
id=event.id,
node_title=node.title,
node_id=event.node_id,
node_type=event.node_type,
node_run_result=event.node_run_result,
start_at=event.start_at,
error=event.error,
retry_index=retry_count + 1,
)
def _handle_fail_branch(self, event: NodeRunFailedEvent):
"""
Handle error by taking the fail branch.
This strategy converts failures to exceptions and routes execution
through a designated fail-branch edge.
Args:
event: The failure event
Returns:
NodeRunExceptionEvent to continue via fail branch
"""
outputs = {
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,
}
return NodeRunExceptionEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
edge_source_handle="fail-branch",
metadata={
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH,
},
),
error=event.error,
)
def _handle_default_value(self, event: NodeRunFailedEvent):
"""
Handle error by using default values.
This strategy allows nodes to fail gracefully by providing
predefined default output values.
Args:
event: The failure event
Returns:
NodeRunExceptionEvent with default values
"""
node = self._graph.nodes[event.node_id]
outputs = {
**node.default_value_dict,
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,
}
return NodeRunExceptionEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
finished_at=event.finished_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
metadata={
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE,
},
),
error=event.error,
)

View File

@ -0,0 +1,14 @@
"""
Event management subsystem for graph engine.
This package handles event routing, collection, and emission for
workflow graph execution events.
"""
from .event_handlers import EventHandler
from .event_manager import EventManager
__all__ = [
"EventHandler",
"EventManager",
]

View File

@ -0,0 +1,367 @@
"""
Event handler implementations for different event types.
"""
import logging
from collections.abc import Mapping
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState
from graphon.graph import Graph
from graphon.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
NodeRunVariableUpdatedEvent,
)
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.runtime import GraphRuntimeState
from ..domain.graph_execution import GraphExecution
from ..response_coordinator import ResponseStreamCoordinator
if TYPE_CHECKING:
from ..error_handler import ErrorHandler
from ..graph_state_manager import GraphStateManager
from ..graph_traversal import EdgeProcessor
from .event_manager import EventManager
logger = logging.getLogger(__name__)
@final
class EventHandler:
"""
Registry of event handlers for different event types.
This centralizes the business logic for handling specific events,
keeping it separate from the routing and collection infrastructure.
"""
def __init__(
self,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
graph_execution: GraphExecution,
response_coordinator: ResponseStreamCoordinator,
event_collector: "EventManager",
edge_processor: "EdgeProcessor",
state_manager: "GraphStateManager",
error_handler: "ErrorHandler",
) -> None:
"""
Initialize the event handler registry.
Args:
graph: The workflow graph
graph_runtime_state: Runtime state with variable pool
graph_execution: Graph execution aggregate
response_coordinator: Response stream coordinator
event_collector: Event manager for collecting events
edge_processor: Edge processor for edge traversal
state_manager: Unified state manager
error_handler: Error handler
"""
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._event_collector = event_collector
self._edge_processor = edge_processor
self._state_manager = state_manager
self._error_handler = error_handler
def dispatch(self, event: GraphNodeEventBase) -> None:
"""
Handle any node event by dispatching to the appropriate handler.
Args:
event: The event to handle
"""
if isinstance(event, NodeRunVariableUpdatedEvent):
self._dispatch(event)
return
# Events in loops or iterations are always collected
if event.in_loop_id or event.in_iteration_id:
self._event_collector.collect(event)
return
return self._dispatch(event)
@singledispatchmethod
def _dispatch(self, event: GraphNodeEventBase) -> None:
self._event_collector.collect(event)
logger.warning("Unhandled event type: %s", type(event).__name__)
@_dispatch.register(NodeRunIterationStartedEvent)
@_dispatch.register(NodeRunIterationNextEvent)
@_dispatch.register(NodeRunIterationSucceededEvent)
@_dispatch.register(NodeRunIterationFailedEvent)
@_dispatch.register(NodeRunLoopStartedEvent)
@_dispatch.register(NodeRunLoopNextEvent)
@_dispatch.register(NodeRunLoopSucceededEvent)
@_dispatch.register(NodeRunLoopFailedEvent)
@_dispatch.register(NodeRunAgentLogEvent)
@_dispatch.register(NodeRunRetrieverResourceEvent)
def _(self, event: GraphNodeEventBase) -> None:
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunStartedEvent) -> None:
"""
Handle node started event.
Args:
event: The node started event
"""
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id)
self._graph_runtime_state.increment_node_run_steps()
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event only for the first attempt; retries remain silent
if is_initial_attempt:
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunStreamChunkEvent) -> None:
"""
Handle stream chunk event with full processing.
Args:
event: The stream chunk event
"""
# Process with response coordinator
streaming_events = list(self._response_coordinator.intercept_event(event))
# Collect all events
for stream_event in streaming_events:
self._event_collector.collect(stream_event)
@_dispatch.register
def _(self, event: NodeRunVariableUpdatedEvent) -> None:
"""
Apply a node-requested variable mutation before downstream observers run.
The event is collected like other node events so parent/container engines can
forward the updated payload to outer layers, including persistence listeners.
"""
self._graph_runtime_state.variable_pool.add(event.variable.selector, event.variable)
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunSucceededEvent) -> None:
"""
Handle node success by coordinating subsystems.
This method coordinates between different subsystems to process
node completion, handle edges, and trigger downstream execution.
Args:
event: The node succeeded event
"""
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Store outputs in variable pool
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
# Forward to response coordinator and emit streaming events
streaming_events = self._response_coordinator.intercept_event(event)
for stream_event in streaming_events:
self._event_collector.collect(stream_event)
# Process edges and get ready nodes
node = self._graph.nodes[event.node_id]
if node.execution_type == NodeExecutionType.BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
# Collect streaming events from edge processing
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
# Enqueue ready nodes
if self._graph_execution.is_paused:
for node_id in ready_nodes:
self._graph_runtime_state.register_deferred_node(node_id)
else:
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update execution tracking
self._state_manager.finish_execution(event.node_id)
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)
# Collect the event
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunPauseRequestedEvent) -> None:
"""Handle pause requests emitted by nodes."""
pause_reason = event.reason
self._graph_execution.pause(pause_reason)
self._state_manager.finish_execution(event.node_id)
if event.node_id in self._graph.nodes:
self._graph.nodes[event.node_id].state = NodeState.UNKNOWN
self._graph_runtime_state.register_paused_node(event.node_id)
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunFailedEvent) -> None:
"""
Handle node failure using error handler.
Args:
event: The node failed event
"""
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()
self._accumulate_node_usage(event.node_run_result.llm_usage)
result = self._error_handler.handle_node_failure(event)
if result:
# Process the resulting event (retry, exception, etc.)
self.dispatch(result)
else:
# Abort execution
self._graph_execution.fail(RuntimeError(event.error))
self._event_collector.collect(event)
self._state_manager.finish_execution(event.node_id)
@_dispatch.register
def _(self, event: NodeRunExceptionEvent) -> None:
"""
Handle node exception event (fail-branch strategy).
Args:
event: The node exception event
"""
# Node continues via fail-branch/default-value, treat as completion
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
node = self._graph.nodes[event.node_id]
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update response outputs if applicable
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)
self._state_manager.finish_execution(event.node_id)
# Collect the exception event for observers
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunRetryEvent) -> None:
"""
Handle node retry event.
Args:
event: The node retry event
"""
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()
# Finish the previous attempt before re-queuing the node
self._state_manager.finish_execution(event.node_id)
# Emit retry event for observers
self._event_collector.collect(event)
# Re-queue node for execution
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)
def _accumulate_node_usage(self, usage: LLMUsage) -> None:
"""Accumulate token usage into the shared runtime state."""
if usage.total_tokens <= 0:
return
self._graph_runtime_state.add_tokens(usage.total_tokens)
current_usage = self._graph_runtime_state.llm_usage
if current_usage.total_tokens == 0:
self._graph_runtime_state.llm_usage = usage
else:
self._graph_runtime_state.llm_usage = current_usage.plus(usage)
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
"""
Store node outputs in the variable pool.
Args:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in outputs.items():
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
"""Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in outputs.items():
if key == "answer":
existing = self._graph_runtime_state.get_output("answer", "")
if existing:
self._graph_runtime_state.set_output("answer", f"{existing}{value}")
else:
self._graph_runtime_state.set_output("answer", value)
else:
self._graph_runtime_state.set_output(key, value)

View File

@ -0,0 +1,186 @@
"""
Unified event manager for collecting and emitting events.
"""
import logging
import threading
import time
from collections.abc import Generator
from contextlib import contextmanager
from typing import final
from graphon.graph_events import GraphEngineEvent
from ..layers.base import GraphEngineLayer
_logger = logging.getLogger(__name__)
@final
class ReadWriteLock:
"""
A read-write lock implementation that allows multiple concurrent readers
but only one writer at a time.
"""
def __init__(self) -> None:
self._read_ready = threading.Condition(threading.RLock())
self._readers = 0
def acquire_read(self) -> None:
"""Acquire a read lock."""
_ = self._read_ready.acquire()
try:
self._readers += 1
finally:
self._read_ready.release()
def release_read(self) -> None:
"""Release a read lock."""
_ = self._read_ready.acquire()
try:
self._readers -= 1
if self._readers == 0:
self._read_ready.notify_all()
finally:
self._read_ready.release()
def acquire_write(self) -> None:
"""Acquire a write lock."""
_ = self._read_ready.acquire()
while self._readers > 0:
_ = self._read_ready.wait()
def release_write(self) -> None:
"""Release a write lock."""
self._read_ready.release()
@contextmanager
def read_lock(self):
"""Return a context manager for read locking."""
self.acquire_read()
try:
yield
finally:
self.release_read()
@contextmanager
def write_lock(self):
"""Return a context manager for write locking."""
self.acquire_write()
try:
yield
finally:
self.release_write()
@final
class EventManager:
"""
Unified event manager that collects, buffers, and emits events.
This class combines event collection with event emission, providing
thread-safe event management with support for notifying layers and
streaming events to external consumers.
"""
def __init__(self) -> None:
"""Initialize the event manager."""
self._events: list[GraphEngineEvent] = []
self._lock = ReadWriteLock()
self._layers: list[GraphEngineLayer] = []
self._execution_complete = threading.Event()
def set_layers(self, layers: list[GraphEngineLayer]) -> None:
"""
Set the layers to notify on event collection.
Args:
layers: List of layers to notify
"""
self._layers = layers
def notify_layers(self, event: GraphEngineEvent) -> None:
"""Notify registered layers about an event without buffering it."""
self._notify_layers(event)
def collect(self, event: GraphEngineEvent) -> None:
"""
Thread-safe method to collect an event.
Args:
event: The event to collect
"""
with self._lock.write_lock():
self._events.append(event)
# NOTE: `_notify_layers` is intentionally called outside the critical section
# to minimize lock contention and avoid blocking other readers or writers.
#
# The public `notify_layers` method also does not use a write lock,
# so protecting `_notify_layers` with a lock here is unnecessary.
self._notify_layers(event)
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
"""
Get new events starting from a specific index.
Args:
start_index: The index to start from
Returns:
List of new events
"""
with self._lock.read_lock():
return list(self._events[start_index:])
def _event_count(self) -> int:
"""
Get the current count of collected events.
Returns:
Number of collected events
"""
with self._lock.read_lock():
return len(self._events)
def mark_complete(self) -> None:
"""Mark execution as complete to stop the event emission generator."""
self._execution_complete.set()
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
"""
Generator that yields events as they're collected.
Yields:
GraphEngineEvent instances as they're processed
"""
yielded_count = 0
while not self._execution_complete.is_set() or yielded_count < self._event_count():
# Get new events since last yield
new_events = self._get_new_events(yielded_count)
# Yield any new events
for event in new_events:
yield event
yielded_count += 1
# Small sleep to avoid busy waiting
if not self._execution_complete.is_set() and not new_events:
time.sleep(0.001)
def _notify_layers(self, event: GraphEngineEvent) -> None:
"""
Notify all layers of an event.
Layer exceptions are caught and logged to prevent disrupting collection.
Args:
event: The event to send to layers
"""
for layer in self._layers:
try:
layer.on_event(event)
except Exception:
_logger.exception("Error in layer on_event, layer_type=%s", type(layer))

View File

@ -0,0 +1,377 @@
"""
QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution.
This engine uses a modular architecture with separated packages following
Domain-Driven Design principles for improved maintainability and testability.
"""
from __future__ import annotations
import logging
import queue
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
from graphon.entities.workflow_start_reason import WorkflowStartReason
from graphon.enums import NodeExecutionType
from graphon.graph import Graph
from graphon.graph_events import (
GraphEngineEvent,
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
from graphon.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from graphon.runtime.graph_runtime_state import GraphProtocol
from .command_processing import (
AbortCommandHandler,
CommandProcessor,
PauseCommandHandler,
UpdateVariablesCommandHandler,
)
from .config import GraphEngineConfig
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
from .worker_management import WorkerPool
if TYPE_CHECKING:
from graphon.entities import GraphInitParams
from graphon.graph_engine.domain.graph_execution import GraphExecution
from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator
logger = logging.getLogger(__name__)
_DEFAULT_CONFIG = GraphEngineConfig()
@final
class GraphEngine:
"""
Queue-based graph execution engine.
Uses a modular architecture that delegates responsibilities to specialized
subsystems, following Domain-Driven Design and SOLID principles.
"""
def __init__(
self,
workflow_id: str,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
command_channel: CommandChannel,
config: GraphEngineConfig = _DEFAULT_CONFIG,
child_engine_builder: ChildGraphEngineBuilderProtocol | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
self._config = config
self._layers: list[GraphEngineLayer] = []
self._child_engine_builder = child_engine_builder
if child_engine_builder is not None:
self._graph_runtime_state.bind_child_engine_builder(child_engine_builder)
# Graph execution tracks the overall execution state
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
self._graph_execution.workflow_id = workflow_id
# === Execution Queues ===
self._ready_queue = self._graph_runtime_state.ready_queue
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
# === State Management ===
# Unified state manager handles all node state transitions and queue operations
self._state_manager = GraphStateManager(self._graph, self._ready_queue)
# === Response Coordination ===
# Coordinates response streaming from response nodes
self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator)
# === Event Management ===
# Event manager handles both collection and emission of events
self._event_manager = EventManager()
# === Error Handling ===
# Centralized error handler for graph execution errors
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
# === Graph Traversal Components ===
# Propagates skip status through the graph when conditions aren't met
self._skip_propagator = SkipPropagator(
graph=self._graph,
state_manager=self._state_manager,
)
# Processes edges to determine next nodes after execution
# Also handles conditional branching and route selection
self._edge_processor = EdgeProcessor(
graph=self._graph,
state_manager=self._state_manager,
response_coordinator=self._response_coordinator,
skip_propagator=self._skip_propagator,
)
# === Command Processing ===
# Processes external commands (e.g., abort requests)
self._command_processor = CommandProcessor(
command_channel=self._command_channel,
graph_execution=self._graph_execution,
)
# Register command handlers
abort_handler = AbortCommandHandler()
self._command_processor.register_handler(AbortCommand, abort_handler)
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
# === Worker Pool Setup ===
# Create worker pool for parallel node execution
self._worker_pool = WorkerPool(
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
execution_context=self._graph_runtime_state.execution_context,
config=self._config,
)
# === Orchestration ===
# Coordinates the overall execution lifecycle
self._execution_coordinator = ExecutionCoordinator(
graph_execution=self._graph_execution,
state_manager=self._state_manager,
command_processor=self._command_processor,
worker_pool=self._worker_pool,
)
# === Event Handler Registry ===
# Central registry for handling all node execution events
self._event_handler_registry = EventHandler(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_manager,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)
# Dispatches events and manages execution flow
self._dispatcher = Dispatcher(
event_queue=self._event_queue,
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
)
# === Validation ===
# Ensure all nodes share the same GraphRuntimeState instance
self._validate_graph_state_consistency()
def _validate_graph_state_consistency(self) -> None:
"""Validate that all nodes share the same GraphRuntimeState."""
expected_state_id = id(self._graph_runtime_state)
for node in self._graph.nodes.values():
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
def _bind_layer_context(
self,
layer: GraphEngineLayer,
) -> None:
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
"""Add a layer for extending functionality."""
self._layers.append(layer)
self._bind_layer_context(layer)
return self
def request_abort(self, reason: str | None = None) -> None:
"""Queue an abort command for this engine."""
self._command_channel.send_command(AbortCommand(reason=reason or "User requested abort"))
def create_child_engine(
self,
*,
workflow_id: str,
graph_init_params: GraphInitParams,
root_node_id: str,
variable_pool: VariablePool | None = None,
) -> GraphEngine:
return self._graph_runtime_state.create_child_engine(
workflow_id=workflow_id,
graph_init_params=graph_init_params,
root_node_id=root_node_id,
variable_pool=variable_pool,
)
def run(self) -> Generator[GraphEngineEvent, None, None]:
"""
Execute the graph using the modular architecture.
Returns:
Generator yielding GraphEngineEvent instances
"""
try:
# Initialize layers
self._initialize_layers()
is_resume = self._graph_execution.started
if not is_resume:
self._graph_execution.start()
else:
self._graph_execution.paused = False
self._graph_execution.pause_reasons = []
start_event = GraphRunStartedEvent(
reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL,
)
self._event_manager.notify_layers(start_event)
yield start_event
# Start subsystems
self._start_execution(resume=is_resume)
# Yield events as they occur
yield from self._event_manager.emit_events()
# Handle completion
if self._graph_execution.is_paused:
pause_reasons = self._graph_execution.pause_reasons
assert pause_reasons, "pause_reasons should not be empty when execution is paused."
# Ensure we have a valid PauseReason for the event
paused_event = GraphRunPausedEvent(
reasons=pause_reasons,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(paused_event)
yield paused_event
elif self._graph_execution.aborted:
abort_reason = "Workflow execution aborted by user command"
if self._graph_execution.error:
abort_reason = str(self._graph_execution.error)
aborted_event = GraphRunAbortedEvent(
reason=abort_reason,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(aborted_event)
yield aborted_event
elif self._graph_execution.has_error:
if self._graph_execution.error:
raise self._graph_execution.error
else:
outputs = self._graph_runtime_state.outputs
exceptions_count = self._graph_execution.exceptions_count
if exceptions_count > 0:
partial_event = GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
self._event_manager.notify_layers(partial_event)
yield partial_event
else:
succeeded_event = GraphRunSucceededEvent(
outputs=outputs,
)
self._event_manager.notify_layers(succeeded_event)
yield succeeded_event
except Exception as e:
failed_event = GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
self._event_manager.notify_layers(failed_event)
yield failed_event
raise
finally:
self._stop_execution()
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
for layer in self._layers:
try:
layer.on_graph_start()
except Exception:
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
deferred_nodes = self._graph_runtime_state.consume_deferred_nodes()
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()
# Register response nodes
for node in self._graph.nodes.values():
if node.execution_type == NodeExecutionType.RESPONSE:
self._response_coordinator.register(node.id)
if not resume:
# Enqueue root node
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
else:
seen_nodes: set[str] = set()
for node_id in paused_nodes + deferred_nodes:
if node_id in seen_nodes:
continue
seen_nodes.add(node_id)
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Start dispatcher
self._dispatcher.start()
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
# Notify layers
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)
except Exception:
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
# Public property accessors for attributes that need external access
@property
def graph_runtime_state(self) -> GraphRuntimeState:
"""Get the graph runtime state."""
return self._graph_runtime_state

View File

@ -0,0 +1,290 @@
"""
Graph state manager that combines node, edge, and execution tracking.
"""
import threading
from collections.abc import Sequence
from typing import TypedDict, final
from graphon.enums import NodeState
from graphon.graph import Edge, Graph
from .ready_queue import ReadyQueue
class EdgeStateAnalysis(TypedDict):
"""Analysis result for edge states."""
has_unknown: bool
has_taken: bool
all_skipped: bool
@final
class GraphStateManager:
def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None:
"""
Initialize the state manager.
Args:
graph: The workflow graph
ready_queue: Queue for nodes ready to execute
"""
self._graph = graph
self._ready_queue = ready_queue
self._lock = threading.RLock()
# Execution tracking state
self._executing_nodes: set[str] = set()
# ============= Node State Operations =============
def enqueue_node(self, node_id: str) -> None:
"""
Mark a node as TAKEN and add it to the ready queue.
This combines the state transition and enqueueing operations
that always occur together when preparing a node for execution.
Args:
node_id: The ID of the node to enqueue
"""
with self._lock:
self._graph.nodes[node_id].state = NodeState.TAKEN
self._ready_queue.put(node_id)
def mark_node_skipped(self, node_id: str) -> None:
"""
Mark a node as SKIPPED.
Args:
node_id: The ID of the node to skip
"""
with self._lock:
self._graph.nodes[node_id].state = NodeState.SKIPPED
def is_node_ready(self, node_id: str) -> bool:
"""
Check if a node is ready to be executed.
A node is ready when all its incoming edges from taken branches
have been satisfied.
Args:
node_id: The ID of the node to check
Returns:
True if the node is ready for execution
"""
with self._lock:
# Get all incoming edges to this node
incoming_edges = self._graph.get_incoming_edges(node_id)
# If no incoming edges, node is always ready
if not incoming_edges:
return True
# If any edge is UNKNOWN, node is not ready
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
return False
# Node is ready if at least one edge is TAKEN
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
def get_node_state(self, node_id: str) -> NodeState:
"""
Get the current state of a node.
Args:
node_id: The ID of the node
Returns:
The current node state
"""
with self._lock:
return self._graph.nodes[node_id].state
# ============= Edge State Operations =============
def mark_edge_taken(self, edge_id: str) -> None:
"""
Mark an edge as TAKEN.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self._graph.edges[edge_id].state = NodeState.TAKEN
def mark_edge_skipped(self, edge_id: str) -> None:
"""
Mark an edge as SKIPPED.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self._graph.edges[edge_id].state = NodeState.SKIPPED
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
"""
Analyze the states of edges and return summary flags.
Args:
edges: List of edges to analyze
Returns:
Analysis result with state flags
"""
with self._lock:
states = {edge.state for edge in edges}
return EdgeStateAnalysis(
has_unknown=NodeState.UNKNOWN in states,
has_taken=NodeState.TAKEN in states,
all_skipped=states == {NodeState.SKIPPED} if states else True,
)
def get_edge_state(self, edge_id: str) -> NodeState:
"""
Get the current state of an edge.
Args:
edge_id: The ID of the edge
Returns:
The current edge state
"""
with self._lock:
return self._graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
Categorize branch edges into selected and unselected.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
A tuple of (selected_edges, unselected_edges)
"""
with self._lock:
outgoing_edges = self._graph.get_outgoing_edges(node_id)
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []
for edge in outgoing_edges:
if edge.source_handle == selected_handle:
selected_edges.append(edge)
else:
unselected_edges.append(edge)
return selected_edges, unselected_edges
# ============= Execution Tracking Operations =============
def start_execution(self, node_id: str) -> None:
"""
Mark a node as executing.
Args:
node_id: The ID of the node starting execution
"""
with self._lock:
self._executing_nodes.add(node_id)
def finish_execution(self, node_id: str) -> None:
"""
Mark a node as no longer executing.
Args:
node_id: The ID of the node finishing execution
"""
with self._lock:
self._executing_nodes.discard(node_id)
def is_executing(self, node_id: str) -> bool:
"""
Check if a node is currently executing.
Args:
node_id: The ID of the node to check
Returns:
True if the node is executing
"""
with self._lock:
return node_id in self._executing_nodes
def get_executing_count(self) -> int:
"""
Get the count of currently executing nodes.
Returns:
Number of executing nodes
"""
# This count is a best-effort snapshot and can change concurrently.
# Only use it for pause-drain checks where scheduling is already frozen.
with self._lock:
return len(self._executing_nodes)
def get_executing_nodes(self) -> set[str]:
"""
Get a copy of the set of executing node IDs.
Returns:
Set of node IDs currently executing
"""
with self._lock:
return self._executing_nodes.copy()
def clear_executing(self) -> None:
"""Clear all executing nodes."""
with self._lock:
self._executing_nodes.clear()
# ============= Composite Operations =============
def is_execution_complete(self) -> bool:
"""
Check if graph execution is complete.
Execution is complete when:
- Ready queue is empty
- No nodes are executing
Returns:
True if execution is complete
"""
with self._lock:
return self._ready_queue.empty() and len(self._executing_nodes) == 0
def get_queue_depth(self) -> int:
"""
Get the current depth of the ready queue.
Returns:
Number of nodes in the ready queue
"""
return self._ready_queue.qsize()
def get_execution_stats(self) -> dict[str, int]:
"""
Get execution statistics.
Returns:
Dictionary with execution statistics
"""
with self._lock:
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
return {
"queue_depth": self._ready_queue.qsize(),
"executing": len(self._executing_nodes),
"taken_nodes": taken_nodes,
"skipped_nodes": skipped_nodes,
"unknown_nodes": unknown_nodes,
}

View File

@ -0,0 +1,14 @@
"""
Graph traversal subsystem for graph engine.
This package handles graph navigation, edge processing,
and skip propagation logic.
"""
from .edge_processor import EdgeProcessor
from .skip_propagator import SkipPropagator
__all__ = [
"EdgeProcessor",
"SkipPropagator",
]

View File

@ -0,0 +1,201 @@
"""
Edge processing logic for graph traversal.
"""
from collections.abc import Sequence
from typing import TYPE_CHECKING, final
from graphon.enums import NodeExecutionType
from graphon.graph import Edge, Graph
from graphon.graph_events import NodeRunStreamChunkEvent
from ..graph_state_manager import GraphStateManager
from ..response_coordinator import ResponseStreamCoordinator
if TYPE_CHECKING:
from .skip_propagator import SkipPropagator
@final
class EdgeProcessor:
"""
Processes edges during graph execution.
This handles marking edges as taken or skipped, notifying
the response coordinator, triggering downstream node execution,
and managing branch node logic.
"""
def __init__(
self,
graph: Graph,
state_manager: GraphStateManager,
response_coordinator: ResponseStreamCoordinator,
skip_propagator: "SkipPropagator",
) -> None:
"""
Initialize the edge processor.
Args:
graph: The workflow graph
state_manager: Unified state manager
response_coordinator: Response stream coordinator
skip_propagator: Propagator for skip states
"""
self._graph = graph
self._state_manager = state_manager
self._response_coordinator = response_coordinator
self._skip_propagator = skip_propagator
def process_node_success(
self, node_id: str, selected_handle: str | None = None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges after a node succeeds.
Args:
node_id: The ID of the succeeded node
selected_handle: For branch nodes, the selected edge handle
Returns:
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
"""
node = self._graph.nodes[node_id]
if node.execution_type == NodeExecutionType.BRANCH:
return self._process_branch_node_edges(node_id, selected_handle)
else:
return self._process_non_branch_node_edges(node_id)
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for non-branch nodes (mark all as TAKEN).
Args:
node_id: The ID of the succeeded node
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
"""
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
nodes, events = self._process_taken_edge(edge)
ready_nodes.extend(nodes)
all_streaming_events.extend(events)
return ready_nodes, all_streaming_events
def _process_branch_node_edges(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for branch nodes.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no edge was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} did not select any edge")
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
# Categorize edges
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Process unselected edges first (mark as skipped)
for edge in unselected_edges:
self._process_skipped_edge(edge)
# Process selected edges
for edge in selected_edges:
nodes, events = self._process_taken_edge(edge)
ready_nodes.extend(nodes)
all_streaming_events.extend(events)
return ready_nodes, all_streaming_events
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Mark edge as taken and check downstream node.
Args:
edge: The edge to process
Returns:
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
"""
# Mark edge as taken
self._state_manager.mark_edge_taken(edge.id)
# Notify response coordinator and get streaming events
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready
ready_nodes: list[str] = []
if self._state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)
return ready_nodes, streaming_events
def _process_skipped_edge(self, edge: Edge) -> None:
"""
Mark edge as skipped.
Args:
edge: The edge to skip
"""
self._state_manager.mark_edge_skipped(edge.id)
def handle_branch_completion(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Handle completion of a branch node.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected branch
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no branch was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
# Categorize edges into selected and unselected
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Skip all unselected paths
self._skip_propagator.skip_branch_paths(unselected_edges)
# Process selected edges and get ready nodes and streaming events
return self.process_node_success(node_id, selected_handle)
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
"""
Validate that a branch selection is valid.
Args:
node_id: The ID of the branch node
selected_handle: The handle to validate
Returns:
True if the selection is valid
"""
outgoing_edges = self._graph.get_outgoing_edges(node_id)
valid_handles = {edge.source_handle for edge in outgoing_edges}
return selected_handle in valid_handles

View File

@ -0,0 +1,96 @@
"""
Skip state propagation through the graph.
"""
from collections.abc import Sequence
from typing import final
from graphon.graph import Edge, Graph
from ..graph_state_manager import GraphStateManager
@final
class SkipPropagator:
"""
Propagates skip states through the graph.
When a node is skipped, this ensures all downstream nodes
that depend solely on it are also skipped.
"""
def __init__(
self,
graph: Graph,
state_manager: GraphStateManager,
) -> None:
"""
Initialize the skip propagator.
Args:
graph: The workflow graph
state_manager: Unified state manager
"""
self._graph = graph
self._state_manager = state_manager
def propagate_skip_from_edge(self, edge_id: str) -> None:
"""
Recursively propagate skip state from a skipped edge.
Rules:
- If a node has any UNKNOWN incoming edges, stop processing
- If all incoming edges are SKIPPED, skip the node and its edges
- If any incoming edge is TAKEN, the node may still execute
Args:
edge_id: The ID of the skipped edge to start from
"""
downstream_node_id = self._graph.edges[edge_id].head
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
# Analyze edge states
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
# Stop if there are unknown edges (not yet processed)
if edge_states["has_unknown"]:
return
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
if edge_states["all_skipped"]:
self._propagate_skip_to_node(downstream_node_id)
def _propagate_skip_to_node(self, node_id: str) -> None:
"""
Mark a node and all its outgoing edges as skipped.
Args:
node_id: The ID of the node to skip
"""
# Mark node as skipped
self._state_manager.mark_node_skipped(node_id)
# Mark all outgoing edges as skipped and propagate
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
self._state_manager.mark_edge_skipped(edge.id)
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
"""
Skip all paths from unselected branch edges.
Args:
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:
self._state_manager.mark_edge_skipped(edge.id)
self.propagate_skip_from_edge(edge.id)

View File

@ -0,0 +1,55 @@
# Layers
Pluggable middleware for engine extensions.
## Components
### Layer (base)
Abstract base class for layers.
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
### DebugLoggingLayer
Comprehensive execution logging.
- Configurable detail levels
- Tracks execution statistics
- Truncates long values
## Usage
```python
debug_layer = DebugLoggingLayer(
level="INFO",
include_outputs=True
)
engine = GraphEngine(graph)
engine.layer(debug_layer)
engine.run()
```
`engine.layer()` binds the read-only runtime state before execution, so
`graph_runtime_state` is always available inside layer hooks.
## Custom Layers
```python
class MetricsLayer(Layer):
def on_event(self, event):
if isinstance(event, NodeRunSucceededEvent):
self.metrics[event.node_id] = event.elapsed_time
```
## Configuration
**DebugLoggingLayer Options:**
- `level` - Log level (INFO, DEBUG, ERROR)
- `include_inputs/outputs` - Log data values
- `max_value_length` - Truncate long values

View File

@ -0,0 +1,16 @@
"""
Layer system for GraphEngine extensibility.
This module provides the layer infrastructure for extending GraphEngine functionality
with middleware-like components that can observe events and interact with execution.
"""
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"GraphEngineLayer",
]

View File

@ -0,0 +1,128 @@
"""
Base layer class for GraphEngine extensions.
This module provides the abstract base class for implementing layers that can
intercept and respond to GraphEngine events.
"""
from abc import ABC, abstractmethod
from graphon.graph_engine.protocols.command_channel import CommandChannel
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase
from graphon.nodes.base.node import Node
from graphon.runtime import ReadOnlyGraphRuntimeState
class GraphEngineLayerNotInitializedError(Exception):
"""Raised when a layer's runtime state is accessed before initialization."""
def __init__(self, layer_name: str | None = None) -> None:
name = layer_name or "GraphEngineLayer"
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
Layers are middleware-like components that can:
- Observe all events emitted by the GraphEngine
- Access the graph runtime state
- Send commands to control execution
Subclasses should override the constructor to accept configuration parameters,
then implement the three lifecycle methods.
"""
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
@property
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
if self._graph_runtime_state is None:
raise GraphEngineLayerNotInitializedError(type(self).__name__)
return self._graph_runtime_state
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine to inject the read-only runtime state and command channel.
This is invoked when the layer is registered with a `GraphEngine` instance.
Implementations should be idempotent.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self._graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod
def on_graph_start(self) -> None:
"""
Called when graph execution starts.
This is called after the engine has been initialized but before any nodes
are executed. Layers can use this to set up resources or log start information.
"""
pass
@abstractmethod
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
This method receives all events generated during graph execution, including:
- Graph lifecycle events (start, success, failure)
- Node execution events (start, success, failure, retry)
- Stream events for response nodes
- Container events (iteration, loop)
Args:
event: The event emitted by the engine
"""
pass
@abstractmethod
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.
This is called after all nodes have been executed or when execution is
aborted. Layers can use this to clean up resources or log final state.
Args:
error: The exception that caused execution to fail, or None if successful
"""
pass
def on_node_run_start(self, node: Node) -> None:
"""
Called immediately before a node begins execution.
Layers can override to inject behavior (e.g., start spans) prior to node execution.
The node's execution ID is available via `node._node_execution_id` and will be
consistent with all events emitted by this node execution.
Args:
node: The node instance about to be executed
"""
return
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
"""
Called after a node finishes execution.
The node's execution ID is available via `node._node_execution_id` and matches
the `id` field in all events emitted by this node execution.
Args:
node: The node instance that just finished execution
error: Exception instance if the node failed, otherwise None
result_event: The final result event from node execution (succeeded/failed/paused), if any
"""
return

View File

@ -0,0 +1,247 @@
"""
Debug logging layer for GraphEngine.
This module provides a layer that logs all events and state changes during
graph execution for debugging purposes.
"""
import logging
from collections.abc import Mapping
from typing import Any, final
from typing_extensions import override
from graphon.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from .base import GraphEngineLayer
@final
class DebugLoggingLayer(GraphEngineLayer):
"""
A layer that provides comprehensive logging of GraphEngine execution.
This layer logs all events with configurable detail levels, helping developers
debug workflow execution and understand the flow of events.
"""
def __init__(
self,
level: str = "INFO",
include_inputs: bool = False,
include_outputs: bool = True,
include_process_data: bool = False,
logger_name: str = "GraphEngine.Debug",
max_value_length: int = 500,
) -> None:
"""
Initialize the debug logging layer.
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR)
include_inputs: Whether to log node input values
include_outputs: Whether to log node output values
include_process_data: Whether to log node process data
logger_name: Name of the logger to use
max_value_length: Maximum length of logged values (truncated if longer)
"""
super().__init__()
self.level = level
self.include_inputs = include_inputs
self.include_outputs = include_outputs
self.include_process_data = include_process_data
self.max_value_length = max_value_length
# Set up logger
self.logger = logging.getLogger(logger_name)
log_level = getattr(logging, level.upper(), logging.INFO)
self.logger.setLevel(log_level)
# Track execution stats
self.node_count = 0
self.success_count = 0
self.failure_count = 0
self.retry_count = 0
def _truncate_value(self, value: Any) -> str:
"""Truncate long values for logging."""
str_value = str(value)
if len(str_value) > self.max_value_length:
return str_value[: self.max_value_length] + "... (truncated)"
return str_value
def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str:
"""Format a dictionary or mapping for logging with truncation."""
if not data:
return "{}"
formatted_items: list[str] = []
for key, value in data.items():
formatted_value = self._truncate_value(value)
formatted_items.append(f" {key}: {formatted_value}")
return "{\n" + ",\n".join(formatted_items) + "\n}"
@override
def on_graph_start(self) -> None:
"""Log graph execution start."""
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
# Log initial state
self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""Log individual events based on their type."""
event_class = event.__class__.__name__
# Graph-level events
if isinstance(event, GraphRunStartedEvent):
self.logger.debug("Graph run started event")
elif isinstance(event, GraphRunSucceededEvent):
self.logger.info("✅ Graph run succeeded")
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self.logger.warning("⚠️ Graph run partially succeeded")
if event.exceptions_count > 0:
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0:
self.logger.error(" Total exceptions: %s", event.exceptions_count)
elif isinstance(event, GraphRunAbortedEvent):
self.logger.warning("⚠️ Graph run aborted: %s", event.reason)
if event.outputs:
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
# Node-level events
# Retry before Started because Retry subclasses Started;
elif isinstance(event, NodeRunRetryEvent):
self.retry_count += 1
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
self.logger.warning(" Previous error: %s", event.error)
elif isinstance(event, NodeRunStartedEvent):
self.node_count += 1
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
if self.include_inputs and event.node_run_result.inputs:
self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs))
elif isinstance(event, NodeRunSucceededEvent):
self.success_count += 1
self.logger.info("✅ Node succeeded: %s", event.node_id)
if self.include_outputs and event.node_run_result.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs))
if self.include_process_data and event.node_run_result.process_data:
self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data))
elif isinstance(event, NodeRunFailedEvent):
self.failure_count += 1
self.logger.error("❌ Node failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
if event.node_run_result.error:
self.logger.error(" Details: %s", event.node_run_result.error)
elif isinstance(event, NodeRunExceptionEvent):
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
self.logger.warning(" Error: %s", event.error)
elif isinstance(event, NodeRunStreamChunkEvent):
# Log stream chunks at debug level to avoid spam
final_indicator = " (FINAL)" if event.is_final else ""
self.logger.debug(
"📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk)
)
# Iteration events
elif isinstance(event, NodeRunIterationStartedEvent):
self.logger.info("🔁 Iteration started: %s", event.node_id)
elif isinstance(event, NodeRunIterationNextEvent):
self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index)
elif isinstance(event, NodeRunIterationSucceededEvent):
self.logger.info("✅ Iteration succeeded: %s", event.node_id)
if self.include_outputs and event.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, NodeRunIterationFailedEvent):
self.logger.error("❌ Iteration failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
# Loop events
elif isinstance(event, NodeRunLoopStartedEvent):
self.logger.info("🔄 Loop started: %s", event.node_id)
elif isinstance(event, NodeRunLoopNextEvent):
self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index)
elif isinstance(event, NodeRunLoopSucceededEvent):
self.logger.info("✅ Loop succeeded: %s", event.node_id)
if self.include_outputs and event.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, NodeRunLoopFailedEvent):
self.logger.error("❌ Loop failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
else:
# Log unknown events at debug level
self.logger.debug("Event: %s", event_class)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Log graph execution end with summary statistics."""
self.logger.info("=" * 80)
if error:
self.logger.error("🔴 GRAPH EXECUTION FAILED")
self.logger.error(" Error: %s", error)
else:
self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY")
# Log execution statistics
self.logger.info("Execution Statistics:")
self.logger.info(" Total nodes executed: %s", self.node_count)
self.logger.info(" Successful nodes: %s", self.success_count)
self.logger.info(" Failed nodes: %s", self.failure_count)
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
if self.include_outputs and self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)

View File

@ -0,0 +1,150 @@
"""
Execution limits layer for GraphEngine.
This layer monitors workflow execution to enforce limits on:
- Maximum execution steps
- Maximum execution time
When limits are exceeded, the layer automatically aborts execution.
"""
import logging
import time
from enum import StrEnum
from typing import final
from typing_extensions import override
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import (
GraphEngineEvent,
NodeRunStartedEvent,
)
from graphon.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
class LimitType(StrEnum):
"""Types of execution limits that can be exceeded."""
STEP_LIMIT = "step_limit"
TIME_LIMIT = "time_limit"
@final
class ExecutionLimitsLayer(GraphEngineLayer):
"""
Layer that enforces execution limits for workflows.
Monitors:
- Step count: Tracks number of node executions
- Time limit: Monitors total execution time
Automatically aborts execution when limits are exceeded.
"""
def __init__(self, max_steps: int, max_time: int) -> None:
"""
Initialize the execution limits layer.
Args:
max_steps: Maximum number of execution steps allowed
max_time: Maximum execution time in seconds allowed
"""
super().__init__()
self.max_steps = max_steps
self.max_time = max_time
# Runtime tracking
self.start_time: float | None = None
self.step_count = 0
self.logger = logging.getLogger(__name__)
# State tracking
self._execution_started = False
self._execution_ended = False
self._abort_sent = False # Track if abort command has been sent
@override
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self.start_time = time.time()
self.step_count = 0
self._execution_started = True
self._execution_ended = False
self._abort_sent = False
self.logger.debug("Execution limits monitoring started")
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
Monitors execution progress and enforces limits.
"""
if not self._execution_started or self._execution_ended or self._abort_sent:
return
# Track step count for node execution events
if isinstance(event, NodeRunStartedEvent):
self.step_count += 1
self.logger.debug("Step %d started: %s", self.step_count, event.node_id)
# Check step limit when node execution completes
if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent):
if self._reached_step_limitation():
self._send_abort_command(LimitType.STEP_LIMIT)
if self._reached_time_limitation():
self._send_abort_command(LimitType.TIME_LIMIT)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._execution_started and not self._execution_ended:
self._execution_ended = True
if self.start_time:
total_time = time.time() - self.start_time
self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time)
def _reached_step_limitation(self) -> bool:
"""Check if step count limit has been exceeded."""
return self.step_count > self.max_steps
def _reached_time_limitation(self) -> bool:
"""Check if time limit has been exceeded."""
return self.start_time is not None and (time.time() - self.start_time) > self.max_time
def _send_abort_command(self, limit_type: LimitType) -> None:
"""
Send abort command due to limit violation.
Args:
limit_type: Type of limit exceeded
"""
if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent:
return
# Format detailed reason message
if limit_type == LimitType.STEP_LIMIT:
reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}"
elif limit_type == LimitType.TIME_LIMIT:
elapsed_time = time.time() - self.start_time if self.start_time else 0
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
self.logger.warning("Execution limit exceeded: %s", reason)
try:
# Send abort command to the engine
abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason)
self.command_channel.send_command(abort_command)
# Mark that abort has been sent to prevent duplicate commands
self._abort_sent = True
self.logger.debug("Abort command sent to engine")
except Exception:
self.logger.exception("Failed to send abort command")

View File

@ -0,0 +1,79 @@
"""
GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Callers must provide a Redis client dependency from outside the workflow package.
"""
import logging
from collections.abc import Sequence
from typing import final
from graphon.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol
from graphon.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
logger = logging.getLogger(__name__)
@final
class GraphEngineManager:
"""
Manager for sending control commands to GraphEngine instances.
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
"""
_redis_client: RedisClientProtocol
def __init__(self, redis_client: RedisClientProtocol) -> None:
self._redis_client = redis_client
def send_stop_command(self, task_id: str, reason: str | None = None) -> None:
"""
Send a stop command to a running workflow.
Args:
task_id: The task ID of the workflow to stop
reason: Optional reason for stopping (defaults to "User requested stop")
"""
abort_command = AbortCommand(reason=reason or "User requested stop")
self._send_command(task_id, abort_command)
def send_pause_command(self, task_id: str, reason: str | None = None) -> None:
"""Send a pause command to a running workflow."""
pause_command = PauseCommand(reason=reason or "User requested pause")
self._send_command(task_id, pause_command)
def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
self._send_command(task_id, update_command)
def _send_command(self, task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""
if not task_id:
return
channel_key = f"workflow:{task_id}:commands"
channel = RedisChannel(self._redis_client, channel_key)
try:
channel.send_command(command)
except Exception:
# Silently fail if Redis is unavailable
# The legacy control mechanisms will still work
logger.exception("Failed to send graph engine command %s for task %s", command.__class__.__name__, task_id)

View File

@ -0,0 +1,14 @@
"""
Orchestration subsystem for graph engine.
This package coordinates the overall execution flow between
different subsystems.
"""
from .dispatcher import Dispatcher
from .execution_coordinator import ExecutionCoordinator
__all__ = [
"Dispatcher",
"ExecutionCoordinator",
]

View File

@ -0,0 +1,143 @@
"""
Main dispatcher for processing events from workers.
"""
import logging
import queue
import threading
import time
from typing import TYPE_CHECKING, final
from graphon.graph_events import (
GraphNodeEventBase,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunSucceededEvent,
)
from ..event_management import EventManager
from .execution_coordinator import ExecutionCoordinator
if TYPE_CHECKING:
from ..event_management import EventHandler
logger = logging.getLogger(__name__)
@final
class Dispatcher:
"""
Main dispatcher that processes events from the event queue.
This runs in a separate thread and coordinates event processing
with timeout and completion detection.
"""
_COMMAND_TRIGGER_EVENTS = (
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunExceptionEvent,
)
def __init__(
self,
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
event_emitter: EventManager | None = None,
) -> None:
"""
Initialize the dispatcher.
Args:
event_queue: Queue of events from workers
event_handler: Event handler registry for processing events
execution_coordinator: Coordinator for execution flow
event_emitter: Optional event manager to signal completion
"""
self._event_queue = event_queue
self._event_handler = event_handler
self._execution_coordinator = execution_coordinator
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._start_time: float | None = None
def start(self) -> None:
"""Start the dispatcher thread."""
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=2.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
self._process_commands()
paused = False
while not self._stop_event.is_set():
if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete:
break
if self._execution_coordinator.paused:
paused = True
break
self._execution_coordinator.check_scaling()
try:
event = self._event_queue.get(timeout=0.1)
self._event_handler.dispatch(event)
self._event_queue.task_done()
self._process_commands(event)
except queue.Empty:
time.sleep(0.1)
self._process_commands()
if paused:
self._drain_events_until_idle()
else:
self._drain_event_queue()
except Exception as e:
logger.exception("Dispatcher error")
self._execution_coordinator.mark_failed(e)
finally:
self._execution_coordinator.mark_complete()
# Signal the event emitter that execution is complete
if self._event_emitter:
self._event_emitter.mark_complete()
def _process_commands(self, event: GraphNodeEventBase | None = None):
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
self._execution_coordinator.process_commands()
def _drain_event_queue(self) -> None:
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
def _drain_events_until_idle(self) -> None:
while not self._stop_event.is_set():
try:
event = self._event_queue.get(timeout=0.1)
self._event_handler.dispatch(event)
self._event_queue.task_done()
self._process_commands(event)
except queue.Empty:
if not self._execution_coordinator.has_executing_nodes():
break
self._drain_event_queue()

View File

@ -0,0 +1,104 @@
"""
Execution coordinator for managing overall workflow execution.
"""
from typing import final
from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..graph_state_manager import GraphStateManager
from ..worker_management import WorkerPool
@final
class ExecutionCoordinator:
"""
Coordinates overall execution flow between subsystems.
This provides high-level coordination methods used by the
dispatcher to manage execution state.
"""
def __init__(
self,
graph_execution: GraphExecution,
state_manager: GraphStateManager,
command_processor: CommandProcessor,
worker_pool: WorkerPool,
) -> None:
"""
Initialize the execution coordinator.
Args:
graph_execution: Graph execution aggregate
state_manager: Unified state manager
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self._graph_execution = graph_execution
self._state_manager = state_manager
self._command_processor = command_processor
self._worker_pool = worker_pool
def process_commands(self) -> None:
"""Process any pending commands."""
self._command_processor.process_commands()
def check_scaling(self) -> None:
"""Check and perform worker scaling if needed."""
self._worker_pool.check_and_scale()
@property
def execution_complete(self):
return self._state_manager.is_execution_complete()
@property
def aborted(self):
return self._graph_execution.aborted or self._graph_execution.has_error
@property
def paused(self) -> bool:
"""Expose whether the underlying graph execution is paused."""
return self._graph_execution.is_paused
def mark_complete(self) -> None:
"""Mark execution as complete."""
if self._graph_execution.is_paused:
return
if not self._graph_execution.completed:
self._graph_execution.complete()
def mark_failed(self, error: Exception) -> None:
"""
Mark execution as failed.
Args:
error: The error that caused failure
"""
self._graph_execution.fail(error)
def handle_pause_if_needed(self) -> None:
"""If the execution has been paused, stop workers immediately."""
if not self._graph_execution.is_paused:
return
self._worker_pool.stop()
self._state_manager.clear_executing()
def handle_abort_if_needed(self) -> None:
"""If the execution has been aborted, stop workers immediately."""
if not self._graph_execution.aborted:
return
self._worker_pool.stop()
self._state_manager.clear_executing()
def has_executing_nodes(self) -> bool:
"""Return True if any nodes are currently marked as executing."""
# This check is only safe once execution has already paused.
# Before pause, executing state can change concurrently, which makes the result unreliable.
if not self._graph_execution.is_paused:
raise AssertionError("has_executing_nodes should only be called after execution is paused")
return self._state_manager.get_executing_count() > 0

View File

@ -0,0 +1,41 @@
"""
CommandChannel protocol for GraphEngine command communication.
This protocol defines the interface for sending and receiving commands
to/from a GraphEngine instance, supporting both local and distributed scenarios.
"""
from typing import Protocol
from ..entities.commands import GraphEngineCommand
class CommandChannel(Protocol):
"""
Protocol for bidirectional command communication with GraphEngine.
Since each GraphEngine instance processes only one workflow execution,
this channel is dedicated to that single execution.
"""
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch pending commands for this GraphEngine instance.
Called by GraphEngine to poll for commands that need to be processed.
Returns:
List of pending commands (may be empty)
"""
...
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to be processed by this GraphEngine instance.
Called by external systems to send control commands to the running workflow.
Args:
command: The command to send
"""
...

View File

@ -0,0 +1,12 @@
"""
Ready queue implementations for GraphEngine.
This package contains the protocol and implementations for managing
the queue of nodes ready for execution.
"""
from .factory import create_ready_queue_from_state
from .in_memory import InMemoryReadyQueue
from .protocol import ReadyQueue, ReadyQueueState
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"]

View File

@ -0,0 +1,37 @@
"""
Factory for creating ReadyQueue instances from serialized state.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
from .protocol import ReadyQueueState
if TYPE_CHECKING:
from .protocol import ReadyQueue
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
"""
Create a ReadyQueue instance from a serialized state.
Args:
state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue
Returns:
A ReadyQueue instance initialized with the given state
Raises:
ValueError: If the queue type is unknown or version is unsupported
"""
if state.type == "InMemoryReadyQueue":
if state.version != "1.0":
raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}")
queue = InMemoryReadyQueue()
# Always pass as JSON string to loads()
queue.loads(state.model_dump_json())
return queue
else:
raise ValueError(f"Unknown ready queue type: {state.type}")

View File

@ -0,0 +1,140 @@
"""
In-memory implementation of the ReadyQueue protocol.
This implementation wraps Python's standard queue.Queue and adds
serialization capabilities for state storage.
"""
import queue
from typing import final
from .protocol import ReadyQueue, ReadyQueueState
@final
class InMemoryReadyQueue(ReadyQueue):
"""
In-memory ready queue implementation with serialization support.
This implementation uses Python's queue.Queue internally and provides
methods to serialize and restore the queue state.
"""
def __init__(self, maxsize: int = 0) -> None:
"""
Initialize the in-memory ready queue.
Args:
maxsize: Maximum size of the queue (0 for unlimited)
"""
self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize)
def put(self, item: str) -> None:
"""
Add a node ID to the ready queue.
Args:
item: The node ID to add to the queue
"""
self._queue.put(item)
def get(self, timeout: float | None = None) -> str:
"""
Retrieve and remove a node ID from the queue.
Args:
timeout: Maximum time to wait for an item (None for blocking)
Returns:
The node ID retrieved from the queue
Raises:
queue.Empty: If timeout expires and no item is available
"""
if timeout is None:
return self._queue.get(block=True)
return self._queue.get(timeout=timeout)
def task_done(self) -> None:
"""
Indicate that a previously retrieved task is complete.
Used by worker threads to signal task completion for
join() synchronization.
"""
self._queue.task_done()
def empty(self) -> bool:
"""
Check if the queue is empty.
Returns:
True if the queue has no items, False otherwise
"""
return self._queue.empty()
def qsize(self) -> int:
"""
Get the approximate size of the queue.
Returns:
The approximate number of items in the queue
"""
return self._queue.qsize()
def dumps(self) -> str:
"""
Serialize the queue state to a JSON string for storage.
Returns:
A JSON string containing the serialized queue state
"""
# Extract all items from the queue without removing them
items: list[str] = []
temp_items: list[str] = []
# Drain the queue temporarily to get all items
while not self._queue.empty():
try:
item = self._queue.get_nowait()
temp_items.append(item)
items.append(item)
except queue.Empty:
break
# Put items back in the same order
for item in temp_items:
self._queue.put(item)
state = ReadyQueueState(
type="InMemoryReadyQueue",
version="1.0",
items=items,
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""
Restore the queue state from a JSON string.
Args:
data: The JSON string containing the serialized queue state to restore
"""
state = ReadyQueueState.model_validate_json(data)
if state.type != "InMemoryReadyQueue":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported version: {state.version}")
# Clear the current queue
while not self._queue.empty():
try:
self._queue.get_nowait()
except queue.Empty:
break
# Restore items
for item in state.items:
self._queue.put(item)

View File

@ -0,0 +1,104 @@
"""
ReadyQueue protocol for GraphEngine node execution queue.
This protocol defines the interface for managing the queue of nodes ready
for execution, supporting both in-memory and persistent storage scenarios.
"""
from collections.abc import Sequence
from typing import Protocol
from pydantic import BaseModel, Field
class ReadyQueueState(BaseModel):
"""
Pydantic model for serialized ready queue state.
This defines the structure of the data returned by dumps()
and expected by loads() for ready queue serialization.
"""
type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')")
version: str = Field(description="Serialization format version")
items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue")
class ReadyQueue(Protocol):
"""
Protocol for managing nodes ready for execution in GraphEngine.
This protocol defines the interface that any ready queue implementation
must provide, enabling both in-memory queues and persistent queues
that can be serialized for state storage.
"""
def put(self, item: str) -> None:
"""
Add a node ID to the ready queue.
Args:
item: The node ID to add to the queue
"""
...
def get(self, timeout: float | None = None) -> str:
"""
Retrieve and remove a node ID from the queue.
Args:
timeout: Maximum time to wait for an item (None for blocking)
Returns:
The node ID retrieved from the queue
Raises:
queue.Empty: If timeout expires and no item is available
"""
...
def task_done(self) -> None:
"""
Indicate that a previously retrieved task is complete.
Used by worker threads to signal task completion for
join() synchronization.
"""
...
def empty(self) -> bool:
"""
Check if the queue is empty.
Returns:
True if the queue has no items, False otherwise
"""
...
def qsize(self) -> int:
"""
Get the approximate size of the queue.
Returns:
The approximate number of items in the queue
"""
...
def dumps(self) -> str:
"""
Serialize the queue state to a JSON string for storage.
Returns:
A JSON string containing the serialized queue state
that can be persisted and later restored
"""
...
def loads(self, data: str) -> None:
"""
Restore the queue state from a JSON string.
Args:
data: The JSON string containing the serialized queue state to restore
"""
...

View File

@ -0,0 +1,10 @@
"""
ResponseStreamCoordinator - Coordinates streaming output from response nodes
This component manages response streaming sessions and ensures ordered streaming
of responses based on upstream node outputs and constants.
"""
from .coordinator import ResponseStreamCoordinator
__all__ = ["ResponseStreamCoordinator"]

View File

@ -0,0 +1,697 @@
"""
Main ResponseStreamCoordinator implementation.
This module contains the public ResponseStreamCoordinator class that manages
response streaming sessions and ensures ordered streaming of responses.
"""
import logging
from collections import deque
from collections.abc import Sequence
from threading import RLock
from typing import Literal, TypeAlias, final
from uuid import uuid4
from pydantic import BaseModel, Field
from graphon.enums import NodeExecutionType, NodeState
from graphon.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from graphon.nodes.base.template import TextSegment, VariableSegment
from graphon.runtime import VariablePool
from graphon.runtime.graph_runtime_state import GraphProtocol
from .path import Path
from .session import ResponseSession
logger = logging.getLogger(__name__)
# Type definitions
NodeID: TypeAlias = str
EdgeID: TypeAlias = str
class ResponseSessionState(BaseModel):
"""Serializable representation of a response session."""
node_id: str
index: int = Field(default=0, ge=0)
class StreamBufferState(BaseModel):
"""Serializable representation of buffered stream chunks."""
selector: tuple[str, ...]
events: list[NodeRunStreamChunkEvent] = Field(default_factory=list)
class StreamPositionState(BaseModel):
"""Serializable representation for stream read positions."""
selector: tuple[str, ...]
position: int = Field(default=0, ge=0)
class ResponseStreamCoordinatorState(BaseModel):
"""Serialized snapshot of ResponseStreamCoordinator."""
type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator")
version: str = Field(default="1.0")
response_nodes: Sequence[str] = Field(default_factory=list)
active_session: ResponseSessionState | None = None
waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
node_execution_ids: dict[str, str] = Field(default_factory=dict)
paths_map: dict[str, list[list[str]]] = Field(default_factory=dict)
stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list)
stream_positions: Sequence[StreamPositionState] = Field(default_factory=list)
closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list)
@final
class ResponseStreamCoordinator:
"""
Manages response streaming sessions without relying on global state.
Ensures ordered streaming of responses based on upstream node outputs and constants.
"""
def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None:
"""
Initialize coordinator with variable pool.
Args:
variable_pool: VariablePool instance for accessing node variables
graph: Graph instance for looking up node information
"""
self._variable_pool = variable_pool
self._graph = graph
self._active_session: ResponseSession | None = None
self._waiting_sessions: deque[ResponseSession] = deque()
self._lock = RLock()
# Internal stream management (replacing OutputRegistry)
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
self._stream_positions: dict[tuple[str, ...], int] = {}
self._closed_streams: set[tuple[str, ...]] = set()
# Track response nodes
self._response_nodes: set[NodeID] = set()
# Store paths for each response node
self._paths_maps: dict[NodeID, list[Path]] = {}
# Track node execution IDs and types for proper event forwarding
self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id
# Track response sessions to ensure only one per node
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
def register(self, response_node_id: NodeID) -> None:
with self._lock:
if response_node_id in self._response_nodes:
return
self._response_nodes.add(response_node_id)
# Build and save paths map for this response node
paths_map = self._build_paths_map(response_node_id)
self._paths_maps[response_node_id] = paths_map
# Create and store response session for this node
response_node = self._graph.nodes[response_node_id]
session = ResponseSession.from_node(response_node)
self._response_sessions[response_node_id] = session
def track_node_execution(self, node_id: NodeID, execution_id: str) -> None:
"""Track the execution ID for a node when it starts executing.
Args:
node_id: The ID of the node
execution_id: The execution ID from NodeRunStartedEvent
"""
with self._lock:
self._node_execution_ids[node_id] = execution_id
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
"""Get the execution ID for a node, creating one if it doesn't exist.
Args:
node_id: The ID of the node
Returns:
The execution ID for the node
"""
with self._lock:
if node_id not in self._node_execution_ids:
self._node_execution_ids[node_id] = str(uuid4())
return self._node_execution_ids[node_id]
def _build_paths_map(self, response_node_id: NodeID) -> list[Path]:
"""
Build a paths map for a response node by finding all paths from root node
to the response node, recording branch edges along each path.
Args:
response_node_id: ID of the response node to analyze
Returns:
List of Path objects, where each path contains branch edge IDs
"""
# Get root node ID
root_node_id = self._graph.root_node.id
# If root is the response node, return empty path
if root_node_id == response_node_id:
return [Path()]
# Extract variable selectors from the response node's template
response_node = self._graph.nodes[response_node_id]
response_session = ResponseSession.from_node(response_node)
template = response_session.template
# Collect all variable selectors from the template
variable_selectors: set[tuple[str, ...]] = set()
for segment in template.segments:
if isinstance(segment, VariableSegment):
variable_selectors.add(tuple(segment.selector[:2]))
# Step 1: Find all complete paths from root to response node
all_complete_paths: list[list[EdgeID]] = []
def find_paths(
current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID]
) -> None:
"""Recursively find all paths from current node to target node."""
if current_node_id == target_node_id:
# Found a complete path, store it
all_complete_paths.append(current_path.copy())
return
# Mark as visited to avoid cycles
visited.add(current_node_id)
# Explore outgoing edges
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
for edge in outgoing_edges:
edge_id = edge.id
next_node_id = edge.head
# Skip if already visited in this path
if next_node_id not in visited:
# Add edge to path and recurse
new_path = current_path + [edge_id]
find_paths(next_node_id, target_node_id, new_path, visited.copy())
# Start searching from root node
find_paths(root_node_id, response_node_id, [], set())
# Step 2: For each complete path, filter edges based on node blocking behavior
filtered_paths: list[Path] = []
for path in all_complete_paths:
blocking_edges: list[str] = []
for edge_id in path:
edge = self._graph.edges[edge_id]
source_node = self._graph.nodes[edge.tail]
# Check if node is a branch, container, or response node
if source_node.execution_type in {
NodeExecutionType.BRANCH,
NodeExecutionType.CONTAINER,
NodeExecutionType.RESPONSE,
} or source_node.blocks_variable_output(variable_selectors):
blocking_edges.append(edge_id)
# Keep the path even if it's empty
filtered_paths.append(Path(edges=blocking_edges))
return filtered_paths
def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]:
"""
Handle when an edge is taken (selected by a branch node).
This method updates the paths for all response nodes by removing
the taken edge. If any response node has an empty path after removal,
it means the node is now deterministically reachable and should start.
Args:
edge_id: The ID of the edge that was taken
Returns:
List of events to emit from starting new sessions
"""
events: list[NodeRunStreamChunkEvent] = []
with self._lock:
# Check each response node in order
for response_node_id in self._response_nodes:
if response_node_id not in self._paths_maps:
continue
paths = self._paths_maps[response_node_id]
has_reachable_path = False
# Update each path by removing the taken edge
for path in paths:
# Remove the taken edge from this path
path.remove_edge(edge_id)
# Check if this path is now empty (node is reachable)
if path.is_empty():
has_reachable_path = True
# If node is now reachable (has empty path), start/queue session
if has_reachable_path:
# Pass the node_id to the activation method
# The method will handle checking and removing from map
events.extend(self._active_or_queue_session(response_node_id))
return events
def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]:
"""
Start a session immediately if no active session, otherwise queue it.
Only activates sessions that exist in the _response_sessions map.
Args:
node_id: The ID of the response node to activate
Returns:
List of events from flush attempt if session started immediately
"""
events: list[NodeRunStreamChunkEvent] = []
# Get the session from our map (only activate if it exists)
session = self._response_sessions.get(node_id)
if not session:
return events
# Remove from map to ensure it won't be activated again
del self._response_sessions[node_id]
if self._active_session is None:
self._active_session = session
# Try to flush immediately
events.extend(self.try_flush())
else:
# Queue the session if another is active
self._waiting_sessions.append(session)
return events
def intercept_event(
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
) -> Sequence[NodeRunStreamChunkEvent]:
with self._lock:
if isinstance(event, NodeRunStreamChunkEvent):
self._append_stream_chunk(event.selector, event)
if event.is_final:
self._close_stream(event.selector)
return self.try_flush()
else:
# Skip cause we share the same variable pool.
#
# for variable_name, variable_value in event.node_run_result.outputs.items():
# self._variable_pool.add((event.node_id, variable_name), variable_value)
return self.try_flush()
def _create_stream_chunk_event(
self,
node_id: str,
execution_id: str,
selector: Sequence[str],
chunk: str,
is_final: bool = False,
) -> NodeRunStreamChunkEvent:
"""Create a stream chunk event with consistent structure.
For selectors with special prefixes (sys, env, conversation), we use the
active response node's information since these are not actual node IDs.
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self._graph.nodes and self._active_session:
# Use the active response node for special selectors
response_node = self._graph.nodes[self._active_session.node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=selector,
chunk=chunk,
is_final=is_final,
)
# Standard case: selector refers to an actual node
node = self._graph.nodes[node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=node.id,
node_type=node.node_type,
selector=selector,
chunk=chunk,
is_final=is_final,
)
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
"""Process a variable segment. Returns (events, is_complete).
Handles both regular node selectors and special system selectors (sys, env, conversation).
For special selectors, we attribute the output to the active response node.
"""
events: list[NodeRunStreamChunkEvent] = []
source_selector_prefix = segment.selector[0] if segment.selector else ""
is_complete = False
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
if self._active_session and source_selector_prefix not in self._graph.nodes:
# Special selector - use active response node
output_node_id = self._active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
execution_id = self._get_or_create_execution_id(output_node_id)
# Stream all available chunks
while self._has_unread_stream(segment.selector):
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self._active_session and source_selector_prefix not in self._graph.nodes:
response_node = self._graph.nodes[self._active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=event.selector, # Keep original selector
chunk=event.chunk,
is_final=event.is_final,
)
events.append(updated_event)
else:
# Regular node selector - use event as is
events.append(event)
# Check if this is the last chunk by looking ahead
stream_closed = self._is_stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete
if stream_closed:
is_complete = True
elif value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
)
)
is_complete = True
return events, is_complete
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self._active_session is not None
current_response_node = self._graph.nodes[self._active_session.node_id]
# Use get_or_create_execution_id to ensure we have a consistent ID
execution_id = self._get_or_create_execution_id(current_response_node.id)
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
event = self._create_stream_chunk_event(
node_id=current_response_node.id,
execution_id=execution_id,
selector=[current_response_node.id, "answer"], # FIXME(-LAN-)
chunk=segment.text,
is_final=is_last_segment,
)
return [event]
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
with self._lock:
if not self._active_session:
return []
template = self._active_session.template
response_node_id = self._active_session.node_id
events: list[NodeRunStreamChunkEvent] = []
# Process segments sequentially from current index
while self._active_session.index < len(template.segments):
segment = template.segments[self._active_session.index]
if isinstance(segment, VariableSegment):
# Check if the source node for this variable is skipped
# Only check for actual nodes, not special selectors (sys, env, conversation)
source_selector_prefix = segment.selector[0] if segment.selector else ""
if source_selector_prefix in self._graph.nodes:
source_node = self._graph.nodes[source_selector_prefix]
if source_node.state == NodeState.SKIPPED:
# Skip this variable segment if the source node is skipped
self._active_session.index += 1
continue
segment_events, is_complete = self._process_variable_segment(segment)
events.extend(segment_events)
# Only advance index if this variable segment is complete
if is_complete:
self._active_session.index += 1
else:
# Wait for more data
break
else:
segment_events = self._process_text_segment(segment)
events.extend(segment_events)
self._active_session.index += 1
if self._active_session.is_complete():
# End current session and get events from starting next session
next_session_events = self.end_session(response_node_id)
events.extend(next_session_events)
return events
def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]:
"""
End the active session for a response node.
Automatically starts the next waiting session if available.
Args:
node_id: ID of the response node ending its session
Returns:
List of events from starting the next session
"""
with self._lock:
events: list[NodeRunStreamChunkEvent] = []
if self._active_session and self._active_session.node_id == node_id:
self._active_session = None
# Try to start next waiting session
if self._waiting_sessions:
next_session = self._waiting_sessions.popleft()
self._active_session = next_session
# Immediately try to flush any available segments
events = self.try_flush()
return events
# ============= Internal Stream Management Methods =============
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
"""
Append a stream chunk to the internal buffer.
Args:
selector: List of strings identifying the stream location
event: The NodeRunStreamChunkEvent to append
Raises:
ValueError: If the stream is already closed
"""
key = tuple(selector)
if key in self._closed_streams:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
if key not in self._stream_buffers:
self._stream_buffers[key] = []
self._stream_positions[key] = 0
self._stream_buffers[key].append(event)
def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
"""
Pop the next unread stream chunk from the buffer.
Args:
selector: List of strings identifying the stream location
Returns:
The next event, or None if no unread events available
"""
key = tuple(selector)
if key not in self._stream_buffers:
return None
position = self._stream_positions.get(key, 0)
buffer = self._stream_buffers[key]
if position >= len(buffer):
return None
event = buffer[position]
self._stream_positions[key] = position + 1
return event
def _has_unread_stream(self, selector: Sequence[str]) -> bool:
"""
Check if the stream has unread events.
Args:
selector: List of strings identifying the stream location
Returns:
True if there are unread events, False otherwise
"""
key = tuple(selector)
if key not in self._stream_buffers:
return False
position = self._stream_positions.get(key, 0)
return position < len(self._stream_buffers[key])
def _close_stream(self, selector: Sequence[str]) -> None:
"""
Mark a stream as closed (no more chunks can be appended).
Args:
selector: List of strings identifying the stream location
"""
key = tuple(selector)
self._closed_streams.add(key)
def _is_stream_closed(self, selector: Sequence[str]) -> bool:
"""
Check if a stream is closed.
Args:
selector: List of strings identifying the stream location
Returns:
True if the stream is closed, False otherwise
"""
key = tuple(selector)
return key in self._closed_streams
def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None:
"""Convert an in-memory session into its serializable form."""
if session is None:
return None
return ResponseSessionState(node_id=session.node_id, index=session.index)
def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession:
"""Rebuild a response session from serialized data."""
node = self._graph.nodes.get(session_state.node_id)
if node is None:
raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state")
session = ResponseSession.from_node(node)
session.index = session_state.index
return session
def dumps(self) -> str:
"""Serialize coordinator state to JSON."""
with self._lock:
state = ResponseStreamCoordinatorState(
response_nodes=sorted(self._response_nodes),
active_session=self._serialize_session(self._active_session),
waiting_sessions=[
session_state
for session in list(self._waiting_sessions)
if (session_state := self._serialize_session(session)) is not None
],
pending_sessions=[
session_state
for _, session in sorted(self._response_sessions.items())
if (session_state := self._serialize_session(session)) is not None
],
node_execution_ids=dict(sorted(self._node_execution_ids.items())),
paths_map={
node_id: [path.edges.copy() for path in paths]
for node_id, paths in sorted(self._paths_maps.items())
},
stream_buffers=[
StreamBufferState(
selector=selector,
events=[event.model_copy(deep=True) for event in events],
)
for selector, events in sorted(self._stream_buffers.items())
],
stream_positions=[
StreamPositionState(selector=selector, position=position)
for selector, position in sorted(self._stream_positions.items())
],
closed_streams=sorted(self._closed_streams),
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""Restore coordinator state from JSON."""
state = ResponseStreamCoordinatorState.model_validate_json(data)
if state.type != "ResponseStreamCoordinator":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported serialized version: {state.version}")
with self._lock:
self._response_nodes = set(state.response_nodes)
self._paths_maps = {
node_id: [Path(edges=list(path_edges)) for path_edges in paths]
for node_id, paths in state.paths_map.items()
}
self._node_execution_ids = dict(state.node_execution_ids)
self._stream_buffers = {
tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events]
for buffer in state.stream_buffers
}
self._stream_positions = {
tuple(position.selector): position.position for position in state.stream_positions
}
for selector in self._stream_buffers:
self._stream_positions.setdefault(selector, 0)
self._closed_streams = {tuple(selector) for selector in state.closed_streams}
self._waiting_sessions = deque(
self._session_from_state(session_state) for session_state in state.waiting_sessions
)
self._response_sessions = {
session_state.node_id: self._session_from_state(session_state)
for session_state in state.pending_sessions
}
self._active_session = self._session_from_state(state.active_session) if state.active_session else None

View File

@ -0,0 +1,35 @@
"""
Internal path representation for response coordinator.
This module contains the private Path class used internally by ResponseStreamCoordinator
to track execution paths to response nodes.
"""
from dataclasses import dataclass, field
from typing import TypeAlias
EdgeID: TypeAlias = str
@dataclass
class Path:
"""
Represents a path of branch edges that must be taken to reach a response node.
Note: This is an internal class not exposed in the public API.
"""
edges: list[EdgeID] = field(default_factory=list[EdgeID])
def contains_edge(self, edge_id: EdgeID) -> bool:
"""Check if this path contains the given edge."""
return edge_id in self.edges
def remove_edge(self, edge_id: EdgeID) -> None:
"""Remove the given edge from this path in place."""
if self.contains_edge(edge_id):
self.edges.remove(edge_id)
def is_empty(self) -> bool:
"""Check if the path has no edges (node is reachable)."""
return len(self.edges) == 0

View File

@ -0,0 +1,66 @@
"""
Internal response session management for response coordinator.
This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol, cast
from graphon.nodes.base.template import Template
from graphon.runtime.graph_runtime_state import NodeProtocol
class _ResponseSessionNodeProtocol(NodeProtocol, Protocol):
"""Structural contract required from nodes that can open a response session."""
def get_streaming_template(self) -> Template: ...
@dataclass
class ResponseSession:
"""
Represents an active response streaming session.
Note: This is an internal class not exposed in the public API.
"""
node_id: str
template: Template # Template object from the response node
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: NodeProtocol) -> ResponseSession:
"""
Create a ResponseSession from a response-capable node.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer.
At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which
graph nodes should be treated as response-capable before they reach this factory.
Args:
node: Node from the materialized workflow graph.
Returns:
ResponseSession configured with the node's streaming template
Raises:
TypeError: If node does not implement the response-session streaming contract.
"""
response_node = cast(_ResponseSessionNodeProtocol, node)
try:
template = response_node.get_streaming_template()
except AttributeError as exc:
raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc
return cls(
node_id=node.id,
template=template,
)
def is_complete(self) -> bool:
"""Check if all segments in the template have been processed."""
return self.index >= len(self.template.segments)

View File

@ -0,0 +1,204 @@
"""
Worker - Thread implementation for queue-based node execution
Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
import queue
import threading
import time
from collections.abc import Sequence
from contextlib import AbstractContextManager
from datetime import UTC, datetime
from typing import TYPE_CHECKING, final
from typing_extensions import override
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.graph_engine.layers.base import GraphEngineLayer
from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
from .ready_queue import ReadyQueue
if TYPE_CHECKING:
pass
@final
class Worker(threading.Thread):
"""
Worker thread that executes nodes from the ready queue.
Workers continuously pull node IDs from the ready_queue, execute the
corresponding nodes, and push the resulting events to the event_queue
for the dispatcher to process.
"""
def __init__(
self,
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
worker_id: int = 0,
execution_context: AbstractContextManager[object] | None = None,
) -> None:
"""
Initialize worker thread.
Args:
ready_queue: Ready queue containing node IDs ready for execution
event_queue: Queue for pushing execution events
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
execution_context: Optional execution context for context preservation
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
self._execution_context = execution_context
self._stop_event = threading.Event()
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
self._current_node_started_at: datetime | None = None
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
@property
def is_idle(self) -> bool:
"""Check if the worker is currently idle."""
# Worker is idle if it hasn't processed a task recently (within 0.2 seconds)
return (time.time() - self._last_task_time) > 0.2
@property
def idle_duration(self) -> float:
"""Get the duration in seconds since the worker last processed a task."""
return time.time() - self._last_task_time
@property
def worker_id(self) -> int:
"""Get the worker's ID."""
return self._worker_id
@override
def run(self) -> None:
"""
Main worker loop.
Continuously pulls node IDs from ready_queue, executes them,
and pushes events to event_queue until stopped.
"""
while not self._stop_event.is_set():
# Try to get a node ID from the ready queue (with timeout)
try:
node_id = self._ready_queue.get(timeout=0.1)
except queue.Empty:
continue
self._last_task_time = time.time()
node = self._graph.nodes[node_id]
try:
self._current_node_started_at = None
self._execute_node(node)
self._ready_queue.task_done()
except Exception as e:
self._event_queue.put(
self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at)
)
finally:
self._current_node_started_at = None
def _execute_node(self, node: Node) -> None:
"""
Execute a single node and handle its events.
Args:
node: The node instance to execute
"""
node.ensure_execution_id()
error: Exception | None = None
result_event: GraphNodeEventBase | None = None
# Execute the node with preserved context if execution context is provided
if self._execution_context is not None:
with self._execution_context:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
self._current_node_started_at = event.start_at
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error, result_event)
else:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
self._current_node_started_at = event.start_at
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error, result_event)
def _invoke_node_run_start_hooks(self, node: Node) -> None:
"""Invoke on_node_run_start hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_start(node)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue
def _invoke_node_run_end_hooks(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
"""Invoke on_node_run_end hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_end(node, error, result_event)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue
def _build_fallback_failure_event(
self, node: Node, error: Exception, *, started_at: datetime | None = None
) -> NodeRunFailedEvent:
"""Build a failed event when worker-level execution aborts before a node emits its own result event."""
failure_time = datetime.now(UTC).replace(tzinfo=None)
error_message = str(error)
return NodeRunFailedEvent(
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
error=error_message,
start_at=started_at or failure_time,
finished_at=failure_time,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
error_type=type(error).__name__,
),
)

View File

@ -0,0 +1,12 @@
"""
Worker management subsystem for graph engine.
This package manages the worker pool, including creation,
scaling, and activity tracking.
"""
from .worker_pool import WorkerPool
__all__ = [
"WorkerPool",
]

View File

@ -0,0 +1,277 @@
"""
Simple worker pool that consolidates functionality.
This is a simpler implementation that merges WorkerPool, ActivityTracker,
DynamicScaler, and WorkerFactory into a single class.
"""
import logging
import queue
import threading
from contextlib import AbstractContextManager
from typing import final
from graphon.graph import Graph
from graphon.graph_events import GraphNodeEventBase
from ..config import GraphEngineConfig
from ..layers.base import GraphEngineLayer
from ..ready_queue import ReadyQueue
from ..worker import Worker
logger = logging.getLogger(__name__)
@final
class WorkerPool:
"""
Simple worker pool with integrated management.
This class consolidates all worker management functionality into
a single, simpler implementation without excessive abstraction.
"""
def __init__(
self,
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
config: GraphEngineConfig,
execution_context: AbstractContextManager[object] | None = None,
) -> None:
"""
Initialize the simple worker pool.
Args:
ready_queue: Ready queue for nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
config: GraphEngine worker pool configuration
execution_context: Optional execution context for context preservation
"""
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._execution_context = execution_context
self._layers = layers
self._config = config
# Worker management
self._workers: list[Worker] = []
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
# No longer tracking worker states with callbacks to avoid lock contention
def start(self, initial_count: int | None = None) -> None:
"""
Start the worker pool.
Args:
initial_count: Number of workers to start with (auto-calculated if None)
"""
with self._lock:
if self._running:
return
self._running = True
# Calculate initial worker count
if initial_count is None:
node_count = len(self._graph.nodes)
if node_count < 10:
initial_count = self._config.min_workers
elif node_count < 50:
initial_count = min(self._config.min_workers + 1, self._config.max_workers)
else:
initial_count = min(self._config.min_workers + 2, self._config.max_workers)
logger.debug(
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
initial_count,
node_count,
self._config.min_workers,
self._config.max_workers,
)
# Create initial workers
for _ in range(initial_count):
self._create_worker()
def stop(self) -> None:
"""Stop all workers in the pool."""
with self._lock:
self._running = False
worker_count = len(self._workers)
if worker_count > 0:
logger.debug("Stopping worker pool: %d workers", worker_count)
# Stop all workers
for worker in self._workers:
worker.stop()
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=2.0)
self._workers.clear()
def _create_worker(self) -> None:
"""Create and start a new worker."""
worker_id = self._worker_counter
self._worker_counter += 1
worker = Worker(
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
execution_context=self._execution_context,
)
worker.start()
self._workers.append(worker)
def _remove_worker(self, worker: Worker, worker_id: int) -> None:
"""Remove a specific worker from the pool."""
# Stop the worker
worker.stop()
# Wait for it to finish
if worker.is_alive():
worker.join(timeout=2.0)
# Remove from list
if worker in self._workers:
self._workers.remove(worker)
def _try_scale_up(self, queue_depth: int, current_count: int) -> bool:
"""
Try to scale up workers if needed.
Args:
queue_depth: Current queue depth
current_count: Current number of workers
Returns:
True if scaled up, False otherwise
"""
if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers:
old_count = current_count
self._create_worker()
logger.debug(
"Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)",
old_count,
len(self._workers),
queue_depth,
self._config.scale_up_threshold,
)
return True
return False
def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool:
"""
Try to scale down workers if we have excess capacity.
Args:
queue_depth: Current queue depth
current_count: Current number of workers
active_count: Number of active workers
idle_count: Number of idle workers
Returns:
True if scaled down, False otherwise
"""
# Skip if we're at minimum or have no idle workers
if current_count <= self._config.min_workers or idle_count == 0:
return False
# Check if we have excess capacity
has_excess_capacity = (
queue_depth <= active_count # Active workers can handle current queue
or idle_count > active_count # More idle than active workers
or (queue_depth == 0 and idle_count > 0) # No work and have idle workers
)
if not has_excess_capacity:
return False
# Find and remove idle workers that have been idle long enough
workers_to_remove: list[tuple[Worker, int]] = []
for worker in self._workers:
# Check if worker is idle and has exceeded idle time threshold
if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time:
# Don't remove if it would leave us unable to handle the queue
remaining_workers = current_count - len(workers_to_remove) - 1
if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2):
workers_to_remove.append((worker, worker.worker_id))
# Only remove one worker per check to avoid aggressive scaling
break
# Remove idle workers if any found
if workers_to_remove:
old_count = current_count
for worker, worker_id in workers_to_remove:
self._remove_worker(worker, worker_id)
logger.debug(
"Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, "
"queue_depth=%d, active=%d, idle=%d)",
old_count,
len(self._workers),
len(workers_to_remove),
self._config.scale_down_idle_time,
queue_depth,
active_count,
idle_count - len(workers_to_remove),
)
return True
return False
def check_and_scale(self) -> None:
"""Check and perform scaling if needed."""
with self._lock:
if not self._running:
return
current_count = len(self._workers)
queue_depth = self._ready_queue.qsize()
# Count active vs idle workers by querying their state directly
idle_count = sum(1 for worker in self._workers if worker.is_idle)
active_count = current_count - idle_count
# Try to scale up if queue is backing up
self._try_scale_up(queue_depth, current_count)
# Try to scale down if we have excess capacity
self._try_scale_down(queue_depth, current_count, active_count, idle_count)
def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self._workers)
def get_status(self) -> dict[str, int]:
"""
Get pool status information.
Returns:
Dictionary with status information
"""
with self._lock:
return {
"total_workers": len(self._workers),
"queue_depth": self._ready_queue.qsize(),
"min_workers": self._config.min_workers,
"max_workers": self._config.max_workers,
}

View File

@ -0,0 +1,84 @@
# Agent events
from .agent import NodeRunAgentLogEvent
# Base events
from .base import (
BaseGraphEvent,
GraphEngineEvent,
GraphNodeEventBase,
)
# Graph events
from .graph import (
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
# Iteration events
from .iteration import (
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
)
# Loop events
from .loop import (
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
)
# Node events
from .node import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
NodeRunVariableUpdatedEvent,
is_node_result_event,
)
__all__ = [
"BaseGraphEvent",
"GraphEngineEvent",
"GraphNodeEventBase",
"GraphRunAbortedEvent",
"GraphRunFailedEvent",
"GraphRunPartialSucceededEvent",
"GraphRunPausedEvent",
"GraphRunStartedEvent",
"GraphRunSucceededEvent",
"NodeRunAgentLogEvent",
"NodeRunExceptionEvent",
"NodeRunFailedEvent",
"NodeRunHumanInputFormFilledEvent",
"NodeRunHumanInputFormTimeoutEvent",
"NodeRunIterationFailedEvent",
"NodeRunIterationNextEvent",
"NodeRunIterationStartedEvent",
"NodeRunIterationSucceededEvent",
"NodeRunLoopFailedEvent",
"NodeRunLoopNextEvent",
"NodeRunLoopStartedEvent",
"NodeRunLoopSucceededEvent",
"NodeRunPauseRequestedEvent",
"NodeRunRetrieverResourceEvent",
"NodeRunRetryEvent",
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
"NodeRunVariableUpdatedEvent",
"is_node_result_event",
]

View File

@ -0,0 +1,17 @@
from collections.abc import Mapping
from typing import Any
from pydantic import Field
from .base import GraphAgentNodeEventBase
class NodeRunAgentLogEvent(GraphAgentNodeEventBase):
message_id: str = Field(..., description="message id")
label: str = Field(..., description="label")
node_execution_id: str = Field(..., description="node execution id")
parent_id: str | None = Field(..., description="parent id")
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
metadata: Mapping[str, object] = Field(default_factory=dict)

View File

@ -0,0 +1,31 @@
from pydantic import BaseModel, Field
from graphon.enums import NodeType
from graphon.node_events import NodeRunResult
class GraphEngineEvent(BaseModel):
pass
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphNodeEventBase(GraphEngineEvent):
id: str = Field(..., description="node execution id")
node_id: str
node_type: NodeType
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
node_run_result: NodeRunResult = Field(default_factory=NodeRunResult)
class GraphAgentNodeEventBase(GraphNodeEventBase):
pass

View File

@ -0,0 +1,57 @@
from pydantic import Field
from graphon.entities.pause_reason import PauseReason
from graphon.entities.workflow_start_reason import WorkflowStartReason
from graphon.graph_events import BaseGraphEvent
class GraphRunStartedEvent(BaseGraphEvent):
# Reason is emitted for workflow start events and is always set.
reason: WorkflowStartReason = Field(
default=WorkflowStartReason.INITIAL,
description="reason for workflow start",
)
class GraphRunSucceededEvent(BaseGraphEvent):
"""Event emitted when a run completes successfully with final outputs."""
outputs: dict[str, object] = Field(
default_factory=dict,
description="Final workflow outputs keyed by output selector.",
)
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: int = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
"""Event emitted when a run finishes with partial success and failures."""
exceptions_count: int = Field(..., description="exception count")
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs that were materialised before failures occurred.",
)
class GraphRunAbortedEvent(BaseGraphEvent):
"""Event emitted when a graph run is aborted by user command."""
reason: str | None = Field(default=None, description="reason for abort")
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs produced before the abort was requested.",
)
class GraphRunPausedEvent(BaseGraphEvent):
"""Event emitted when a graph run is paused by user command."""
reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs available to the client while the run is paused.",
)

View File

View File

@ -0,0 +1,40 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from pydantic import Field
from .base import GraphNodeEventBase
class NodeRunIterationStartedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
class NodeRunIterationNextEvent(GraphNodeEventBase):
node_title: str
index: int = Field(..., description="index")
pre_iteration_output: Any = None
class NodeRunIterationSucceededEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
class NodeRunIterationFailedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -0,0 +1,40 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from pydantic import Field
from .base import GraphNodeEventBase
class NodeRunLoopStartedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
class NodeRunLoopNextEvent(GraphNodeEventBase):
node_title: str
index: int = Field(..., description="index")
pre_loop_output: Any = None
class NodeRunLoopSucceededEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
class NodeRunLoopFailedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -0,0 +1,106 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any
from pydantic import Field
from graphon.entities.pause_reason import PauseReason
from graphon.variables.variables import Variable
from .base import GraphNodeEventBase
class NodeRunStartedEvent(GraphNodeEventBase):
node_title: str
predecessor_node_id: str | None = None
start_at: datetime = Field(..., description="node start time")
extras: dict[str, object] = Field(default_factory=dict)
# FIXME(-LAN-): only for ToolNode
provider_type: str = ""
provider_id: str = ""
class NodeRunStreamChunkEvent(GraphNodeEventBase):
# Spec-compliant fields
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(GraphNodeEventBase):
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunVariableUpdatedEvent(GraphNodeEventBase):
"""Request that the engine apply a variable update before downstream observers continue."""
variable: Variable = Field(..., description="Updated variable payload to apply.")
class NodeRunFailedEvent(GraphNodeEventBase):
error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunExceptionEvent(GraphNodeEventBase):
error: str = Field(..., description="error")
start_at: datetime = Field(..., description="node start time")
finished_at: datetime | None = Field(default=None, description="node finish time")
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase):
"""Emitted when a HumanInput form is submitted and before the node finishes."""
node_title: str = Field(..., description="HumanInput node title")
rendered_content: str = Field(..., description="Markdown content rendered with user inputs.")
action_id: str = Field(..., description="User action identifier chosen in the form.")
action_text: str = Field(..., description="Display text of the chosen action button.")
class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase):
"""Emitted when a HumanInput form times out."""
node_title: str = Field(..., description="HumanInput node title")
expiration_time: datetime = Field(..., description="Form expiration time")
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: PauseReason = Field(..., description="pause reason")
def is_node_result_event(event: GraphNodeEventBase) -> bool:
"""
Check if an event is a final result event from node execution.
A result event indicates the completion of a node execution and contains
runtime information such as inputs, outputs, or error details.
Args:
event: The event to check
Returns:
True if the event is a node result event (succeeded/failed/paused), False otherwise
"""
return isinstance(
event,
(
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
),
)

View File

@ -0,0 +1,51 @@
# Model Runtime
This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
## Features
- Supports capability invocation for 6 types of models
- `LLM` - LLM text completion, dialogue, pre-computed tokens capability
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
- `Rerank Model` - Segment Rerank capability
- `Speech-to-text Model` - Speech to text capability
- `Text-to-speech Model` - Text to speech capability
- `Moderation` - Moderation capability
- Model provider display
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
- Selectable model list display
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
- Provider/model credential authentication
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
## Structure
Model Runtime is divided into three layers:
- The outermost layer is the factory method
It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials.
- The second layer is the provider layer
It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers.
- The bottom layer is the model layer
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
## Documentation
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).

View File

@ -0,0 +1,64 @@
# Model Runtime
该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。
- 一方面将模型和上下游解耦,方便开发者对模型横向扩展,
- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。
## 功能介绍
- 支持 6 种模型类型的能力调用
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
- `Text Embedding Model` - 文本 Embedding预计算 tokens 能力
- `Rerank Model` - 分段 Rerank 能力
- `Speech-to-text Model` - 语音转文本能力
- `Text-to-speech Model` - 文本转语音能力
- `Moderation` - Moderation 能力
- 模型供应商展示
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
- 可选择的模型列表展示
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
- 供应商/模型凭据鉴权
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
## 结构
Model Runtime 分三层:
- 最外层为工厂方法
提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。
- 第二层为供应商层
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
对于供应商/模型凭据,有两种情况
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
- 最底层为模型层
提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。
在这里我们需要先区分模型参数与模型凭据。
- 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
## 文档
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。

View File

View File

@ -0,0 +1,159 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
class Callback(ABC):
"""
Base class for callbacks.
Only for LLM.
"""
raise_error: bool = False
@abstractmethod
def on_before_invoke(
self,
llm_instance: AIModel,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
invocation_context: Mapping[str, object] | None = None,
):
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: optional end-user identifier for the invocation
:param invocation_context: opaque request metadata for the current invocation
"""
raise NotImplementedError()
@abstractmethod
def on_new_chunk(
self,
llm_instance: AIModel,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
invocation_context: Mapping[str, object] | None = None,
):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: optional end-user identifier for the invocation
:param invocation_context: opaque request metadata for the current invocation
"""
raise NotImplementedError()
@abstractmethod
def on_after_invoke(
self,
llm_instance: AIModel,
result: LLMResult,
model: str,
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
invocation_context: Mapping[str, object] | None = None,
):
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: optional end-user identifier for the invocation
:param invocation_context: opaque request metadata for the current invocation
"""
raise NotImplementedError()
@abstractmethod
def on_invoke_error(
self,
llm_instance: AIModel,
ex: Exception,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
invocation_context: Mapping[str, object] | None = None,
):
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: optional end-user identifier for the invocation
:param invocation_context: opaque request metadata for the current invocation
"""
raise NotImplementedError()
def print_text(self, text: str, color: str | None = None, end: str = ""):
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(text_to_print, end=end)
def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"

View File

@ -0,0 +1,180 @@
import json
import logging
import sys
from collections.abc import Mapping, Sequence
from typing import cast
from graphon.model_runtime.callbacks.base_callback import Callback
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class LoggingCallback(Callback):
def on_before_invoke(
self,
llm_instance: AIModel,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
invocation_context: Mapping[str, object] | None = None,
):
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: optional end-user identifier for the invocation
:param invocation_context: opaque request metadata for the current invocation
"""
self.print_text("\n[on_llm_before_invoke]\n", color="blue")
self.print_text(f"Model: {model}\n", color="blue")
self.print_text("Parameters:\n", color="blue")
for key, value in model_parameters.items():
self.print_text(f"\t{key}: {value}\n", color="blue")
if stop:
self.print_text(f"\tstop: {stop}\n", color="blue")
if tools:
self.print_text("\tTools:\n", color="blue")
for tool in tools:
self.print_text(f"\t\t{tool.name}\n", color="blue")
self.print_text(f"Stream: {stream}\n", color="blue")
if user:
self.print_text(f"User: {user}\n", color="blue")
if invocation_context:
self.print_text(f"Invocation context: {dict(invocation_context)}\n", color="blue")
self.print_text("Prompt messages:\n", color="blue")
for prompt_message in prompt_messages:
if prompt_message.name:
self.print_text(f"\tname: {prompt_message.name}\n", color="blue")
self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue")
self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue")
if stream:
self.print_text("\n[on_llm_new_chunk]")
def on_new_chunk(
self,
llm_instance: AIModel,
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
invocation_context: Mapping[str, object] | None = None,
):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param invocation_context: opaque request metadata for the current invocation
"""
_ = user, invocation_context
sys.stdout.write(cast(str, chunk.delta.message.content))
sys.stdout.flush()
def on_after_invoke(
self,
llm_instance: AIModel,
result: LLMResult,
model: str,
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
invocation_context: Mapping[str, object] | None = None,
):
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param invocation_context: opaque request metadata for the current invocation
"""
_ = user, invocation_context
self.print_text("\n[on_llm_after_invoke]\n", color="yellow")
self.print_text(f"Content: {result.message.content}\n", color="yellow")
if result.message.tool_calls:
self.print_text("Tool calls:\n", color="yellow")
for tool_call in result.message.tool_calls:
self.print_text(f"\t{tool_call.id}\n", color="yellow")
self.print_text(f"\t{tool_call.function.name}\n", color="yellow")
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow")
self.print_text(f"Model: {result.model}\n", color="yellow")
self.print_text(f"Usage: {result.usage}\n", color="yellow")
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow")
def on_invoke_error(
self,
llm_instance: AIModel,
ex: Exception,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
user: str | None = None,
invocation_context: Mapping[str, object] | None = None,
):
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param invocation_context: opaque request metadata for the current invocation
"""
_ = user, invocation_context
self.print_text("\n[on_llm_invoke_error]\n", color="red")
logger.exception(ex)

View File

@ -0,0 +1,43 @@
from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from .message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from .model_entities import ModelPropertyKey
__all__ = [
"AssistantPromptMessage",
"AudioPromptMessageContent",
"DocumentPromptMessageContent",
"ImagePromptMessageContent",
"LLMMode",
"LLMResult",
"LLMResultChunk",
"LLMResultChunkDelta",
"LLMUsage",
"ModelPropertyKey",
"MultiModalPromptMessageContent",
"PromptMessage",
"PromptMessageContent",
"PromptMessageContentType",
"PromptMessageRole",
"PromptMessageTool",
"SystemPromptMessage",
"TextPromptMessageContent",
"ToolPromptMessage",
"UserPromptMessage",
"VideoPromptMessageContent",
]

View File

@ -0,0 +1,16 @@
from pydantic import BaseModel, model_validator
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
zh_Hans: str | None = None
en_US: str
@model_validator(mode="after")
def _(self):
if not self.zh_Hans:
self.zh_Hans = self.en_US
return self

View File

@ -0,0 +1,130 @@
from graphon.model_runtime.entities.model_entities import DefaultParameterName
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.TEMPERATURE: {
"label": {
"en_US": "Temperature",
"zh_Hans": "温度",
},
"type": "float",
"help": {
"en_US": "Controls randomness. Lower temperature results in less random completions."
" As the temperature approaches zero, the model will become deterministic and repetitive."
" Higher temperature results in more random completions.",
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。"
"较高的温度会导致更多的随机完成。",
},
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.TOP_P: {
"label": {
"en_US": "Top P",
"zh_Hans": "Top P",
},
"type": "float",
"help": {
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
" are considered.",
"zh_Hans": "通过核心采样控制多样性0.5 表示考虑了一半的所有可能性加权选项。",
},
"required": False,
"default": 1.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.TOP_K: {
"label": {
"en_US": "Top K",
"zh_Hans": "Top K",
},
"type": "int",
"help": {
"en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.",
"zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。",
},
"required": False,
"default": 50,
"min": 1,
"max": 100,
"precision": 0,
},
DefaultParameterName.PRESENCE_PENALTY: {
"label": {
"en_US": "Presence Penalty",
"zh_Hans": "存在惩罚",
},
"type": "float",
"help": {
"en_US": "Applies a penalty to the log-probability of tokens already in the text.",
"zh_Hans": "对文本中已有的标记的对数概率施加惩罚。",
},
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.FREQUENCY_PENALTY: {
"label": {
"en_US": "Frequency Penalty",
"zh_Hans": "频率惩罚",
},
"type": "float",
"help": {
"en_US": "Applies a penalty to the log-probability of tokens that appear in the text.",
"zh_Hans": "对文本中出现的标记的对数概率施加惩罚。",
},
"required": False,
"default": 0.0,
"min": 0.0,
"max": 1.0,
"precision": 2,
},
DefaultParameterName.MAX_TOKENS: {
"label": {
"en_US": "Max Tokens",
"zh_Hans": "最大 Token 数",
},
"type": "int",
"help": {
"en_US": "Specifies the upper limit on the length of generated results."
" If the generated results are truncated, you can increase this parameter.",
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
},
"required": False,
"default": 64,
"min": 1,
"max": 2048,
"precision": 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
"label": {
"en_US": "Response Format",
"zh_Hans": "回复格式",
},
"type": "string",
"help": {
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
" such as JSON, XML, etc.",
"zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML 等",
},
"required": False,
"options": ["JSON", "XML"],
},
DefaultParameterName.JSON_SCHEMA: {
"label": {
"en_US": "JSON Schema",
},
"type": "text",
"help": {
"en_US": "Set a response json schema will ensure LLM to adhere it.",
"zh_Hans": "设置返回的 json schemallm 将按照它返回",
},
"required": False,
},
}

View File

@ -0,0 +1,219 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from decimal import Decimal
from enum import StrEnum
from typing import Any, TypedDict, Union
from pydantic import BaseModel, Field
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from graphon.model_runtime.entities.model_entities import ModelUsage, PriceInfo
class LLMMode(StrEnum):
"""
Enum class for large language model mode.
"""
COMPLETION = "completion"
CHAT = "chat"
class LLMUsageMetadata(TypedDict, total=False):
"""
TypedDict for LLM usage metadata.
All fields are optional.
"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_unit_price: Union[float, str]
completion_unit_price: Union[float, str]
total_price: Union[float, str]
currency: str
prompt_price_unit: Union[float, str]
completion_price_unit: Union[float, str]
prompt_price: Union[float, str]
completion_price: Union[float, str]
latency: float
time_to_first_token: float
time_to_generate: float
class LLMUsage(ModelUsage):
"""
Model class for llm usage.
"""
prompt_tokens: int
prompt_unit_price: Decimal
prompt_price_unit: Decimal
prompt_price: Decimal
completion_tokens: int
completion_unit_price: Decimal
completion_price_unit: Decimal
completion_price: Decimal
total_tokens: int
total_price: Decimal
currency: str
latency: float
time_to_first_token: float | None = None
time_to_generate: float | None = None
@classmethod
def empty_usage(cls):
return cls(
prompt_tokens=0,
prompt_unit_price=Decimal("0.0"),
prompt_price_unit=Decimal("0.0"),
prompt_price=Decimal("0.0"),
completion_tokens=0,
completion_unit_price=Decimal("0.0"),
completion_price_unit=Decimal("0.0"),
completion_price=Decimal("0.0"),
total_tokens=0,
total_price=Decimal("0.0"),
currency="USD",
latency=0.0,
time_to_first_token=None,
time_to_generate=None,
)
@classmethod
def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage:
"""
Create LLMUsage instance from metadata dictionary with default values.
Args:
metadata: TypedDict containing usage metadata
Returns:
LLMUsage instance with values from metadata or defaults
"""
prompt_tokens = metadata.get("prompt_tokens", 0)
completion_tokens = metadata.get("completion_tokens", 0)
total_tokens = metadata.get("total_tokens", 0)
# If total_tokens is not provided but prompt and completion tokens are,
# calculate total_tokens
if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0):
total_tokens = prompt_tokens + completion_tokens
return cls(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
total_price=Decimal(str(metadata.get("total_price", 0))),
currency=metadata.get("currency", "USD"),
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
completion_price=Decimal(str(metadata.get("completion_price", 0))),
latency=metadata.get("latency", 0.0),
time_to_first_token=metadata.get("time_to_first_token"),
time_to_generate=metadata.get("time_to_generate"),
)
def plus(self, other: LLMUsage) -> LLMUsage:
"""
Add two LLMUsage instances together.
:param other: Another LLMUsage instance to add
:return: A new LLMUsage instance with summed values
"""
if self.total_tokens == 0:
return other
else:
return LLMUsage(
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
prompt_unit_price=other.prompt_unit_price,
prompt_price_unit=other.prompt_price_unit,
prompt_price=self.prompt_price + other.prompt_price,
completion_tokens=self.completion_tokens + other.completion_tokens,
completion_unit_price=other.completion_unit_price,
completion_price_unit=other.completion_price_unit,
completion_price=self.completion_price + other.completion_price,
total_tokens=self.total_tokens + other.total_tokens,
total_price=self.total_price + other.total_price,
currency=other.currency,
latency=self.latency + other.latency,
time_to_first_token=other.time_to_first_token,
time_to_generate=other.time_to_generate,
)
def __add__(self, other: LLMUsage) -> LLMUsage:
"""
Overload the + operator to add two LLMUsage instances.
:param other: Another LLMUsage instance to add
:return: A new LLMUsage instance with summed values
"""
return self.plus(other)
class LLMResult(BaseModel):
"""
Model class for llm result.
"""
id: str | None = None
model: str
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
message: AssistantPromptMessage
usage: LLMUsage
system_fingerprint: str | None = None
reasoning_content: str | None = None
class LLMStructuredOutput(BaseModel):
"""
Model class for llm structured output.
"""
structured_output: Mapping[str, Any] | None = None
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
"""
Model class for llm result with structured output.
"""
class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
"""
index: int
message: AssistantPromptMessage
usage: LLMUsage | None = None
finish_reason: str | None = None
class LLMResultChunk(BaseModel):
"""
Model class for llm result chunk.
"""
model: str
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
system_fingerprint: str | None = None
delta: LLMResultChunkDelta
class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput):
"""
Model class for llm result chunk with structured output.
"""
class NumTokensResult(PriceInfo):
"""
Model class for number of tokens result.
"""
tokens: int

View File

@ -0,0 +1,279 @@
from __future__ import annotations
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
from typing import Annotated, Any, Literal, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
class PromptMessageRole(StrEnum):
"""
Enum class for prompt message.
"""
SYSTEM = auto()
USER = auto()
ASSISTANT = auto()
TOOL = auto()
@classmethod
def value_of(cls, value: str) -> PromptMessageRole:
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid prompt message type value {value}")
class PromptMessageTool(BaseModel):
"""
Model class for prompt message tool.
"""
name: str
description: str
parameters: dict
class PromptMessageFunction(BaseModel):
"""
Model class for prompt message function.
"""
type: str = "function"
function: PromptMessageTool
class PromptMessageContentType(StrEnum):
"""
Enum class for prompt message content type.
"""
TEXT = auto()
IMAGE = auto()
AUDIO = auto()
VIDEO = auto()
DOCUMENT = auto()
class PromptMessageContent(ABC, BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
class TextPromptMessageContent(PromptMessageContent):
"""
Model class for text prompt message content.
"""
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
data: str
class MultiModalPromptMessageContent(PromptMessageContent):
"""
Model class for multi-modal prompt message content.
"""
format: str = Field(default=..., description="the format of multi-modal file")
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file")
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
filename: str = Field(default="", description="the filename of multi-modal file")
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""
Model class for image prompt message content.
"""
class DETAIL(StrEnum):
LOW = auto()
HIGH = auto()
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
PromptMessageContentUnionTypes = Annotated[
Union[
TextPromptMessageContent,
ImagePromptMessageContent,
DocumentPromptMessageContent,
AudioPromptMessageContent,
VideoPromptMessageContent,
],
Field(discriminator="type"),
]
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
PromptMessageContentType.TEXT: TextPromptMessageContent,
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
}
class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole
content: str | list[PromptMessageContentUnionTypes] | None = None
name: str | None = None
def is_empty(self) -> bool:
"""
Check if prompt message is empty.
:return: True if prompt message is empty, False otherwise
"""
return not self.content
def get_text_content(self) -> str:
"""
Get text content from prompt message.
:return: Text content as string, empty string if no text content
"""
if isinstance(self.content, str):
return self.content
elif isinstance(self.content, list):
text_parts = []
for item in self.content:
if isinstance(item, TextPromptMessageContent):
text_parts.append(item.data)
return "".join(text_parts)
else:
return ""
@field_validator("content", mode="before")
@classmethod
def validate_content(cls, v):
if isinstance(v, list):
prompts = []
for prompt in v:
if isinstance(prompt, PromptMessageContent):
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
elif isinstance(prompt, dict):
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
else:
raise ValueError(f"invalid prompt message {prompt}")
prompts.append(prompt)
return prompts
return v
@field_serializer("content")
def serialize_content(
self, content: Union[str, Sequence[PromptMessageContent]] | None
) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None:
if content is None or isinstance(content, str):
return content
if isinstance(content, list):
return [item.model_dump() if hasattr(item, "model_dump") else item for item in content]
return content
class UserPromptMessage(PromptMessage):
"""
Model class for user prompt message.
"""
role: PromptMessageRole = PromptMessageRole.USER
class AssistantPromptMessage(PromptMessage):
"""
Model class for assistant prompt message.
"""
class ToolCall(BaseModel):
"""
Model class for assistant prompt message tool call.
"""
class ToolCallFunction(BaseModel):
"""
Model class for assistant prompt message tool call function.
"""
name: str
arguments: str
id: str
type: str
function: ToolCallFunction
@field_validator("id", mode="before")
@classmethod
def transform_id_to_str(cls, value) -> str:
if not isinstance(value, str):
return str(value)
else:
return value
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = []
def is_empty(self) -> bool:
"""
Check if prompt message is empty.
:return: True if prompt message is empty, False otherwise
"""
return super().is_empty() and not self.tool_calls
class SystemPromptMessage(PromptMessage):
"""
Model class for system prompt message.
"""
role: PromptMessageRole = PromptMessageRole.SYSTEM
class ToolPromptMessage(PromptMessage):
"""
Model class for tool prompt message.
"""
role: PromptMessageRole = PromptMessageRole.TOOL
tool_call_id: str
def is_empty(self) -> bool:
"""
Check if prompt message is empty.
:return: True if prompt message is empty, False otherwise
"""
return super().is_empty() and not self.tool_call_id

View File

@ -0,0 +1,242 @@
from __future__ import annotations
from decimal import Decimal
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, ConfigDict, model_validator
from graphon.model_runtime.entities.common_entities import I18nObject
class ModelType(StrEnum):
"""
Enum class for model type.
"""
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
TTS = auto()
@classmethod
def value_of(cls, origin_model_type: str) -> ModelType:
"""
Get model type from origin model type.
:return: model type
"""
if origin_model_type in {"text-generation", cls.LLM}:
return cls.LLM
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
return cls.TEXT_EMBEDDING
elif origin_model_type in {"reranking", cls.RERANK}:
return cls.RERANK
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
return cls.SPEECH2TEXT
elif origin_model_type in {"tts", cls.TTS}:
return cls.TTS
elif origin_model_type == cls.MODERATION:
return cls.MODERATION
else:
raise ValueError(f"invalid origin model type {origin_model_type}")
def to_origin_model_type(self) -> str:
"""
Get origin model type from model type.
:return: origin model type
"""
if self == self.LLM:
return "text-generation"
elif self == self.TEXT_EMBEDDING:
return "embeddings"
elif self == self.RERANK:
return "reranking"
elif self == self.SPEECH2TEXT:
return "speech2text"
elif self == self.TTS:
return "tts"
elif self == self.MODERATION:
return "moderation"
else:
raise ValueError(f"invalid model type {self}")
class FetchFrom(StrEnum):
"""
Enum class for fetch from.
"""
PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"
class ModelFeature(StrEnum):
"""
Enum class for llm feature.
"""
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()
STRUCTURED_OUTPUT = "structured-output"
class DefaultParameterName(StrEnum):
"""
Enum class for parameter template variable.
"""
TEMPERATURE = auto()
TOP_P = auto()
TOP_K = auto()
PRESENCE_PENALTY = auto()
FREQUENCY_PENALTY = auto()
MAX_TOKENS = auto()
RESPONSE_FORMAT = auto()
JSON_SCHEMA = auto()
@classmethod
def value_of(cls, value: Any) -> DefaultParameterName:
"""
Get parameter name from value.
:param value: parameter value
:return: parameter name
"""
for name in cls:
if name.value == value:
return name
raise ValueError(f"invalid parameter name {value}")
class ParameterType(StrEnum):
"""
Enum class for parameter type.
"""
FLOAT = auto()
INT = auto()
STRING = auto()
BOOLEAN = auto()
TEXT = auto()
class ModelPropertyKey(StrEnum):
"""
Enum class for model property key.
"""
MODE = auto()
CONTEXT_SIZE = auto()
MAX_CHUNKS = auto()
FILE_UPLOAD_LIMIT = auto()
SUPPORTED_FILE_EXTENSIONS = auto()
MAX_CHARACTERS_PER_CHUNK = auto()
DEFAULT_VOICE = auto()
VOICES = auto()
WORD_LIMIT = auto()
AUDIO_TYPE = auto()
MAX_WORKERS = auto()
class ProviderModel(BaseModel):
"""
Model class for provider model.
"""
model: str
label: I18nObject
model_type: ModelType
features: list[ModelFeature] | None = None
fetch_from: FetchFrom
model_properties: dict[ModelPropertyKey, Any]
deprecated: bool = False
model_config = ConfigDict(protected_namespaces=())
@property
def support_structure_output(self) -> bool:
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
class ParameterRule(BaseModel):
"""
Model class for parameter rule.
"""
name: str
use_template: str | None = None
label: I18nObject
type: ParameterType
help: I18nObject | None = None
required: bool = False
default: Any | None = None
min: float | None = None
max: float | None = None
precision: int | None = None
options: list[str] = []
class PriceConfig(BaseModel):
"""
Model class for pricing info.
"""
input: Decimal
output: Decimal | None = None
unit: Decimal
currency: str
class AIModelEntity(ProviderModel):
"""
Model class for AI model.
"""
parameter_rules: list[ParameterRule] = []
pricing: PriceConfig | None = None
@model_validator(mode="after")
def validate_model(self):
supported_schema_keys = ["json_schema"]
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
if not schema_key:
return self
if self.features is None:
self.features = [ModelFeature.STRUCTURED_OUTPUT]
else:
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
return self
class ModelUsage(BaseModel):
pass
class PriceType(StrEnum):
"""
Enum class for price type.
"""
INPUT = auto()
OUTPUT = auto()
class PriceInfo(BaseModel):
"""
Model class for price info.
"""
unit_price: Decimal
unit: Decimal
total_amount: Decimal
currency: str

View File

@ -0,0 +1,179 @@
from collections.abc import Sequence
from enum import StrEnum, auto
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
class ConfigurateMethod(StrEnum):
"""
Enum class for configurate method of provider model.
"""
PREDEFINED_MODEL = "predefined-model"
CUSTOMIZABLE_MODEL = "customizable-model"
class FormType(StrEnum):
"""
Enum class for form type.
"""
TEXT_INPUT = "text-input"
SECRET_INPUT = "secret-input"
SELECT = auto()
RADIO = auto()
SWITCH = auto()
class FormShowOnObject(BaseModel):
"""
Model class for form show on.
"""
variable: str
value: str
class FormOption(BaseModel):
"""
Model class for form option.
"""
label: I18nObject
value: str
show_on: list[FormShowOnObject] = []
@model_validator(mode="after")
def _(self):
if not self.label:
self.label = I18nObject(en_US=self.value)
return self
class CredentialFormSchema(BaseModel):
"""
Model class for credential form schema.
"""
variable: str
label: I18nObject
type: FormType
required: bool = True
default: str | None = None
options: list[FormOption] | None = None
placeholder: I18nObject | None = None
max_length: int = 0
show_on: list[FormShowOnObject] = []
class ProviderCredentialSchema(BaseModel):
"""
Model class for provider credential schema.
"""
credential_form_schemas: list[CredentialFormSchema]
class FieldModelSchema(BaseModel):
label: I18nObject
placeholder: I18nObject | None = None
class ModelCredentialSchema(BaseModel):
"""
Model class for model credential schema.
"""
model: FieldModelSchema
credential_form_schemas: list[CredentialFormSchema]
class SimpleProviderEntity(BaseModel):
"""
Simplified provider schema exposed to callers.
`provider` is the canonical runtime identifier. `provider_name` is an optional
compatibility alias for short-name lookups and is empty when no alias exists.
"""
provider: str
provider_name: str = ""
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = []
class ProviderHelpEntity(BaseModel):
"""
Model class for provider help.
"""
title: I18nObject
url: I18nObject
class ProviderEntity(BaseModel):
"""
Runtime-native provider schema.
`provider` is the canonical runtime identifier. `provider_name` is a
compatibility alias for callers that still resolve providers by short name and
is empty when no alias exists.
"""
provider: str
provider_name: str = ""
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod]
models: list[AIModelEntity] = Field(default_factory=list)
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
# position from plugin _position.yaml
position: dict[str, list[str]] | None = {}
@field_validator("models", mode="before")
@classmethod
def validate_models(cls, v):
# returns EmptyList if v is empty
if not v:
return []
return v
def to_simple_provider(self) -> SimpleProviderEntity:
"""
Convert to simple provider.
:return: simple provider
"""
return SimpleProviderEntity(
provider=self.provider,
provider_name=self.provider_name,
label=self.label,
icon_small=self.icon_small,
supported_model_types=self.supported_model_types,
models=self.models,
)
class ProviderConfig(BaseModel):
"""
Model class for provider config.
"""
provider: str
credentials: dict

View File

@ -0,0 +1,27 @@
from typing import TypedDict
from pydantic import BaseModel
class MultimodalRerankInput(TypedDict):
content: str
content_type: str
class RerankDocument(BaseModel):
"""
Model class for rerank document.
"""
index: int
text: str
score: float
class RerankResult(BaseModel):
"""
Model class for rerank result.
"""
model: str
docs: list[RerankDocument]

View File

@ -0,0 +1,47 @@
from decimal import Decimal
from enum import StrEnum, auto
from pydantic import BaseModel
from graphon.model_runtime.entities.model_entities import ModelUsage
class EmbeddingInputType(StrEnum):
"""Embedding request input variants understood by the model runtime."""
DOCUMENT = auto()
QUERY = auto()
class EmbeddingUsage(ModelUsage):
"""
Model class for embedding usage.
"""
tokens: int
total_tokens: int
unit_price: Decimal
price_unit: Decimal
total_price: Decimal
currency: str
latency: float
class EmbeddingResult(BaseModel):
"""
Model class for text embedding result.
"""
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage
class FileEmbeddingResult(BaseModel):
"""
Model class for file embedding result.
"""
model: str
embeddings: list[list[float]]
usage: EmbeddingUsage

View File

@ -0,0 +1,41 @@
class InvokeError(ValueError):
"""Base class for all LLM exceptions."""
description: str | None = None
def __init__(self, description: str | None = None):
if description is not None:
self.description = description
def __str__(self):
return self.description or self.__class__.__name__
class InvokeConnectionError(InvokeError):
"""Raised when the Invoke returns connection error."""
description = "Connection Error"
class InvokeServerUnavailableError(InvokeError):
"""Raised when the Invoke returns server unavailable error."""
description = "Server Unavailable Error"
class InvokeRateLimitError(InvokeError):
"""Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error"
class InvokeAuthorizationError(InvokeError):
"""Raised when the Invoke returns authorization error."""
description = "Incorrect model credentials provided, please check and try again. "
class InvokeBadRequestError(InvokeError):
"""Raised when the Invoke returns bad request."""
description = "Bad Request Error"

View File

@ -0,0 +1,6 @@
class CredentialsValidateFailedError(ValueError):
"""
Credentials validate failed error
"""
pass

View File

@ -0,0 +1,3 @@
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]

Some files were not shown because too many files have changed in this diff Show More