mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 23:36:20 +08:00
243 lines
8.2 KiB
Python
243 lines
8.2 KiB
Python
"""GraphExecution aggregate root managing the overall graph execution state."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from importlib import import_module
|
|
from typing import Literal
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from dify_graph.entities.pause_reason import PauseReason
|
|
from dify_graph.enums import NodeState
|
|
from dify_graph.runtime.graph_runtime_state import GraphExecutionProtocol
|
|
|
|
from .node_execution import NodeExecution
|
|
|
|
|
|
class GraphExecutionErrorState(BaseModel):
|
|
"""Serializable representation of an execution error."""
|
|
|
|
module: str = Field(description="Module containing the exception class")
|
|
qualname: str = Field(description="Qualified name of the exception class")
|
|
message: str | None = Field(default=None, description="Exception message string")
|
|
|
|
|
|
class NodeExecutionState(BaseModel):
|
|
"""Serializable representation of a node execution entity."""
|
|
|
|
node_id: str
|
|
state: NodeState = Field(default=NodeState.UNKNOWN)
|
|
retry_count: int = Field(default=0)
|
|
execution_id: str | None = Field(default=None)
|
|
error: str | None = Field(default=None)
|
|
|
|
|
|
class GraphExecutionState(BaseModel):
|
|
"""Pydantic model describing serialized GraphExecution state."""
|
|
|
|
type: Literal["GraphExecution"] = Field(default="GraphExecution")
|
|
version: str = Field(default="1.0")
|
|
workflow_id: str
|
|
started: bool = Field(default=False)
|
|
completed: bool = Field(default=False)
|
|
aborted: bool = Field(default=False)
|
|
paused: bool = Field(default=False)
|
|
pause_reasons: list[PauseReason] = Field(default_factory=list)
|
|
error: GraphExecutionErrorState | None = Field(default=None)
|
|
exceptions_count: int = Field(default=0)
|
|
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
|
|
|
|
|
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
|
"""Convert an exception into its serializable representation."""
|
|
|
|
if error is None:
|
|
return None
|
|
|
|
return GraphExecutionErrorState(
|
|
module=error.__class__.__module__,
|
|
qualname=error.__class__.__qualname__,
|
|
message=str(error),
|
|
)
|
|
|
|
|
|
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
|
|
"""Locate an exception class from its module and qualified name."""
|
|
|
|
module = import_module(module_name)
|
|
attr: object = module
|
|
for part in qualname.split("."):
|
|
attr = getattr(attr, part)
|
|
|
|
if isinstance(attr, type) and issubclass(attr, Exception):
|
|
return attr
|
|
|
|
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
|
|
|
|
|
|
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
|
|
"""Reconstruct an exception instance from serialized data."""
|
|
|
|
if state is None:
|
|
return None
|
|
|
|
try:
|
|
exception_class = _resolve_exception_class(state.module, state.qualname)
|
|
if state.message is None:
|
|
return exception_class()
|
|
return exception_class(state.message)
|
|
except Exception:
|
|
# Fallback to RuntimeError when reconstruction fails
|
|
if state.message is None:
|
|
return RuntimeError(state.qualname)
|
|
return RuntimeError(state.message)
|
|
|
|
|
|
@dataclass
|
|
class GraphExecution:
|
|
"""
|
|
Aggregate root for graph execution.
|
|
|
|
This manages the overall execution state of a workflow graph,
|
|
coordinating between multiple node executions.
|
|
"""
|
|
|
|
workflow_id: str
|
|
started: bool = False
|
|
completed: bool = False
|
|
aborted: bool = False
|
|
paused: bool = False
|
|
pause_reasons: list[PauseReason] = field(default_factory=list)
|
|
error: Exception | None = None
|
|
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
|
exceptions_count: int = 0
|
|
|
|
def start(self) -> None:
|
|
"""Mark the graph execution as started."""
|
|
if self.started:
|
|
raise RuntimeError("Graph execution already started")
|
|
self.started = True
|
|
|
|
def complete(self) -> None:
|
|
"""Mark the graph execution as completed."""
|
|
if not self.started:
|
|
raise RuntimeError("Cannot complete execution that hasn't started")
|
|
if self.completed:
|
|
raise RuntimeError("Graph execution already completed")
|
|
self.completed = True
|
|
|
|
def abort(self, reason: str) -> None:
|
|
"""Abort the graph execution."""
|
|
self.aborted = True
|
|
self.error = RuntimeError(f"Aborted: {reason}")
|
|
|
|
def pause(self, reason: PauseReason) -> None:
|
|
"""Pause the graph execution without marking it complete."""
|
|
if self.completed:
|
|
raise RuntimeError("Cannot pause execution that has completed")
|
|
if self.aborted:
|
|
raise RuntimeError("Cannot pause execution that has been aborted")
|
|
self.paused = True
|
|
self.pause_reasons.append(reason)
|
|
|
|
def fail(self, error: Exception) -> None:
|
|
"""Mark the graph execution as failed."""
|
|
self.error = error
|
|
self.completed = True
|
|
|
|
def get_or_create_node_execution(self, node_id: str) -> NodeExecution:
|
|
"""Get or create a node execution entity."""
|
|
if node_id not in self.node_executions:
|
|
self.node_executions[node_id] = NodeExecution(node_id=node_id)
|
|
return self.node_executions[node_id]
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
"""Check if the execution is currently running."""
|
|
return self.started and not self.completed and not self.aborted and not self.paused
|
|
|
|
@property
|
|
def is_paused(self) -> bool:
|
|
"""Check if the execution is currently paused."""
|
|
return self.paused
|
|
|
|
@property
|
|
def has_error(self) -> bool:
|
|
"""Check if the execution has encountered an error."""
|
|
return self.error is not None
|
|
|
|
@property
|
|
def error_message(self) -> str | None:
|
|
"""Get the error message if an error exists."""
|
|
if not self.error:
|
|
return None
|
|
return str(self.error)
|
|
|
|
def dumps(self) -> str:
|
|
"""Serialize the aggregate state into a JSON string."""
|
|
|
|
node_states = [
|
|
NodeExecutionState(
|
|
node_id=node_id,
|
|
state=node_execution.state,
|
|
retry_count=node_execution.retry_count,
|
|
execution_id=node_execution.execution_id,
|
|
error=node_execution.error,
|
|
)
|
|
for node_id, node_execution in sorted(self.node_executions.items())
|
|
]
|
|
|
|
state = GraphExecutionState(
|
|
workflow_id=self.workflow_id,
|
|
started=self.started,
|
|
completed=self.completed,
|
|
aborted=self.aborted,
|
|
paused=self.paused,
|
|
pause_reasons=self.pause_reasons,
|
|
error=_serialize_error(self.error),
|
|
exceptions_count=self.exceptions_count,
|
|
node_executions=node_states,
|
|
)
|
|
|
|
return state.model_dump_json()
|
|
|
|
def loads(self, data: str) -> None:
|
|
"""Restore aggregate state from a serialized JSON string."""
|
|
|
|
state = GraphExecutionState.model_validate_json(data)
|
|
|
|
if state.type != "GraphExecution":
|
|
raise ValueError(f"Invalid serialized data type: {state.type}")
|
|
|
|
if state.version != "1.0":
|
|
raise ValueError(f"Unsupported serialized version: {state.version}")
|
|
|
|
if self.workflow_id != state.workflow_id:
|
|
raise ValueError("Serialized workflow_id does not match aggregate identity")
|
|
|
|
self.started = state.started
|
|
self.completed = state.completed
|
|
self.aborted = state.aborted
|
|
self.paused = state.paused
|
|
self.pause_reasons = state.pause_reasons
|
|
self.error = _deserialize_error(state.error)
|
|
self.exceptions_count = state.exceptions_count
|
|
self.node_executions = {
|
|
item.node_id: NodeExecution(
|
|
node_id=item.node_id,
|
|
state=item.state,
|
|
retry_count=item.retry_count,
|
|
execution_id=item.execution_id,
|
|
error=item.error,
|
|
)
|
|
for item in state.node_executions
|
|
}
|
|
|
|
def record_node_failure(self) -> None:
|
|
"""Increment the count of node failures encountered during execution."""
|
|
self.exceptions_count += 1
|
|
|
|
|
|
_: GraphExecutionProtocol = GraphExecution(workflow_id="")
|