mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
feat(graph_engine): support dumps and loads in GraphExecution
This commit is contained in:
@ -1,12 +1,94 @@
|
||||
"""
|
||||
GraphExecution aggregate root managing the overall graph execution state.
|
||||
"""
|
||||
"""GraphExecution aggregate root managing the overall graph execution state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import import_module
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
|
||||
class GraphExecutionErrorState(BaseModel):
|
||||
"""Serializable representation of an execution error."""
|
||||
|
||||
module: str = Field(description="Module containing the exception class")
|
||||
qualname: str = Field(description="Qualified name of the exception class")
|
||||
message: str | None = Field(default=None, description="Exception message string")
|
||||
|
||||
|
||||
class NodeExecutionState(BaseModel):
|
||||
"""Serializable representation of a node execution entity."""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = Field(default=NodeState.UNKNOWN)
|
||||
retry_count: int = Field(default=0)
|
||||
execution_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
|
||||
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Pydantic model describing serialized GraphExecution state."""
|
||||
|
||||
type: Literal["GraphExecution"] = Field(default="GraphExecution")
|
||||
version: str = Field(default="1.0")
|
||||
workflow_id: str
|
||||
started: bool = Field(default=False)
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
"""Convert an exception into its serializable representation."""
|
||||
|
||||
if error is None:
|
||||
return None
|
||||
|
||||
return GraphExecutionErrorState(
|
||||
module=error.__class__.__module__,
|
||||
qualname=error.__class__.__qualname__,
|
||||
message=str(error),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
|
||||
"""Locate an exception class from its module and qualified name."""
|
||||
|
||||
module = import_module(module_name)
|
||||
attr: object = module
|
||||
for part in qualname.split("."):
|
||||
attr = getattr(attr, part)
|
||||
|
||||
if isinstance(attr, type) and issubclass(attr, Exception):
|
||||
return attr
|
||||
|
||||
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
|
||||
|
||||
|
||||
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
|
||||
"""Reconstruct an exception instance from serialized data."""
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
exception_class = _resolve_exception_class(state.module, state.qualname)
|
||||
if state.message is None:
|
||||
return exception_class()
|
||||
return exception_class(state.message)
|
||||
except Exception:
|
||||
# Fallback to RuntimeError when reconstruction fails
|
||||
if state.message is None:
|
||||
return RuntimeError(state.qualname)
|
||||
return RuntimeError(state.message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExecution:
|
||||
"""
|
||||
@ -69,3 +151,57 @@ class GraphExecution:
|
||||
if not self.error:
|
||||
return None
|
||||
return str(self.error)
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the aggregate state into a JSON string."""
|
||||
|
||||
node_states = [
|
||||
NodeExecutionState(
|
||||
node_id=node_id,
|
||||
state=node_execution.state,
|
||||
retry_count=node_execution.retry_count,
|
||||
execution_id=node_execution.execution_id,
|
||||
error=node_execution.error,
|
||||
)
|
||||
for node_id, node_execution in sorted(self.node_executions.items())
|
||||
]
|
||||
|
||||
state = GraphExecutionState(
|
||||
workflow_id=self.workflow_id,
|
||||
started=self.started,
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
error=_serialize_error(self.error),
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore aggregate state from a serialized JSON string."""
|
||||
|
||||
state = GraphExecutionState.model_validate_json(data)
|
||||
|
||||
if state.type != "GraphExecution":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
if self.workflow_id != state.workflow_id:
|
||||
raise ValueError("Serialized workflow_id does not match aggregate identity")
|
||||
|
||||
self.started = state.started
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
state=item.state,
|
||||
retry_count=item.retry_count,
|
||||
execution_id=item.execution_id,
|
||||
error=item.error,
|
||||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
|
||||
@ -68,6 +68,8 @@ class GraphEngine:
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
if graph_runtime_state.graph_execution_json != "":
|
||||
self._graph_execution.loads(graph_runtime_state.graph_execution_json)
|
||||
|
||||
# === Core Dependencies ===
|
||||
# Graph structure and configuration
|
||||
|
||||
Reference in New Issue
Block a user