feat(graph_engine): add ready_queue state persistence to GraphRuntimeState

- Add ReadyQueueState TypedDict for type-safe queue serialization
- Add ready_queue attribute to GraphRuntimeState for initializing with pre-existing queue state
- Update GraphEngine to load ready_queue from GraphRuntimeState on initialization
- Implement proper type hints using ReadyQueueState for better type safety
- Add comprehensive tests for ready_queue loading functionality

The ready_queue is read-only after initialization and allows resuming workflow
execution with a pre-populated queue of nodes ready to execute.
This commit is contained in:
-LAN-
2025-09-15 03:05:10 +08:00
parent 0f15a2baca
commit b4ef1de30f
7 changed files with 159 additions and 16 deletions

View File

@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, PrivateAttr
@ -7,6 +7,9 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from .variable_pool import VariablePool
if TYPE_CHECKING:
from core.workflow.graph_engine.ready_queue import ReadyQueueState
class GraphRuntimeState(BaseModel):
# Private attributes to prevent direct modification
@ -16,6 +19,7 @@ class GraphRuntimeState(BaseModel):
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
_node_run_steps: int = PrivateAttr(default=0)
_ready_queue: "ReadyQueueState | dict[str, object]" = PrivateAttr(default_factory=dict)
def __init__(
self,
@ -25,6 +29,7 @@ class GraphRuntimeState(BaseModel):
llm_usage: LLMUsage | None = None,
outputs: dict[str, Any] | None = None,
node_run_steps: int = 0,
ready_queue: "ReadyQueueState | dict[str, object] | None" = None,
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
@ -51,6 +56,10 @@ class GraphRuntimeState(BaseModel):
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
if ready_queue is None:
ready_queue = {}
self._ready_queue = deepcopy(ready_queue)
@property
def variable_pool(self) -> VariablePool:
"""Get the variable pool."""
@ -133,3 +142,8 @@ class GraphRuntimeState(BaseModel):
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
@property
def ready_queue(self) -> "ReadyQueueState | dict[str, object]":
"""Get a copy of the ready queue state."""
return deepcopy(self._ready_queue)