mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
refactor(api): rename dify_graph to graphon (#34095)
This commit is contained in:
135
api/graphon/README.md
Normal file
135
api/graphon/README.md
Normal file
@ -0,0 +1,135 @@
|
||||
# Workflow
|
||||
|
||||
## Project Overview
|
||||
|
||||
This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
The graph engine follows a layered architecture with strict dependency rules:
|
||||
|
||||
1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution
|
||||
|
||||
- **Manager** - External control interface for stop/pause/resume commands
|
||||
- **Worker** - Node execution runtime
|
||||
- **Command Processing** - Handles control commands (abort, pause, resume)
|
||||
- **Event Management** - Event propagation and layer notifications
|
||||
- **Graph Traversal** - Edge processing and skip propagation
|
||||
- **Response Coordinator** - Path tracking and session management
|
||||
- **Layers** - Pluggable middleware (debug logging, execution limits)
|
||||
- **Command Channels** - Communication channels (InMemory, Redis)
|
||||
|
||||
1. **Graph** (`graph/`) - Graph structure and runtime state
|
||||
|
||||
- **Graph Template** - Workflow definition
|
||||
- **Edge** - Node connections with conditions
|
||||
- **Runtime State Protocol** - State management interface
|
||||
|
||||
1. **Nodes** (`nodes/`) - Node implementations
|
||||
|
||||
- **Base** - Abstract node classes and variable parsing
|
||||
- **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc.
|
||||
|
||||
1. **Events** (`node_events/`) - Event system
|
||||
|
||||
- **Base** - Event protocols
|
||||
- **Node Events** - Node lifecycle events
|
||||
|
||||
1. **Entities** (`entities/`) - Domain models
|
||||
|
||||
- **Variable Pool** - Variable storage
|
||||
- **Graph Init Params** - Initialization configuration
|
||||
|
||||
## Key Design Patterns
|
||||
|
||||
### Command Channel Pattern
|
||||
|
||||
External workflow control via Redis or in-memory channels:
|
||||
|
||||
```python
|
||||
# Send stop command to running workflow
|
||||
channel = RedisChannel(redis_client, f"workflow:{task_id}:commands")
|
||||
channel.send_command(AbortCommand(reason="User requested"))
|
||||
```
|
||||
|
||||
### Layer System
|
||||
|
||||
Extensible middleware for cross-cutting concerns:
|
||||
|
||||
```python
|
||||
engine = GraphEngine(graph)
|
||||
engine.layer(DebugLoggingLayer(level="INFO"))
|
||||
engine.layer(ExecutionLimitsLayer(max_nodes=100))
|
||||
```
|
||||
|
||||
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
|
||||
can assume `graph_runtime_state` is available.
|
||||
|
||||
### Event-Driven Architecture
|
||||
|
||||
All node executions emit events for monitoring and integration:
|
||||
|
||||
- `NodeRunStartedEvent` - Node execution begins
|
||||
- `NodeRunSucceededEvent` - Node completes successfully
|
||||
- `NodeRunFailedEvent` - Node encounters error
|
||||
- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle
|
||||
|
||||
### Variable Pool
|
||||
|
||||
Centralized variable storage with namespace isolation:
|
||||
|
||||
```python
|
||||
# Variables scoped by node_id
|
||||
pool.add(["node1", "output"], value)
|
||||
result = pool.get(["node1", "output"])
|
||||
```
|
||||
|
||||
## Import Architecture Rules
|
||||
|
||||
The codebase enforces strict layering via import-linter:
|
||||
|
||||
1. **Workflow Layers** (top to bottom):
|
||||
|
||||
- graph_engine → graph_events → graph → nodes → node_events → entities
|
||||
|
||||
1. **Graph Engine Internal Layers**:
|
||||
|
||||
- orchestration → command_processing → event_management → graph_traversal → domain
|
||||
|
||||
1. **Domain Isolation**:
|
||||
|
||||
- Domain models cannot import from infrastructure layers
|
||||
|
||||
1. **Command Channel Independence**:
|
||||
|
||||
- InMemory and Redis channels must remain independent
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Adding a New Node Type
|
||||
|
||||
1. Create node class in `nodes/<node_type>/`
|
||||
1. Inherit from `BaseNode` or appropriate base class
|
||||
1. Implement `_run()` method
|
||||
1. Ensure the node module is importable under `nodes/<node_type>/`
|
||||
1. Add tests in `tests/unit_tests/graphon/nodes/`
|
||||
|
||||
### Implementing a Custom Layer
|
||||
|
||||
1. Create class inheriting from `Layer` base
|
||||
1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()`
|
||||
1. Add to engine via `engine.layer()`
|
||||
|
||||
### Debugging Workflow Execution
|
||||
|
||||
Enable debug logging layer:
|
||||
|
||||
```python
|
||||
debug_layer = DebugLoggingLayer(
|
||||
level="DEBUG",
|
||||
include_inputs=True,
|
||||
include_outputs=True
|
||||
)
|
||||
```
|
||||
0
api/graphon/__init__.py
Normal file
0
api/graphon/__init__.py
Normal file
11
api/graphon/entities/__init__.py
Normal file
11
api/graphon/entities/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_start_reason import WorkflowStartReason
|
||||
|
||||
__all__ = [
|
||||
"GraphInitParams",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowStartReason",
|
||||
]
|
||||
178
api/graphon/entities/base_node_data.py
Normal file
178
api/graphon/entities/base_node_data.py
Normal file
@ -0,0 +1,178 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from graphon.entities.exc import DefaultValueTypeError
|
||||
from graphon.enums import ErrorStrategy, NodeType
|
||||
|
||||
# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`.
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
|
||||
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
|
||||
# `type` therefore accepts downstream string node kinds; unknown node implementations
|
||||
# are rejected later when the node factory resolves the node registry.
|
||||
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
|
||||
# and persisted templates/workflows also carry undeclared compatibility keys such as
|
||||
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive
|
||||
# here until graph parsing becomes discriminated by node type or those legacy payloads
|
||||
# are normalized.
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
type: NodeType
|
||||
title: str = ""
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = Field(default_factory=RetryConfig)
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""
|
||||
Dict-style access without calling model_dump() on every lookup.
|
||||
Prefer using model fields and Pydantic's extra storage.
|
||||
"""
|
||||
# First, check declared model fields
|
||||
if key in self.__class__.model_fields:
|
||||
return getattr(self, key)
|
||||
|
||||
# Then, check undeclared compatibility fields stored in Pydantic's extra dict.
|
||||
extras = getattr(self, "__pydantic_extra__", None)
|
||||
if extras is None:
|
||||
extras = getattr(self, "model_extra", None)
|
||||
if extras is not None and key in extras:
|
||||
return extras[key]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Dict-style .get() without calling model_dump() on every lookup.
|
||||
"""
|
||||
if key in self.__class__.model_fields:
|
||||
return getattr(self, key)
|
||||
|
||||
extras = getattr(self, "__pydantic_extra__", None)
|
||||
if extras is None:
|
||||
extras = getattr(self, "model_extra", None)
|
||||
if extras is not None and key in extras:
|
||||
return extras.get(key, default)
|
||||
|
||||
return default
|
||||
10
api/graphon/entities/exc.py
Normal file
10
api/graphon/entities/exc.py
Normal file
@ -0,0 +1,10 @@
|
||||
class BaseNodeError(ValueError):
|
||||
"""Base class for node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DefaultValueTypeError(BaseNodeError):
|
||||
"""Raised when the default value type is invalid."""
|
||||
|
||||
pass
|
||||
23
api/graphon/entities/graph_config.py
Normal file
23
api/graphon/entities/graph_config.py
Normal file
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from pydantic import TypeAdapter, with_config
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigDict(TypedDict):
|
||||
id: str
|
||||
# This is the permissive raw graph boundary. Node factories re-validate `data`
|
||||
# with the concrete `NodeData` subtype after resolving the node implementation.
|
||||
data: BaseNodeData
|
||||
|
||||
|
||||
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)
|
||||
24
api/graphon/entities/graph_init_params.py
Normal file
24
api/graphon/entities/graph_init_params.py
Normal file
@ -0,0 +1,24 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
DIFY_RUN_CONTEXT_KEY = "_dify"
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
"""GraphInitParams encapsulates the configurations and contextual information
|
||||
that remain constant throughout a single execution of the graph engine.
|
||||
|
||||
A single execution is defined as follows: as long as the execution has not reached
|
||||
its conclusion, it is considered one execution. For instance, if a workflow is suspended
|
||||
and later resumed, it is still regarded as a single execution, not two.
|
||||
|
||||
For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`.
|
||||
"""
|
||||
|
||||
# init params
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
run_context: Mapping[str, Any] = Field(..., description="runtime context")
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
42
api/graphon/entities/pause_reason.py
Normal file
42
api/graphon/entities/pause_reason.py
Normal file
@ -0,0 +1,42 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
class PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
SCHEDULED_PAUSE = auto()
|
||||
|
||||
|
||||
class HumanInputRequired(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
form_id: str
|
||||
form_content: str
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
actions: list[UserAction] = Field(default_factory=list)
|
||||
node_id: str
|
||||
node_title: str
|
||||
|
||||
# The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from
|
||||
# `output_variable_name` to their resolved values.
|
||||
#
|
||||
# For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its
|
||||
# selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable
|
||||
# `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The
|
||||
# `resolved_default_values` is `{"name": "John"}`.
|
||||
#
|
||||
# Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`.
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SchedulingPause(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")]
|
||||
71
api/graphon/entities/workflow_execution.py
Normal file
71
api/graphon/entities/workflow_execution.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Domain entities for workflow execution.
|
||||
|
||||
Models describe graph runtime state and avoid infrastructure-specific details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
class WorkflowExecution(BaseModel):
|
||||
"""
|
||||
Domain model for a workflow execution within the graph runtime.
|
||||
"""
|
||||
|
||||
id_: str = Field(...)
|
||||
workflow_id: str = Field(...)
|
||||
workflow_version: str = Field(...)
|
||||
workflow_type: WorkflowType = Field(...)
|
||||
graph: Mapping[str, Any] = Field(...)
|
||||
|
||||
inputs: Mapping[str, Any] = Field(...)
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
|
||||
status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
|
||||
error_message: str = Field(default="")
|
||||
total_tokens: int = Field(default=0)
|
||||
total_steps: int = Field(default=0)
|
||||
exceptions_count: int = Field(default=0)
|
||||
|
||||
started_at: datetime = Field(...)
|
||||
finished_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
"""
|
||||
Calculate elapsed time in seconds.
|
||||
If workflow is not finished, use current time.
|
||||
"""
|
||||
end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None)
|
||||
return (end_time - self.started_at).total_seconds()
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
*,
|
||||
id_: str,
|
||||
workflow_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_version: str,
|
||||
graph: Mapping[str, Any],
|
||||
inputs: Mapping[str, Any],
|
||||
started_at: datetime,
|
||||
) -> WorkflowExecution:
|
||||
return WorkflowExecution(
|
||||
id_=id_,
|
||||
workflow_id=workflow_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_version=workflow_version,
|
||||
graph=graph,
|
||||
inputs=inputs,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
started_at=started_at,
|
||||
)
|
||||
141
api/graphon/entities/workflow_node_execution.py
Normal file
141
api/graphon/entities/workflow_node_execution.py
Normal file
@ -0,0 +1,141 @@
|
||||
"""
|
||||
Domain entities for workflow node execution.
|
||||
|
||||
These models capture node-level execution state for the graph runtime without
|
||||
describing storage or application-layer concerns.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class WorkflowNodeExecution(BaseModel):
|
||||
"""
|
||||
Domain model for workflow node execution.
|
||||
|
||||
This model represents the graph-level record of a node execution and
|
||||
contains only execution state relevant to the runtime.
|
||||
"""
|
||||
|
||||
# --------- Core identification fields ---------
|
||||
|
||||
# Unique identifier for this execution record, used when persisting to storage.
|
||||
# Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382').
|
||||
id: str
|
||||
|
||||
# Optional secondary ID for cross-referencing purposes.
|
||||
#
|
||||
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
|
||||
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
|
||||
# In most scenarios, `id` should be used as the primary identifier.
|
||||
node_execution_id: str | None = None
|
||||
workflow_id: str # ID of the workflow this node belongs to
|
||||
workflow_execution_id: str | None = None # ID of the workflow execution (null for single-step debugging)
|
||||
# --------- Core identification fields ends ---------
|
||||
|
||||
# Execution positioning and flow
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
predecessor_node_id: str | None = None # ID of the node that executed before this one
|
||||
node_id: str # ID of the node being executed
|
||||
node_type: NodeType # Type of node (e.g., start, llm, downstream response node)
|
||||
title: str # Display title of the node
|
||||
|
||||
# Execution data
|
||||
# The `inputs` and `outputs` fields hold the full content
|
||||
inputs: Mapping[str, Any] | None = None # Input variables used by this node
|
||||
process_data: Mapping[str, Any] | None = None # Intermediate processing data
|
||||
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
|
||||
|
||||
# Execution state
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
|
||||
error: str | None = None # Error message if execution failed
|
||||
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
|
||||
|
||||
# Additional metadata
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.)
|
||||
|
||||
# Timing information
|
||||
created_at: datetime # When execution started
|
||||
finished_at: datetime | None = None # When execution completed
|
||||
|
||||
_truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None)
|
||||
_truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None)
|
||||
_truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None)
|
||||
|
||||
def get_truncated_inputs(self) -> Mapping[str, Any] | None:
|
||||
return self._truncated_inputs
|
||||
|
||||
def get_truncated_outputs(self) -> Mapping[str, Any] | None:
|
||||
return self._truncated_outputs
|
||||
|
||||
def get_truncated_process_data(self) -> Mapping[str, Any] | None:
|
||||
return self._truncated_process_data
|
||||
|
||||
def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None):
|
||||
self._truncated_inputs = truncated_inputs
|
||||
|
||||
def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None):
|
||||
self._truncated_outputs = truncated_outputs
|
||||
|
||||
def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None):
|
||||
self._truncated_process_data = truncated_process_data
|
||||
|
||||
def get_response_inputs(self) -> Mapping[str, Any] | None:
|
||||
inputs = self.get_truncated_inputs()
|
||||
if inputs:
|
||||
return inputs
|
||||
return self.inputs
|
||||
|
||||
@property
|
||||
def inputs_truncated(self):
|
||||
return self._truncated_inputs is not None
|
||||
|
||||
@property
|
||||
def outputs_truncated(self):
|
||||
return self._truncated_outputs is not None
|
||||
|
||||
@property
|
||||
def process_data_truncated(self):
|
||||
return self._truncated_process_data is not None
|
||||
|
||||
def get_response_outputs(self) -> Mapping[str, Any] | None:
|
||||
outputs = self.get_truncated_outputs()
|
||||
if outputs is not None:
|
||||
return outputs
|
||||
return self.outputs
|
||||
|
||||
def get_response_process_data(self) -> Mapping[str, Any] | None:
|
||||
process_data = self.get_truncated_process_data()
|
||||
if process_data is not None:
|
||||
return process_data
|
||||
return self.process_data
|
||||
|
||||
def update_from_mapping(
|
||||
self,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Update the model from mappings.
|
||||
|
||||
Args:
|
||||
inputs: The inputs to update
|
||||
process_data: The process data to update
|
||||
outputs: The outputs to update
|
||||
metadata: The metadata to update
|
||||
"""
|
||||
if inputs is not None:
|
||||
self.inputs = dict(inputs)
|
||||
if process_data is not None:
|
||||
self.process_data = dict(process_data)
|
||||
if outputs is not None:
|
||||
self.outputs = dict(outputs)
|
||||
if metadata is not None:
|
||||
self.metadata = dict(metadata)
|
||||
8
api/graphon/entities/workflow_start_reason.py
Normal file
8
api/graphon/entities/workflow_start_reason.py
Normal file
@ -0,0 +1,8 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class WorkflowStartReason(StrEnum):
|
||||
"""Reason for workflow start events across graph/queue/SSE layers."""
|
||||
|
||||
INITIAL = "initial" # First start of a workflow run.
|
||||
RESUMPTION = "resumption" # Start triggered after resuming a paused run.
|
||||
262
api/graphon/enums.py
Normal file
262
api/graphon/enums.py
Normal file
@ -0,0 +1,262 @@
|
||||
from enum import StrEnum
|
||||
from typing import ClassVar, TypeAlias
|
||||
|
||||
|
||||
class NodeState(StrEnum):
|
||||
"""State of a node or edge during workflow execution."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
TAKEN = "taken"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
NodeType: TypeAlias = str
|
||||
|
||||
|
||||
class BuiltinNodeTypes:
|
||||
"""Built-in node type string constants.
|
||||
|
||||
`node_type` values are plain strings throughout the graph runtime. This namespace
|
||||
only exposes the built-in values shipped by `graphon`; downstream packages can
|
||||
use additional strings without extending this class.
|
||||
"""
|
||||
|
||||
START: ClassVar[NodeType] = "start"
|
||||
END: ClassVar[NodeType] = "end"
|
||||
ANSWER: ClassVar[NodeType] = "answer"
|
||||
LLM: ClassVar[NodeType] = "llm"
|
||||
KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval"
|
||||
IF_ELSE: ClassVar[NodeType] = "if-else"
|
||||
CODE: ClassVar[NodeType] = "code"
|
||||
TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform"
|
||||
QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier"
|
||||
HTTP_REQUEST: ClassVar[NodeType] = "http-request"
|
||||
TOOL: ClassVar[NodeType] = "tool"
|
||||
DATASOURCE: ClassVar[NodeType] = "datasource"
|
||||
VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner"
|
||||
LOOP: ClassVar[NodeType] = "loop"
|
||||
LOOP_START: ClassVar[NodeType] = "loop-start"
|
||||
LOOP_END: ClassVar[NodeType] = "loop-end"
|
||||
ITERATION: ClassVar[NodeType] = "iteration"
|
||||
ITERATION_START: ClassVar[NodeType] = "iteration-start"
|
||||
PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor"
|
||||
VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner"
|
||||
DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor"
|
||||
LIST_OPERATOR: ClassVar[NodeType] = "list-operator"
|
||||
AGENT: ClassVar[NodeType] = "agent"
|
||||
HUMAN_INPUT: ClassVar[NodeType] = "human-input"
|
||||
|
||||
|
||||
BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = (
|
||||
BuiltinNodeTypes.START,
|
||||
BuiltinNodeTypes.END,
|
||||
BuiltinNodeTypes.ANSWER,
|
||||
BuiltinNodeTypes.LLM,
|
||||
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL,
|
||||
BuiltinNodeTypes.IF_ELSE,
|
||||
BuiltinNodeTypes.CODE,
|
||||
BuiltinNodeTypes.TEMPLATE_TRANSFORM,
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
BuiltinNodeTypes.HTTP_REQUEST,
|
||||
BuiltinNodeTypes.TOOL,
|
||||
BuiltinNodeTypes.DATASOURCE,
|
||||
BuiltinNodeTypes.VARIABLE_AGGREGATOR,
|
||||
BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR,
|
||||
BuiltinNodeTypes.LOOP,
|
||||
BuiltinNodeTypes.LOOP_START,
|
||||
BuiltinNodeTypes.LOOP_END,
|
||||
BuiltinNodeTypes.ITERATION,
|
||||
BuiltinNodeTypes.ITERATION_START,
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
|
||||
BuiltinNodeTypes.VARIABLE_ASSIGNER,
|
||||
BuiltinNodeTypes.DOCUMENT_EXTRACTOR,
|
||||
BuiltinNodeTypes.LIST_OPERATOR,
|
||||
BuiltinNodeTypes.AGENT,
|
||||
BuiltinNodeTypes.HUMAN_INPUT,
|
||||
)
|
||||
|
||||
|
||||
class NodeExecutionType(StrEnum):
|
||||
"""Node execution type classification."""
|
||||
|
||||
EXECUTABLE = "executable" # Regular nodes that execute and produce outputs
|
||||
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
|
||||
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
|
||||
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
|
||||
ROOT = "root" # Nodes that can serve as execution entry points
|
||||
|
||||
|
||||
class ErrorStrategy(StrEnum):
|
||||
FAIL_BRANCH = "fail-branch"
|
||||
DEFAULT_VALUE = "default-value"
|
||||
|
||||
|
||||
class FailBranchSourceHandle(StrEnum):
|
||||
FAILED = "fail-branch"
|
||||
SUCCESS = "success-branch"
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum for domain layer
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
# State diagram for the workflw status:
|
||||
# (@) means start, (*) means end
|
||||
#
|
||||
# ┌------------------>------------------------->------------------->--------------┐
|
||||
# | |
|
||||
# | ┌-----------------------<--------------------┐ |
|
||||
# ^ | | |
|
||||
# | | ^ |
|
||||
# | V | |
|
||||
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
|
||||
# | Scheduled |------->| Running |---------------------->| paused | |
|
||||
# └-----------┘ └-----------------------┘ └-----------┘ |
|
||||
# | | | | | | |
|
||||
# | | | | | | |
|
||||
# ^ | | | V V |
|
||||
# | | | | | ┌---------┐ |
|
||||
# (@) | | | └------------------------>| Stopped |<----┘
|
||||
# | | | └---------┘
|
||||
# | | | |
|
||||
# | | V V
|
||||
# | | ┌-----------┐ |
|
||||
# | | | Succeeded |------------->--------------┤
|
||||
# | | └-----------┘ |
|
||||
# | V V
|
||||
# | +--------┐ |
|
||||
# | | Failed |---------------------->----------------┤
|
||||
# | └--------┘ |
|
||||
# V V
|
||||
# ┌---------------------┐ |
|
||||
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
|
||||
# └---------------------┘
|
||||
#
|
||||
# Mermaid diagram:
|
||||
#
|
||||
# ---
|
||||
# title: State diagram for Workflow run state
|
||||
# ---
|
||||
# stateDiagram-v2
|
||||
# scheduled: Scheduled
|
||||
# running: Running
|
||||
# succeeded: Succeeded
|
||||
# failed: Failed
|
||||
# partial_succeeded: Partial Succeeded
|
||||
# paused: Paused
|
||||
# stopped: Stopped
|
||||
#
|
||||
# [*] --> scheduled:
|
||||
# scheduled --> running: Start Execution
|
||||
# running --> paused: Human input required
|
||||
# paused --> running: human input added
|
||||
# paused --> stopped: User stops execution
|
||||
# running --> succeeded: Execution finishes without any error
|
||||
# running --> failed: Execution finishes with errors
|
||||
# running --> stopped: User stops execution
|
||||
# running --> partial_succeeded: some execution occurred and handled during execution
|
||||
#
|
||||
# scheduled --> stopped: User stops execution
|
||||
#
|
||||
# succeeded --> [*]
|
||||
# failed --> [*]
|
||||
# partial_succeeded --> [*]
|
||||
# stopped --> [*]
|
||||
|
||||
# `SCHEDULED` means that the workflow is scheduled to run, but has not
|
||||
# started running yet. (maybe due to possible worker saturation.)
|
||||
#
|
||||
# This enum value is currently unused.
|
||||
SCHEDULED = "scheduled"
|
||||
|
||||
# `RUNNING` means the workflow is exeuting.
|
||||
RUNNING = "running"
|
||||
|
||||
# `SUCCEEDED` means the execution of workflow succeed without any error.
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
# `FAILED` means the execution of workflow failed without some errors.
|
||||
FAILED = "failed"
|
||||
|
||||
# `STOPPED` means the execution of workflow was stopped, either manually
|
||||
# by the user, or automatically by the Dify application (E.G. the moderation
|
||||
# mechanism.)
|
||||
STOPPED = "stopped"
|
||||
|
||||
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
|
||||
# execution, but they were successfully handled (e.g., by using an error
|
||||
# strategy such as "fail branch" or "default value").
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
# `PAUSED` indicates that the workflow execution is temporarily paused
|
||||
# (e.g., awaiting human input) and is expected to resume later.
|
||||
PAUSED = "paused"
|
||||
|
||||
def is_ended(self) -> bool:
|
||||
return self in _END_STATE
|
||||
|
||||
@classmethod
|
||||
def ended_values(cls) -> list[str]:
|
||||
return [status.value for status in _END_STATE]
|
||||
|
||||
|
||||
_END_STATE = frozenset(
|
||||
[
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
|
||||
Values in this enum are persisted as execution metadata and must stay in sync
|
||||
with every node that writes `NodeRunResult.metadata`.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
PENDING = "pending" # Node is scheduled but not yet executing
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
STOPPED = "stopped"
|
||||
PAUSED = "paused"
|
||||
|
||||
# Legacy statuses - kept for backward compatibility
|
||||
RETRY = "retry" # Legacy: replaced by retry mechanism in error handling
|
||||
16
api/graphon/errors.py
Normal file
16
api/graphon/errors.py
Normal file
@ -0,0 +1,16 @@
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node: Node, err_msg: str):
|
||||
self._node = node
|
||||
self._error = err_msg
|
||||
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
||||
|
||||
@property
|
||||
def node(self) -> Node:
|
||||
return self._node
|
||||
|
||||
@property
|
||||
def error(self) -> str:
|
||||
return self._error
|
||||
22
api/graphon/file/__init__.py
Normal file
22
api/graphon/file/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||
from .file_factory import get_file_type_by_mime_type, standardize_file_type
|
||||
from .models import (
|
||||
File,
|
||||
FileUploadConfig,
|
||||
ImageConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FILE_MODEL_IDENTITY",
|
||||
"ArrayFileAttribute",
|
||||
"File",
|
||||
"FileAttribute",
|
||||
"FileBelongsTo",
|
||||
"FileTransferMethod",
|
||||
"FileType",
|
||||
"FileUploadConfig",
|
||||
"ImageConfig",
|
||||
"get_file_type_by_mime_type",
|
||||
"standardize_file_type",
|
||||
]
|
||||
48
api/graphon/file/constants.py
Normal file
48
api/graphon/file/constants.py
Normal file
@ -0,0 +1,48 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
|
||||
# comparing `dify_model_identity` with constants throughout the codebase, extract
|
||||
# this logic into a dedicated function. This would encapsulate the implementation
|
||||
# details of how different variable types are identified.
|
||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||
DEFAULT_MIME_TYPE = "application/octet-stream"
|
||||
DEFAULT_EXTENSION = ".bin"
|
||||
|
||||
|
||||
def _with_case_variants(extensions: Iterable[str]) -> frozenset[str]:
|
||||
normalized = {extension.lower() for extension in extensions}
|
||||
return frozenset(normalized | {extension.upper() for extension in normalized})
|
||||
|
||||
|
||||
IMAGE_EXTENSIONS = _with_case_variants({"jpg", "jpeg", "png", "webp", "gif", "svg"})
|
||||
VIDEO_EXTENSIONS = _with_case_variants({"mp4", "mov", "mpeg", "webm"})
|
||||
AUDIO_EXTENSIONS = _with_case_variants({"mp3", "m4a", "wav", "amr", "mpga"})
|
||||
DOCUMENT_EXTENSIONS = _with_case_variants(
|
||||
{
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
"mdx",
|
||||
"pdf",
|
||||
"html",
|
||||
"htm",
|
||||
"xlsx",
|
||||
"xls",
|
||||
"vtt",
|
||||
"properties",
|
||||
"doc",
|
||||
"docx",
|
||||
"csv",
|
||||
"eml",
|
||||
"msg",
|
||||
"ppt",
|
||||
"pptx",
|
||||
"xml",
|
||||
"epub",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def maybe_file_object(o: Any) -> bool:
|
||||
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
57
api/graphon/file/enums.py
Normal file
57
api/graphon/file/enums.py
Normal file
@ -0,0 +1,57 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class FileType(StrEnum):
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
CUSTOM = "custom"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(StrEnum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
DATASOURCE_FILE = "datasource_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileBelongsTo(StrEnum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileBelongsTo:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileAttribute(StrEnum):
|
||||
TYPE = "type"
|
||||
SIZE = "size"
|
||||
NAME = "name"
|
||||
MIME_TYPE = "mime_type"
|
||||
TRANSFER_METHOD = "transfer_method"
|
||||
URL = "url"
|
||||
EXTENSION = "extension"
|
||||
RELATED_ID = "related_id"
|
||||
|
||||
|
||||
class ArrayFileAttribute(StrEnum):
|
||||
LENGTH = "length"
|
||||
39
api/graphon/file/file_factory.py
Normal file
39
api/graphon/file/file_factory.py
Normal file
@ -0,0 +1,39 @@
|
||||
from .constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
from .enums import FileType
|
||||
|
||||
|
||||
def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType:
|
||||
"""
|
||||
Infer the actual file type from extension and mime type.
|
||||
"""
|
||||
guessed_type = None
|
||||
if extension:
|
||||
guessed_type = _get_file_type_by_extension(extension)
|
||||
if guessed_type is None and mime_type:
|
||||
guessed_type = get_file_type_by_mime_type(mime_type)
|
||||
return guessed_type or FileType.CUSTOM
|
||||
|
||||
|
||||
def _get_file_type_by_extension(extension: str) -> FileType | None:
|
||||
normalized_extension = extension.lstrip(".")
|
||||
if normalized_extension in IMAGE_EXTENSIONS:
|
||||
return FileType.IMAGE
|
||||
if normalized_extension in VIDEO_EXTENSIONS:
|
||||
return FileType.VIDEO
|
||||
if normalized_extension in AUDIO_EXTENSIONS:
|
||||
return FileType.AUDIO
|
||||
if normalized_extension in DOCUMENT_EXTENSIONS:
|
||||
return FileType.DOCUMENT
|
||||
return None
|
||||
|
||||
|
||||
def get_file_type_by_mime_type(mime_type: str) -> FileType:
|
||||
if "image" in mime_type:
|
||||
return FileType.IMAGE
|
||||
if "video" in mime_type:
|
||||
return FileType.VIDEO
|
||||
if "audio" in mime_type:
|
||||
return FileType.AUDIO
|
||||
if "text" in mime_type or "pdf" in mime_type:
|
||||
return FileType.DOCUMENT
|
||||
return FileType.CUSTOM
|
||||
129
api/graphon/file/file_manager.py
Normal file
129
api/graphon/file/file_manager.py
Normal file
@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from collections.abc import Mapping
|
||||
|
||||
from graphon.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
|
||||
from .enums import FileAttribute
|
||||
from .models import File, FileTransferMethod, FileType
|
||||
from .runtime import get_workflow_file_runtime
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
case FileAttribute.TYPE:
|
||||
return file.type.value
|
||||
case FileAttribute.SIZE:
|
||||
return file.size
|
||||
case FileAttribute.NAME:
|
||||
return file.filename
|
||||
case FileAttribute.MIME_TYPE:
|
||||
return file.mime_type
|
||||
case FileAttribute.TRANSFER_METHOD:
|
||||
return file.transfer_method.value
|
||||
case FileAttribute.URL:
|
||||
return _to_url(file)
|
||||
case FileAttribute.EXTENSION:
|
||||
return file.extension
|
||||
case FileAttribute.RELATED_ID:
|
||||
return file.related_id
|
||||
|
||||
|
||||
def to_prompt_message_content(
|
||||
f: File,
|
||||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> PromptMessageContentUnionTypes:
|
||||
"""Convert a file to prompt message content."""
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
|
||||
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
|
||||
FileType.IMAGE: ImagePromptMessageContent,
|
||||
FileType.AUDIO: AudioPromptMessageContent,
|
||||
FileType.VIDEO: VideoPromptMessageContent,
|
||||
FileType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
if f.type not in prompt_class_map:
|
||||
return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
|
||||
|
||||
send_format = get_workflow_file_runtime().multimodal_send_format
|
||||
params = {
|
||||
"base64_data": _get_encoded_string(f) if send_format == "base64" else "",
|
||||
"url": _to_url(f) if send_format == "url" else "",
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
"filename": f.filename or "",
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
return prompt_class_map[f.type].model_validate(params)
|
||||
|
||||
|
||||
def download(f: File, /) -> bytes:
|
||||
if f.transfer_method in (
|
||||
FileTransferMethod.TOOL_FILE,
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.DATASOURCE_FILE,
|
||||
):
|
||||
return _download_file_content(f)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _download_file_content(file: File, /) -> bytes:
|
||||
"""Download and return a file from storage as bytes."""
|
||||
return get_workflow_file_runtime().load_file_bytes(file=file)
|
||||
|
||||
|
||||
def _get_encoded_string(f: File, /) -> str:
|
||||
match f.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
data = _download_file_content(f)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
data = _download_file_content(f)
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
data = _download_file_content(f)
|
||||
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
url = f.generate_url()
|
||||
if url is None:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
return url
|
||||
|
||||
|
||||
class FileManager:
|
||||
"""Adapter exposing file manager helpers behind FileManagerProtocol."""
|
||||
|
||||
def download(self, f: File, /) -> bytes:
|
||||
return download(f)
|
||||
|
||||
|
||||
file_manager = FileManager()
|
||||
48
api/graphon/file/helpers.py
Normal file
48
api/graphon/file/helpers.py
Normal file
@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .runtime import get_workflow_file_runtime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import File
|
||||
|
||||
|
||||
def resolve_file_url(file: File, /, *, for_external: bool = True) -> str | None:
|
||||
return get_workflow_file_runtime().resolve_file_url(file=file, for_external=for_external)
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str:
|
||||
return get_workflow_file_runtime().resolve_upload_file_url(
|
||||
upload_file_id=upload_file_id,
|
||||
as_attachment=as_attachment,
|
||||
for_external=for_external,
|
||||
)
|
||||
|
||||
|
||||
def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
return get_workflow_file_runtime().resolve_tool_file_url(
|
||||
tool_file_id=tool_file_id,
|
||||
extension=extension,
|
||||
for_external=for_external,
|
||||
)
|
||||
|
||||
|
||||
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
return get_workflow_file_runtime().verify_preview_signature(
|
||||
preview_kind="image",
|
||||
file_id=upload_file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
|
||||
|
||||
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
return get_workflow_file_runtime().verify_preview_signature(
|
||||
preview_kind="file",
|
||||
file_id=upload_file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
215
api/graphon/file/models.py
Normal file
215
api/graphon/file/models.py
Normal file
@ -0,0 +1,215 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
from . import helpers
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import FileTransferMethod, FileType
|
||||
|
||||
_FILE_REFERENCE_PREFIX = "dify-file-ref:"
|
||||
|
||||
|
||||
def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
"""Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``."""
|
||||
return helpers.get_signed_tool_file_url(
|
||||
tool_file_id=tool_file_id,
|
||||
extension=extension,
|
||||
for_external=for_external,
|
||||
)
|
||||
|
||||
|
||||
class ImageConfig(BaseModel):
|
||||
"""
|
||||
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||
"""
|
||||
|
||||
number_limits: int = 0
|
||||
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
detail: ImagePromptMessageContent.DETAIL | None = None
|
||||
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
|
||||
image_config: ImageConfig | None = None
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
number_limits: int = 0
|
||||
|
||||
|
||||
def _parse_reference(reference: str | None) -> tuple[str | None, str | None]:
|
||||
"""Best-effort parser for record references and historical storage-key payloads."""
|
||||
if not reference:
|
||||
return None, None
|
||||
|
||||
if not reference.startswith(_FILE_REFERENCE_PREFIX):
|
||||
return reference, None
|
||||
|
||||
encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX)
|
||||
try:
|
||||
payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode()))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
return reference, None
|
||||
|
||||
record_id = payload.get("record_id")
|
||||
if not isinstance(record_id, str) or not record_id:
|
||||
return reference, None
|
||||
|
||||
storage_key = payload.get("storage_key")
|
||||
if not isinstance(storage_key, str):
|
||||
storage_key = None
|
||||
|
||||
return record_id, storage_key
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
"""Graph-owned file reference.
|
||||
|
||||
The graph layer deliberately keeps only the metadata required to route,
|
||||
serialize, and render files. Application ownership concerns such as
|
||||
tenant/user/conversation identity stay in the workflow/storage layer.
|
||||
"""
|
||||
|
||||
# NOTE: dify_model_identity is a special identifier used to distinguish between
|
||||
# new and old data formats during serialization and deserialization.
|
||||
dify_model_identity: str = FILE_MODEL_IDENTITY
|
||||
|
||||
id: str | None = None # message file id
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
# If `transfer_method` is `FileTransferMethod.remote_url`, the
|
||||
# `remote_url` attribute must not be `None`.
|
||||
remote_url: str | None = None # remote url
|
||||
# Opaque workflow-layer reference for files resolved outside ``graphon``.
|
||||
# New payloads only carry the backing record id; historical payloads may
|
||||
# still include storage_key and must remain readable.
|
||||
reference: str | None = None
|
||||
filename: str | None = None
|
||||
extension: str | None = Field(default=None, description="File extension, should contain dot")
|
||||
mime_type: str | None = None
|
||||
size: int = -1
|
||||
_storage_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
type: FileType,
|
||||
transfer_method: FileTransferMethod,
|
||||
remote_url: str | None = None,
|
||||
reference: str | None = None,
|
||||
related_id: str | None = None,
|
||||
filename: str | None = None,
|
||||
extension: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
size: int = -1,
|
||||
storage_key: str | None = None,
|
||||
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
|
||||
url: str | None = None,
|
||||
# Legacy compatibility fields - explicitly accept known extra fields
|
||||
tool_file_id: str | None = None,
|
||||
upload_file_id: str | None = None,
|
||||
datasource_file_id: str | None = None,
|
||||
):
|
||||
legacy_record_id = related_id or tool_file_id or upload_file_id or datasource_file_id
|
||||
normalized_reference = reference
|
||||
if normalized_reference is None and legacy_record_id is not None:
|
||||
normalized_reference = str(legacy_record_id)
|
||||
_, parsed_storage_key = _parse_reference(normalized_reference)
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
type=type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=remote_url,
|
||||
reference=normalized_reference,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
dify_model_identity=dify_model_identity,
|
||||
url=url,
|
||||
)
|
||||
# Accept legacy constructor fields without promoting them back into the graph model.
|
||||
_ = tenant_id
|
||||
self._storage_key = storage_key or parsed_storage_key or ""
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
return {
|
||||
**data,
|
||||
"related_id": self.related_id,
|
||||
"url": self.generate_url(),
|
||||
}
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
url = self.generate_url()
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f""
|
||||
else:
|
||||
text = f"[{self.filename or url}]({url})"
|
||||
|
||||
return text
|
||||
|
||||
def generate_url(self, for_external: bool = True) -> str | None:
|
||||
return helpers.resolve_file_url(self, for_external=for_external)
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"mime_type": self.mime_type,
|
||||
"filename": self.filename,
|
||||
"extension": self.extension,
|
||||
"size": self.size,
|
||||
"type": self.type,
|
||||
"url": self.generate_url(for_external=False),
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_after(self) -> File:
|
||||
match self.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if not self.remote_url:
|
||||
raise ValueError("Missing file url")
|
||||
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
|
||||
raise ValueError("Invalid file url")
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
if not self.reference:
|
||||
raise ValueError("Missing file reference")
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
if not self.reference:
|
||||
raise ValueError("Missing file reference")
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
if not self.reference:
|
||||
raise ValueError("Missing file reference")
|
||||
return self
|
||||
|
||||
@property
|
||||
def related_id(self) -> str | None:
|
||||
record_id, _ = _parse_reference(self.reference)
|
||||
return record_id
|
||||
|
||||
@related_id.setter
|
||||
def related_id(self, value: str | None) -> None:
|
||||
self.reference = value
|
||||
|
||||
@property
|
||||
def storage_key(self) -> str:
|
||||
_, storage_key = _parse_reference(self.reference)
|
||||
return storage_key or self._storage_key
|
||||
|
||||
@storage_key.setter
|
||||
def storage_key(self, value: str) -> None:
|
||||
self._storage_key = value
|
||||
56
api/graphon/file/protocols.py
Normal file
56
api/graphon/file/protocols.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import File
|
||||
|
||||
|
||||
class HttpResponseProtocol(Protocol):
|
||||
"""Subset of response behavior needed by workflow file helpers."""
|
||||
|
||||
@property
|
||||
def content(self) -> bytes: ...
|
||||
|
||||
def raise_for_status(self) -> object: ...
|
||||
|
||||
|
||||
class WorkflowFileRuntimeProtocol(Protocol):
|
||||
"""Runtime dependencies required by ``graphon.file``.
|
||||
|
||||
Implementations are expected to be provided by integration layers (for example,
|
||||
``core.app.workflow.file_runtime``) so the workflow package avoids importing
|
||||
application infrastructure modules directly.
|
||||
"""
|
||||
|
||||
@property
|
||||
def multimodal_send_format(self) -> str: ...
|
||||
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ...
|
||||
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ...
|
||||
|
||||
def load_file_bytes(self, *, file: File) -> bytes: ...
|
||||
|
||||
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: ...
|
||||
|
||||
def resolve_upload_file_url(
|
||||
self,
|
||||
*,
|
||||
upload_file_id: str,
|
||||
as_attachment: bool = False,
|
||||
for_external: bool = True,
|
||||
) -> str: ...
|
||||
|
||||
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ...
|
||||
|
||||
def verify_preview_signature(
|
||||
self,
|
||||
*,
|
||||
preview_kind: Literal["image", "file"],
|
||||
file_id: str,
|
||||
timestamp: str,
|
||||
nonce: str,
|
||||
sign: str,
|
||||
) -> bool: ...
|
||||
71
api/graphon/file/runtime.py
Normal file
71
api/graphon/file/runtime.py
Normal file
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Literal, NoReturn
|
||||
|
||||
from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import File
|
||||
|
||||
|
||||
class WorkflowFileRuntimeNotConfiguredError(RuntimeError):
|
||||
"""Raised when workflow file runtime dependencies were not configured."""
|
||||
|
||||
|
||||
class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
def _raise(self) -> NoReturn:
|
||||
raise WorkflowFileRuntimeNotConfiguredError(
|
||||
"workflow file runtime is not configured, call set_workflow_file_runtime(...) first"
|
||||
)
|
||||
|
||||
@property
|
||||
def multimodal_send_format(self) -> str:
|
||||
self._raise()
|
||||
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
self._raise()
|
||||
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
self._raise()
|
||||
|
||||
def load_file_bytes(self, *, file: File) -> bytes:
|
||||
self._raise()
|
||||
|
||||
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
|
||||
self._raise()
|
||||
|
||||
def resolve_upload_file_url(
|
||||
self,
|
||||
*,
|
||||
upload_file_id: str,
|
||||
as_attachment: bool = False,
|
||||
for_external: bool = True,
|
||||
) -> str:
|
||||
self._raise()
|
||||
|
||||
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
self._raise()
|
||||
|
||||
def verify_preview_signature(
|
||||
self,
|
||||
*,
|
||||
preview_kind: Literal["image", "file"],
|
||||
file_id: str,
|
||||
timestamp: str,
|
||||
nonce: str,
|
||||
sign: str,
|
||||
) -> bool:
|
||||
self._raise()
|
||||
|
||||
|
||||
_runtime: WorkflowFileRuntimeProtocol = _UnconfiguredWorkflowFileRuntime()
|
||||
|
||||
|
||||
def set_workflow_file_runtime(runtime: WorkflowFileRuntimeProtocol) -> None:
|
||||
global _runtime
|
||||
_runtime = runtime
|
||||
|
||||
|
||||
def get_workflow_file_runtime() -> WorkflowFileRuntimeProtocol:
|
||||
return _runtime
|
||||
9
api/graphon/file/tool_file_parser.py
Normal file
9
api/graphon/file/tool_file_parser.py
Normal file
@ -0,0 +1,9 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
_tool_file_manager_factory: Callable[[], Any] | None = None
|
||||
|
||||
|
||||
def set_tool_file_manager_factory(factory: Callable[[], Any]):
|
||||
global _tool_file_manager_factory
|
||||
_tool_file_manager_factory = factory
|
||||
11
api/graphon/graph/__init__.py
Normal file
11
api/graphon/graph/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .edge import Edge
|
||||
from .graph import Graph, GraphBuilder, NodeFactory
|
||||
from .graph_template import GraphTemplate
|
||||
|
||||
__all__ = [
|
||||
"Edge",
|
||||
"Graph",
|
||||
"GraphBuilder",
|
||||
"GraphTemplate",
|
||||
"NodeFactory",
|
||||
]
|
||||
15
api/graphon/graph/edge.py
Normal file
15
api/graphon/graph/edge.py
Normal file
@ -0,0 +1,15 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from graphon.enums import NodeState
|
||||
|
||||
|
||||
@dataclass
|
||||
class Edge:
|
||||
"""Edge connecting two nodes in a workflow graph."""
|
||||
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
tail: str = "" # tail node id (source)
|
||||
head: str = "" # head node id (target)
|
||||
source_handle: str = "source" # source handle for conditional branching
|
||||
state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state
|
||||
438
api/graphon/graph/graph.py
Normal file
438
api/graphon/graph/graph.py
Normal file
@ -0,0 +1,438 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
from .edge import Edge
|
||||
from .validation import get_graph_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict])
|
||||
|
||||
|
||||
class NodeFactory(Protocol):
|
||||
"""
|
||||
Protocol for creating Node instances from node data dictionaries.
|
||||
|
||||
This protocol decouples the Graph class from specific node mapping implementations,
|
||||
allowing for different node creation strategies while maintaining type safety.
|
||||
"""
|
||||
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or no implementation exists for the resolved version
|
||||
:raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@final
|
||||
class Graph:
|
||||
"""Graph representation with nodes and edges for workflow execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
nodes: dict[str, Node] | None = None,
|
||||
edges: dict[str, Edge] | None = None,
|
||||
in_edges: dict[str, list[str]] | None = None,
|
||||
out_edges: dict[str, list[str]] | None = None,
|
||||
root_node: Node,
|
||||
):
|
||||
"""
|
||||
Initialize Graph instance.
|
||||
|
||||
:param nodes: graph nodes mapping (node id: node object)
|
||||
:param edges: graph edges mapping (edge id: edge object)
|
||||
:param in_edges: incoming edges mapping (node id: list of edge ids)
|
||||
:param out_edges: outgoing edges mapping (node id: list of edge ids)
|
||||
:param root_node: root node object
|
||||
"""
|
||||
self.nodes = nodes or {}
|
||||
self.edges = edges or {}
|
||||
self.in_edges = in_edges or {}
|
||||
self.out_edges = out_edges or {}
|
||||
self.root_node = root_node
|
||||
|
||||
@classmethod
|
||||
def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]:
|
||||
"""
|
||||
Parse node configurations and build a mapping of node IDs to configs.
|
||||
|
||||
:param node_configs: list of node configuration dictionaries
|
||||
:return: mapping of node ID to node config
|
||||
"""
|
||||
node_configs_map: dict[str, NodeConfigDict] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_configs_map[node_config["id"]] = node_config
|
||||
|
||||
return node_configs_map
|
||||
|
||||
@classmethod
|
||||
def _build_edges(
|
||||
cls, edge_configs: list[dict[str, object]]
|
||||
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
|
||||
"""
|
||||
Build edge objects and mappings from edge configurations.
|
||||
|
||||
:param edge_configs: list of edge configurations
|
||||
:return: tuple of (edges dict, in_edges dict, out_edges dict)
|
||||
"""
|
||||
edges: dict[str, Edge] = {}
|
||||
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
edge_counter = 0
|
||||
for edge_config in edge_configs:
|
||||
source = edge_config.get("source")
|
||||
target = edge_config.get("target")
|
||||
|
||||
if not isinstance(source, str) or not isinstance(target, str):
|
||||
continue
|
||||
|
||||
# Create edge
|
||||
edge_id = f"edge_{edge_counter}"
|
||||
edge_counter += 1
|
||||
|
||||
source_handle = edge_config.get("sourceHandle", "source")
|
||||
if not isinstance(source_handle, str):
|
||||
continue
|
||||
|
||||
edge = Edge(
|
||||
id=edge_id,
|
||||
tail=source,
|
||||
head=target,
|
||||
source_handle=source_handle,
|
||||
)
|
||||
|
||||
edges[edge_id] = edge
|
||||
out_edges[source].append(edge_id)
|
||||
in_edges[target].append(edge_id)
|
||||
|
||||
return edges, dict(in_edges), dict(out_edges)
|
||||
|
||||
@classmethod
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, NodeConfigDict],
|
||||
node_factory: NodeFactory,
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
Create node instances from configurations using the node factory.
|
||||
|
||||
:param node_configs_map: mapping of node ID to node config
|
||||
:param node_factory: factory for creating node instances
|
||||
:return: mapping of node ID to node instance
|
||||
"""
|
||||
nodes: dict[str, Node] = {}
|
||||
|
||||
for node_id, node_config in node_configs_map.items():
|
||||
try:
|
||||
node_instance = node_factory.create_node(node_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to create node instance for node_id %s", node_id)
|
||||
raise
|
||||
nodes[node_id] = node_instance
|
||||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def new(cls) -> GraphBuilder:
|
||||
"""Create a fluent builder for assembling a graph programmatically."""
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@staticmethod
|
||||
def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]:
|
||||
"""
|
||||
Remove editor-only nodes before `NodeConfigDict` validation.
|
||||
|
||||
Persisted note widgets use a top-level `type == "custom-note"` but leave
|
||||
`data.type` empty because they are never executable graph nodes. Filter
|
||||
them while configs are still raw dicts so Pydantic does not validate
|
||||
their placeholder payloads against `BaseNodeData.type: NodeType`.
|
||||
"""
|
||||
filtered_node_configs: list[dict[str, object]] = []
|
||||
for node_config in node_configs:
|
||||
if node_config.get("type", "") == "custom-note":
|
||||
continue
|
||||
filtered_node_configs.append(dict(node_config))
|
||||
return filtered_node_configs
|
||||
|
||||
@classmethod
|
||||
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||
"""
|
||||
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
"""
|
||||
for node in nodes.values():
|
||||
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
nodes: dict[str, Node],
|
||||
edges: dict[str, Edge],
|
||||
in_edges: dict[str, list[str]],
|
||||
out_edges: dict[str, list[str]],
|
||||
active_root_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Mark nodes and edges from inactive root branches as skipped.
|
||||
|
||||
Algorithm:
|
||||
1. Mark inactive root nodes as skipped
|
||||
2. For skipped nodes, mark all their outgoing edges as skipped
|
||||
3. For each edge marked as skipped, check its target node:
|
||||
- If ALL incoming edges are skipped, mark the node as skipped
|
||||
- Otherwise, leave the node state unchanged
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
:param edges: mapping of edge ID to edge instance
|
||||
:param in_edges: mapping of node ID to incoming edge IDs
|
||||
:param out_edges: mapping of node ID to outgoing edge IDs
|
||||
:param active_root_id: ID of the active root node
|
||||
"""
|
||||
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
|
||||
top_level_roots: list[str] = [
|
||||
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
|
||||
]
|
||||
|
||||
# If there's only one root or the active root is not a top-level root, no marking needed
|
||||
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
|
||||
return
|
||||
|
||||
# Mark inactive root nodes as skipped
|
||||
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
|
||||
for root_id in inactive_roots:
|
||||
if root_id in nodes:
|
||||
nodes[root_id].state = NodeState.SKIPPED
|
||||
|
||||
# Recursively mark downstream nodes and edges
|
||||
def mark_downstream(node_id: str) -> None:
|
||||
"""Recursively mark downstream nodes and edges as skipped."""
|
||||
if nodes[node_id].state != NodeState.SKIPPED:
|
||||
return
|
||||
# If this node is skipped, mark all its outgoing edges as skipped
|
||||
out_edge_ids = out_edges.get(node_id, [])
|
||||
for edge_id in out_edge_ids:
|
||||
edge = edges[edge_id]
|
||||
edge.state = NodeState.SKIPPED
|
||||
|
||||
# Check the target node of this edge
|
||||
target_node = nodes[edge.head]
|
||||
in_edge_ids = in_edges.get(target_node.id, [])
|
||||
in_edge_states = [edges[eid].state for eid in in_edge_ids]
|
||||
|
||||
# If all incoming edges are skipped, mark the node as skipped
|
||||
if all(state == NodeState.SKIPPED for state in in_edge_states):
|
||||
target_node.state = NodeState.SKIPPED
|
||||
# Recursively process downstream nodes
|
||||
mark_downstream(target_node.id)
|
||||
|
||||
# Process each inactive root and its downstream nodes
|
||||
for root_id in inactive_roots:
|
||||
mark_downstream(root_id)
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, object],
|
||||
node_factory: NodeFactory,
|
||||
root_node_id: str,
|
||||
skip_validation: bool = False,
|
||||
) -> Graph:
|
||||
"""
|
||||
Initialize a graph with an explicit execution entry point.
|
||||
|
||||
:param graph_config: graph config containing nodes and edges
|
||||
:param node_factory: factory for creating node instances from config data
|
||||
:param root_node_id: active root node id
|
||||
:return: graph instance
|
||||
"""
|
||||
# Parse configs
|
||||
edge_configs = graph_config.get("edges", [])
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
node_configs = cls._filter_canvas_only_nodes(node_configs)
|
||||
node_configs = _ListNodeConfigDict.validate_python(node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
if root_node_id not in node_configs_map:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
|
||||
# Build edges
|
||||
edges, in_edges, out_edges = cls._build_edges(edge_configs)
|
||||
|
||||
# Create node instances
|
||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||
|
||||
# Promote fail-branch nodes to branch execution type at graph level
|
||||
cls._promote_fail_branch_nodes(nodes)
|
||||
|
||||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
# Mark inactive root branches as skipped
|
||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
graph = cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=in_edges,
|
||||
out_edges=out_edges,
|
||||
root_node=root_node,
|
||||
)
|
||||
|
||||
if not skip_validation:
|
||||
# Validate the graph structure using built-in validators
|
||||
get_graph_validator().validate(graph)
|
||||
|
||||
return graph
|
||||
|
||||
@property
|
||||
def node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get list of node IDs (compatibility property for existing code)
|
||||
|
||||
:return: list of node IDs
|
||||
"""
|
||||
return list(self.nodes.keys())
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
|
||||
"""
|
||||
Get all outgoing edges from a node (V2 method)
|
||||
|
||||
:param node_id: node id
|
||||
:return: list of outgoing edges
|
||||
"""
|
||||
edge_ids = self.out_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
||||
def get_incoming_edges(self, node_id: str) -> list[Edge]:
|
||||
"""
|
||||
Get all incoming edges to a node (V2 method)
|
||||
|
||||
:param node_id: node id
|
||||
:return: list of incoming edges
|
||||
"""
|
||||
edge_ids = self.in_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
||||
|
||||
@final
|
||||
class GraphBuilder:
|
||||
"""Fluent helper for constructing simple graphs, primarily for tests."""
|
||||
|
||||
def __init__(self, *, graph_cls: type[Graph]):
|
||||
self._graph_cls = graph_cls
|
||||
self._nodes: list[Node] = []
|
||||
self._nodes_by_id: dict[str, Node] = {}
|
||||
self._edges: list[Edge] = []
|
||||
self._edge_counter = 0
|
||||
|
||||
def add_root(self, node: Node) -> GraphBuilder:
|
||||
"""Register the root node. Must be called exactly once."""
|
||||
|
||||
if self._nodes:
|
||||
raise ValueError("Root node has already been added")
|
||||
self._register_node(node)
|
||||
self._nodes.append(node)
|
||||
return self
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
node: Node,
|
||||
*,
|
||||
from_node_id: str | None = None,
|
||||
source_handle: str = "source",
|
||||
) -> GraphBuilder:
|
||||
"""Append a node and connect it from the specified predecessor."""
|
||||
|
||||
if not self._nodes:
|
||||
raise ValueError("Root node must be added before adding other nodes")
|
||||
|
||||
predecessor_id = from_node_id or self._nodes[-1].id
|
||||
if predecessor_id not in self._nodes_by_id:
|
||||
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
|
||||
|
||||
predecessor = self._nodes_by_id[predecessor_id]
|
||||
self._register_node(node)
|
||||
self._nodes.append(node)
|
||||
|
||||
edge_id = f"edge_{self._edge_counter}"
|
||||
self._edge_counter += 1
|
||||
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
|
||||
self._edges.append(edge)
|
||||
|
||||
return self
|
||||
|
||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
|
||||
"""Connect two existing nodes without adding a new node."""
|
||||
|
||||
if tail not in self._nodes_by_id:
|
||||
raise ValueError(f"Tail node '{tail}' not found")
|
||||
if head not in self._nodes_by_id:
|
||||
raise ValueError(f"Head node '{head}' not found")
|
||||
|
||||
edge_id = f"edge_{self._edge_counter}"
|
||||
self._edge_counter += 1
|
||||
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
|
||||
self._edges.append(edge)
|
||||
|
||||
return self
|
||||
|
||||
def build(self) -> Graph:
|
||||
"""Materialize the graph instance from the accumulated nodes and edges."""
|
||||
|
||||
if not self._nodes:
|
||||
raise ValueError("Cannot build an empty graph")
|
||||
|
||||
nodes = {node.id: node for node in self._nodes}
|
||||
edges = {edge.id: edge for edge in self._edges}
|
||||
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for edge in self._edges:
|
||||
out_edges[edge.tail].append(edge.id)
|
||||
in_edges[edge.head].append(edge.id)
|
||||
|
||||
return self._graph_cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=dict(in_edges),
|
||||
out_edges=dict(out_edges),
|
||||
root_node=self._nodes[0],
|
||||
)
|
||||
|
||||
def _register_node(self, node: Node) -> None:
|
||||
if not node.id:
|
||||
raise ValueError("Node must have a non-empty id")
|
||||
if node.id in self._nodes_by_id:
|
||||
raise ValueError(f"Duplicate node id detected: {node.id}")
|
||||
self._nodes_by_id[node.id] = node
|
||||
20
api/graphon/graph/graph_template.py
Normal file
20
api/graphon/graph/graph_template.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GraphTemplate(BaseModel):
|
||||
"""
|
||||
Graph Template for container nodes and subgraph expansion
|
||||
|
||||
According to GraphEngine V2 spec, GraphTemplate contains:
|
||||
- nodes: mapping of node definitions
|
||||
- edges: mapping of edge definitions
|
||||
- root_ids: list of root node IDs
|
||||
- output_selectors: list of output selectors for the template
|
||||
"""
|
||||
|
||||
nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping")
|
||||
edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping")
|
||||
root_ids: list[str] = Field(default_factory=list, description="root node IDs")
|
||||
output_selectors: list[str] = Field(default_factory=list, description="output selectors")
|
||||
125
api/graphon/graph/validation.py
Normal file
125
api/graphon/graph/validation.py
Normal file
@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidationIssue:
|
||||
"""Immutable value object describing a single validation issue."""
|
||||
|
||||
code: str
|
||||
message: str
|
||||
node_id: str | None = None
|
||||
|
||||
|
||||
class GraphValidationError(ValueError):
|
||||
"""Raised when graph validation fails."""
|
||||
|
||||
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
|
||||
if not issues:
|
||||
raise ValueError("GraphValidationError requires at least one issue.")
|
||||
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
|
||||
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class GraphValidationRule(Protocol):
|
||||
"""Protocol that individual validation rules must satisfy."""
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
"""Validate the provided graph and return any discovered issues."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _EdgeEndpointValidator:
|
||||
"""Ensures all edges reference existing nodes."""
|
||||
|
||||
missing_node_code: str = "MISSING_NODE"
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for edge in graph.edges.values():
|
||||
if edge.tail not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
|
||||
node_id=edge.tail,
|
||||
)
|
||||
)
|
||||
if edge.head not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
|
||||
node_id=edge.head,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _RootNodeValidator:
|
||||
"""Validates root node invariants."""
|
||||
|
||||
invalid_root_code: str = "INVALID_ROOT"
|
||||
container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START)
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
root_node = graph.root_node
|
||||
issues: list[GraphValidationIssue] = []
|
||||
if root_node.id not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' is missing from the node registry.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
node_type = root_node.node_type
|
||||
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidator:
|
||||
"""Coordinates execution of graph validation rules."""
|
||||
|
||||
rules: tuple[GraphValidationRule, ...]
|
||||
|
||||
def validate(self, graph: Graph) -> None:
|
||||
"""Validate the graph against all configured rules."""
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for rule in self.rules:
|
||||
issues.extend(rule.validate(graph))
|
||||
|
||||
if issues:
|
||||
raise GraphValidationError(issues)
|
||||
|
||||
|
||||
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||
_EdgeEndpointValidator(),
|
||||
_RootNodeValidator(),
|
||||
)
|
||||
|
||||
|
||||
def get_graph_validator() -> GraphValidator:
|
||||
"""Construct the validator composed of default rules."""
|
||||
return GraphValidator(_DEFAULT_RULES)
|
||||
4
api/graphon/graph_engine/__init__.py
Normal file
4
api/graphon/graph_engine/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .config import GraphEngineConfig
|
||||
from .graph_engine import GraphEngine
|
||||
|
||||
__all__ = ["GraphEngine", "GraphEngineConfig"]
|
||||
15
api/graphon/graph_engine/_engine_utils.py
Normal file
15
api/graphon/graph_engine/_engine_utils.py
Normal file
@ -0,0 +1,15 @@
|
||||
import time
|
||||
|
||||
|
||||
def get_timestamp() -> float:
|
||||
"""Retrieve a timestamp as a float point numer representing the number of seconds
|
||||
since the Unix epoch.
|
||||
|
||||
This function is primarily used to measure the execution time of the workflow engine.
|
||||
Since workflow execution may be paused and resumed on a different machine,
|
||||
`time.perf_counter` cannot be used as it is inconsistent across machines.
|
||||
|
||||
To address this, the function uses the wall clock as the time source.
|
||||
However, it assumes that the clocks of all servers are properly synchronized.
|
||||
"""
|
||||
return round(time.time())
|
||||
33
api/graphon/graph_engine/command_channels/README.md
Normal file
33
api/graphon/graph_engine/command_channels/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# Command Channels
|
||||
|
||||
Channel implementations for external workflow control.
|
||||
|
||||
## Components
|
||||
|
||||
### InMemoryChannel
|
||||
|
||||
Thread-safe in-memory queue for single-process deployments.
|
||||
|
||||
- `fetch_commands()` - Get pending commands
|
||||
- `send_command()` - Add command to queue
|
||||
|
||||
### RedisChannel
|
||||
|
||||
Redis-based queue for distributed deployments.
|
||||
|
||||
- `fetch_commands()` - Get commands with JSON deserialization
|
||||
- `send_command()` - Store commands with TTL
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# Local execution
|
||||
channel = InMemoryChannel()
|
||||
channel.send_command(AbortCommand(graph_id="workflow-123"))
|
||||
|
||||
# Distributed execution
|
||||
redis_channel = RedisChannel(
|
||||
redis_client=redis_client,
|
||||
channel_key="workflow:123:commands"
|
||||
)
|
||||
```
|
||||
6
api/graphon/graph_engine/command_channels/__init__.py
Normal file
6
api/graphon/graph_engine/command_channels/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Command channel implementations for GraphEngine."""
|
||||
|
||||
from .in_memory_channel import InMemoryChannel
|
||||
from .redis_channel import RedisChannel
|
||||
|
||||
__all__ = ["InMemoryChannel", "RedisChannel"]
|
||||
@ -0,0 +1,53 @@
|
||||
"""
|
||||
In-memory implementation of CommandChannel for local/testing scenarios.
|
||||
|
||||
This implementation uses a thread-safe queue for command communication
|
||||
within a single process. Each instance handles commands for one workflow execution.
|
||||
"""
|
||||
|
||||
from queue import Queue
|
||||
from typing import final
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryChannel:
|
||||
"""
|
||||
In-memory command channel implementation using a thread-safe queue.
|
||||
|
||||
Each instance is dedicated to a single GraphEngine/workflow execution.
|
||||
Suitable for local development, testing, and single-instance deployments.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in-memory channel with a single queue."""
|
||||
self._queue: Queue[GraphEngineCommand] = Queue()
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from the queue.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the queue)
|
||||
"""
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Drain all available commands from the queue
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
command = self._queue.get_nowait()
|
||||
commands.append(command)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to this channel's queue.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
self._queue.put(command)
|
||||
153
api/graphon/graph_engine/command_channels/redis_channel.py
Normal file
153
api/graphon/graph_engine/command_channels/redis_channel.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""
|
||||
Redis-based implementation of CommandChannel for distributed scenarios.
|
||||
|
||||
This implementation uses Redis lists for command queuing, supporting
|
||||
multi-instance deployments and cross-server communication.
|
||||
Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Protocol, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
|
||||
|
||||
class RedisPipelineProtocol(Protocol):
|
||||
"""Minimal Redis pipeline contract used by the command channel."""
|
||||
|
||||
def lrange(self, name: str, start: int, end: int) -> Any: ...
|
||||
def delete(self, *names: str) -> Any: ...
|
||||
def execute(self) -> list[Any]: ...
|
||||
def rpush(self, name: str, *values: str) -> Any: ...
|
||||
def expire(self, name: str, time: int) -> Any: ...
|
||||
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
|
||||
def get(self, name: str) -> Any: ...
|
||||
|
||||
|
||||
class RedisClientProtocol(Protocol):
|
||||
"""Redis client contract required by the command channel."""
|
||||
|
||||
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...
|
||||
|
||||
|
||||
@final
|
||||
class RedisChannel:
|
||||
"""
|
||||
Redis-based command channel implementation for distributed systems.
|
||||
|
||||
Each instance uses a unique Redis key for its command queue.
|
||||
Commands are JSON-serialized for transport.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: RedisClientProtocol,
|
||||
channel_key: str,
|
||||
command_ttl: int = 3600,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Redis channel.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
channel_key: Unique key for this channel's command queue
|
||||
command_ttl: TTL for command keys in seconds (default: 3600)
|
||||
"""
|
||||
self._redis = redis_client
|
||||
self._key = channel_key
|
||||
self._command_ttl = command_ttl
|
||||
self._pending_key = f"{channel_key}:pending"
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from Redis.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the Redis list)
|
||||
"""
|
||||
if not self._has_pending_commands():
|
||||
return []
|
||||
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
with self._redis.pipeline() as pipe:
|
||||
# Get all commands and clear the list atomically
|
||||
pipe.lrange(self._key, 0, -1)
|
||||
pipe.delete(self._key)
|
||||
results = pipe.execute()
|
||||
|
||||
# Parse commands from JSON
|
||||
if results[0]:
|
||||
for command_json in results[0]:
|
||||
try:
|
||||
command_data = json.loads(command_json)
|
||||
command = self._deserialize_command(command_data)
|
||||
if command:
|
||||
commands.append(command)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Skip invalid commands
|
||||
continue
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to Redis.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
command_json = json.dumps(command.model_dump())
|
||||
|
||||
# Push to list and set expiry
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.rpush(self._key, command_json)
|
||||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.set(self._pending_key, "1", ex=self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
||||
"""
|
||||
Deserialize a command from dictionary data.
|
||||
|
||||
Args:
|
||||
data: Command data dictionary
|
||||
|
||||
Returns:
|
||||
Deserialized command or None if invalid
|
||||
"""
|
||||
command_type_value = data.get("command_type")
|
||||
if not isinstance(command_type_value, str):
|
||||
return None
|
||||
|
||||
try:
|
||||
command_type = CommandType(command_type_value)
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand.model_validate(data)
|
||||
if command_type == CommandType.PAUSE:
|
||||
return PauseCommand.model_validate(data)
|
||||
if command_type == CommandType.UPDATE_VARIABLES:
|
||||
return UpdateVariablesCommand.model_validate(data)
|
||||
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def _has_pending_commands(self) -> bool:
|
||||
"""
|
||||
Check and consume the pending marker to avoid unnecessary list reads.
|
||||
|
||||
Returns:
|
||||
True if commands should be fetched from Redis.
|
||||
"""
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.get(self._pending_key)
|
||||
pipe.delete(self._pending_key)
|
||||
pending_value, _ = pipe.execute()
|
||||
|
||||
return pending_value is not None
|
||||
16
api/graphon/graph_engine/command_processing/__init__.py
Normal file
16
api/graphon/graph_engine/command_processing/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Command processing subsystem for graph engine.
|
||||
|
||||
This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
"PauseCommandHandler",
|
||||
"UpdateVariablesCommandHandler",
|
||||
]
|
||||
@ -0,0 +1,56 @@
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from graphon.entities.pause_reason import SchedulingPause
|
||||
from graphon.runtime import VariablePool
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, AbortCommand)
|
||||
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.abort(command.reason or "User requested abort")
|
||||
|
||||
|
||||
@final
|
||||
class PauseCommandHandler(CommandHandler):
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, PauseCommand)
|
||||
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
|
||||
# Convert string reason to PauseReason if needed
|
||||
reason = command.reason
|
||||
pause_reason = SchedulingPause(message=reason)
|
||||
execution.pause(pause_reason)
|
||||
|
||||
|
||||
@final
|
||||
class UpdateVariablesCommandHandler(CommandHandler):
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, UpdateVariablesCommand)
|
||||
for update in command.updates:
|
||||
try:
|
||||
variable = update.value
|
||||
self._variable_pool.add(variable.selector, variable)
|
||||
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Skipping invalid variable selector %s for workflow %s: %s",
|
||||
getattr(update.value, "selector", None),
|
||||
execution.workflow_id,
|
||||
exc,
|
||||
)
|
||||
@ -0,0 +1,79 @@
|
||||
"""
|
||||
Main command processor for handling external commands.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Protocol, final
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
from ..protocols.command_channel import CommandChannel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommandHandler(Protocol):
|
||||
"""Protocol for command handlers."""
|
||||
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
|
||||
|
||||
|
||||
@final
|
||||
class CommandProcessor:
|
||||
"""
|
||||
Processes external commands sent to the engine.
|
||||
|
||||
This polls the command channel and dispatches commands to
|
||||
appropriate handlers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command_channel: CommandChannel,
|
||||
graph_execution: GraphExecution,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the command processor.
|
||||
|
||||
Args:
|
||||
command_channel: Channel for receiving commands
|
||||
graph_execution: Graph execution aggregate
|
||||
"""
|
||||
self._command_channel = command_channel
|
||||
self._graph_execution = graph_execution
|
||||
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
|
||||
|
||||
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
|
||||
"""
|
||||
Register a handler for a command type.
|
||||
|
||||
Args:
|
||||
command_type: Type of command to handle
|
||||
handler: Handler for the command
|
||||
"""
|
||||
self._handlers[command_type] = handler
|
||||
|
||||
def process_commands(self) -> None:
|
||||
"""Check for and process any pending commands."""
|
||||
try:
|
||||
commands = self._command_channel.fetch_commands()
|
||||
for command in commands:
|
||||
self._handle_command(command)
|
||||
except Exception as e:
|
||||
logger.warning("Error processing commands: %s", e)
|
||||
|
||||
def _handle_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Handle a single command.
|
||||
|
||||
Args:
|
||||
command: The command to handle
|
||||
"""
|
||||
handler = self._handlers.get(type(command))
|
||||
if handler:
|
||||
try:
|
||||
handler.handle(command, self._graph_execution)
|
||||
except Exception:
|
||||
logger.exception("Error handling command %s", command.__class__.__name__)
|
||||
else:
|
||||
logger.warning("No handler registered for command: %s", command.__class__.__name__)
|
||||
16
api/graphon/graph_engine/config.py
Normal file
16
api/graphon/graph_engine/config.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
GraphEngine configuration models.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class GraphEngineConfig(BaseModel):
|
||||
"""Configuration for GraphEngine worker pool scaling."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
min_workers: int = 1
|
||||
max_workers: int = 5
|
||||
scale_up_threshold: int = 3
|
||||
scale_down_idle_time: float = 5.0
|
||||
14
api/graphon/graph_engine/domain/__init__.py
Normal file
14
api/graphon/graph_engine/domain/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Domain models for graph engine.
|
||||
|
||||
This package contains the core domain entities, value objects, and aggregates
|
||||
that represent the business concepts of workflow graph execution.
|
||||
"""
|
||||
|
||||
from .graph_execution import GraphExecution
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
__all__ = [
|
||||
"GraphExecution",
|
||||
"NodeExecution",
|
||||
]
|
||||
242
api/graphon/graph_engine/domain/graph_execution.py
Normal file
242
api/graphon/graph_engine/domain/graph_execution.py
Normal file
@ -0,0 +1,242 @@
|
||||
"""GraphExecution aggregate root managing the overall graph execution state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import import_module
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.entities.pause_reason import PauseReason
|
||||
from graphon.enums import NodeState
|
||||
from graphon.runtime.graph_runtime_state import GraphExecutionProtocol
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
|
||||
class GraphExecutionErrorState(BaseModel):
|
||||
"""Serializable representation of an execution error."""
|
||||
|
||||
module: str = Field(description="Module containing the exception class")
|
||||
qualname: str = Field(description="Qualified name of the exception class")
|
||||
message: str | None = Field(default=None, description="Exception message string")
|
||||
|
||||
|
||||
class NodeExecutionState(BaseModel):
|
||||
"""Serializable representation of a node execution entity."""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = Field(default=NodeState.UNKNOWN)
|
||||
retry_count: int = Field(default=0)
|
||||
execution_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
|
||||
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Pydantic model describing serialized GraphExecution state."""
|
||||
|
||||
type: Literal["GraphExecution"] = Field(default="GraphExecution")
|
||||
version: str = Field(default="1.0")
|
||||
workflow_id: str
|
||||
started: bool = Field(default=False)
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reasons: list[PauseReason] = Field(default_factory=list)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
"""Convert an exception into its serializable representation."""
|
||||
|
||||
if error is None:
|
||||
return None
|
||||
|
||||
return GraphExecutionErrorState(
|
||||
module=error.__class__.__module__,
|
||||
qualname=error.__class__.__qualname__,
|
||||
message=str(error),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
|
||||
"""Locate an exception class from its module and qualified name."""
|
||||
|
||||
module = import_module(module_name)
|
||||
attr: object = module
|
||||
for part in qualname.split("."):
|
||||
attr = getattr(attr, part)
|
||||
|
||||
if isinstance(attr, type) and issubclass(attr, Exception):
|
||||
return attr
|
||||
|
||||
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
|
||||
|
||||
|
||||
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
|
||||
"""Reconstruct an exception instance from serialized data."""
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
exception_class = _resolve_exception_class(state.module, state.qualname)
|
||||
if state.message is None:
|
||||
return exception_class()
|
||||
return exception_class(state.message)
|
||||
except Exception:
|
||||
# Fallback to RuntimeError when reconstruction fails
|
||||
if state.message is None:
|
||||
return RuntimeError(state.qualname)
|
||||
return RuntimeError(state.message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExecution:
|
||||
"""
|
||||
Aggregate root for graph execution.
|
||||
|
||||
This manages the overall execution state of a workflow graph,
|
||||
coordinating between multiple node executions.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
started: bool = False
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reasons: list[PauseReason] = field(default_factory=list)
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark the graph execution as started."""
|
||||
if self.started:
|
||||
raise RuntimeError("Graph execution already started")
|
||||
self.started = True
|
||||
|
||||
def complete(self) -> None:
|
||||
"""Mark the graph execution as completed."""
|
||||
if not self.started:
|
||||
raise RuntimeError("Cannot complete execution that hasn't started")
|
||||
if self.completed:
|
||||
raise RuntimeError("Graph execution already completed")
|
||||
self.completed = True
|
||||
|
||||
def abort(self, reason: str) -> None:
|
||||
"""Abort the graph execution."""
|
||||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def pause(self, reason: PauseReason) -> None:
|
||||
"""Pause the graph execution without marking it complete."""
|
||||
if self.completed:
|
||||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
if self.aborted:
|
||||
raise RuntimeError("Cannot pause execution that has been aborted")
|
||||
self.paused = True
|
||||
self.pause_reasons.append(reason)
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
self.error = error
|
||||
self.completed = True
|
||||
|
||||
def get_or_create_node_execution(self, node_id: str) -> NodeExecution:
|
||||
"""Get or create a node execution entity."""
|
||||
if node_id not in self.node_executions:
|
||||
self.node_executions[node_id] = NodeExecution(node_id=node_id)
|
||||
return self.node_executions[node_id]
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the execution is currently running."""
|
||||
return self.started and not self.completed and not self.aborted and not self.paused
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
"""Check if the execution is currently paused."""
|
||||
return self.paused
|
||||
|
||||
@property
|
||||
def has_error(self) -> bool:
|
||||
"""Check if the execution has encountered an error."""
|
||||
return self.error is not None
|
||||
|
||||
@property
|
||||
def error_message(self) -> str | None:
|
||||
"""Get the error message if an error exists."""
|
||||
if not self.error:
|
||||
return None
|
||||
return str(self.error)
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the aggregate state into a JSON string."""
|
||||
|
||||
node_states = [
|
||||
NodeExecutionState(
|
||||
node_id=node_id,
|
||||
state=node_execution.state,
|
||||
retry_count=node_execution.retry_count,
|
||||
execution_id=node_execution.execution_id,
|
||||
error=node_execution.error,
|
||||
)
|
||||
for node_id, node_execution in sorted(self.node_executions.items())
|
||||
]
|
||||
|
||||
state = GraphExecutionState(
|
||||
workflow_id=self.workflow_id,
|
||||
started=self.started,
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
paused=self.paused,
|
||||
pause_reasons=self.pause_reasons,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore aggregate state from a serialized JSON string."""
|
||||
|
||||
state = GraphExecutionState.model_validate_json(data)
|
||||
|
||||
if state.type != "GraphExecution":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
if self.workflow_id != state.workflow_id:
|
||||
raise ValueError("Serialized workflow_id does not match aggregate identity")
|
||||
|
||||
self.started = state.started
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.paused = state.paused
|
||||
self.pause_reasons = state.pause_reasons
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
state=item.state,
|
||||
retry_count=item.retry_count,
|
||||
execution_id=item.execution_id,
|
||||
error=item.error,
|
||||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
|
||||
|
||||
_: GraphExecutionProtocol = GraphExecution(workflow_id="")
|
||||
45
api/graphon/graph_engine/domain/node_execution.py
Normal file
45
api/graphon/graph_engine/domain/node_execution.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""
|
||||
NodeExecution entity representing a node's execution state.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from graphon.enums import NodeState
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeExecution:
|
||||
"""
|
||||
Entity representing the execution state of a single node.
|
||||
|
||||
This is a mutable entity that tracks the runtime state of a node
|
||||
during graph execution.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = NodeState.UNKNOWN
|
||||
retry_count: int = 0
|
||||
execution_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
def mark_started(self, execution_id: str) -> None:
|
||||
"""Mark the node as started with an execution ID."""
|
||||
self.state = NodeState.TAKEN
|
||||
self.execution_id = execution_id
|
||||
|
||||
def mark_taken(self) -> None:
|
||||
"""Mark the node as successfully completed."""
|
||||
self.state = NodeState.TAKEN
|
||||
self.error = None
|
||||
|
||||
def mark_failed(self, error: str) -> None:
|
||||
"""Mark the node as failed with an error."""
|
||||
self.error = error
|
||||
|
||||
def mark_skipped(self) -> None:
|
||||
"""Mark the node as skipped."""
|
||||
self.state = NodeState.SKIPPED
|
||||
|
||||
def increment_retry(self) -> None:
|
||||
"""Increment the retry count for this node."""
|
||||
self.retry_count += 1
|
||||
0
api/graphon/graph_engine/entities/__init__.py
Normal file
0
api/graphon/graph_engine/entities/__init__.py
Normal file
56
api/graphon/graph_engine/entities/commands.py
Normal file
56
api/graphon/graph_engine/entities/commands.py
Normal file
@ -0,0 +1,56 @@
|
||||
"""
|
||||
GraphEngine command entities for external control.
|
||||
|
||||
This module defines command types that can be sent to a running GraphEngine
|
||||
instance to control its execution flow.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.variables.variables import Variable
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
"""Types of commands that can be sent to GraphEngine."""
|
||||
|
||||
ABORT = auto()
|
||||
PAUSE = auto()
|
||||
UPDATE_VARIABLES = auto()
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
"""Base class for all GraphEngine commands."""
|
||||
|
||||
command_type: CommandType = Field(..., description="Type of command")
|
||||
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")
|
||||
|
||||
|
||||
class AbortCommand(GraphEngineCommand):
|
||||
"""Command to abort a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for abort")
|
||||
|
||||
|
||||
class PauseCommand(GraphEngineCommand):
|
||||
"""Command to pause a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str = Field(default="unknown reason", description="reason for pause")
|
||||
|
||||
|
||||
class VariableUpdate(BaseModel):
|
||||
"""Represents a single variable update instruction."""
|
||||
|
||||
value: Variable = Field(description="New variable value")
|
||||
|
||||
|
||||
class UpdateVariablesCommand(GraphEngineCommand):
|
||||
"""Command to update a group of variables in the variable pool."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
|
||||
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")
|
||||
213
api/graphon/graph_engine/error_handler.py
Normal file
213
api/graphon/graph_engine/error_handler.py
Normal file
@ -0,0 +1,213 @@
|
||||
"""
|
||||
Main error handler that coordinates error strategies.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from graphon.enums import (
|
||||
ErrorStrategy as ErrorStrategyEnum,
|
||||
)
|
||||
from graphon.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetryEvent,
|
||||
)
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .domain import GraphExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class ErrorHandler:
|
||||
"""
|
||||
Coordinates error handling strategies for node failures.
|
||||
|
||||
This acts as a facade for the various error strategies,
|
||||
selecting and applying the appropriate strategy based on
|
||||
node configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
|
||||
"""
|
||||
Initialize the error handler.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
graph_execution: The graph execution state
|
||||
"""
|
||||
self._graph = graph
|
||||
self._graph_execution = graph_execution
|
||||
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
Selects and applies the appropriate error strategy based on
|
||||
the node's configuration.
|
||||
|
||||
Args:
|
||||
event: The node failure event
|
||||
|
||||
Returns:
|
||||
Optional new event to process, or None to abort
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
# Get retry count from NodeExecution
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
retry_count = node_execution.retry_count
|
||||
|
||||
# First check if retry is configured and not exhausted
|
||||
if node.retry and retry_count < node.retry_config.max_retries:
|
||||
result = self._handle_retry(event, retry_count)
|
||||
if result:
|
||||
# Retry count will be incremented when NodeRunRetryEvent is handled
|
||||
return result
|
||||
|
||||
# Apply configured error strategy
|
||||
strategy = node.error_strategy
|
||||
|
||||
match strategy:
|
||||
case None:
|
||||
return self._handle_abort(event)
|
||||
case ErrorStrategyEnum.FAIL_BRANCH:
|
||||
return self._handle_fail_branch(event)
|
||||
case ErrorStrategyEnum.DEFAULT_VALUE:
|
||||
return self._handle_default_value(event)
|
||||
|
||||
def _handle_abort(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by aborting execution.
|
||||
|
||||
This is the default strategy when no other strategy is specified.
|
||||
It stops the entire graph execution when a node fails.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
None - signals abortion
|
||||
"""
|
||||
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
|
||||
# Return None to signal that execution should stop
|
||||
|
||||
def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int):
|
||||
"""
|
||||
Handle error by retrying the node.
|
||||
|
||||
This strategy re-attempts node execution up to a configured
|
||||
maximum number of retries with configurable intervals.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
retry_count: Current retry attempt count
|
||||
|
||||
Returns:
|
||||
NodeRunRetryEvent if retry should occur, None otherwise
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
# Check if we've exceeded max retries
|
||||
if not node.retry or retry_count >= node.retry_config.max_retries:
|
||||
return None
|
||||
|
||||
# Wait for retry interval
|
||||
time.sleep(node.retry_config.retry_interval_seconds)
|
||||
|
||||
# Create retry event
|
||||
return NodeRunRetryEvent(
|
||||
id=event.id,
|
||||
node_title=node.title,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_run_result=event.node_run_result,
|
||||
start_at=event.start_at,
|
||||
error=event.error,
|
||||
retry_index=retry_count + 1,
|
||||
)
|
||||
|
||||
def _handle_fail_branch(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by taking the fail branch.
|
||||
|
||||
This strategy converts failures to exceptions and routes execution
|
||||
through a designated fail-branch edge.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
NodeRunExceptionEvent to continue via fail branch
|
||||
"""
|
||||
outputs = {
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
}
|
||||
|
||||
return NodeRunExceptionEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
edge_source_handle="fail-branch",
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH,
|
||||
},
|
||||
),
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_default_value(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by using default values.
|
||||
|
||||
This strategy allows nodes to fail gracefully by providing
|
||||
predefined default output values.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
NodeRunExceptionEvent with default values
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
outputs = {
|
||||
**node.default_value_dict,
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
}
|
||||
|
||||
return NodeRunExceptionEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE,
|
||||
},
|
||||
),
|
||||
error=event.error,
|
||||
)
|
||||
14
api/graphon/graph_engine/event_management/__init__.py
Normal file
14
api/graphon/graph_engine/event_management/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Event management subsystem for graph engine.
|
||||
|
||||
This package handles event routing, collection, and emission for
|
||||
workflow graph execution events.
|
||||
"""
|
||||
|
||||
from .event_handlers import EventHandler
|
||||
from .event_manager import EventManager
|
||||
|
||||
__all__ = [
|
||||
"EventHandler",
|
||||
"EventManager",
|
||||
]
|
||||
367
api/graphon/graph_engine/event_management/event_handlers.py
Normal file
367
api/graphon/graph_engine/event_management/event_handlers.py
Normal file
@ -0,0 +1,367 @@
|
||||
"""
|
||||
Event handler implementations for different event types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunVariableUpdatedEvent,
|
||||
)
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..error_handler import ErrorHandler
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..graph_traversal import EdgeProcessor
|
||||
from .event_manager import EventManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class EventHandler:
|
||||
"""
|
||||
Registry of event handlers for different event types.
|
||||
|
||||
This centralizes the business logic for handling specific events,
|
||||
keeping it separate from the routing and collection infrastructure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_execution: GraphExecution,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
event_collector: "EventManager",
|
||||
edge_processor: "EdgeProcessor",
|
||||
state_manager: "GraphStateManager",
|
||||
error_handler: "ErrorHandler",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the event handler registry.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
graph_runtime_state: Runtime state with variable pool
|
||||
graph_execution: Graph execution aggregate
|
||||
response_coordinator: Response stream coordinator
|
||||
event_collector: Event manager for collecting events
|
||||
edge_processor: Edge processor for edge traversal
|
||||
state_manager: Unified state manager
|
||||
error_handler: Error handler
|
||||
"""
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._event_collector = event_collector
|
||||
self._edge_processor = edge_processor
|
||||
self._state_manager = state_manager
|
||||
self._error_handler = error_handler
|
||||
|
||||
def dispatch(self, event: GraphNodeEventBase) -> None:
|
||||
"""
|
||||
Handle any node event by dispatching to the appropriate handler.
|
||||
|
||||
Args:
|
||||
event: The event to handle
|
||||
"""
|
||||
if isinstance(event, NodeRunVariableUpdatedEvent):
|
||||
self._dispatch(event)
|
||||
return
|
||||
|
||||
# Events in loops or iterations are always collected
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
return self._dispatch(event)
|
||||
|
||||
@singledispatchmethod
|
||||
def _dispatch(self, event: GraphNodeEventBase) -> None:
|
||||
self._event_collector.collect(event)
|
||||
logger.warning("Unhandled event type: %s", type(event).__name__)
|
||||
|
||||
@_dispatch.register(NodeRunIterationStartedEvent)
|
||||
@_dispatch.register(NodeRunIterationNextEvent)
|
||||
@_dispatch.register(NodeRunIterationSucceededEvent)
|
||||
@_dispatch.register(NodeRunIterationFailedEvent)
|
||||
@_dispatch.register(NodeRunLoopStartedEvent)
|
||||
@_dispatch.register(NodeRunLoopNextEvent)
|
||||
@_dispatch.register(NodeRunLoopSucceededEvent)
|
||||
@_dispatch.register(NodeRunLoopFailedEvent)
|
||||
@_dispatch.register(NodeRunAgentLogEvent)
|
||||
@_dispatch.register(NodeRunRetrieverResourceEvent)
|
||||
def _(self, event: GraphNodeEventBase) -> None:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Handle node started event.
|
||||
|
||||
Args:
|
||||
event: The node started event
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
self._graph_runtime_state.increment_node_run_steps()
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event only for the first attempt; retries remain silent
|
||||
if is_initial_attempt:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Handle stream chunk event with full processing.
|
||||
|
||||
Args:
|
||||
event: The stream chunk event
|
||||
"""
|
||||
# Process with response coordinator
|
||||
streaming_events = list(self._response_coordinator.intercept_event(event))
|
||||
|
||||
# Collect all events
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunVariableUpdatedEvent) -> None:
|
||||
"""
|
||||
Apply a node-requested variable mutation before downstream observers run.
|
||||
|
||||
The event is collected like other node events so parent/container engines can
|
||||
forward the updated payload to outer layers, including persistence listeners.
|
||||
"""
|
||||
self._graph_runtime_state.variable_pool.add(event.variable.selector, event.variable)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Handle node success by coordinating subsystems.
|
||||
|
||||
This method coordinates between different subsystems to process
|
||||
node completion, handle edges, and trigger downstream execution.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
# Process edges and get ready nodes
|
||||
node = self._graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
|
||||
# Collect streaming events from edge processing
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
if self._graph_execution.is_paused:
|
||||
for node_id in ready_nodes:
|
||||
self._graph_runtime_state.register_deferred_node(node_id)
|
||||
else:
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
"""Handle pause requests emitted by nodes."""
|
||||
|
||||
pause_reason = event.reason
|
||||
self._graph_execution.pause(pause_reason)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
if event.node_id in self._graph.nodes:
|
||||
self._graph.nodes[event.node_id].state = NodeState.UNKNOWN
|
||||
self._graph_runtime_state.register_paused_node(event.node_id)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Handle node failure using error handler.
|
||||
|
||||
Args:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.dispatch(result)
|
||||
else:
|
||||
# Abort execution
|
||||
self._graph_execution.fail(RuntimeError(event.error))
|
||||
self._event_collector.collect(event)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
Handle node exception event (fail-branch strategy).
|
||||
|
||||
Args:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch/default-value, treat as completion
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
|
||||
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update response outputs if applicable
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Collect the exception event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunRetryEvent) -> None:
|
||||
"""
|
||||
Handle node retry event.
|
||||
|
||||
Args:
|
||||
event: The node retry event
|
||||
"""
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
# Finish the previous attempt before re-queuing the node
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Emit retry event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
# Re-queue node for execution
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _accumulate_node_usage(self, usage: LLMUsage) -> None:
|
||||
"""Accumulate token usage into the shared runtime state."""
|
||||
if usage.total_tokens <= 0:
|
||||
return
|
||||
|
||||
self._graph_runtime_state.add_tokens(usage.total_tokens)
|
||||
|
||||
current_usage = self._graph_runtime_state.llm_usage
|
||||
if current_usage.total_tokens == 0:
|
||||
self._graph_runtime_state.llm_usage = usage
|
||||
else:
|
||||
self._graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
self._graph_runtime_state.set_output("answer", f"{existing}{value}")
|
||||
else:
|
||||
self._graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
186
api/graphon/graph_engine/event_management/event_manager.py
Normal file
186
api/graphon/graph_engine/event_management/event_manager.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""
|
||||
Unified event manager for collecting and emitting events.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import final
|
||||
|
||||
from graphon.graph_events import GraphEngineEvent
|
||||
|
||||
from ..layers.base import GraphEngineLayer
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock implementation that allows multiple concurrent readers
|
||||
but only one writer at a time.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._read_ready = threading.Condition(threading.RLock())
|
||||
self._readers = 0
|
||||
|
||||
def acquire_read(self) -> None:
|
||||
"""Acquire a read lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers += 1
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
self._read_ready.notify_all()
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def acquire_write(self) -> None:
|
||||
"""Acquire a write lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
while self._readers > 0:
|
||||
_ = self._read_ready.wait()
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release a write lock."""
|
||||
self._read_ready.release()
|
||||
|
||||
@contextmanager
|
||||
def read_lock(self):
|
||||
"""Return a context manager for read locking."""
|
||||
self.acquire_read()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_read()
|
||||
|
||||
@contextmanager
|
||||
def write_lock(self):
|
||||
"""Return a context manager for write locking."""
|
||||
self.acquire_write()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_write()
|
||||
|
||||
|
||||
@final
|
||||
class EventManager:
|
||||
"""
|
||||
Unified event manager that collects, buffers, and emits events.
|
||||
|
||||
This class combines event collection with event emission, providing
|
||||
thread-safe event management with support for notifying layers and
|
||||
streaming events to external consumers.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the event manager."""
|
||||
self._events: list[GraphEngineEvent] = []
|
||||
self._lock = ReadWriteLock()
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
self._execution_complete = threading.Event()
|
||||
|
||||
def set_layers(self, layers: list[GraphEngineLayer]) -> None:
|
||||
"""
|
||||
Set the layers to notify on event collection.
|
||||
|
||||
Args:
|
||||
layers: List of layers to notify
|
||||
"""
|
||||
self._layers = layers
|
||||
|
||||
def notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""Notify registered layers about an event without buffering it."""
|
||||
self._notify_layers(event)
|
||||
|
||||
def collect(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Thread-safe method to collect an event.
|
||||
|
||||
Args:
|
||||
event: The event to collect
|
||||
"""
|
||||
with self._lock.write_lock():
|
||||
self._events.append(event)
|
||||
|
||||
# NOTE: `_notify_layers` is intentionally called outside the critical section
|
||||
# to minimize lock contention and avoid blocking other readers or writers.
|
||||
#
|
||||
# The public `notify_layers` method also does not use a write lock,
|
||||
# so protecting `_notify_layers` with a lock here is unnecessary.
|
||||
self._notify_layers(event)
|
||||
|
||||
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
Get new events starting from a specific index.
|
||||
|
||||
Args:
|
||||
start_index: The index to start from
|
||||
|
||||
Returns:
|
||||
List of new events
|
||||
"""
|
||||
with self._lock.read_lock():
|
||||
return list(self._events[start_index:])
|
||||
|
||||
def _event_count(self) -> int:
|
||||
"""
|
||||
Get the current count of collected events.
|
||||
|
||||
Returns:
|
||||
Number of collected events
|
||||
"""
|
||||
with self._lock.read_lock():
|
||||
return len(self._events)
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete to stop the event emission generator."""
|
||||
self._execution_complete.set()
|
||||
|
||||
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generator that yields events as they're collected.
|
||||
|
||||
Yields:
|
||||
GraphEngineEvent instances as they're processed
|
||||
"""
|
||||
yielded_count = 0
|
||||
|
||||
while not self._execution_complete.is_set() or yielded_count < self._event_count():
|
||||
# Get new events since last yield
|
||||
new_events = self._get_new_events(yielded_count)
|
||||
|
||||
# Yield any new events
|
||||
for event in new_events:
|
||||
yield event
|
||||
yielded_count += 1
|
||||
|
||||
# Small sleep to avoid busy waiting
|
||||
if not self._execution_complete.is_set() and not new_events:
|
||||
time.sleep(0.001)
|
||||
|
||||
def _notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Notify all layers of an event.
|
||||
|
||||
Layer exceptions are caught and logged to prevent disrupting collection.
|
||||
|
||||
Args:
|
||||
event: The event to send to layers
|
||||
"""
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_event(event)
|
||||
except Exception:
|
||||
_logger.exception("Error in layer on_event, layer_type=%s", type(layer))
|
||||
377
api/graphon/graph_engine/graph_engine.py
Normal file
377
api/graphon/graph_engine/graph_engine.py
Normal file
@ -0,0 +1,377 @@
|
||||
"""
|
||||
QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution.
|
||||
|
||||
This engine uses a modular architecture with separated packages following
|
||||
Domain-Driven Design principles for improved maintainability and testability.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from graphon.entities.workflow_start_reason import WorkflowStartReason
|
||||
from graphon.enums import NodeExecutionType
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
|
||||
from graphon.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from graphon.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .command_processing import (
|
||||
AbortCommandHandler,
|
||||
CommandProcessor,
|
||||
PauseCommandHandler,
|
||||
UpdateVariablesCommandHandler,
|
||||
)
|
||||
from .config import GraphEngineConfig
|
||||
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
|
||||
from .error_handler import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_state_manager import GraphStateManager
|
||||
from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import GraphEngineLayer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.graph_engine.domain.graph_execution import GraphExecution
|
||||
from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DEFAULT_CONFIG = GraphEngineConfig()
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngine:
|
||||
"""
|
||||
Queue-based graph execution engine.
|
||||
|
||||
Uses a modular architecture that delegates responsibilities to specialized
|
||||
subsystems, following Domain-Driven Design and SOLID principles.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_id: str,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
command_channel: CommandChannel,
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
child_engine_builder: ChildGraphEngineBuilderProtocol | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
self._config = config
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
self._child_engine_builder = child_engine_builder
|
||||
if child_engine_builder is not None:
|
||||
self._graph_runtime_state.bind_child_engine_builder(child_engine_builder)
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
||||
self._graph_execution.workflow_id = workflow_id
|
||||
|
||||
# === Execution Queues ===
|
||||
self._ready_queue = self._graph_runtime_state.ready_queue
|
||||
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
# === State Management ===
|
||||
# Unified state manager handles all node state transitions and queue operations
|
||||
self._state_manager = GraphStateManager(self._graph, self._ready_queue)
|
||||
|
||||
# === Response Coordination ===
|
||||
# Coordinates response streaming from response nodes
|
||||
self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator)
|
||||
|
||||
# === Event Management ===
|
||||
# Event manager handles both collection and emission of events
|
||||
self._event_manager = EventManager()
|
||||
|
||||
# === Error Handling ===
|
||||
# Centralized error handler for graph execution errors
|
||||
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
|
||||
|
||||
# === Graph Traversal Components ===
|
||||
# Propagates skip status through the graph when conditions aren't met
|
||||
self._skip_propagator = SkipPropagator(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
)
|
||||
|
||||
# Processes edges to determine next nodes after execution
|
||||
# Also handles conditional branching and route selection
|
||||
self._edge_processor = EdgeProcessor(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
response_coordinator=self._response_coordinator,
|
||||
skip_propagator=self._skip_propagator,
|
||||
)
|
||||
|
||||
# === Command Processing ===
|
||||
# Processes external commands (e.g., abort requests)
|
||||
self._command_processor = CommandProcessor(
|
||||
command_channel=self._command_channel,
|
||||
graph_execution=self._graph_execution,
|
||||
)
|
||||
|
||||
# Register command handlers
|
||||
abort_handler = AbortCommandHandler()
|
||||
self._command_processor.register_handler(AbortCommand, abort_handler)
|
||||
|
||||
pause_handler = PauseCommandHandler()
|
||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||
|
||||
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
|
||||
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Create worker pool for parallel node execution
|
||||
self._worker_pool = WorkerPool(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
execution_context=self._graph_runtime_state.execution_context,
|
||||
config=self._config,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
# Coordinates the overall execution lifecycle
|
||||
self._execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self._graph_execution,
|
||||
state_manager=self._state_manager,
|
||||
command_processor=self._command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandler(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_manager,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# Dispatches events and manages execution flow
|
||||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
)
|
||||
|
||||
# === Validation ===
|
||||
# Ensure all nodes share the same GraphRuntimeState instance
|
||||
self._validate_graph_state_consistency()
|
||||
|
||||
def _validate_graph_state_consistency(self) -> None:
|
||||
"""Validate that all nodes share the same GraphRuntimeState."""
|
||||
expected_state_id = id(self._graph_runtime_state)
|
||||
for node in self._graph.nodes.values():
|
||||
if id(node.graph_runtime_state) != expected_state_id:
|
||||
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
||||
|
||||
def _bind_layer_context(
|
||||
self,
|
||||
layer: GraphEngineLayer,
|
||||
) -> None:
|
||||
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
||||
|
||||
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
|
||||
"""Add a layer for extending functionality."""
|
||||
self._layers.append(layer)
|
||||
self._bind_layer_context(layer)
|
||||
return self
|
||||
|
||||
def request_abort(self, reason: str | None = None) -> None:
|
||||
"""Queue an abort command for this engine."""
|
||||
self._command_channel.send_command(AbortCommand(reason=reason or "User requested abort"))
|
||||
|
||||
def create_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
root_node_id: str,
|
||||
variable_pool: VariablePool | None = None,
|
||||
) -> GraphEngine:
|
||||
return self._graph_runtime_state.create_child_engine(
|
||||
workflow_id=workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
root_node_id=root_node_id,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Execute the graph using the modular architecture.
|
||||
|
||||
Returns:
|
||||
Generator yielding GraphEngineEvent instances
|
||||
"""
|
||||
try:
|
||||
# Initialize layers
|
||||
self._initialize_layers()
|
||||
|
||||
is_resume = self._graph_execution.started
|
||||
if not is_resume:
|
||||
self._graph_execution.start()
|
||||
else:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reasons = []
|
||||
|
||||
start_event = GraphRunStartedEvent(
|
||||
reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL,
|
||||
)
|
||||
self._event_manager.notify_layers(start_event)
|
||||
yield start_event
|
||||
|
||||
# Start subsystems
|
||||
self._start_execution(resume=is_resume)
|
||||
|
||||
# Yield events as they occur
|
||||
yield from self._event_manager.emit_events()
|
||||
|
||||
# Handle completion
|
||||
if self._graph_execution.is_paused:
|
||||
pause_reasons = self._graph_execution.pause_reasons
|
||||
assert pause_reasons, "pause_reasons should not be empty when execution is paused."
|
||||
# Ensure we have a valid PauseReason for the event
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reasons=pause_reasons,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
yield paused_event
|
||||
elif self._graph_execution.aborted:
|
||||
abort_reason = "Workflow execution aborted by user command"
|
||||
if self._graph_execution.error:
|
||||
abort_reason = str(self._graph_execution.error)
|
||||
aborted_event = GraphRunAbortedEvent(
|
||||
reason=abort_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(aborted_event)
|
||||
yield aborted_event
|
||||
elif self._graph_execution.has_error:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
else:
|
||||
outputs = self._graph_runtime_state.outputs
|
||||
exceptions_count = self._graph_execution.exceptions_count
|
||||
if exceptions_count > 0:
|
||||
partial_event = GraphRunPartialSucceededEvent(
|
||||
exceptions_count=exceptions_count,
|
||||
outputs=outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(partial_event)
|
||||
yield partial_event
|
||||
else:
|
||||
succeeded_event = GraphRunSucceededEvent(
|
||||
outputs=outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(succeeded_event)
|
||||
yield succeeded_event
|
||||
|
||||
except Exception as e:
|
||||
failed_event = GraphRunFailedEvent(
|
||||
error=str(e),
|
||||
exceptions_count=self._graph_execution.exceptions_count,
|
||||
)
|
||||
self._event_manager.notify_layers(failed_event)
|
||||
yield failed_event
|
||||
raise
|
||||
|
||||
finally:
|
||||
self._stop_execution()
|
||||
|
||||
def _initialize_layers(self) -> None:
|
||||
"""Initialize layers with context."""
|
||||
self._event_manager.set_layers(self._layers)
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_start()
|
||||
except Exception:
|
||||
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
paused_nodes: list[str] = []
|
||||
deferred_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
deferred_nodes = self._graph_runtime_state.consume_deferred_nodes()
|
||||
|
||||
# Start worker pool (it calculates initial workers internally)
|
||||
self._worker_pool.start()
|
||||
|
||||
# Register response nodes
|
||||
for node in self._graph.nodes.values():
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._response_coordinator.register(node.id)
|
||||
|
||||
if not resume:
|
||||
# Enqueue root node
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
else:
|
||||
seen_nodes: set[str] = set()
|
||||
for node_id in paused_nodes + deferred_nodes:
|
||||
if node_id in seen_nodes:
|
||||
continue
|
||||
seen_nodes.add(node_id)
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Start dispatcher
|
||||
self._dispatcher.start()
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
# Notify layers
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
except Exception:
|
||||
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
|
||||
|
||||
# Public property accessors for attributes that need external access
|
||||
@property
|
||||
def graph_runtime_state(self) -> GraphRuntimeState:
|
||||
"""Get the graph runtime state."""
|
||||
return self._graph_runtime_state
|
||||
290
api/graphon/graph_engine/graph_state_manager.py
Normal file
290
api/graphon/graph_engine/graph_state_manager.py
Normal file
@ -0,0 +1,290 @@
|
||||
"""
|
||||
Graph state manager that combines node, edge, and execution tracking.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, final
|
||||
|
||||
from graphon.enums import NodeState
|
||||
from graphon.graph import Edge, Graph
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
|
||||
class EdgeStateAnalysis(TypedDict):
|
||||
"""Analysis result for edge states."""
|
||||
|
||||
has_unknown: bool
|
||||
has_taken: bool
|
||||
all_skipped: bool
|
||||
|
||||
|
||||
@final
|
||||
class GraphStateManager:
|
||||
def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None:
|
||||
"""
|
||||
Initialize the state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
ready_queue: Queue for nodes ready to execute
|
||||
"""
|
||||
self._graph = graph
|
||||
self._ready_queue = ready_queue
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Execution tracking state
|
||||
self._executing_nodes: set[str] = set()
|
||||
|
||||
# ============= Node State Operations =============
|
||||
|
||||
def enqueue_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as TAKEN and add it to the ready queue.
|
||||
|
||||
This combines the state transition and enqueueing operations
|
||||
that always occur together when preparing a node for execution.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to enqueue
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self._ready_queue.put(node_id)
|
||||
|
||||
def mark_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as SKIPPED.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when all its incoming edges from taken branches
|
||||
have been satisfied.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
with self._lock:
|
||||
# Get all incoming edges to this node
|
||||
incoming_edges = self._graph.get_incoming_edges(node_id)
|
||||
|
||||
# If no incoming edges, node is always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# If any edge is UNKNOWN, node is not ready
|
||||
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
|
||||
return False
|
||||
|
||||
# Node is ready if at least one edge is TAKEN
|
||||
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
|
||||
|
||||
def get_node_state(self, node_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The current node state
|
||||
"""
|
||||
with self._lock:
|
||||
return self._graph.nodes[node_id].state
|
||||
|
||||
# ============= Edge State Operations =============
|
||||
|
||||
def mark_edge_taken(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as TAKEN.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.edges[edge_id].state = NodeState.TAKEN
|
||||
|
||||
def mark_edge_skipped(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as SKIPPED.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
|
||||
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
||||
"""
|
||||
Analyze the states of edges and return summary flags.
|
||||
|
||||
Args:
|
||||
edges: List of edges to analyze
|
||||
|
||||
Returns:
|
||||
Analysis result with state flags
|
||||
"""
|
||||
with self._lock:
|
||||
states = {edge.state for edge in edges}
|
||||
|
||||
return EdgeStateAnalysis(
|
||||
has_unknown=NodeState.UNKNOWN in states,
|
||||
has_taken=NodeState.TAKEN in states,
|
||||
all_skipped=states == {NodeState.SKIPPED} if states else True,
|
||||
)
|
||||
|
||||
def get_edge_state(self, edge_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of an edge.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge
|
||||
|
||||
Returns:
|
||||
The current edge state
|
||||
"""
|
||||
with self._lock:
|
||||
return self._graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
A tuple of (selected_edges, unselected_edges)
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
selected_edges.append(edge)
|
||||
else:
|
||||
unselected_edges.append(edge)
|
||||
|
||||
return selected_edges, unselected_edges
|
||||
|
||||
# ============= Execution Tracking Operations =============
|
||||
|
||||
def start_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node starting execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.add(node_id)
|
||||
|
||||
def finish_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as no longer executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node finishing execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.discard(node_id)
|
||||
|
||||
def is_executing(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is currently executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is executing
|
||||
"""
|
||||
with self._lock:
|
||||
return node_id in self._executing_nodes
|
||||
|
||||
def get_executing_count(self) -> int:
|
||||
"""
|
||||
Get the count of currently executing nodes.
|
||||
|
||||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
# This count is a best-effort snapshot and can change concurrently.
|
||||
# Only use it for pause-drain checks where scheduling is already frozen.
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
def get_executing_nodes(self) -> set[str]:
|
||||
"""
|
||||
Get a copy of the set of executing node IDs.
|
||||
|
||||
Returns:
|
||||
Set of node IDs currently executing
|
||||
"""
|
||||
with self._lock:
|
||||
return self._executing_nodes.copy()
|
||||
|
||||
def clear_executing(self) -> None:
|
||||
"""Clear all executing nodes."""
|
||||
with self._lock:
|
||||
self._executing_nodes.clear()
|
||||
|
||||
# ============= Composite Operations =============
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
Check if graph execution is complete.
|
||||
|
||||
Execution is complete when:
|
||||
- Ready queue is empty
|
||||
- No nodes are executing
|
||||
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
with self._lock:
|
||||
return self._ready_queue.empty() and len(self._executing_nodes) == 0
|
||||
|
||||
def get_queue_depth(self) -> int:
|
||||
"""
|
||||
Get the current depth of the ready queue.
|
||||
|
||||
Returns:
|
||||
Number of nodes in the ready queue
|
||||
"""
|
||||
return self._ready_queue.qsize()
|
||||
|
||||
def get_execution_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get execution statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with execution statistics
|
||||
"""
|
||||
with self._lock:
|
||||
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
|
||||
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
|
||||
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
|
||||
|
||||
return {
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"executing": len(self._executing_nodes),
|
||||
"taken_nodes": taken_nodes,
|
||||
"skipped_nodes": skipped_nodes,
|
||||
"unknown_nodes": unknown_nodes,
|
||||
}
|
||||
14
api/graphon/graph_engine/graph_traversal/__init__.py
Normal file
14
api/graphon/graph_engine/graph_traversal/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Graph traversal subsystem for graph engine.
|
||||
|
||||
This package handles graph navigation, edge processing,
|
||||
and skip propagation logic.
|
||||
"""
|
||||
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
__all__ = [
|
||||
"EdgeProcessor",
|
||||
"SkipPropagator",
|
||||
]
|
||||
201
api/graphon/graph_engine/graph_traversal/edge_processor.py
Normal file
201
api/graphon/graph_engine/graph_traversal/edge_processor.py
Normal file
@ -0,0 +1,201 @@
|
||||
"""
|
||||
Edge processing logic for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from graphon.enums import NodeExecutionType
|
||||
from graphon.graph import Edge, Graph
|
||||
from graphon.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
class EdgeProcessor:
|
||||
"""
|
||||
Processes edges during graph execution.
|
||||
|
||||
This handles marking edges as taken or skipped, notifying
|
||||
the response coordinator, triggering downstream node execution,
|
||||
and managing branch node logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
state_manager: GraphStateManager,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
skip_propagator: "SkipPropagator",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the edge processor.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
response_coordinator: Response stream coordinator
|
||||
skip_propagator: Propagator for skip states
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
self._response_coordinator = response_coordinator
|
||||
self._skip_propagator = skip_propagator
|
||||
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges after a node succeeds.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
selected_handle: For branch nodes, the selected edge handle
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
|
||||
"""
|
||||
node = self._graph.nodes[node_id]
|
||||
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
return self._process_branch_node_edges(node_id, selected_handle)
|
||||
else:
|
||||
return self._process_non_branch_node_edges(node_id)
|
||||
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for non-branch nodes (mark all as TAKEN).
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
"""
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_branch_node_edges(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for branch nodes.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no edge was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} did not select any edge")
|
||||
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Process unselected edges first (mark as skipped)
|
||||
for edge in unselected_edges:
|
||||
self._process_skipped_edge(edge)
|
||||
|
||||
# Process selected edges
|
||||
for edge in selected_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Mark edge as taken and check downstream node.
|
||||
|
||||
Args:
|
||||
edge: The edge to process
|
||||
|
||||
Returns:
|
||||
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
|
||||
"""
|
||||
# Mark edge as taken
|
||||
self._state_manager.mark_edge_taken(edge.id)
|
||||
|
||||
# Notify response coordinator and get streaming events
|
||||
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes: list[str] = []
|
||||
if self._state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, streaming_events
|
||||
|
||||
def _process_skipped_edge(self, edge: Edge) -> None:
|
||||
"""
|
||||
Mark edge as skipped.
|
||||
|
||||
Args:
|
||||
edge: The edge to skip
|
||||
"""
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected branch
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no branch was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self._skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
Validate that a branch selection is valid.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle to validate
|
||||
|
||||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
96
api/graphon/graph_engine/graph_traversal/skip_propagator.py
Normal file
96
api/graphon/graph_engine/graph_traversal/skip_propagator.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""
|
||||
Skip state propagation through the graph.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from graphon.graph import Edge, Graph
|
||||
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
|
||||
|
||||
@final
|
||||
class SkipPropagator:
|
||||
"""
|
||||
Propagates skip states through the graph.
|
||||
|
||||
When a node is skipped, this ensures all downstream nodes
|
||||
that depend solely on it are also skipped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
state_manager: GraphStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the skip propagator.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
|
||||
def propagate_skip_from_edge(self, edge_id: str) -> None:
|
||||
"""
|
||||
Recursively propagate skip state from a skipped edge.
|
||||
|
||||
Rules:
|
||||
- If a node has any UNKNOWN incoming edges, stop processing
|
||||
- If all incoming edges are SKIPPED, skip the node and its edges
|
||||
- If any incoming edge is TAKEN, the node may still execute
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the skipped edge to start from
|
||||
"""
|
||||
downstream_node_id = self._graph.edges[edge_id].head
|
||||
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
|
||||
|
||||
# Analyze edge states
|
||||
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
|
||||
|
||||
# Stop if there are unknown edges (not yet processed)
|
||||
if edge_states["has_unknown"]:
|
||||
return
|
||||
|
||||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self._state_manager.enqueue_node(downstream_node_id)
|
||||
self._state_manager.start_execution(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
if edge_states["all_skipped"]:
|
||||
self._propagate_skip_to_node(downstream_node_id)
|
||||
|
||||
def _propagate_skip_to_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node and all its outgoing edges as skipped.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
# Mark node as skipped
|
||||
self._state_manager.mark_node_skipped(node_id)
|
||||
|
||||
# Mark all outgoing edges as skipped and propagate
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
for edge in outgoing_edges:
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
|
||||
"""
|
||||
Skip all paths from unselected branch edges.
|
||||
|
||||
Args:
|
||||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
55
api/graphon/graph_engine/layers/README.md
Normal file
55
api/graphon/graph_engine/layers/README.md
Normal file
@ -0,0 +1,55 @@
|
||||
# Layers
|
||||
|
||||
Pluggable middleware for engine extensions.
|
||||
|
||||
## Components
|
||||
|
||||
### Layer (base)
|
||||
|
||||
Abstract base class for layers.
|
||||
|
||||
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
|
||||
- `on_graph_start()` - Execution start hook
|
||||
- `on_event()` - Process all events
|
||||
- `on_graph_end()` - Execution end hook
|
||||
|
||||
### DebugLoggingLayer
|
||||
|
||||
Comprehensive execution logging.
|
||||
|
||||
- Configurable detail levels
|
||||
- Tracks execution statistics
|
||||
- Truncates long values
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
debug_layer = DebugLoggingLayer(
|
||||
level="INFO",
|
||||
include_outputs=True
|
||||
)
|
||||
|
||||
engine = GraphEngine(graph)
|
||||
engine.layer(debug_layer)
|
||||
engine.run()
|
||||
```
|
||||
|
||||
`engine.layer()` binds the read-only runtime state before execution, so
|
||||
`graph_runtime_state` is always available inside layer hooks.
|
||||
|
||||
## Custom Layers
|
||||
|
||||
```python
|
||||
class MetricsLayer(Layer):
|
||||
def on_event(self, event):
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self.metrics[event.node_id] = event.elapsed_time
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
**DebugLoggingLayer Options:**
|
||||
|
||||
- `level` - Log level (INFO, DEBUG, ERROR)
|
||||
- `include_inputs/outputs` - Log data values
|
||||
- `max_value_length` - Truncate long values
|
||||
16
api/graphon/graph_engine/layers/__init__.py
Normal file
16
api/graphon/graph_engine/layers/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Layer system for GraphEngine extensibility.
|
||||
|
||||
This module provides the layer infrastructure for extending GraphEngine functionality
|
||||
with middleware-like components that can observe events and interact with execution.
|
||||
"""
|
||||
|
||||
from .base import GraphEngineLayer
|
||||
from .debug_logging import DebugLoggingLayer
|
||||
from .execution_limits import ExecutionLimitsLayer
|
||||
|
||||
__all__ = [
|
||||
"DebugLoggingLayer",
|
||||
"ExecutionLimitsLayer",
|
||||
"GraphEngineLayer",
|
||||
]
|
||||
128
api/graphon/graph_engine/layers/base.py
Normal file
128
api/graphon/graph_engine/layers/base.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
Base layer class for GraphEngine extensions.
|
||||
|
||||
This module provides the abstract base class for implementing layers that can
|
||||
intercept and respond to GraphEngine events.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from graphon.graph_engine.protocols.command_channel import CommandChannel
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.runtime import ReadOnlyGraphRuntimeState
|
||||
|
||||
|
||||
class GraphEngineLayerNotInitializedError(Exception):
|
||||
"""Raised when a layer's runtime state is accessed before initialization."""
|
||||
|
||||
def __init__(self, layer_name: str | None = None) -> None:
|
||||
name = layer_name or "GraphEngineLayer"
|
||||
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
|
||||
|
||||
|
||||
class GraphEngineLayer(ABC):
|
||||
"""
|
||||
Abstract base class for GraphEngine layers.
|
||||
|
||||
Layers are middleware-like components that can:
|
||||
- Observe all events emitted by the GraphEngine
|
||||
- Access the graph runtime state
|
||||
- Send commands to control execution
|
||||
|
||||
Subclasses should override the constructor to accept configuration parameters,
|
||||
then implement the three lifecycle methods.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self.command_channel: CommandChannel | None = None
|
||||
|
||||
@property
|
||||
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
|
||||
if self._graph_runtime_state is None:
|
||||
raise GraphEngineLayerNotInitializedError(type(self).__name__)
|
||||
return self._graph_runtime_state
|
||||
|
||||
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
Initialize the layer with engine dependencies.
|
||||
|
||||
Called by GraphEngine to inject the read-only runtime state and command channel.
|
||||
This is invoked when the layer is registered with a `GraphEngine` instance.
|
||||
Implementations should be idempotent.
|
||||
Args:
|
||||
graph_runtime_state: Read-only view of the runtime state
|
||||
command_channel: Channel for sending commands to the engine
|
||||
"""
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self.command_channel = command_channel
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
|
||||
This is called after the engine has been initialized but before any nodes
|
||||
are executed. Layers can use this to set up resources or log start information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
This method receives all events generated during graph execution, including:
|
||||
- Graph lifecycle events (start, success, failure)
|
||||
- Node execution events (start, success, failure, retry)
|
||||
- Stream events for response nodes
|
||||
- Container events (iteration, loop)
|
||||
|
||||
Args:
|
||||
event: The event emitted by the engine
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
This is called after all nodes have been executed or when execution is
|
||||
aborted. Layers can use this to clean up resources or log final state.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_node_run_start(self, node: Node) -> None:
|
||||
"""
|
||||
Called immediately before a node begins execution.
|
||||
|
||||
Layers can override to inject behavior (e.g., start spans) prior to node execution.
|
||||
The node's execution ID is available via `node._node_execution_id` and will be
|
||||
consistent with all events emitted by this node execution.
|
||||
|
||||
Args:
|
||||
node: The node instance about to be executed
|
||||
"""
|
||||
return
|
||||
|
||||
def on_node_run_end(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Called after a node finishes execution.
|
||||
|
||||
The node's execution ID is available via `node._node_execution_id` and matches
|
||||
the `id` field in all events emitted by this node execution.
|
||||
|
||||
Args:
|
||||
node: The node instance that just finished execution
|
||||
error: Exception instance if the node failed, otherwise None
|
||||
result_event: The final result event from node execution (succeeded/failed/paused), if any
|
||||
"""
|
||||
return
|
||||
247
api/graphon/graph_engine/layers/debug_logging.py
Normal file
247
api/graphon/graph_engine/layers/debug_logging.py
Normal file
@ -0,0 +1,247 @@
|
||||
"""
|
||||
Debug logging layer for GraphEngine.
|
||||
|
||||
This module provides a layer that logs all events and state changes during
|
||||
graph execution for debugging purposes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from graphon.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .base import GraphEngineLayer
|
||||
|
||||
|
||||
@final
|
||||
class DebugLoggingLayer(GraphEngineLayer):
|
||||
"""
|
||||
A layer that provides comprehensive logging of GraphEngine execution.
|
||||
|
||||
This layer logs all events with configurable detail levels, helping developers
|
||||
debug workflow execution and understand the flow of events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
level: str = "INFO",
|
||||
include_inputs: bool = False,
|
||||
include_outputs: bool = True,
|
||||
include_process_data: bool = False,
|
||||
logger_name: str = "GraphEngine.Debug",
|
||||
max_value_length: int = 500,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the debug logging layer.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
include_inputs: Whether to log node input values
|
||||
include_outputs: Whether to log node output values
|
||||
include_process_data: Whether to log node process data
|
||||
logger_name: Name of the logger to use
|
||||
max_value_length: Maximum length of logged values (truncated if longer)
|
||||
"""
|
||||
super().__init__()
|
||||
self.level = level
|
||||
self.include_inputs = include_inputs
|
||||
self.include_outputs = include_outputs
|
||||
self.include_process_data = include_process_data
|
||||
self.max_value_length = max_value_length
|
||||
|
||||
# Set up logger
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
log_level = getattr(logging, level.upper(), logging.INFO)
|
||||
self.logger.setLevel(log_level)
|
||||
|
||||
# Track execution stats
|
||||
self.node_count = 0
|
||||
self.success_count = 0
|
||||
self.failure_count = 0
|
||||
self.retry_count = 0
|
||||
|
||||
def _truncate_value(self, value: Any) -> str:
|
||||
"""Truncate long values for logging."""
|
||||
str_value = str(value)
|
||||
if len(str_value) > self.max_value_length:
|
||||
return str_value[: self.max_value_length] + "... (truncated)"
|
||||
return str_value
|
||||
|
||||
def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str:
|
||||
"""Format a dictionary or mapping for logging with truncation."""
|
||||
if not data:
|
||||
return "{}"
|
||||
|
||||
formatted_items: list[str] = []
|
||||
for key, value in data.items():
|
||||
formatted_value = self._truncate_value(value)
|
||||
formatted_items.append(f" {key}: {formatted_value}")
|
||||
|
||||
return "{\n" + ",\n".join(formatted_items) + "\n}"
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Log graph execution start."""
|
||||
self.logger.info("=" * 80)
|
||||
self.logger.info("🚀 GRAPH EXECUTION STARTED")
|
||||
self.logger.info("=" * 80)
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""Log individual events based on their type."""
|
||||
event_class = event.__class__.__name__
|
||||
|
||||
# Graph-level events
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.logger.debug("Graph run started event")
|
||||
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.logger.info("✅ Graph run succeeded")
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.logger.warning("⚠️ Graph run partially succeeded")
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.error(" Total exceptions: %s", event.exceptions_count)
|
||||
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self.logger.warning("⚠️ Graph run aborted: %s", event.reason)
|
||||
if event.outputs:
|
||||
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
# Node-level events
|
||||
# Retry before Started because Retry subclasses Started;
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.node_count += 1
|
||||
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
|
||||
|
||||
if self.include_inputs and event.node_run_result.inputs:
|
||||
self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs))
|
||||
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.success_count += 1
|
||||
self.logger.info("✅ Node succeeded: %s", event.node_id)
|
||||
|
||||
if self.include_outputs and event.node_run_result.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs))
|
||||
|
||||
if self.include_process_data and event.node_run_result.process_data:
|
||||
self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data))
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.failure_count += 1
|
||||
self.logger.error("❌ Node failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
if event.node_run_result.error:
|
||||
self.logger.error(" Details: %s", event.node_run_result.error)
|
||||
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
|
||||
self.logger.warning(" Error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
# Log stream chunks at debug level to avoid spam
|
||||
final_indicator = " (FINAL)" if event.is_final else ""
|
||||
self.logger.debug(
|
||||
"📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk)
|
||||
)
|
||||
|
||||
# Iteration events
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
self.logger.info("🔁 Iteration started: %s", event.node_id)
|
||||
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index)
|
||||
|
||||
elif isinstance(event, NodeRunIterationSucceededEvent):
|
||||
self.logger.info("✅ Iteration succeeded: %s", event.node_id)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, NodeRunIterationFailedEvent):
|
||||
self.logger.error("❌ Iteration failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
# Loop events
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
self.logger.info("🔄 Loop started: %s", event.node_id)
|
||||
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index)
|
||||
|
||||
elif isinstance(event, NodeRunLoopSucceededEvent):
|
||||
self.logger.info("✅ Loop succeeded: %s", event.node_id)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, NodeRunLoopFailedEvent):
|
||||
self.logger.error("❌ Loop failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
else:
|
||||
# Log unknown events at debug level
|
||||
self.logger.debug("Event: %s", event_class)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Log graph execution end with summary statistics."""
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if error:
|
||||
self.logger.error("🔴 GRAPH EXECUTION FAILED")
|
||||
self.logger.error(" Error: %s", error)
|
||||
else:
|
||||
self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY")
|
||||
|
||||
# Log execution statistics
|
||||
self.logger.info("Execution Statistics:")
|
||||
self.logger.info(" Total nodes executed: %s", self.node_count)
|
||||
self.logger.info(" Successful nodes: %s", self.success_count)
|
||||
self.logger.info(" Failed nodes: %s", self.failure_count)
|
||||
self.logger.info(" Node retries: %s", self.retry_count)
|
||||
|
||||
# Log final state if available
|
||||
if self.include_outputs and self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
|
||||
self.logger.info("=" * 80)
|
||||
150
api/graphon/graph_engine/layers/execution_limits.py
Normal file
150
api/graphon/graph_engine/layers/execution_limits.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""
|
||||
Execution limits layer for GraphEngine.
|
||||
|
||||
This layer monitors workflow execution to enforce limits on:
|
||||
- Maximum execution steps
|
||||
- Maximum execution time
|
||||
|
||||
When limits are exceeded, the layer automatically aborts execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from enum import StrEnum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from graphon.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
||||
|
||||
class LimitType(StrEnum):
|
||||
"""Types of execution limits that can be exceeded."""
|
||||
|
||||
STEP_LIMIT = "step_limit"
|
||||
TIME_LIMIT = "time_limit"
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
"""
|
||||
Layer that enforces execution limits for workflows.
|
||||
|
||||
Monitors:
|
||||
- Step count: Tracks number of node executions
|
||||
- Time limit: Monitors total execution time
|
||||
|
||||
Automatically aborts execution when limits are exceeded.
|
||||
"""
|
||||
|
||||
def __init__(self, max_steps: int, max_time: int) -> None:
|
||||
"""
|
||||
Initialize the execution limits layer.
|
||||
|
||||
Args:
|
||||
max_steps: Maximum number of execution steps allowed
|
||||
max_time: Maximum execution time in seconds allowed
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
self.max_time = max_time
|
||||
|
||||
# Runtime tracking
|
||||
self.start_time: float | None = None
|
||||
self.step_count = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# State tracking
|
||||
self._execution_started = False
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False # Track if abort command has been sent
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Called when graph execution starts."""
|
||||
self.start_time = time.time()
|
||||
self.step_count = 0
|
||||
self._execution_started = True
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False
|
||||
|
||||
self.logger.debug("Execution limits monitoring started")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
Monitors execution progress and enforces limits.
|
||||
"""
|
||||
if not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Track step count for node execution events
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self.step_count += 1
|
||||
self.logger.debug("Step %d started: %s", self.step_count, event.node_id)
|
||||
|
||||
# Check step limit when node execution completes
|
||||
if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent):
|
||||
if self._reached_step_limitation():
|
||||
self._send_abort_command(LimitType.STEP_LIMIT)
|
||||
|
||||
if self._reached_time_limitation():
|
||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._execution_started and not self._execution_ended:
|
||||
self._execution_ended = True
|
||||
|
||||
if self.start_time:
|
||||
total_time = time.time() - self.start_time
|
||||
self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time)
|
||||
|
||||
def _reached_step_limitation(self) -> bool:
|
||||
"""Check if step count limit has been exceeded."""
|
||||
return self.step_count > self.max_steps
|
||||
|
||||
def _reached_time_limitation(self) -> bool:
|
||||
"""Check if time limit has been exceeded."""
|
||||
return self.start_time is not None and (time.time() - self.start_time) > self.max_time
|
||||
|
||||
def _send_abort_command(self, limit_type: LimitType) -> None:
|
||||
"""
|
||||
Send abort command due to limit violation.
|
||||
|
||||
Args:
|
||||
limit_type: Type of limit exceeded
|
||||
"""
|
||||
if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Format detailed reason message
|
||||
if limit_type == LimitType.STEP_LIMIT:
|
||||
reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}"
|
||||
elif limit_type == LimitType.TIME_LIMIT:
|
||||
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
||||
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
||||
|
||||
self.logger.warning("Execution limit exceeded: %s", reason)
|
||||
|
||||
try:
|
||||
# Send abort command to the engine
|
||||
abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason)
|
||||
self.command_channel.send_command(abort_command)
|
||||
|
||||
# Mark that abort has been sent to prevent duplicate commands
|
||||
self._abort_sent = True
|
||||
|
||||
self.logger.debug("Abort command sent to engine")
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Failed to send abort command")
|
||||
79
api/graphon/graph_engine/manager.py
Normal file
79
api/graphon/graph_engine/manager.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""
|
||||
GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Callers must provide a Redis client dependency from outside the workflow package.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from graphon.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol
|
||||
from graphon.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
GraphEngineCommand,
|
||||
PauseCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngineManager:
|
||||
"""
|
||||
Manager for sending control commands to GraphEngine instances.
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
"""
|
||||
|
||||
_redis_client: RedisClientProtocol
|
||||
|
||||
def __init__(self, redis_client: RedisClientProtocol) -> None:
|
||||
self._redis_client = redis_client
|
||||
|
||||
def send_stop_command(self, task_id: str, reason: str | None = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
Args:
|
||||
task_id: The task ID of the workflow to stop
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
self._send_command(task_id, abort_command)
|
||||
|
||||
def send_pause_command(self, task_id: str, reason: str | None = None) -> None:
|
||||
"""Send a pause command to a running workflow."""
|
||||
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
self._send_command(task_id, pause_command)
|
||||
|
||||
def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
"""Send a command to update variables in a running workflow."""
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
update_command = UpdateVariablesCommand(updates=updates)
|
||||
self._send_command(task_id, update_command)
|
||||
|
||||
def _send_command(self, task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(self._redis_client, channel_key)
|
||||
|
||||
try:
|
||||
channel.send_command(command)
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
# The legacy control mechanisms will still work
|
||||
logger.exception("Failed to send graph engine command %s for task %s", command.__class__.__name__, task_id)
|
||||
14
api/graphon/graph_engine/orchestration/__init__.py
Normal file
14
api/graphon/graph_engine/orchestration/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Orchestration subsystem for graph engine.
|
||||
|
||||
This package coordinates the overall execution flow between
|
||||
different subsystems.
|
||||
"""
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
__all__ = [
|
||||
"Dispatcher",
|
||||
"ExecutionCoordinator",
|
||||
]
|
||||
143
api/graphon/graph_engine/orchestration/dispatcher.py
Normal file
143
api/graphon/graph_engine/orchestration/dispatcher.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""
|
||||
Main dispatcher for processing events from workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from graphon.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from ..event_management import EventManager
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class Dispatcher:
|
||||
"""
|
||||
Main dispatcher that processes events from the event queue.
|
||||
|
||||
This runs in a separate thread and coordinates event processing
|
||||
with timeout and completion detection.
|
||||
"""
|
||||
|
||||
_COMMAND_TRIGGER_EVENTS = (
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dispatcher.
|
||||
|
||||
Args:
|
||||
event_queue: Queue of events from workers
|
||||
event_handler: Event handler registry for processing events
|
||||
execution_coordinator: Coordinator for execution flow
|
||||
event_emitter: Optional event manager to signal completion
|
||||
"""
|
||||
self._event_queue = event_queue
|
||||
self._event_handler = event_handler
|
||||
self._execution_coordinator = execution_coordinator
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the dispatcher thread."""
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
self._process_commands()
|
||||
paused = False
|
||||
while not self._stop_event.is_set():
|
||||
if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete:
|
||||
break
|
||||
if self._execution_coordinator.paused:
|
||||
paused = True
|
||||
break
|
||||
|
||||
self._execution_coordinator.check_scaling()
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
self._process_commands(event)
|
||||
except queue.Empty:
|
||||
time.sleep(0.1)
|
||||
|
||||
self._process_commands()
|
||||
if paused:
|
||||
self._drain_events_until_idle()
|
||||
else:
|
||||
self._drain_event_queue()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
self._execution_coordinator.mark_failed(e)
|
||||
|
||||
finally:
|
||||
self._execution_coordinator.mark_complete()
|
||||
# Signal the event emitter that execution is complete
|
||||
if self._event_emitter:
|
||||
self._event_emitter.mark_complete()
|
||||
|
||||
def _process_commands(self, event: GraphNodeEventBase | None = None):
|
||||
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
|
||||
self._execution_coordinator.process_commands()
|
||||
|
||||
def _drain_event_queue(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
event = self._event_queue.get(block=False)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
def _drain_events_until_idle(self) -> None:
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
self._process_commands(event)
|
||||
except queue.Empty:
|
||||
if not self._execution_coordinator.has_executing_nodes():
|
||||
break
|
||||
self._drain_event_queue()
|
||||
104
api/graphon/graph_engine/orchestration/execution_coordinator.py
Normal file
104
api/graphon/graph_engine/orchestration/execution_coordinator.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""
|
||||
Execution coordinator for managing overall workflow execution.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..worker_management import WorkerPool
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionCoordinator:
|
||||
"""
|
||||
Coordinates overall execution flow between subsystems.
|
||||
|
||||
This provides high-level coordination methods used by the
|
||||
dispatcher to manage execution state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
state_manager: GraphStateManager,
|
||||
command_processor: CommandProcessor,
|
||||
worker_pool: WorkerPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the execution coordinator.
|
||||
|
||||
Args:
|
||||
graph_execution: Graph execution aggregate
|
||||
state_manager: Unified state manager
|
||||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self._graph_execution = graph_execution
|
||||
self._state_manager = state_manager
|
||||
self._command_processor = command_processor
|
||||
self._worker_pool = worker_pool
|
||||
|
||||
def process_commands(self) -> None:
|
||||
"""Process any pending commands."""
|
||||
self._command_processor.process_commands()
|
||||
|
||||
def check_scaling(self) -> None:
|
||||
"""Check and perform worker scaling if needed."""
|
||||
self._worker_pool.check_and_scale()
|
||||
|
||||
@property
|
||||
def execution_complete(self):
|
||||
return self._state_manager.is_execution_complete()
|
||||
|
||||
@property
|
||||
def aborted(self):
|
||||
return self._graph_execution.aborted or self._graph_execution.has_error
|
||||
|
||||
@property
|
||||
def paused(self) -> bool:
|
||||
"""Expose whether the underlying graph execution is paused."""
|
||||
return self._graph_execution.is_paused
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
if self._graph_execution.is_paused:
|
||||
return
|
||||
if not self._graph_execution.completed:
|
||||
self._graph_execution.complete()
|
||||
|
||||
def mark_failed(self, error: Exception) -> None:
|
||||
"""
|
||||
Mark execution as failed.
|
||||
|
||||
Args:
|
||||
error: The error that caused failure
|
||||
"""
|
||||
self._graph_execution.fail(error)
|
||||
|
||||
def handle_pause_if_needed(self) -> None:
|
||||
"""If the execution has been paused, stop workers immediately."""
|
||||
|
||||
if not self._graph_execution.is_paused:
|
||||
return
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
def handle_abort_if_needed(self) -> None:
|
||||
"""If the execution has been aborted, stop workers immediately."""
|
||||
|
||||
if not self._graph_execution.aborted:
|
||||
return
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
def has_executing_nodes(self) -> bool:
|
||||
"""Return True if any nodes are currently marked as executing."""
|
||||
# This check is only safe once execution has already paused.
|
||||
# Before pause, executing state can change concurrently, which makes the result unreliable.
|
||||
if not self._graph_execution.is_paused:
|
||||
raise AssertionError("has_executing_nodes should only be called after execution is paused")
|
||||
return self._state_manager.get_executing_count() > 0
|
||||
41
api/graphon/graph_engine/protocols/command_channel.py
Normal file
41
api/graphon/graph_engine/protocols/command_channel.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""
|
||||
CommandChannel protocol for GraphEngine command communication.
|
||||
|
||||
This protocol defines the interface for sending and receiving commands
|
||||
to/from a GraphEngine instance, supporting both local and distributed scenarios.
|
||||
"""
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
class CommandChannel(Protocol):
|
||||
"""
|
||||
Protocol for bidirectional command communication with GraphEngine.
|
||||
|
||||
Since each GraphEngine instance processes only one workflow execution,
|
||||
this channel is dedicated to that single execution.
|
||||
"""
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch pending commands for this GraphEngine instance.
|
||||
|
||||
Called by GraphEngine to poll for commands that need to be processed.
|
||||
|
||||
Returns:
|
||||
List of pending commands (may be empty)
|
||||
"""
|
||||
...
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to be processed by this GraphEngine instance.
|
||||
|
||||
Called by external systems to send control commands to the running workflow.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
...
|
||||
12
api/graphon/graph_engine/ready_queue/__init__.py
Normal file
12
api/graphon/graph_engine/ready_queue/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""
|
||||
Ready queue implementations for GraphEngine.
|
||||
|
||||
This package contains the protocol and implementations for managing
|
||||
the queue of nodes ready for execution.
|
||||
"""
|
||||
|
||||
from .factory import create_ready_queue_from_state
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
from .protocol import ReadyQueue, ReadyQueueState
|
||||
|
||||
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"]
|
||||
37
api/graphon/graph_engine/ready_queue/factory.py
Normal file
37
api/graphon/graph_engine/ready_queue/factory.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""
|
||||
Factory for creating ReadyQueue instances from serialized state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
from .protocol import ReadyQueueState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocol import ReadyQueue
|
||||
|
||||
|
||||
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
|
||||
"""
|
||||
Create a ReadyQueue instance from a serialized state.
|
||||
|
||||
Args:
|
||||
state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue
|
||||
|
||||
Returns:
|
||||
A ReadyQueue instance initialized with the given state
|
||||
|
||||
Raises:
|
||||
ValueError: If the queue type is unknown or version is unsupported
|
||||
"""
|
||||
if state.type == "InMemoryReadyQueue":
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}")
|
||||
queue = InMemoryReadyQueue()
|
||||
# Always pass as JSON string to loads()
|
||||
queue.loads(state.model_dump_json())
|
||||
return queue
|
||||
else:
|
||||
raise ValueError(f"Unknown ready queue type: {state.type}")
|
||||
140
api/graphon/graph_engine/ready_queue/in_memory.py
Normal file
140
api/graphon/graph_engine/ready_queue/in_memory.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""
|
||||
In-memory implementation of the ReadyQueue protocol.
|
||||
|
||||
This implementation wraps Python's standard queue.Queue and adds
|
||||
serialization capabilities for state storage.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import final
|
||||
|
||||
from .protocol import ReadyQueue, ReadyQueueState
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryReadyQueue(ReadyQueue):
|
||||
"""
|
||||
In-memory ready queue implementation with serialization support.
|
||||
|
||||
This implementation uses Python's queue.Queue internally and provides
|
||||
methods to serialize and restore the queue state.
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int = 0) -> None:
|
||||
"""
|
||||
Initialize the in-memory ready queue.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum size of the queue (0 for unlimited)
|
||||
"""
|
||||
self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize)
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""
|
||||
Add a node ID to the ready queue.
|
||||
|
||||
Args:
|
||||
item: The node ID to add to the queue
|
||||
"""
|
||||
self._queue.put(item)
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""
|
||||
Retrieve and remove a node ID from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for an item (None for blocking)
|
||||
|
||||
Returns:
|
||||
The node ID retrieved from the queue
|
||||
|
||||
Raises:
|
||||
queue.Empty: If timeout expires and no item is available
|
||||
"""
|
||||
if timeout is None:
|
||||
return self._queue.get(block=True)
|
||||
return self._queue.get(timeout=timeout)
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""
|
||||
Indicate that a previously retrieved task is complete.
|
||||
|
||||
Used by worker threads to signal task completion for
|
||||
join() synchronization.
|
||||
"""
|
||||
self._queue.task_done()
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""
|
||||
Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
True if the queue has no items, False otherwise
|
||||
"""
|
||||
return self._queue.empty()
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""
|
||||
Get the approximate size of the queue.
|
||||
|
||||
Returns:
|
||||
The approximate number of items in the queue
|
||||
"""
|
||||
return self._queue.qsize()
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""
|
||||
Serialize the queue state to a JSON string for storage.
|
||||
|
||||
Returns:
|
||||
A JSON string containing the serialized queue state
|
||||
"""
|
||||
# Extract all items from the queue without removing them
|
||||
items: list[str] = []
|
||||
temp_items: list[str] = []
|
||||
|
||||
# Drain the queue temporarily to get all items
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
temp_items.append(item)
|
||||
items.append(item)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Put items back in the same order
|
||||
for item in temp_items:
|
||||
self._queue.put(item)
|
||||
|
||||
state = ReadyQueueState(
|
||||
type="InMemoryReadyQueue",
|
||||
version="1.0",
|
||||
items=items,
|
||||
)
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""
|
||||
Restore the queue state from a JSON string.
|
||||
|
||||
Args:
|
||||
data: The JSON string containing the serialized queue state to restore
|
||||
"""
|
||||
state = ReadyQueueState.model_validate_json(data)
|
||||
|
||||
if state.type != "InMemoryReadyQueue":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported version: {state.version}")
|
||||
|
||||
# Clear the current queue
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Restore items
|
||||
for item in state.items:
|
||||
self._queue.put(item)
|
||||
104
api/graphon/graph_engine/ready_queue/protocol.py
Normal file
104
api/graphon/graph_engine/ready_queue/protocol.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""
|
||||
ReadyQueue protocol for GraphEngine node execution queue.
|
||||
|
||||
This protocol defines the interface for managing the queue of nodes ready
|
||||
for execution, supporting both in-memory and persistent storage scenarios.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ReadyQueueState(BaseModel):
|
||||
"""
|
||||
Pydantic model for serialized ready queue state.
|
||||
|
||||
This defines the structure of the data returned by dumps()
|
||||
and expected by loads() for ready queue serialization.
|
||||
"""
|
||||
|
||||
type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')")
|
||||
version: str = Field(description="Serialization format version")
|
||||
items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue")
|
||||
|
||||
|
||||
class ReadyQueue(Protocol):
|
||||
"""
|
||||
Protocol for managing nodes ready for execution in GraphEngine.
|
||||
|
||||
This protocol defines the interface that any ready queue implementation
|
||||
must provide, enabling both in-memory queues and persistent queues
|
||||
that can be serialized for state storage.
|
||||
"""
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""
|
||||
Add a node ID to the ready queue.
|
||||
|
||||
Args:
|
||||
item: The node ID to add to the queue
|
||||
"""
|
||||
...
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""
|
||||
Retrieve and remove a node ID from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for an item (None for blocking)
|
||||
|
||||
Returns:
|
||||
The node ID retrieved from the queue
|
||||
|
||||
Raises:
|
||||
queue.Empty: If timeout expires and no item is available
|
||||
"""
|
||||
...
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""
|
||||
Indicate that a previously retrieved task is complete.
|
||||
|
||||
Used by worker threads to signal task completion for
|
||||
join() synchronization.
|
||||
"""
|
||||
...
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""
|
||||
Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
True if the queue has no items, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""
|
||||
Get the approximate size of the queue.
|
||||
|
||||
Returns:
|
||||
The approximate number of items in the queue
|
||||
"""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""
|
||||
Serialize the queue state to a JSON string for storage.
|
||||
|
||||
Returns:
|
||||
A JSON string containing the serialized queue state
|
||||
that can be persisted and later restored
|
||||
"""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""
|
||||
Restore the queue state from a JSON string.
|
||||
|
||||
Args:
|
||||
data: The JSON string containing the serialized queue state to restore
|
||||
"""
|
||||
...
|
||||
10
api/graphon/graph_engine/response_coordinator/__init__.py
Normal file
10
api/graphon/graph_engine/response_coordinator/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""
|
||||
ResponseStreamCoordinator - Coordinates streaming output from response nodes
|
||||
|
||||
This component manages response streaming sessions and ensures ordered streaming
|
||||
of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
from .coordinator import ResponseStreamCoordinator
|
||||
|
||||
__all__ = ["ResponseStreamCoordinator"]
|
||||
697
api/graphon/graph_engine/response_coordinator/coordinator.py
Normal file
697
api/graphon/graph_engine/response_coordinator/coordinator.py
Normal file
@ -0,0 +1,697 @@
|
||||
"""
|
||||
Main ResponseStreamCoordinator implementation.
|
||||
|
||||
This module contains the public ResponseStreamCoordinator class that manages
|
||||
response streaming sessions and ensures ordered streaming of responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import Literal, TypeAlias, final
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.enums import NodeExecutionType, NodeState
|
||||
from graphon.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from graphon.nodes.base.template import TextSegment, VariableSegment
|
||||
from graphon.runtime import VariablePool
|
||||
from graphon.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type definitions
|
||||
NodeID: TypeAlias = str
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
class ResponseSessionState(BaseModel):
|
||||
"""Serializable representation of a response session."""
|
||||
|
||||
node_id: str
|
||||
index: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class StreamBufferState(BaseModel):
|
||||
"""Serializable representation of buffered stream chunks."""
|
||||
|
||||
selector: tuple[str, ...]
|
||||
events: list[NodeRunStreamChunkEvent] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StreamPositionState(BaseModel):
|
||||
"""Serializable representation for stream read positions."""
|
||||
|
||||
selector: tuple[str, ...]
|
||||
position: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorState(BaseModel):
|
||||
"""Serialized snapshot of ResponseStreamCoordinator."""
|
||||
|
||||
type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator")
|
||||
version: str = Field(default="1.0")
|
||||
response_nodes: Sequence[str] = Field(default_factory=list)
|
||||
active_session: ResponseSessionState | None = None
|
||||
waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||
pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||
node_execution_ids: dict[str, str] = Field(default_factory=dict)
|
||||
paths_map: dict[str, list[list[str]]] = Field(default_factory=dict)
|
||||
stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list)
|
||||
stream_positions: Sequence[StreamPositionState] = Field(default_factory=list)
|
||||
closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list)
|
||||
|
||||
|
||||
@final
|
||||
class ResponseStreamCoordinator:
|
||||
"""
|
||||
Manages response streaming sessions without relying on global state.
|
||||
|
||||
Ensures ordered streaming of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None:
|
||||
"""
|
||||
Initialize coordinator with variable pool.
|
||||
|
||||
Args:
|
||||
variable_pool: VariablePool instance for accessing node variables
|
||||
graph: Graph instance for looking up node information
|
||||
"""
|
||||
self._variable_pool = variable_pool
|
||||
self._graph = graph
|
||||
self._active_session: ResponseSession | None = None
|
||||
self._waiting_sessions: deque[ResponseSession] = deque()
|
||||
self._lock = RLock()
|
||||
|
||||
# Internal stream management (replacing OutputRegistry)
|
||||
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
|
||||
self._stream_positions: dict[tuple[str, ...], int] = {}
|
||||
self._closed_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Track response nodes
|
||||
self._response_nodes: set[NodeID] = set()
|
||||
|
||||
# Store paths for each response node
|
||||
self._paths_maps: dict[NodeID, list[Path]] = {}
|
||||
|
||||
# Track node execution IDs and types for proper event forwarding
|
||||
self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id
|
||||
|
||||
# Track response sessions to ensure only one per node
|
||||
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
|
||||
|
||||
def register(self, response_node_id: NodeID) -> None:
|
||||
with self._lock:
|
||||
if response_node_id in self._response_nodes:
|
||||
return
|
||||
self._response_nodes.add(response_node_id)
|
||||
|
||||
# Build and save paths map for this response node
|
||||
paths_map = self._build_paths_map(response_node_id)
|
||||
self._paths_maps[response_node_id] = paths_map
|
||||
|
||||
# Create and store response session for this node
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
session = ResponseSession.from_node(response_node)
|
||||
self._response_sessions[response_node_id] = session
|
||||
|
||||
def track_node_execution(self, node_id: NodeID, execution_id: str) -> None:
|
||||
"""Track the execution ID for a node when it starts executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
execution_id: The execution ID from NodeRunStartedEvent
|
||||
"""
|
||||
with self._lock:
|
||||
self._node_execution_ids[node_id] = execution_id
|
||||
|
||||
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
|
||||
"""Get the execution ID for a node, creating one if it doesn't exist.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The execution ID for the node
|
||||
"""
|
||||
with self._lock:
|
||||
if node_id not in self._node_execution_ids:
|
||||
self._node_execution_ids[node_id] = str(uuid4())
|
||||
return self._node_execution_ids[node_id]
|
||||
|
||||
def _build_paths_map(self, response_node_id: NodeID) -> list[Path]:
|
||||
"""
|
||||
Build a paths map for a response node by finding all paths from root node
|
||||
to the response node, recording branch edges along each path.
|
||||
|
||||
Args:
|
||||
response_node_id: ID of the response node to analyze
|
||||
|
||||
Returns:
|
||||
List of Path objects, where each path contains branch edge IDs
|
||||
"""
|
||||
# Get root node ID
|
||||
root_node_id = self._graph.root_node.id
|
||||
|
||||
# If root is the response node, return empty path
|
||||
if root_node_id == response_node_id:
|
||||
return [Path()]
|
||||
|
||||
# Extract variable selectors from the response node's template
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
response_session = ResponseSession.from_node(response_node)
|
||||
template = response_session.template
|
||||
|
||||
# Collect all variable selectors from the template
|
||||
variable_selectors: set[tuple[str, ...]] = set()
|
||||
for segment in template.segments:
|
||||
if isinstance(segment, VariableSegment):
|
||||
variable_selectors.add(tuple(segment.selector[:2]))
|
||||
|
||||
# Step 1: Find all complete paths from root to response node
|
||||
all_complete_paths: list[list[EdgeID]] = []
|
||||
|
||||
def find_paths(
|
||||
current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID]
|
||||
) -> None:
|
||||
"""Recursively find all paths from current node to target node."""
|
||||
if current_node_id == target_node_id:
|
||||
# Found a complete path, store it
|
||||
all_complete_paths.append(current_path.copy())
|
||||
return
|
||||
|
||||
# Mark as visited to avoid cycles
|
||||
visited.add(current_node_id)
|
||||
|
||||
# Explore outgoing edges
|
||||
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
|
||||
for edge in outgoing_edges:
|
||||
edge_id = edge.id
|
||||
next_node_id = edge.head
|
||||
|
||||
# Skip if already visited in this path
|
||||
if next_node_id not in visited:
|
||||
# Add edge to path and recurse
|
||||
new_path = current_path + [edge_id]
|
||||
find_paths(next_node_id, target_node_id, new_path, visited.copy())
|
||||
|
||||
# Start searching from root node
|
||||
find_paths(root_node_id, response_node_id, [], set())
|
||||
|
||||
# Step 2: For each complete path, filter edges based on node blocking behavior
|
||||
filtered_paths: list[Path] = []
|
||||
for path in all_complete_paths:
|
||||
blocking_edges: list[str] = []
|
||||
for edge_id in path:
|
||||
edge = self._graph.edges[edge_id]
|
||||
source_node = self._graph.nodes[edge.tail]
|
||||
|
||||
# Check if node is a branch, container, or response node
|
||||
if source_node.execution_type in {
|
||||
NodeExecutionType.BRANCH,
|
||||
NodeExecutionType.CONTAINER,
|
||||
NodeExecutionType.RESPONSE,
|
||||
} or source_node.blocks_variable_output(variable_selectors):
|
||||
blocking_edges.append(edge_id)
|
||||
|
||||
# Keep the path even if it's empty
|
||||
filtered_paths.append(Path(edges=blocking_edges))
|
||||
|
||||
return filtered_paths
|
||||
|
||||
def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
Handle when an edge is taken (selected by a branch node).
|
||||
|
||||
This method updates the paths for all response nodes by removing
|
||||
the taken edge. If any response node has an empty path after removal,
|
||||
it means the node is now deterministically reachable and should start.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge that was taken
|
||||
|
||||
Returns:
|
||||
List of events to emit from starting new sessions
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
with self._lock:
|
||||
# Check each response node in order
|
||||
for response_node_id in self._response_nodes:
|
||||
if response_node_id not in self._paths_maps:
|
||||
continue
|
||||
|
||||
paths = self._paths_maps[response_node_id]
|
||||
has_reachable_path = False
|
||||
|
||||
# Update each path by removing the taken edge
|
||||
for path in paths:
|
||||
# Remove the taken edge from this path
|
||||
path.remove_edge(edge_id)
|
||||
|
||||
# Check if this path is now empty (node is reachable)
|
||||
if path.is_empty():
|
||||
has_reachable_path = True
|
||||
|
||||
# If node is now reachable (has empty path), start/queue session
|
||||
if has_reachable_path:
|
||||
# Pass the node_id to the activation method
|
||||
# The method will handle checking and removing from map
|
||||
events.extend(self._active_or_queue_session(response_node_id))
|
||||
return events
|
||||
|
||||
def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
Start a session immediately if no active session, otherwise queue it.
|
||||
Only activates sessions that exist in the _response_sessions map.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the response node to activate
|
||||
|
||||
Returns:
|
||||
List of events from flush attempt if session started immediately
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Get the session from our map (only activate if it exists)
|
||||
session = self._response_sessions.get(node_id)
|
||||
if not session:
|
||||
return events
|
||||
|
||||
# Remove from map to ensure it won't be activated again
|
||||
del self._response_sessions[node_id]
|
||||
|
||||
if self._active_session is None:
|
||||
self._active_session = session
|
||||
|
||||
# Try to flush immediately
|
||||
events.extend(self.try_flush())
|
||||
else:
|
||||
# Queue the session if another is active
|
||||
self._waiting_sessions.append(session)
|
||||
|
||||
return events
|
||||
|
||||
def intercept_event(
|
||||
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
|
||||
) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
with self._lock:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._append_stream_chunk(event.selector, event)
|
||||
if event.is_final:
|
||||
self._close_stream(event.selector)
|
||||
return self.try_flush()
|
||||
else:
|
||||
# Skip cause we share the same variable pool.
|
||||
#
|
||||
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
# self._variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
return self.try_flush()
|
||||
|
||||
def _create_stream_chunk_event(
|
||||
self,
|
||||
node_id: str,
|
||||
execution_id: str,
|
||||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
For selectors with special prefixes (sys, env, conversation), we use the
|
||||
active response node's information since these are not actual node IDs.
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
# Use the active response node for special selectors
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
node = self._graph.nodes[node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
"""Process a variable segment. Returns (events, is_complete).
|
||||
|
||||
Handles both regular node selectors and special system selectors (sys, env, conversation).
|
||||
For special selectors, we attribute the output to the active response node.
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
is_complete = False
|
||||
|
||||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector, # Keep original selector
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
# Regular node selector - use event as is
|
||||
events.append(event)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
|
||||
return events, is_complete
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self._active_session is not None
|
||||
current_response_node = self._graph.nodes[self._active_session.node_id]
|
||||
|
||||
# Use get_or_create_execution_id to ensure we have a consistent ID
|
||||
execution_id = self._get_or_create_execution_id(current_response_node.id)
|
||||
|
||||
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
event = self._create_stream_chunk_event(
|
||||
node_id=current_response_node.id,
|
||||
execution_id=execution_id,
|
||||
selector=[current_response_node.id, "answer"], # FIXME(-LAN-)
|
||||
chunk=segment.text,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
return [event]
|
||||
|
||||
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
|
||||
with self._lock:
|
||||
if not self._active_session:
|
||||
return []
|
||||
|
||||
template = self._active_session.template
|
||||
response_node_id = self._active_session.node_id
|
||||
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Process segments sequentially from current index
|
||||
while self._active_session.index < len(template.segments):
|
||||
segment = template.segments[self._active_session.index]
|
||||
|
||||
if isinstance(segment, VariableSegment):
|
||||
# Check if the source node for this variable is skipped
|
||||
# Only check for actual nodes, not special selectors (sys, env, conversation)
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
if source_selector_prefix in self._graph.nodes:
|
||||
source_node = self._graph.nodes[source_selector_prefix]
|
||||
|
||||
if source_node.state == NodeState.SKIPPED:
|
||||
# Skip this variable segment if the source node is skipped
|
||||
self._active_session.index += 1
|
||||
continue
|
||||
|
||||
segment_events, is_complete = self._process_variable_segment(segment)
|
||||
events.extend(segment_events)
|
||||
|
||||
# Only advance index if this variable segment is complete
|
||||
if is_complete:
|
||||
self._active_session.index += 1
|
||||
else:
|
||||
# Wait for more data
|
||||
break
|
||||
|
||||
else:
|
||||
segment_events = self._process_text_segment(segment)
|
||||
events.extend(segment_events)
|
||||
self._active_session.index += 1
|
||||
|
||||
if self._active_session.is_complete():
|
||||
# End current session and get events from starting next session
|
||||
next_session_events = self.end_session(response_node_id)
|
||||
events.extend(next_session_events)
|
||||
|
||||
return events
|
||||
|
||||
def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
End the active session for a response node.
|
||||
Automatically starts the next waiting session if available.
|
||||
|
||||
Args:
|
||||
node_id: ID of the response node ending its session
|
||||
|
||||
Returns:
|
||||
List of events from starting the next session
|
||||
"""
|
||||
with self._lock:
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
if self._active_session and self._active_session.node_id == node_id:
|
||||
self._active_session = None
|
||||
|
||||
# Try to start next waiting session
|
||||
if self._waiting_sessions:
|
||||
next_session = self._waiting_sessions.popleft()
|
||||
self._active_session = next_session
|
||||
|
||||
# Immediately try to flush any available segments
|
||||
events = self.try_flush()
|
||||
|
||||
return events
|
||||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key in self._closed_streams:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
self._stream_buffers[key] = []
|
||||
self._stream_positions[key] = 0
|
||||
|
||||
self._stream_buffers[key].append(event)
|
||||
|
||||
def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
|
||||
"""
|
||||
Pop the next unread stream chunk from the buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return None
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
buffer = self._stream_buffers[key]
|
||||
|
||||
if position >= len(buffer):
|
||||
return None
|
||||
|
||||
event = buffer[position]
|
||||
self._stream_positions[key] = position + 1
|
||||
return event
|
||||
|
||||
def _has_unread_stream(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return False
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
return position < len(self._stream_buffers[key])
|
||||
|
||||
def _close_stream(self, selector: Sequence[str]) -> None:
|
||||
"""
|
||||
Mark a stream as closed (no more chunks can be appended).
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
"""
|
||||
key = tuple(selector)
|
||||
self._closed_streams.add(key)
|
||||
|
||||
def _is_stream_closed(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if a stream is closed.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if the stream is closed, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
return key in self._closed_streams
|
||||
|
||||
def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None:
|
||||
"""Convert an in-memory session into its serializable form."""
|
||||
|
||||
if session is None:
|
||||
return None
|
||||
return ResponseSessionState(node_id=session.node_id, index=session.index)
|
||||
|
||||
def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession:
|
||||
"""Rebuild a response session from serialized data."""
|
||||
|
||||
node = self._graph.nodes.get(session_state.node_id)
|
||||
if node is None:
|
||||
raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state")
|
||||
|
||||
session = ResponseSession.from_node(node)
|
||||
session.index = session_state.index
|
||||
return session
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize coordinator state to JSON."""
|
||||
|
||||
with self._lock:
|
||||
state = ResponseStreamCoordinatorState(
|
||||
response_nodes=sorted(self._response_nodes),
|
||||
active_session=self._serialize_session(self._active_session),
|
||||
waiting_sessions=[
|
||||
session_state
|
||||
for session in list(self._waiting_sessions)
|
||||
if (session_state := self._serialize_session(session)) is not None
|
||||
],
|
||||
pending_sessions=[
|
||||
session_state
|
||||
for _, session in sorted(self._response_sessions.items())
|
||||
if (session_state := self._serialize_session(session)) is not None
|
||||
],
|
||||
node_execution_ids=dict(sorted(self._node_execution_ids.items())),
|
||||
paths_map={
|
||||
node_id: [path.edges.copy() for path in paths]
|
||||
for node_id, paths in sorted(self._paths_maps.items())
|
||||
},
|
||||
stream_buffers=[
|
||||
StreamBufferState(
|
||||
selector=selector,
|
||||
events=[event.model_copy(deep=True) for event in events],
|
||||
)
|
||||
for selector, events in sorted(self._stream_buffers.items())
|
||||
],
|
||||
stream_positions=[
|
||||
StreamPositionState(selector=selector, position=position)
|
||||
for selector, position in sorted(self._stream_positions.items())
|
||||
],
|
||||
closed_streams=sorted(self._closed_streams),
|
||||
)
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore coordinator state from JSON."""
|
||||
|
||||
state = ResponseStreamCoordinatorState.model_validate_json(data)
|
||||
|
||||
if state.type != "ResponseStreamCoordinator":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
with self._lock:
|
||||
self._response_nodes = set(state.response_nodes)
|
||||
self._paths_maps = {
|
||||
node_id: [Path(edges=list(path_edges)) for path_edges in paths]
|
||||
for node_id, paths in state.paths_map.items()
|
||||
}
|
||||
self._node_execution_ids = dict(state.node_execution_ids)
|
||||
|
||||
self._stream_buffers = {
|
||||
tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events]
|
||||
for buffer in state.stream_buffers
|
||||
}
|
||||
self._stream_positions = {
|
||||
tuple(position.selector): position.position for position in state.stream_positions
|
||||
}
|
||||
for selector in self._stream_buffers:
|
||||
self._stream_positions.setdefault(selector, 0)
|
||||
|
||||
self._closed_streams = {tuple(selector) for selector in state.closed_streams}
|
||||
|
||||
self._waiting_sessions = deque(
|
||||
self._session_from_state(session_state) for session_state in state.waiting_sessions
|
||||
)
|
||||
self._response_sessions = {
|
||||
session_state.node_id: self._session_from_state(session_state)
|
||||
for session_state in state.pending_sessions
|
||||
}
|
||||
self._active_session = self._session_from_state(state.active_session) if state.active_session else None
|
||||
35
api/graphon/graph_engine/response_coordinator/path.py
Normal file
35
api/graphon/graph_engine/response_coordinator/path.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
Internal path representation for response coordinator.
|
||||
|
||||
This module contains the private Path class used internally by ResponseStreamCoordinator
|
||||
to track execution paths to response nodes.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeAlias
|
||||
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Path:
|
||||
"""
|
||||
Represents a path of branch edges that must be taken to reach a response node.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
edges: list[EdgeID] = field(default_factory=list[EdgeID])
|
||||
|
||||
def contains_edge(self, edge_id: EdgeID) -> bool:
|
||||
"""Check if this path contains the given edge."""
|
||||
return edge_id in self.edges
|
||||
|
||||
def remove_edge(self, edge_id: EdgeID) -> None:
|
||||
"""Remove the given edge from this path in place."""
|
||||
if self.contains_edge(edge_id):
|
||||
self.edges.remove(edge_id)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the path has no edges (node is reachable)."""
|
||||
return len(self.edges) == 0
|
||||
66
api/graphon/graph_engine/response_coordinator/session.py
Normal file
66
api/graphon/graph_engine/response_coordinator/session.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Internal response session management for response coordinator.
|
||||
|
||||
This module contains the private ResponseSession class used internally
|
||||
by ResponseStreamCoordinator to manage streaming sessions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol, cast
|
||||
|
||||
from graphon.nodes.base.template import Template
|
||||
from graphon.runtime.graph_runtime_state import NodeProtocol
|
||||
|
||||
|
||||
class _ResponseSessionNodeProtocol(NodeProtocol, Protocol):
|
||||
"""Structural contract required from nodes that can open a response session."""
|
||||
|
||||
def get_streaming_template(self) -> Template: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseSession:
|
||||
"""
|
||||
Represents an active response streaming session.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
template: Template # Template object from the response node
|
||||
index: int = 0 # Current position in the template segments
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: NodeProtocol) -> ResponseSession:
|
||||
"""
|
||||
Create a ResponseSession from a response-capable node.
|
||||
|
||||
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer.
|
||||
At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which
|
||||
graph nodes should be treated as response-capable before they reach this factory.
|
||||
|
||||
Args:
|
||||
node: Node from the materialized workflow graph.
|
||||
|
||||
Returns:
|
||||
ResponseSession configured with the node's streaming template
|
||||
|
||||
Raises:
|
||||
TypeError: If node does not implement the response-session streaming contract.
|
||||
"""
|
||||
response_node = cast(_ResponseSessionNodeProtocol, node)
|
||||
try:
|
||||
template = response_node.get_streaming_template()
|
||||
except AttributeError as exc:
|
||||
raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc
|
||||
|
||||
return cls(
|
||||
node_id=node.id,
|
||||
template=template,
|
||||
)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if all segments in the template have been processed."""
|
||||
return self.index >= len(self.template.segments)
|
||||
204
api/graphon/graph_engine/worker.py
Normal file
204
api/graphon/graph_engine/worker.py
Normal file
@ -0,0 +1,204 @@
|
||||
"""
|
||||
Worker - Thread implementation for queue-based node execution
|
||||
|
||||
Workers pull node IDs from the ready_queue, execute nodes, and push events
|
||||
to the event_queue for the dispatcher to process.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from contextlib import AbstractContextManager
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_engine.layers.base import GraphEngineLayer
|
||||
from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
class Worker(threading.Thread):
|
||||
"""
|
||||
Worker thread that executes nodes from the ready queue.
|
||||
|
||||
Workers continuously pull node IDs from the ready_queue, execute the
|
||||
corresponding nodes, and push the resulting events to the event_queue
|
||||
for the dispatcher to process.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: ReadyQueue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
worker_id: int = 0,
|
||||
execution_context: AbstractContextManager[object] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
|
||||
Args:
|
||||
ready_queue: Ready queue containing node IDs ready for execution
|
||||
event_queue: Queue for pushing execution events
|
||||
graph: Graph containing nodes to execute
|
||||
layers: Graph engine layers for node execution hooks
|
||||
worker_id: Unique identifier for this worker
|
||||
execution_context: Optional execution context for context preservation
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._execution_context = execution_context
|
||||
self._stop_event = threading.Event()
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
self._current_node_started_at: datetime | None = None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
"""Check if the worker is currently idle."""
|
||||
# Worker is idle if it hasn't processed a task recently (within 0.2 seconds)
|
||||
return (time.time() - self._last_task_time) > 0.2
|
||||
|
||||
@property
|
||||
def idle_duration(self) -> float:
|
||||
"""Get the duration in seconds since the worker last processed a task."""
|
||||
return time.time() - self._last_task_time
|
||||
|
||||
@property
|
||||
def worker_id(self) -> int:
|
||||
"""Get the worker's ID."""
|
||||
return self._worker_id
|
||||
|
||||
@override
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Main worker loop.
|
||||
|
||||
Continuously pulls node IDs from ready_queue, executes them,
|
||||
and pushes events to event_queue until stopped.
|
||||
"""
|
||||
while not self._stop_event.is_set():
|
||||
# Try to get a node ID from the ready queue (with timeout)
|
||||
try:
|
||||
node_id = self._ready_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
self._current_node_started_at = None
|
||||
self._execute_node(node)
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
self._event_queue.put(
|
||||
self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at)
|
||||
)
|
||||
finally:
|
||||
self._current_node_started_at = None
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
Execute a single node and handle its events.
|
||||
|
||||
Args:
|
||||
node: The node instance to execute
|
||||
"""
|
||||
node.ensure_execution_id()
|
||||
|
||||
error: Exception | None = None
|
||||
result_event: GraphNodeEventBase | None = None
|
||||
|
||||
# Execute the node with preserved context if execution context is provided
|
||||
if self._execution_context is not None:
|
||||
with self._execution_context:
|
||||
self._invoke_node_run_start_hooks(node)
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
|
||||
self._current_node_started_at = event.start_at
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
except Exception as exc:
|
||||
error = exc
|
||||
raise
|
||||
finally:
|
||||
self._invoke_node_run_end_hooks(node, error, result_event)
|
||||
else:
|
||||
self._invoke_node_run_start_hooks(node)
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
|
||||
self._current_node_started_at = event.start_at
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
except Exception as exc:
|
||||
error = exc
|
||||
raise
|
||||
finally:
|
||||
self._invoke_node_run_end_hooks(node, error, result_event)
|
||||
|
||||
def _invoke_node_run_start_hooks(self, node: Node) -> None:
|
||||
"""Invoke on_node_run_start hooks for all layers."""
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_node_run_start(node)
|
||||
except Exception:
|
||||
# Silently ignore layer errors to prevent disrupting node execution
|
||||
continue
|
||||
|
||||
def _invoke_node_run_end_hooks(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
"""Invoke on_node_run_end hooks for all layers."""
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_node_run_end(node, error, result_event)
|
||||
except Exception:
|
||||
# Silently ignore layer errors to prevent disrupting node execution
|
||||
continue
|
||||
|
||||
def _build_fallback_failure_event(
|
||||
self, node: Node, error: Exception, *, started_at: datetime | None = None
|
||||
) -> NodeRunFailedEvent:
|
||||
"""Build a failed event when worker-level execution aborts before a node emits its own result event."""
|
||||
failure_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
error_message = str(error)
|
||||
return NodeRunFailedEvent(
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=error_message,
|
||||
start_at=started_at or failure_time,
|
||||
finished_at=failure_time,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
error_type=type(error).__name__,
|
||||
),
|
||||
)
|
||||
12
api/graphon/graph_engine/worker_management/__init__.py
Normal file
12
api/graphon/graph_engine/worker_management/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""
|
||||
Worker management subsystem for graph engine.
|
||||
|
||||
This package manages the worker pool, including creation,
|
||||
scaling, and activity tracking.
|
||||
"""
|
||||
|
||||
from .worker_pool import WorkerPool
|
||||
|
||||
__all__ = [
|
||||
"WorkerPool",
|
||||
]
|
||||
277
api/graphon/graph_engine/worker_management/worker_pool.py
Normal file
277
api/graphon/graph_engine/worker_management/worker_pool.py
Normal file
@ -0,0 +1,277 @@
|
||||
"""
|
||||
Simple worker pool that consolidates functionality.
|
||||
|
||||
This is a simpler implementation that merges WorkerPool, ActivityTracker,
|
||||
DynamicScaler, and WorkerFactory into a single class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import final
|
||||
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import GraphNodeEventBase
|
||||
|
||||
from ..config import GraphEngineConfig
|
||||
from ..layers.base import GraphEngineLayer
|
||||
from ..ready_queue import ReadyQueue
|
||||
from ..worker import Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
"""
|
||||
Simple worker pool with integrated management.
|
||||
|
||||
This class consolidates all worker management functionality into
|
||||
a single, simpler implementation without excessive abstraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: ReadyQueue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
config: GraphEngineConfig,
|
||||
execution_context: AbstractContextManager[object] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the simple worker pool.
|
||||
|
||||
Args:
|
||||
ready_queue: Ready queue for nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
layers: Graph engine layers for node execution hooks
|
||||
config: GraphEngine worker pool configuration
|
||||
execution_context: Optional execution context for context preservation
|
||||
"""
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._execution_context = execution_context
|
||||
self._layers = layers
|
||||
self._config = config
|
||||
|
||||
# Worker management
|
||||
self._workers: list[Worker] = []
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
def start(self, initial_count: int | None = None) -> None:
|
||||
"""
|
||||
Start the worker pool.
|
||||
|
||||
Args:
|
||||
initial_count: Number of workers to start with (auto-calculated if None)
|
||||
"""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Calculate initial worker count
|
||||
if initial_count is None:
|
||||
node_count = len(self._graph.nodes)
|
||||
if node_count < 10:
|
||||
initial_count = self._config.min_workers
|
||||
elif node_count < 50:
|
||||
initial_count = min(self._config.min_workers + 1, self._config.max_workers)
|
||||
else:
|
||||
initial_count = min(self._config.min_workers + 2, self._config.max_workers)
|
||||
|
||||
logger.debug(
|
||||
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
|
||||
initial_count,
|
||||
node_count,
|
||||
self._config.min_workers,
|
||||
self._config.max_workers,
|
||||
)
|
||||
|
||||
# Create initial workers
|
||||
for _ in range(initial_count):
|
||||
self._create_worker()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop all workers in the pool."""
|
||||
with self._lock:
|
||||
self._running = False
|
||||
worker_count = len(self._workers)
|
||||
|
||||
if worker_count > 0:
|
||||
logger.debug("Stopping worker pool: %d workers", worker_count)
|
||||
|
||||
# Stop all workers
|
||||
for worker in self._workers:
|
||||
worker.stop()
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
self._workers.clear()
|
||||
|
||||
def _create_worker(self) -> None:
|
||||
"""Create and start a new worker."""
|
||||
worker_id = self._worker_counter
|
||||
self._worker_counter += 1
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
worker_id=worker_id,
|
||||
execution_context=self._execution_context,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
self._workers.append(worker)
|
||||
|
||||
def _remove_worker(self, worker: Worker, worker_id: int) -> None:
|
||||
"""Remove a specific worker from the pool."""
|
||||
# Stop the worker
|
||||
worker.stop()
|
||||
|
||||
# Wait for it to finish
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
# Remove from list
|
||||
if worker in self._workers:
|
||||
self._workers.remove(worker)
|
||||
|
||||
def _try_scale_up(self, queue_depth: int, current_count: int) -> bool:
|
||||
"""
|
||||
Try to scale up workers if needed.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
current_count: Current number of workers
|
||||
|
||||
Returns:
|
||||
True if scaled up, False otherwise
|
||||
"""
|
||||
if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers:
|
||||
old_count = current_count
|
||||
self._create_worker()
|
||||
|
||||
logger.debug(
|
||||
"Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)",
|
||||
old_count,
|
||||
len(self._workers),
|
||||
queue_depth,
|
||||
self._config.scale_up_threshold,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool:
|
||||
"""
|
||||
Try to scale down workers if we have excess capacity.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
current_count: Current number of workers
|
||||
active_count: Number of active workers
|
||||
idle_count: Number of idle workers
|
||||
|
||||
Returns:
|
||||
True if scaled down, False otherwise
|
||||
"""
|
||||
# Skip if we're at minimum or have no idle workers
|
||||
if current_count <= self._config.min_workers or idle_count == 0:
|
||||
return False
|
||||
|
||||
# Check if we have excess capacity
|
||||
has_excess_capacity = (
|
||||
queue_depth <= active_count # Active workers can handle current queue
|
||||
or idle_count > active_count # More idle than active workers
|
||||
or (queue_depth == 0 and idle_count > 0) # No work and have idle workers
|
||||
)
|
||||
|
||||
if not has_excess_capacity:
|
||||
return False
|
||||
|
||||
# Find and remove idle workers that have been idle long enough
|
||||
workers_to_remove: list[tuple[Worker, int]] = []
|
||||
|
||||
for worker in self._workers:
|
||||
# Check if worker is idle and has exceeded idle time threshold
|
||||
if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time:
|
||||
# Don't remove if it would leave us unable to handle the queue
|
||||
remaining_workers = current_count - len(workers_to_remove) - 1
|
||||
if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2):
|
||||
workers_to_remove.append((worker, worker.worker_id))
|
||||
# Only remove one worker per check to avoid aggressive scaling
|
||||
break
|
||||
|
||||
# Remove idle workers if any found
|
||||
if workers_to_remove:
|
||||
old_count = current_count
|
||||
for worker, worker_id in workers_to_remove:
|
||||
self._remove_worker(worker, worker_id)
|
||||
|
||||
logger.debug(
|
||||
"Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, "
|
||||
"queue_depth=%d, active=%d, idle=%d)",
|
||||
old_count,
|
||||
len(self._workers),
|
||||
len(workers_to_remove),
|
||||
self._config.scale_down_idle_time,
|
||||
queue_depth,
|
||||
active_count,
|
||||
idle_count - len(workers_to_remove),
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_and_scale(self) -> None:
|
||||
"""Check and perform scaling if needed."""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
current_count = len(self._workers)
|
||||
queue_depth = self._ready_queue.qsize()
|
||||
|
||||
# Count active vs idle workers by querying their state directly
|
||||
idle_count = sum(1 for worker in self._workers if worker.is_idle)
|
||||
active_count = current_count - idle_count
|
||||
|
||||
# Try to scale up if queue is backing up
|
||||
self._try_scale_up(queue_depth, current_count)
|
||||
|
||||
# Try to scale down if we have excess capacity
|
||||
self._try_scale_down(queue_depth, current_count, active_count, idle_count)
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
"""Get current number of workers."""
|
||||
with self._lock:
|
||||
return len(self._workers)
|
||||
|
||||
def get_status(self) -> dict[str, int]:
|
||||
"""
|
||||
Get pool status information.
|
||||
|
||||
Returns:
|
||||
Dictionary with status information
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_workers": len(self._workers),
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"min_workers": self._config.min_workers,
|
||||
"max_workers": self._config.max_workers,
|
||||
}
|
||||
84
api/graphon/graph_events/__init__.py
Normal file
84
api/graphon/graph_events/__init__.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Agent events
|
||||
from .agent import NodeRunAgentLogEvent
|
||||
|
||||
# Base events
|
||||
from .base import (
|
||||
BaseGraphEvent,
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
)
|
||||
|
||||
# Graph events
|
||||
from .graph import (
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
||||
# Iteration events
|
||||
from .iteration import (
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
)
|
||||
|
||||
# Loop events
|
||||
from .loop import (
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
)
|
||||
|
||||
# Node events
|
||||
from .node import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunVariableUpdatedEvent,
|
||||
is_node_result_event,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseGraphEvent",
|
||||
"GraphEngineEvent",
|
||||
"GraphNodeEventBase",
|
||||
"GraphRunAbortedEvent",
|
||||
"GraphRunFailedEvent",
|
||||
"GraphRunPartialSucceededEvent",
|
||||
"GraphRunPausedEvent",
|
||||
"GraphRunStartedEvent",
|
||||
"GraphRunSucceededEvent",
|
||||
"NodeRunAgentLogEvent",
|
||||
"NodeRunExceptionEvent",
|
||||
"NodeRunFailedEvent",
|
||||
"NodeRunHumanInputFormFilledEvent",
|
||||
"NodeRunHumanInputFormTimeoutEvent",
|
||||
"NodeRunIterationFailedEvent",
|
||||
"NodeRunIterationNextEvent",
|
||||
"NodeRunIterationStartedEvent",
|
||||
"NodeRunIterationSucceededEvent",
|
||||
"NodeRunLoopFailedEvent",
|
||||
"NodeRunLoopNextEvent",
|
||||
"NodeRunLoopStartedEvent",
|
||||
"NodeRunLoopSucceededEvent",
|
||||
"NodeRunPauseRequestedEvent",
|
||||
"NodeRunRetrieverResourceEvent",
|
||||
"NodeRunRetryEvent",
|
||||
"NodeRunStartedEvent",
|
||||
"NodeRunStreamChunkEvent",
|
||||
"NodeRunSucceededEvent",
|
||||
"NodeRunVariableUpdatedEvent",
|
||||
"is_node_result_event",
|
||||
]
|
||||
17
api/graphon/graph_events/agent.py
Normal file
17
api/graphon/graph_events/agent.py
Normal file
@ -0,0 +1,17 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import GraphAgentNodeEventBase
|
||||
|
||||
|
||||
class NodeRunAgentLogEvent(GraphAgentNodeEventBase):
|
||||
message_id: str = Field(..., description="message id")
|
||||
label: str = Field(..., description="label")
|
||||
node_execution_id: str = Field(..., description="node execution id")
|
||||
parent_id: str | None = Field(..., description="parent id")
|
||||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
31
api/graphon/graph_events/base.py
Normal file
31
api/graphon/graph_events/base.py
Normal file
@ -0,0 +1,31 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.enums import NodeType
|
||||
from graphon.node_events import NodeRunResult
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphNodeEventBase(GraphEngineEvent):
|
||||
id: str = Field(..., description="node execution id")
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
node_run_result: NodeRunResult = Field(default_factory=NodeRunResult)
|
||||
|
||||
|
||||
class GraphAgentNodeEventBase(GraphNodeEventBase):
|
||||
pass
|
||||
57
api/graphon/graph_events/graph.py
Normal file
57
api/graphon/graph_events/graph.py
Normal file
@ -0,0 +1,57 @@
|
||||
from pydantic import Field
|
||||
|
||||
from graphon.entities.pause_reason import PauseReason
|
||||
from graphon.entities.workflow_start_reason import WorkflowStartReason
|
||||
from graphon.graph_events import BaseGraphEvent
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
# Reason is emitted for workflow start events and is always set.
|
||||
reason: WorkflowStartReason = Field(
|
||||
default=WorkflowStartReason.INITIAL,
|
||||
description="reason for workflow start",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
"""Event emitted when a run completes successfully with final outputs."""
|
||||
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Final workflow outputs keyed by output selector.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: int = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
"""Event emitted when a run finishes with partial success and failures."""
|
||||
|
||||
exceptions_count: int = Field(..., description="exception count")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs that were materialised before failures occurred.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunAbortedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is aborted by user command."""
|
||||
|
||||
reason: str | None = Field(default=None, description="reason for abort")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs produced before the abort was requested.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list)
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
)
|
||||
0
api/graphon/graph_events/human_input.py
Normal file
0
api/graphon/graph_events/human_input.py
Normal file
40
api/graphon/graph_events/iteration.py
Normal file
40
api/graphon/graph_events/iteration.py
Normal file
@ -0,0 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
|
||||
class NodeRunIterationStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class NodeRunIterationNextEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Any = None
|
||||
|
||||
|
||||
class NodeRunIterationSucceededEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class NodeRunIterationFailedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
40
api/graphon/graph_events/loop.py
Normal file
40
api/graphon/graph_events/loop.py
Normal file
@ -0,0 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
|
||||
class NodeRunLoopStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class NodeRunLoopNextEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Any = None
|
||||
|
||||
|
||||
class NodeRunLoopSucceededEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class NodeRunLoopFailedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
106
api/graphon/graph_events/node.py
Normal file
106
api/graphon/graph_events/node.py
Normal file
@ -0,0 +1,106 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from graphon.entities.pause_reason import PauseReason
|
||||
from graphon.variables.variables import Variable
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
|
||||
class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
predecessor_node_id: str | None = None
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode
|
||||
provider_type: str = ""
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
# Spec-compliant fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(GraphNodeEventBase):
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunVariableUpdatedEvent(GraphNodeEventBase):
|
||||
"""Request that the engine apply a variable update before downstream observers continue."""
|
||||
|
||||
variable: Variable = Field(..., description="Updated variable payload to apply.")
|
||||
|
||||
|
||||
class NodeRunFailedEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
|
||||
|
||||
class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase):
|
||||
"""Emitted when a HumanInput form is submitted and before the node finishes."""
|
||||
|
||||
node_title: str = Field(..., description="HumanInput node title")
|
||||
rendered_content: str = Field(..., description="Markdown content rendered with user inputs.")
|
||||
action_id: str = Field(..., description="User action identifier chosen in the form.")
|
||||
action_text: str = Field(..., description="Display text of the chosen action button.")
|
||||
|
||||
|
||||
class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase):
|
||||
"""Emitted when a HumanInput form times out."""
|
||||
|
||||
node_title: str = Field(..., description="HumanInput node title")
|
||||
expiration_time: datetime = Field(..., description="Form expiration time")
|
||||
|
||||
|
||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
||||
|
||||
def is_node_result_event(event: GraphNodeEventBase) -> bool:
|
||||
"""
|
||||
Check if an event is a final result event from node execution.
|
||||
|
||||
A result event indicates the completion of a node execution and contains
|
||||
runtime information such as inputs, outputs, or error details.
|
||||
|
||||
Args:
|
||||
event: The event to check
|
||||
|
||||
Returns:
|
||||
True if the event is a node result event (succeeded/failed/paused), False otherwise
|
||||
"""
|
||||
return isinstance(
|
||||
event,
|
||||
(
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
),
|
||||
)
|
||||
51
api/graphon/model_runtime/README.md
Normal file
51
api/graphon/model_runtime/README.md
Normal file
@ -0,0 +1,51 @@
|
||||
# Model Runtime
|
||||
|
||||
This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
|
||||
|
||||
- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
|
||||
- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
|
||||
|
||||
## Features
|
||||
|
||||
- Supports capability invocation for 6 types of models
|
||||
|
||||
- `LLM` - LLM text completion, dialogue, pre-computed tokens capability
|
||||
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
|
||||
- `Rerank Model` - Segment Rerank capability
|
||||
- `Speech-to-text Model` - Speech to text capability
|
||||
- `Text-to-speech Model` - Text to speech capability
|
||||
- `Moderation` - Moderation capability
|
||||
|
||||
- Model provider display
|
||||
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
|
||||
|
||||
- Selectable model list display
|
||||
|
||||
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
|
||||
|
||||
In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
|
||||
|
||||
- Provider/model credential authentication
|
||||
|
||||
The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
|
||||
|
||||
## Structure
|
||||
|
||||
Model Runtime is divided into three layers:
|
||||
|
||||
- The outermost layer is the factory method
|
||||
|
||||
It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials.
|
||||
|
||||
- The second layer is the provider layer
|
||||
|
||||
It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers.
|
||||
|
||||
- The bottom layer is the model layer
|
||||
|
||||
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
|
||||
|
||||
## Documentation
|
||||
|
||||
For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).
|
||||
64
api/graphon/model_runtime/README_CN.md
Normal file
64
api/graphon/model_runtime/README_CN.md
Normal file
@ -0,0 +1,64 @@
|
||||
# Model Runtime
|
||||
|
||||
该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。
|
||||
|
||||
- 一方面将模型和上下游解耦,方便开发者对模型横向扩展,
|
||||
- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。
|
||||
|
||||
## 功能介绍
|
||||
|
||||
- 支持 6 种模型类型的能力调用
|
||||
|
||||
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力
|
||||
- `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力
|
||||
- `Rerank Model` - 分段 Rerank 能力
|
||||
- `Speech-to-text Model` - 语音转文本能力
|
||||
- `Text-to-speech Model` - 文本转语音能力
|
||||
- `Moderation` - Moderation 能力
|
||||
|
||||
- 模型供应商展示
|
||||
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
|
||||
|
||||
- 可选择的模型列表展示
|
||||
|
||||
配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
|
||||
|
||||
除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
|
||||
|
||||
- 供应商/模型凭据鉴权
|
||||
|
||||
供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
|
||||
|
||||
## 结构
|
||||
|
||||
Model Runtime 分三层:
|
||||
|
||||
- 最外层为工厂方法
|
||||
|
||||
提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。
|
||||
|
||||
- 第二层为供应商层
|
||||
|
||||
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
|
||||
|
||||
对于供应商/模型凭据,有两种情况
|
||||
|
||||
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
|
||||
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
|
||||
|
||||
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
|
||||
|
||||
- 最底层为模型层
|
||||
|
||||
提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。
|
||||
|
||||
在这里我们需要先区分模型参数与模型凭据。
|
||||
|
||||
- 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。
|
||||
|
||||
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
|
||||
|
||||
## 文档
|
||||
|
||||
有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。
|
||||
0
api/graphon/model_runtime/__init__.py
Normal file
0
api/graphon/model_runtime/__init__.py
Normal file
0
api/graphon/model_runtime/callbacks/__init__.py
Normal file
0
api/graphon/model_runtime/callbacks/__init__.py
Normal file
159
api/graphon/model_runtime/callbacks/base_callback.py
Normal file
159
api/graphon/model_runtime/callbacks/base_callback.py
Normal file
@ -0,0 +1,159 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
class Callback(ABC):
|
||||
"""
|
||||
Base class for callbacks.
|
||||
Only for LLM.
|
||||
"""
|
||||
|
||||
raise_error: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def on_before_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
invocation_context: Mapping[str, object] | None = None,
|
||||
):
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: optional end-user identifier for the invocation
|
||||
:param invocation_context: opaque request metadata for the current invocation
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_new_chunk(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
invocation_context: Mapping[str, object] | None = None,
|
||||
):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: optional end-user identifier for the invocation
|
||||
:param invocation_context: opaque request metadata for the current invocation
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_after_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
result: LLMResult,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
invocation_context: Mapping[str, object] | None = None,
|
||||
):
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: optional end-user identifier for the invocation
|
||||
:param invocation_context: opaque request metadata for the current invocation
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_invoke_error(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
ex: Exception,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
invocation_context: Mapping[str, object] | None = None,
|
||||
):
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: optional end-user identifier for the invocation
|
||||
:param invocation_context: opaque request metadata for the current invocation
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def print_text(self, text: str, color: str | None = None, end: str = ""):
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end)
|
||||
|
||||
def _get_colored_text(self, text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
180
api/graphon/model_runtime/callbacks/logging_callback.py
Normal file
180
api/graphon/model_runtime/callbacks/logging_callback.py
Normal file
@ -0,0 +1,180 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.callbacks.base_callback import Callback
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoggingCallback(Callback):
|
||||
def on_before_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
invocation_context: Mapping[str, object] | None = None,
|
||||
):
|
||||
"""
|
||||
Before invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: optional end-user identifier for the invocation
|
||||
:param invocation_context: opaque request metadata for the current invocation
|
||||
"""
|
||||
self.print_text("\n[on_llm_before_invoke]\n", color="blue")
|
||||
self.print_text(f"Model: {model}\n", color="blue")
|
||||
self.print_text("Parameters:\n", color="blue")
|
||||
for key, value in model_parameters.items():
|
||||
self.print_text(f"\t{key}: {value}\n", color="blue")
|
||||
|
||||
if stop:
|
||||
self.print_text(f"\tstop: {stop}\n", color="blue")
|
||||
|
||||
if tools:
|
||||
self.print_text("\tTools:\n", color="blue")
|
||||
for tool in tools:
|
||||
self.print_text(f"\t\t{tool.name}\n", color="blue")
|
||||
|
||||
self.print_text(f"Stream: {stream}\n", color="blue")
|
||||
if user:
|
||||
self.print_text(f"User: {user}\n", color="blue")
|
||||
|
||||
if invocation_context:
|
||||
self.print_text(f"Invocation context: {dict(invocation_context)}\n", color="blue")
|
||||
|
||||
self.print_text("Prompt messages:\n", color="blue")
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.name:
|
||||
self.print_text(f"\tname: {prompt_message.name}\n", color="blue")
|
||||
|
||||
self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue")
|
||||
self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue")
|
||||
|
||||
if stream:
|
||||
self.print_text("\n[on_llm_new_chunk]")
|
||||
|
||||
def on_new_chunk(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
invocation_context: Mapping[str, object] | None = None,
|
||||
):
|
||||
"""
|
||||
On new chunk callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param chunk: chunk
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param invocation_context: opaque request metadata for the current invocation
|
||||
"""
|
||||
_ = user, invocation_context
|
||||
sys.stdout.write(cast(str, chunk.delta.message.content))
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_after_invoke(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
result: LLMResult,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
invocation_context: Mapping[str, object] | None = None,
|
||||
):
|
||||
"""
|
||||
After invoke callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param result: result
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param invocation_context: opaque request metadata for the current invocation
|
||||
"""
|
||||
_ = user, invocation_context
|
||||
self.print_text("\n[on_llm_after_invoke]\n", color="yellow")
|
||||
self.print_text(f"Content: {result.message.content}\n", color="yellow")
|
||||
|
||||
if result.message.tool_calls:
|
||||
self.print_text("Tool calls:\n", color="yellow")
|
||||
for tool_call in result.message.tool_calls:
|
||||
self.print_text(f"\t{tool_call.id}\n", color="yellow")
|
||||
self.print_text(f"\t{tool_call.function.name}\n", color="yellow")
|
||||
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow")
|
||||
|
||||
self.print_text(f"Model: {result.model}\n", color="yellow")
|
||||
self.print_text(f"Usage: {result.usage}\n", color="yellow")
|
||||
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow")
|
||||
|
||||
def on_invoke_error(
|
||||
self,
|
||||
llm_instance: AIModel,
|
||||
ex: Exception,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
invocation_context: Mapping[str, object] | None = None,
|
||||
):
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
:param llm_instance: LLM instance
|
||||
:param ex: exception
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param invocation_context: opaque request metadata for the current invocation
|
||||
"""
|
||||
_ = user, invocation_context
|
||||
self.print_text("\n[on_llm_invoke_error]\n", color="red")
|
||||
logger.exception(ex)
|
||||
43
api/graphon/model_runtime/entities/__init__.py
Normal file
43
api/graphon/model_runtime/entities/__init__.py
Normal file
@ -0,0 +1,43 @@
|
||||
from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from .message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from .model_entities import ModelPropertyKey
|
||||
|
||||
__all__ = [
|
||||
"AssistantPromptMessage",
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
"ImagePromptMessageContent",
|
||||
"LLMMode",
|
||||
"LLMResult",
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"MultiModalPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
"PromptMessageContentType",
|
||||
"PromptMessageRole",
|
||||
"PromptMessageTool",
|
||||
"SystemPromptMessage",
|
||||
"TextPromptMessageContent",
|
||||
"ToolPromptMessage",
|
||||
"UserPromptMessage",
|
||||
"VideoPromptMessageContent",
|
||||
]
|
||||
16
api/graphon/model_runtime/entities/common_entities.py
Normal file
16
api/graphon/model_runtime/entities/common_entities.py
Normal file
@ -0,0 +1,16 @@
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
zh_Hans: str | None = None
|
||||
en_US: str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.zh_Hans:
|
||||
self.zh_Hans = self.en_US
|
||||
return self
|
||||
130
api/graphon/model_runtime/entities/defaults.py
Normal file
130
api/graphon/model_runtime/entities/defaults.py
Normal file
@ -0,0 +1,130 @@
|
||||
from graphon.model_runtime.entities.model_entities import DefaultParameterName
|
||||
|
||||
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
DefaultParameterName.TEMPERATURE: {
|
||||
"label": {
|
||||
"en_US": "Temperature",
|
||||
"zh_Hans": "温度",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Controls randomness. Lower temperature results in less random completions."
|
||||
" As the temperature approaches zero, the model will become deterministic and repetitive."
|
||||
" Higher temperature results in more random completions.",
|
||||
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。"
|
||||
"较高的温度会导致更多的随机完成。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.TOP_P: {
|
||||
"label": {
|
||||
"en_US": "Top P",
|
||||
"zh_Hans": "Top P",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
|
||||
" are considered.",
|
||||
"zh_Hans": "通过核心采样控制多样性:0.5 表示考虑了一半的所有可能性加权选项。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.TOP_K: {
|
||||
"label": {
|
||||
"en_US": "Top K",
|
||||
"zh_Hans": "Top K",
|
||||
},
|
||||
"type": "int",
|
||||
"help": {
|
||||
"en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.",
|
||||
"zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 50,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"precision": 0,
|
||||
},
|
||||
DefaultParameterName.PRESENCE_PENALTY: {
|
||||
"label": {
|
||||
"en_US": "Presence Penalty",
|
||||
"zh_Hans": "存在惩罚",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Applies a penalty to the log-probability of tokens already in the text.",
|
||||
"zh_Hans": "对文本中已有的标记的对数概率施加惩罚。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.FREQUENCY_PENALTY: {
|
||||
"label": {
|
||||
"en_US": "Frequency Penalty",
|
||||
"zh_Hans": "频率惩罚",
|
||||
},
|
||||
"type": "float",
|
||||
"help": {
|
||||
"en_US": "Applies a penalty to the log-probability of tokens that appear in the text.",
|
||||
"zh_Hans": "对文本中出现的标记的对数概率施加惩罚。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"precision": 2,
|
||||
},
|
||||
DefaultParameterName.MAX_TOKENS: {
|
||||
"label": {
|
||||
"en_US": "Max Tokens",
|
||||
"zh_Hans": "最大 Token 数",
|
||||
},
|
||||
"type": "int",
|
||||
"help": {
|
||||
"en_US": "Specifies the upper limit on the length of generated results."
|
||||
" If the generated results are truncated, you can increase this parameter.",
|
||||
"zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
|
||||
},
|
||||
"required": False,
|
||||
"default": 64,
|
||||
"min": 1,
|
||||
"max": 2048,
|
||||
"precision": 0,
|
||||
},
|
||||
DefaultParameterName.RESPONSE_FORMAT: {
|
||||
"label": {
|
||||
"en_US": "Response Format",
|
||||
"zh_Hans": "回复格式",
|
||||
},
|
||||
"type": "string",
|
||||
"help": {
|
||||
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
|
||||
" such as JSON, XML, etc.",
|
||||
"zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML 等",
|
||||
},
|
||||
"required": False,
|
||||
"options": ["JSON", "XML"],
|
||||
},
|
||||
DefaultParameterName.JSON_SCHEMA: {
|
||||
"label": {
|
||||
"en_US": "JSON Schema",
|
||||
},
|
||||
"type": "text",
|
||||
"help": {
|
||||
"en_US": "Set a response json schema will ensure LLM to adhere it.",
|
||||
"zh_Hans": "设置返回的 json schema,llm 将按照它返回",
|
||||
},
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
219
api/graphon/model_runtime/entities/llm_entities.py
Normal file
219
api/graphon/model_runtime/entities/llm_entities.py
Normal file
@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import Any, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import ModelUsage, PriceInfo
|
||||
|
||||
|
||||
class LLMMode(StrEnum):
|
||||
"""
|
||||
Enum class for large language model mode.
|
||||
"""
|
||||
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
class LLMUsageMetadata(TypedDict, total=False):
|
||||
"""
|
||||
TypedDict for LLM usage metadata.
|
||||
All fields are optional.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_unit_price: Union[float, str]
|
||||
completion_unit_price: Union[float, str]
|
||||
total_price: Union[float, str]
|
||||
currency: str
|
||||
prompt_price_unit: Union[float, str]
|
||||
completion_price_unit: Union[float, str]
|
||||
prompt_price: Union[float, str]
|
||||
completion_price: Union[float, str]
|
||||
latency: float
|
||||
time_to_first_token: float
|
||||
time_to_generate: float
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
"""
|
||||
Model class for llm usage.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
prompt_unit_price: Decimal
|
||||
prompt_price_unit: Decimal
|
||||
prompt_price: Decimal
|
||||
completion_tokens: int
|
||||
completion_unit_price: Decimal
|
||||
completion_price_unit: Decimal
|
||||
completion_price: Decimal
|
||||
total_tokens: int
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
time_to_first_token: float | None = None
|
||||
time_to_generate: float | None = None
|
||||
|
||||
@classmethod
|
||||
def empty_usage(cls):
|
||||
return cls(
|
||||
prompt_tokens=0,
|
||||
prompt_unit_price=Decimal("0.0"),
|
||||
prompt_price_unit=Decimal("0.0"),
|
||||
prompt_price=Decimal("0.0"),
|
||||
completion_tokens=0,
|
||||
completion_unit_price=Decimal("0.0"),
|
||||
completion_price_unit=Decimal("0.0"),
|
||||
completion_price=Decimal("0.0"),
|
||||
total_tokens=0,
|
||||
total_price=Decimal("0.0"),
|
||||
currency="USD",
|
||||
latency=0.0,
|
||||
time_to_first_token=None,
|
||||
time_to_generate=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage:
|
||||
"""
|
||||
Create LLMUsage instance from metadata dictionary with default values.
|
||||
|
||||
Args:
|
||||
metadata: TypedDict containing usage metadata
|
||||
|
||||
Returns:
|
||||
LLMUsage instance with values from metadata or defaults
|
||||
"""
|
||||
prompt_tokens = metadata.get("prompt_tokens", 0)
|
||||
completion_tokens = metadata.get("completion_tokens", 0)
|
||||
total_tokens = metadata.get("total_tokens", 0)
|
||||
|
||||
# If total_tokens is not provided but prompt and completion tokens are,
|
||||
# calculate total_tokens
|
||||
if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0):
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return cls(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
|
||||
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
|
||||
total_price=Decimal(str(metadata.get("total_price", 0))),
|
||||
currency=metadata.get("currency", "USD"),
|
||||
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
|
||||
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
|
||||
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
|
||||
completion_price=Decimal(str(metadata.get("completion_price", 0))),
|
||||
latency=metadata.get("latency", 0.0),
|
||||
time_to_first_token=metadata.get("time_to_first_token"),
|
||||
time_to_generate=metadata.get("time_to_generate"),
|
||||
)
|
||||
|
||||
def plus(self, other: LLMUsage) -> LLMUsage:
|
||||
"""
|
||||
Add two LLMUsage instances together.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
if self.total_tokens == 0:
|
||||
return other
|
||||
else:
|
||||
return LLMUsage(
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
prompt_unit_price=other.prompt_unit_price,
|
||||
prompt_price_unit=other.prompt_price_unit,
|
||||
prompt_price=self.prompt_price + other.prompt_price,
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
completion_unit_price=other.completion_unit_price,
|
||||
completion_price_unit=other.completion_price_unit,
|
||||
completion_price=self.completion_price + other.completion_price,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=other.currency,
|
||||
latency=self.latency + other.latency,
|
||||
time_to_first_token=other.time_to_first_token,
|
||||
time_to_generate=other.time_to_generate,
|
||||
)
|
||||
|
||||
def __add__(self, other: LLMUsage) -> LLMUsage:
|
||||
"""
|
||||
Overload the + operator to add two LLMUsage instances.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
return self.plus(other)
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
Model class for llm result.
|
||||
"""
|
||||
|
||||
id: str | None = None
|
||||
model: str
|
||||
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
|
||||
message: AssistantPromptMessage
|
||||
usage: LLMUsage
|
||||
system_fingerprint: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
class LLMStructuredOutput(BaseModel):
|
||||
"""
|
||||
Model class for llm structured output.
|
||||
"""
|
||||
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
|
||||
|
||||
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
|
||||
"""
|
||||
Model class for llm result with structured output.
|
||||
"""
|
||||
|
||||
|
||||
class LLMResultChunkDelta(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk delta.
|
||||
"""
|
||||
|
||||
index: int
|
||||
message: AssistantPromptMessage
|
||||
usage: LLMUsage | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class LLMResultChunk(BaseModel):
|
||||
"""
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
|
||||
model: str
|
||||
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
|
||||
system_fingerprint: str | None = None
|
||||
delta: LLMResultChunkDelta
|
||||
|
||||
|
||||
class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput):
|
||||
"""
|
||||
Model class for llm result chunk with structured output.
|
||||
"""
|
||||
|
||||
|
||||
class NumTokensResult(PriceInfo):
|
||||
"""
|
||||
Model class for number of tokens result.
|
||||
"""
|
||||
|
||||
tokens: int
|
||||
279
api/graphon/model_runtime/entities/message_entities.py
Normal file
279
api/graphon/model_runtime/entities/message_entities.py
Normal file
@ -0,0 +1,279 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message.
|
||||
"""
|
||||
|
||||
SYSTEM = auto()
|
||||
USER = auto()
|
||||
ASSISTANT = auto()
|
||||
TOOL = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> PromptMessageRole:
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid prompt message type value {value}")
|
||||
|
||||
|
||||
class PromptMessageTool(BaseModel):
|
||||
"""
|
||||
Model class for prompt message tool.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
||||
|
||||
|
||||
class PromptMessageFunction(BaseModel):
|
||||
"""
|
||||
Model class for prompt message function.
|
||||
"""
|
||||
|
||||
type: str = "function"
|
||||
function: PromptMessageTool
|
||||
|
||||
|
||||
class PromptMessageContentType(StrEnum):
|
||||
"""
|
||||
Enum class for prompt message content type.
|
||||
"""
|
||||
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
AUDIO = auto()
|
||||
VIDEO = auto()
|
||||
DOCUMENT = auto()
|
||||
|
||||
|
||||
class PromptMessageContent(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for text prompt message content.
|
||||
"""
|
||||
|
||||
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore
|
||||
data: str
|
||||
|
||||
|
||||
class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for multi-modal prompt message content.
|
||||
"""
|
||||
|
||||
format: str = Field(default=..., description="the format of multi-modal file")
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
filename: str = Field(default="", description="the filename of multi-modal file")
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore
|
||||
|
||||
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore
|
||||
|
||||
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
|
||||
class DETAIL(StrEnum):
|
||||
LOW = auto()
|
||||
HIGH = auto()
|
||||
|
||||
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore
|
||||
|
||||
|
||||
PromptMessageContentUnionTypes = Annotated[
|
||||
Union[
|
||||
TextPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
AudioPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
|
||||
PromptMessageContentType.TEXT: TextPromptMessageContent,
|
||||
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
|
||||
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
|
||||
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
|
||||
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: str | list[PromptMessageContentUnionTypes] | None = None
|
||||
name: str | None = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
return not self.content
|
||||
|
||||
def get_text_content(self) -> str:
|
||||
"""
|
||||
Get text content from prompt message.
|
||||
|
||||
:return: Text content as string, empty string if no text content
|
||||
"""
|
||||
if isinstance(self.content, str):
|
||||
return self.content
|
||||
elif isinstance(self.content, list):
|
||||
text_parts = []
|
||||
for item in self.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return "".join(text_parts)
|
||||
else:
|
||||
return ""
|
||||
|
||||
@field_validator("content", mode="before")
|
||||
@classmethod
|
||||
def validate_content(cls, v):
|
||||
if isinstance(v, list):
|
||||
prompts = []
|
||||
for prompt in v:
|
||||
if isinstance(prompt, PromptMessageContent):
|
||||
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
|
||||
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
|
||||
elif isinstance(prompt, dict):
|
||||
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
|
||||
else:
|
||||
raise ValueError(f"invalid prompt message {prompt}")
|
||||
prompts.append(prompt)
|
||||
return prompts
|
||||
return v
|
||||
|
||||
@field_serializer("content")
|
||||
def serialize_content(
|
||||
self, content: Union[str, Sequence[PromptMessageContent]] | None
|
||||
) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None:
|
||||
if content is None or isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
return [item.model_dump() if hasattr(item, "model_dump") else item for item in content]
|
||||
return content
|
||||
|
||||
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for user prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.USER
|
||||
|
||||
|
||||
class AssistantPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for assistant prompt message.
|
||||
"""
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call.
|
||||
"""
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Model class for assistant prompt message tool call function.
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
id: str
|
||||
type: str
|
||||
function: ToolCallFunction
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
return super().is_empty() and not self.tool_calls
|
||||
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for system prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
||||
|
||||
|
||||
class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
Model class for tool prompt message.
|
||||
"""
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
242
api/graphon/model_runtime/entities/model_entities.py
Normal file
242
api/graphon/model_runtime/entities/model_entities.py
Normal file
@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ModelType(StrEnum):
|
||||
"""
|
||||
Enum class for model type.
|
||||
"""
|
||||
|
||||
LLM = auto()
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = auto()
|
||||
SPEECH2TEXT = auto()
|
||||
MODERATION = auto()
|
||||
TTS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> ModelType:
|
||||
"""
|
||||
Get model type from origin model type.
|
||||
|
||||
:return: model type
|
||||
"""
|
||||
if origin_model_type in {"text-generation", cls.LLM}:
|
||||
return cls.LLM
|
||||
elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}:
|
||||
return cls.TEXT_EMBEDDING
|
||||
elif origin_model_type in {"reranking", cls.RERANK}:
|
||||
return cls.RERANK
|
||||
elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}:
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type in {"tts", cls.TTS}:
|
||||
return cls.TTS
|
||||
elif origin_model_type == cls.MODERATION:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
raise ValueError(f"invalid origin model type {origin_model_type}")
|
||||
|
||||
def to_origin_model_type(self) -> str:
|
||||
"""
|
||||
Get origin model type from model type.
|
||||
|
||||
:return: origin model type
|
||||
"""
|
||||
if self == self.LLM:
|
||||
return "text-generation"
|
||||
elif self == self.TEXT_EMBEDDING:
|
||||
return "embeddings"
|
||||
elif self == self.RERANK:
|
||||
return "reranking"
|
||||
elif self == self.SPEECH2TEXT:
|
||||
return "speech2text"
|
||||
elif self == self.TTS:
|
||||
return "tts"
|
||||
elif self == self.MODERATION:
|
||||
return "moderation"
|
||||
else:
|
||||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
||||
class FetchFrom(StrEnum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
"""
|
||||
|
||||
PREDEFINED_MODEL = "predefined-model"
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class ModelFeature(StrEnum):
|
||||
"""
|
||||
Enum class for llm feature.
|
||||
"""
|
||||
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
STRUCTURED_OUTPUT = "structured-output"
|
||||
|
||||
|
||||
class DefaultParameterName(StrEnum):
|
||||
"""
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
|
||||
TEMPERATURE = auto()
|
||||
TOP_P = auto()
|
||||
TOP_K = auto()
|
||||
PRESENCE_PENALTY = auto()
|
||||
FREQUENCY_PENALTY = auto()
|
||||
MAX_TOKENS = auto()
|
||||
RESPONSE_FORMAT = auto()
|
||||
JSON_SCHEMA = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> DefaultParameterName:
|
||||
"""
|
||||
Get parameter name from value.
|
||||
|
||||
:param value: parameter value
|
||||
:return: parameter name
|
||||
"""
|
||||
for name in cls:
|
||||
if name.value == value:
|
||||
return name
|
||||
raise ValueError(f"invalid parameter name {value}")
|
||||
|
||||
|
||||
class ParameterType(StrEnum):
|
||||
"""
|
||||
Enum class for parameter type.
|
||||
"""
|
||||
|
||||
FLOAT = auto()
|
||||
INT = auto()
|
||||
STRING = auto()
|
||||
BOOLEAN = auto()
|
||||
TEXT = auto()
|
||||
|
||||
|
||||
class ModelPropertyKey(StrEnum):
|
||||
"""
|
||||
Enum class for model property key.
|
||||
"""
|
||||
|
||||
MODE = auto()
|
||||
CONTEXT_SIZE = auto()
|
||||
MAX_CHUNKS = auto()
|
||||
FILE_UPLOAD_LIMIT = auto()
|
||||
SUPPORTED_FILE_EXTENSIONS = auto()
|
||||
MAX_CHARACTERS_PER_CHUNK = auto()
|
||||
DEFAULT_VOICE = auto()
|
||||
VOICES = auto()
|
||||
WORD_LIMIT = auto()
|
||||
AUDIO_TYPE = auto()
|
||||
MAX_WORKERS = auto()
|
||||
|
||||
|
||||
class ProviderModel(BaseModel):
|
||||
"""
|
||||
Model class for provider model.
|
||||
"""
|
||||
|
||||
model: str
|
||||
label: I18nObject
|
||||
model_type: ModelType
|
||||
features: list[ModelFeature] | None = None
|
||||
fetch_from: FetchFrom
|
||||
model_properties: dict[ModelPropertyKey, Any]
|
||||
deprecated: bool = False
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def support_structure_output(self) -> bool:
|
||||
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
|
||||
|
||||
|
||||
class ParameterRule(BaseModel):
|
||||
"""
|
||||
Model class for parameter rule.
|
||||
"""
|
||||
|
||||
name: str
|
||||
use_template: str | None = None
|
||||
label: I18nObject
|
||||
type: ParameterType
|
||||
help: I18nObject | None = None
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
precision: int | None = None
|
||||
options: list[str] = []
|
||||
|
||||
|
||||
class PriceConfig(BaseModel):
|
||||
"""
|
||||
Model class for pricing info.
|
||||
"""
|
||||
|
||||
input: Decimal
|
||||
output: Decimal | None = None
|
||||
unit: Decimal
|
||||
currency: str
|
||||
|
||||
|
||||
class AIModelEntity(ProviderModel):
|
||||
"""
|
||||
Model class for AI model.
|
||||
"""
|
||||
|
||||
parameter_rules: list[ParameterRule] = []
|
||||
pricing: PriceConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model(self):
|
||||
supported_schema_keys = ["json_schema"]
|
||||
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
|
||||
if not schema_key:
|
||||
return self
|
||||
if self.features is None:
|
||||
self.features = [ModelFeature.STRUCTURED_OUTPUT]
|
||||
else:
|
||||
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
|
||||
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
|
||||
return self
|
||||
|
||||
|
||||
class ModelUsage(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class PriceType(StrEnum):
|
||||
"""
|
||||
Enum class for price type.
|
||||
"""
|
||||
|
||||
INPUT = auto()
|
||||
OUTPUT = auto()
|
||||
|
||||
|
||||
class PriceInfo(BaseModel):
|
||||
"""
|
||||
Model class for price info.
|
||||
"""
|
||||
|
||||
unit_price: Decimal
|
||||
unit: Decimal
|
||||
total_amount: Decimal
|
||||
currency: str
|
||||
179
api/graphon/model_runtime/entities/provider_entities.py
Normal file
179
api/graphon/model_runtime/entities/provider_entities.py
Normal file
@ -0,0 +1,179 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
class ConfigurateMethod(StrEnum):
|
||||
"""
|
||||
Enum class for configurate method of provider model.
|
||||
"""
|
||||
|
||||
PREDEFINED_MODEL = "predefined-model"
|
||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||
|
||||
|
||||
class FormType(StrEnum):
|
||||
"""
|
||||
Enum class for form type.
|
||||
"""
|
||||
|
||||
TEXT_INPUT = "text-input"
|
||||
SECRET_INPUT = "secret-input"
|
||||
SELECT = auto()
|
||||
RADIO = auto()
|
||||
SWITCH = auto()
|
||||
|
||||
|
||||
class FormShowOnObject(BaseModel):
|
||||
"""
|
||||
Model class for form show on.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value: str
|
||||
|
||||
|
||||
class FormOption(BaseModel):
|
||||
"""
|
||||
Model class for form option.
|
||||
"""
|
||||
|
||||
label: I18nObject
|
||||
value: str
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
if not self.label:
|
||||
self.label = I18nObject(en_US=self.value)
|
||||
return self
|
||||
|
||||
|
||||
class CredentialFormSchema(BaseModel):
|
||||
"""
|
||||
Model class for credential form schema.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
label: I18nObject
|
||||
type: FormType
|
||||
required: bool = True
|
||||
default: str | None = None
|
||||
options: list[FormOption] | None = None
|
||||
placeholder: I18nObject | None = None
|
||||
max_length: int = 0
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
|
||||
class ProviderCredentialSchema(BaseModel):
|
||||
"""
|
||||
Model class for provider credential schema.
|
||||
"""
|
||||
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
|
||||
|
||||
class FieldModelSchema(BaseModel):
|
||||
label: I18nObject
|
||||
placeholder: I18nObject | None = None
|
||||
|
||||
|
||||
class ModelCredentialSchema(BaseModel):
|
||||
"""
|
||||
Model class for model credential schema.
|
||||
"""
|
||||
|
||||
model: FieldModelSchema
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
|
||||
|
||||
class SimpleProviderEntity(BaseModel):
|
||||
"""
|
||||
Simplified provider schema exposed to callers.
|
||||
|
||||
`provider` is the canonical runtime identifier. `provider_name` is an optional
|
||||
compatibility alias for short-name lookups and is empty when no alias exists.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
provider_name: str = ""
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
|
||||
|
||||
class ProviderHelpEntity(BaseModel):
|
||||
"""
|
||||
Model class for provider help.
|
||||
"""
|
||||
|
||||
title: I18nObject
|
||||
url: I18nObject
|
||||
|
||||
|
||||
class ProviderEntity(BaseModel):
|
||||
"""
|
||||
Runtime-native provider schema.
|
||||
|
||||
`provider` is the canonical runtime identifier. `provider_name` is a
|
||||
compatibility alias for callers that still resolve providers by short name and
|
||||
is empty when no alias exists.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
provider_name: str = ""
|
||||
label: I18nObject
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
models: list[AIModelEntity] = Field(default_factory=list)
|
||||
provider_credential_schema: ProviderCredentialSchema | None = None
|
||||
model_credential_schema: ModelCredentialSchema | None = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
# position from plugin _position.yaml
|
||||
position: dict[str, list[str]] | None = {}
|
||||
|
||||
@field_validator("models", mode="before")
|
||||
@classmethod
|
||||
def validate_models(cls, v):
|
||||
# returns EmptyList if v is empty
|
||||
if not v:
|
||||
return []
|
||||
return v
|
||||
|
||||
def to_simple_provider(self) -> SimpleProviderEntity:
|
||||
"""
|
||||
Convert to simple provider.
|
||||
|
||||
:return: simple provider
|
||||
"""
|
||||
return SimpleProviderEntity(
|
||||
provider=self.provider,
|
||||
provider_name=self.provider_name,
|
||||
label=self.label,
|
||||
icon_small=self.icon_small,
|
||||
supported_model_types=self.supported_model_types,
|
||||
models=self.models,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""
|
||||
Model class for provider config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
credentials: dict
|
||||
27
api/graphon/model_runtime/entities/rerank_entities.py
Normal file
27
api/graphon/model_runtime/entities/rerank_entities.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MultimodalRerankInput(TypedDict):
|
||||
content: str
|
||||
content_type: str
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
"""
|
||||
Model class for rerank document.
|
||||
"""
|
||||
|
||||
index: int
|
||||
text: str
|
||||
score: float
|
||||
|
||||
|
||||
class RerankResult(BaseModel):
|
||||
"""
|
||||
Model class for rerank result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
docs: list[RerankDocument]
|
||||
@ -0,0 +1,47 @@
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelUsage
|
||||
|
||||
|
||||
class EmbeddingInputType(StrEnum):
|
||||
"""Embedding request input variants understood by the model runtime."""
|
||||
|
||||
DOCUMENT = auto()
|
||||
QUERY = auto()
|
||||
|
||||
|
||||
class EmbeddingUsage(ModelUsage):
|
||||
"""
|
||||
Model class for embedding usage.
|
||||
"""
|
||||
|
||||
tokens: int
|
||||
total_tokens: int
|
||||
unit_price: Decimal
|
||||
price_unit: Decimal
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
|
||||
|
||||
class EmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
class FileEmbeddingResult(BaseModel):
|
||||
"""
|
||||
Model class for file embedding result.
|
||||
"""
|
||||
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
usage: EmbeddingUsage
|
||||
0
api/graphon/model_runtime/errors/__init__.py
Normal file
0
api/graphon/model_runtime/errors/__init__.py
Normal file
41
api/graphon/model_runtime/errors/invoke.py
Normal file
41
api/graphon/model_runtime/errors/invoke.py
Normal file
@ -0,0 +1,41 @@
|
||||
class InvokeError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
if description is not None:
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
|
||||
class InvokeConnectionError(InvokeError):
|
||||
"""Raised when the Invoke returns connection error."""
|
||||
|
||||
description = "Connection Error"
|
||||
|
||||
|
||||
class InvokeServerUnavailableError(InvokeError):
|
||||
"""Raised when the Invoke returns server unavailable error."""
|
||||
|
||||
description = "Server Unavailable Error"
|
||||
|
||||
|
||||
class InvokeRateLimitError(InvokeError):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
description = "Rate Limit Error"
|
||||
|
||||
|
||||
class InvokeAuthorizationError(InvokeError):
|
||||
"""Raised when the Invoke returns authorization error."""
|
||||
|
||||
description = "Incorrect model credentials provided, please check and try again. "
|
||||
|
||||
|
||||
class InvokeBadRequestError(InvokeError):
|
||||
"""Raised when the Invoke returns bad request."""
|
||||
|
||||
description = "Bad Request Error"
|
||||
6
api/graphon/model_runtime/errors/validate.py
Normal file
6
api/graphon/model_runtime/errors/validate.py
Normal file
@ -0,0 +1,6 @@
|
||||
class CredentialsValidateFailedError(ValueError):
|
||||
"""
|
||||
Credentials validate failed error
|
||||
"""
|
||||
|
||||
pass
|
||||
3
api/graphon/model_runtime/memory/__init__.py
Normal file
3
api/graphon/model_runtime/memory/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
|
||||
|
||||
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user