mirror of
https://github.com/langgenius/dify.git
synced 2026-03-12 10:38:54 +08:00
feat(graph_engine): add ready_queue state persistence to GraphRuntimeState
- Add ReadyQueueState TypedDict for type-safe queue serialization - Add ready_queue attribute to GraphRuntimeState for initializing with pre-existing queue state - Update GraphEngine to load ready_queue from GraphRuntimeState on initialization - Implement proper type hints using ReadyQueueState for better type safety - Add comprehensive tests for ready_queue loading functionality The ready_queue is read-only after initialization and allows resuming workflow execution with a pre-populated queue of nodes ready to execute.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
@ -7,6 +7,9 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine.ready_queue import ReadyQueueState
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
# Private attributes to prevent direct modification
|
||||
@ -16,6 +19,7 @@ class GraphRuntimeState(BaseModel):
|
||||
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
||||
_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
_node_run_steps: int = PrivateAttr(default=0)
|
||||
_ready_queue: "ReadyQueueState | dict[str, object]" = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -25,6 +29,7 @@ class GraphRuntimeState(BaseModel):
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, Any] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue: "ReadyQueueState | dict[str, object] | None" = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
"""Initialize the GraphRuntimeState with validation."""
|
||||
@ -51,6 +56,10 @@ class GraphRuntimeState(BaseModel):
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
if ready_queue is None:
|
||||
ready_queue = {}
|
||||
self._ready_queue = deepcopy(ready_queue)
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
"""Get the variable pool."""
|
||||
@ -133,3 +142,8 @@ class GraphRuntimeState(BaseModel):
|
||||
if tokens < 0:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
@property
|
||||
def ready_queue(self) -> "ReadyQueueState | dict[str, object]":
|
||||
"""Get a copy of the ready queue state."""
|
||||
return deepcopy(self._ready_queue)
|
||||
|
||||
@ -106,6 +106,16 @@ class GraphEngine:
|
||||
# === Execution Queues ===
|
||||
# Queue for nodes ready to execute
|
||||
self._ready_queue = InMemoryReadyQueue()
|
||||
# Load ready queue state from GraphRuntimeState if not empty
|
||||
ready_queue_state = self._graph_runtime_state.ready_queue
|
||||
if ready_queue_state:
|
||||
# Import ReadyQueueState here to avoid circular imports
|
||||
from .ready_queue import ReadyQueueState
|
||||
|
||||
# Ensure we have a ReadyQueueState object
|
||||
if isinstance(ready_queue_state, dict):
|
||||
ready_queue_state = ReadyQueueState(**ready_queue_state) # type: ignore
|
||||
self._ready_queue.loads(ready_queue_state)
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
|
||||
@ -6,6 +6,6 @@ the queue of nodes ready for execution.
|
||||
"""
|
||||
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
from .protocol import ReadyQueue
|
||||
from .protocol import ReadyQueue, ReadyQueueState
|
||||
|
||||
__all__ = ["InMemoryReadyQueue", "ReadyQueue"]
|
||||
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState"]
|
||||
|
||||
@ -8,6 +8,8 @@ serialization capabilities for state storage.
|
||||
import queue
|
||||
from typing import final
|
||||
|
||||
from .protocol import ReadyQueueState
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryReadyQueue:
|
||||
@ -80,12 +82,12 @@ class InMemoryReadyQueue:
|
||||
"""
|
||||
return self._queue.qsize()
|
||||
|
||||
def dumps(self) -> dict[str, object]:
|
||||
def dumps(self) -> ReadyQueueState:
|
||||
"""
|
||||
Serialize the queue state for storage.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the serialized queue state
|
||||
A ReadyQueueState dictionary containing the serialized queue state
|
||||
"""
|
||||
# Extract all items from the queue without removing them
|
||||
items: list[str] = []
|
||||
@ -104,14 +106,14 @@ class InMemoryReadyQueue:
|
||||
for item in temp_items:
|
||||
self._queue.put(item)
|
||||
|
||||
return {
|
||||
"type": "InMemoryReadyQueue",
|
||||
"version": "1.0",
|
||||
"items": items,
|
||||
"maxsize": self._queue.maxsize,
|
||||
}
|
||||
return ReadyQueueState(
|
||||
type="InMemoryReadyQueue",
|
||||
version="1.0",
|
||||
items=items,
|
||||
maxsize=self._queue.maxsize,
|
||||
)
|
||||
|
||||
def loads(self, data: dict[str, object]) -> None:
|
||||
def loads(self, data: ReadyQueueState) -> None:
|
||||
"""
|
||||
Restore the queue state from serialized data.
|
||||
|
||||
|
||||
@ -5,7 +5,21 @@ This protocol defines the interface for managing the queue of nodes ready
|
||||
for execution, supporting both in-memory and persistent storage scenarios.
|
||||
"""
|
||||
|
||||
from typing import Protocol
|
||||
from typing import Protocol, TypedDict
|
||||
|
||||
|
||||
class ReadyQueueState(TypedDict):
|
||||
"""
|
||||
TypedDict for serialized ready queue state.
|
||||
|
||||
This defines the structure of the dictionary returned by dumps()
|
||||
and expected by loads() for ready queue serialization.
|
||||
"""
|
||||
|
||||
type: str # Queue implementation type (e.g., "InMemoryReadyQueue")
|
||||
version: str # Serialization format version
|
||||
items: list[str] # List of node IDs in the queue
|
||||
maxsize: int # Maximum queue size (0 for unlimited)
|
||||
|
||||
|
||||
class ReadyQueue(Protocol):
|
||||
@ -68,17 +82,17 @@ class ReadyQueue(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def dumps(self) -> dict[str, object]:
|
||||
def dumps(self) -> ReadyQueueState:
|
||||
"""
|
||||
Serialize the queue state for storage.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the serialized queue state
|
||||
A ReadyQueueState dictionary containing the serialized queue state
|
||||
that can be persisted and later restored
|
||||
"""
|
||||
...
|
||||
|
||||
def loads(self, data: dict[str, object]) -> None:
|
||||
def loads(self, data: ReadyQueueState) -> None:
|
||||
"""
|
||||
Restore the queue state from serialized data.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user