mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
feat(graph_engine): dump and load ready queue
This commit is contained in:
@ -31,7 +31,6 @@ ignore_imports =
|
|||||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||||
core.workflow.entities.graph_runtime_state -> core.workflow.graph_engine.ready_queue
|
|
||||||
|
|
||||||
[importlinter:contract:rsc]
|
[importlinter:contract:rsc]
|
||||||
name = RSC
|
name = RSC
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, PrivateAttr
|
from pydantic import BaseModel, PrivateAttr
|
||||||
|
|
||||||
@ -7,9 +6,6 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
|||||||
|
|
||||||
from .variable_pool import VariablePool
|
from .variable_pool import VariablePool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.workflow.graph_engine.ready_queue import ReadyQueueState
|
|
||||||
|
|
||||||
|
|
||||||
class GraphRuntimeState(BaseModel):
|
class GraphRuntimeState(BaseModel):
|
||||||
# Private attributes to prevent direct modification
|
# Private attributes to prevent direct modification
|
||||||
@ -19,17 +15,18 @@ class GraphRuntimeState(BaseModel):
|
|||||||
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
||||||
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
|
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
|
||||||
_node_run_steps: int = PrivateAttr(default=0)
|
_node_run_steps: int = PrivateAttr(default=0)
|
||||||
_ready_queue: "ReadyQueueState | dict[str, object]" = PrivateAttr(default_factory=dict)
|
_ready_queue_json: str = PrivateAttr()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
start_at: float,
|
start_at: float,
|
||||||
total_tokens: int = 0,
|
total_tokens: int = 0,
|
||||||
llm_usage: LLMUsage | None = None,
|
llm_usage: LLMUsage | None = None,
|
||||||
outputs: dict[str, Any] | None = None,
|
outputs: dict[str, object] | None = None,
|
||||||
node_run_steps: int = 0,
|
node_run_steps: int = 0,
|
||||||
ready_queue: "ReadyQueueState | dict[str, object] | None" = None,
|
ready_queue_json: str = "",
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
):
|
):
|
||||||
"""Initialize the GraphRuntimeState with validation."""
|
"""Initialize the GraphRuntimeState with validation."""
|
||||||
@ -56,9 +53,7 @@ class GraphRuntimeState(BaseModel):
|
|||||||
raise ValueError("node_run_steps must be non-negative")
|
raise ValueError("node_run_steps must be non-negative")
|
||||||
self._node_run_steps = node_run_steps
|
self._node_run_steps = node_run_steps
|
||||||
|
|
||||||
if ready_queue is None:
|
self._ready_queue_json = ready_queue_json
|
||||||
ready_queue = {}
|
|
||||||
self._ready_queue = deepcopy(ready_queue)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variable_pool(self) -> VariablePool:
|
def variable_pool(self) -> VariablePool:
|
||||||
@ -99,24 +94,24 @@ class GraphRuntimeState(BaseModel):
|
|||||||
self._llm_usage = value.model_copy()
|
self._llm_usage = value.model_copy()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def outputs(self) -> dict[str, Any]:
|
def outputs(self) -> dict[str, object]:
|
||||||
"""Get a copy of the outputs dictionary."""
|
"""Get a copy of the outputs dictionary."""
|
||||||
return deepcopy(self._outputs)
|
return deepcopy(self._outputs)
|
||||||
|
|
||||||
@outputs.setter
|
@outputs.setter
|
||||||
def outputs(self, value: dict[str, Any]) -> None:
|
def outputs(self, value: dict[str, object]) -> None:
|
||||||
"""Set the outputs dictionary."""
|
"""Set the outputs dictionary."""
|
||||||
self._outputs = deepcopy(value)
|
self._outputs = deepcopy(value)
|
||||||
|
|
||||||
def set_output(self, key: str, value: Any) -> None:
|
def set_output(self, key: str, value: object) -> None:
|
||||||
"""Set a single output value."""
|
"""Set a single output value."""
|
||||||
self._outputs[key] = deepcopy(value)
|
self._outputs[key] = deepcopy(value)
|
||||||
|
|
||||||
def get_output(self, key: str, default: Any = None) -> Any:
|
def get_output(self, key: str, default: object = None) -> object:
|
||||||
"""Get a single output value."""
|
"""Get a single output value."""
|
||||||
return deepcopy(self._outputs.get(key, default))
|
return deepcopy(self._outputs.get(key, default))
|
||||||
|
|
||||||
def update_outputs(self, updates: dict[str, Any]) -> None:
|
def update_outputs(self, updates: dict[str, object]) -> None:
|
||||||
"""Update multiple output values."""
|
"""Update multiple output values."""
|
||||||
for key, value in updates.items():
|
for key, value in updates.items():
|
||||||
self._outputs[key] = deepcopy(value)
|
self._outputs[key] = deepcopy(value)
|
||||||
@ -144,6 +139,6 @@ class GraphRuntimeState(BaseModel):
|
|||||||
self._total_tokens += tokens
|
self._total_tokens += tokens
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ready_queue(self) -> "ReadyQueueState | dict[str, object]":
|
def ready_queue_json(self) -> str:
|
||||||
"""Get a copy of the ready queue state."""
|
"""Get a copy of the ready queue state."""
|
||||||
return deepcopy(self._ready_queue)
|
return self._ready_queue_json
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from core.workflow.entities import GraphRuntimeState
|
|||||||
from core.workflow.enums import NodeExecutionType
|
from core.workflow.enums import NodeExecutionType
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
|
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
|
||||||
|
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphNodeEventBase,
|
GraphNodeEventBase,
|
||||||
@ -38,7 +39,7 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
|
|||||||
from .layers.base import GraphEngineLayer
|
from .layers.base import GraphEngineLayer
|
||||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||||
from .protocols.command_channel import CommandChannel
|
from .protocols.command_channel import CommandChannel
|
||||||
from .ready_queue import InMemoryReadyQueue
|
from .ready_queue import ReadyQueueState, create_ready_queue_from_state
|
||||||
from .response_coordinator import ResponseStreamCoordinator
|
from .response_coordinator import ResponseStreamCoordinator
|
||||||
from .worker_management import WorkerPool
|
from .worker_management import WorkerPool
|
||||||
|
|
||||||
@ -104,18 +105,13 @@ class GraphEngine:
|
|||||||
self._scale_down_idle_time = scale_down_idle_time
|
self._scale_down_idle_time = scale_down_idle_time
|
||||||
|
|
||||||
# === Execution Queues ===
|
# === Execution Queues ===
|
||||||
# Queue for nodes ready to execute
|
# Create ready queue from saved state or initialize new one
|
||||||
self._ready_queue = InMemoryReadyQueue()
|
if self._graph_runtime_state.ready_queue_json == "":
|
||||||
# Load ready queue state from GraphRuntimeState if not empty
|
self._ready_queue = InMemoryReadyQueue()
|
||||||
ready_queue_state = self._graph_runtime_state.ready_queue
|
else:
|
||||||
if ready_queue_state:
|
ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json)
|
||||||
# Import ReadyQueueState here to avoid circular imports
|
self._ready_queue = create_ready_queue_from_state(ready_queue_state)
|
||||||
from .ready_queue import ReadyQueueState
|
|
||||||
|
|
||||||
# Ensure we have a ReadyQueueState object
|
|
||||||
if isinstance(ready_queue_state, dict):
|
|
||||||
ready_queue_state = ReadyQueueState(**ready_queue_state) # type: ignore
|
|
||||||
self._ready_queue.loads(ready_queue_state)
|
|
||||||
# Queue for events generated during execution
|
# Queue for events generated during execution
|
||||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,8 @@ This package contains the protocol and implementations for managing
|
|||||||
the queue of nodes ready for execution.
|
the queue of nodes ready for execution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .factory import create_ready_queue_from_state
|
||||||
from .in_memory import InMemoryReadyQueue
|
from .in_memory import InMemoryReadyQueue
|
||||||
from .protocol import ReadyQueue, ReadyQueueState
|
from .protocol import ReadyQueue, ReadyQueueState
|
||||||
|
|
||||||
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState"]
|
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"]
|
||||||
|
|||||||
35
api/core/workflow/graph_engine/ready_queue/factory.py
Normal file
35
api/core/workflow/graph_engine/ready_queue/factory.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Factory for creating ReadyQueue instances from serialized state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from .in_memory import InMemoryReadyQueue
|
||||||
|
from .protocol import ReadyQueueState
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .protocol import ReadyQueue
|
||||||
|
|
||||||
|
|
||||||
|
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
|
||||||
|
"""
|
||||||
|
Create a ReadyQueue instance from a serialized state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ReadyQueue instance initialized with the given state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the queue type is unknown or version is unsupported
|
||||||
|
"""
|
||||||
|
if state.type == "InMemoryReadyQueue":
|
||||||
|
if state.version != "1.0":
|
||||||
|
raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}")
|
||||||
|
queue = InMemoryReadyQueue()
|
||||||
|
# Always pass as JSON string to loads()
|
||||||
|
queue.loads(state.model_dump_json())
|
||||||
|
return queue
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown ready queue type: {state.type}")
|
||||||
@ -82,12 +82,12 @@ class InMemoryReadyQueue:
|
|||||||
"""
|
"""
|
||||||
return self._queue.qsize()
|
return self._queue.qsize()
|
||||||
|
|
||||||
def dumps(self) -> ReadyQueueState:
|
def dumps(self) -> str:
|
||||||
"""
|
"""
|
||||||
Serialize the queue state for storage.
|
Serialize the queue state to a JSON string for storage.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ReadyQueueState dictionary containing the serialized queue state
|
A JSON string containing the serialized queue state
|
||||||
"""
|
"""
|
||||||
# Extract all items from the queue without removing them
|
# Extract all items from the queue without removing them
|
||||||
items: list[str] = []
|
items: list[str] = []
|
||||||
@ -106,25 +106,27 @@ class InMemoryReadyQueue:
|
|||||||
for item in temp_items:
|
for item in temp_items:
|
||||||
self._queue.put(item)
|
self._queue.put(item)
|
||||||
|
|
||||||
return ReadyQueueState(
|
state = ReadyQueueState(
|
||||||
type="InMemoryReadyQueue",
|
type="InMemoryReadyQueue",
|
||||||
version="1.0",
|
version="1.0",
|
||||||
items=items,
|
items=items,
|
||||||
maxsize=self._queue.maxsize,
|
|
||||||
)
|
)
|
||||||
|
return state.model_dump_json()
|
||||||
|
|
||||||
def loads(self, data: ReadyQueueState) -> None:
|
def loads(self, data: str) -> None:
|
||||||
"""
|
"""
|
||||||
Restore the queue state from serialized data.
|
Restore the queue state from a JSON string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: The serialized queue state to restore
|
data: The JSON string containing the serialized queue state to restore
|
||||||
"""
|
"""
|
||||||
if data.get("type") != "InMemoryReadyQueue":
|
state = ReadyQueueState.model_validate_json(data)
|
||||||
raise ValueError(f"Invalid serialized data type: {data.get('type')}")
|
|
||||||
|
|
||||||
if data.get("version") != "1.0":
|
if state.type != "InMemoryReadyQueue":
|
||||||
raise ValueError(f"Unsupported version: {data.get('version')}")
|
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||||
|
|
||||||
|
if state.version != "1.0":
|
||||||
|
raise ValueError(f"Unsupported version: {state.version}")
|
||||||
|
|
||||||
# Clear the current queue
|
# Clear the current queue
|
||||||
while not self._queue.empty():
|
while not self._queue.empty():
|
||||||
@ -134,11 +136,5 @@ class InMemoryReadyQueue:
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Restore items
|
# Restore items
|
||||||
items = data.get("items", [])
|
for item in state.items:
|
||||||
if not isinstance(items, list):
|
|
||||||
raise ValueError("Invalid items data: expected list")
|
|
||||||
|
|
||||||
for item in items:
|
|
||||||
if not isinstance(item, str):
|
|
||||||
raise ValueError(f"Invalid item type: expected str, got {type(item).__name__}")
|
|
||||||
self._queue.put(item)
|
self._queue.put(item)
|
||||||
|
|||||||
@ -5,21 +5,23 @@ This protocol defines the interface for managing the queue of nodes ready
|
|||||||
for execution, supporting both in-memory and persistent storage scenarios.
|
for execution, supporting both in-memory and persistent storage scenarios.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Protocol, TypedDict
|
from collections.abc import Sequence
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class ReadyQueueState(TypedDict):
|
class ReadyQueueState(BaseModel):
|
||||||
"""
|
"""
|
||||||
TypedDict for serialized ready queue state.
|
Pydantic model for serialized ready queue state.
|
||||||
|
|
||||||
This defines the structure of the dictionary returned by dumps()
|
This defines the structure of the data returned by dumps()
|
||||||
and expected by loads() for ready queue serialization.
|
and expected by loads() for ready queue serialization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str # Queue implementation type (e.g., "InMemoryReadyQueue")
|
type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')")
|
||||||
version: str # Serialization format version
|
version: str = Field(description="Serialization format version")
|
||||||
items: list[str] # List of node IDs in the queue
|
items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue")
|
||||||
maxsize: int # Maximum queue size (0 for unlimited)
|
|
||||||
|
|
||||||
|
|
||||||
class ReadyQueue(Protocol):
|
class ReadyQueue(Protocol):
|
||||||
@ -82,21 +84,21 @@ class ReadyQueue(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def dumps(self) -> ReadyQueueState:
|
def dumps(self) -> str:
|
||||||
"""
|
"""
|
||||||
Serialize the queue state for storage.
|
Serialize the queue state to a JSON string for storage.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ReadyQueueState dictionary containing the serialized queue state
|
A JSON string containing the serialized queue state
|
||||||
that can be persisted and later restored
|
that can be persisted and later restored
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def loads(self, data: ReadyQueueState) -> None:
|
def loads(self, data: str) -> None:
|
||||||
"""
|
"""
|
||||||
Restore the queue state from serialized data.
|
Restore the queue state from a JSON string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: The serialized queue state to restore
|
data: The JSON string containing the serialized queue state to restore
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|||||||
@ -19,7 +19,7 @@ class Path:
|
|||||||
Note: This is an internal class not exposed in the public API.
|
Note: This is an internal class not exposed in the public API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
edges: list[EdgeID] = field(default_factory=list)
|
edges: list[EdgeID] = field(default_factory=list[EdgeID])
|
||||||
|
|
||||||
def contains_edge(self, edge_id: EdgeID) -> bool:
|
def contains_edge(self, edge_id: EdgeID) -> bool:
|
||||||
"""Check if this path contains the given edge."""
|
"""Check if this path contains the given edge."""
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import pytest
|
|||||||
|
|
||||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.ready_queue import ReadyQueueState
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphRuntimeState:
|
class TestGraphRuntimeState:
|
||||||
@ -96,44 +95,3 @@ class TestGraphRuntimeState:
|
|||||||
# Test add_tokens validation
|
# Test add_tokens validation
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
state.add_tokens(-1)
|
state.add_tokens(-1)
|
||||||
|
|
||||||
def test_deep_copy_for_nested_objects(self):
|
|
||||||
variable_pool = VariablePool()
|
|
||||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
|
||||||
|
|
||||||
# Test deep copy for nested dict
|
|
||||||
nested_data = {"level1": {"level2": {"value": "test"}}}
|
|
||||||
state.set_output("nested", nested_data)
|
|
||||||
|
|
||||||
retrieved = state.get_output("nested")
|
|
||||||
retrieved["level1"]["level2"]["value"] = "modified"
|
|
||||||
|
|
||||||
# Original should remain unchanged
|
|
||||||
assert state.get_output("nested")["level1"]["level2"]["value"] == "test"
|
|
||||||
|
|
||||||
def test_ready_queue_property(self):
|
|
||||||
variable_pool = VariablePool()
|
|
||||||
|
|
||||||
# Test default empty ready_queue
|
|
||||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
|
||||||
assert state.ready_queue == {}
|
|
||||||
|
|
||||||
# Test initialization with ready_queue data as ReadyQueueState
|
|
||||||
queue_data = ReadyQueueState(type="InMemoryReadyQueue", version="1.0", items=["node1", "node2"], maxsize=0)
|
|
||||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=queue_data)
|
|
||||||
assert state.ready_queue == queue_data
|
|
||||||
|
|
||||||
# Test with different ready_queue data at initialization
|
|
||||||
another_queue_data = ReadyQueueState(
|
|
||||||
type="InMemoryReadyQueue",
|
|
||||||
version="1.0",
|
|
||||||
items=["node3", "node4", "node5"],
|
|
||||||
maxsize=0,
|
|
||||||
)
|
|
||||||
another_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=another_queue_data)
|
|
||||||
assert another_state.ready_queue == another_queue_data
|
|
||||||
|
|
||||||
# Test immutability - modifying retrieved queue doesn't affect internal state
|
|
||||||
retrieved_queue = state.ready_queue
|
|
||||||
retrieved_queue["items"].append("node6")
|
|
||||||
assert len(state.ready_queue["items"]) == 2 # Should still be 2, not 3
|
|
||||||
|
|||||||
@ -744,78 +744,3 @@ def test_event_sequence_validation_with_table_tests():
|
|||||||
else:
|
else:
|
||||||
assert result.event_sequence_match is True
|
assert result.event_sequence_match is True
|
||||||
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
|
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"
|
||||||
|
|
||||||
|
|
||||||
def test_ready_queue_state_loading():
|
|
||||||
"""
|
|
||||||
Test that the ready_queue state is properly loaded from GraphRuntimeState
|
|
||||||
during GraphEngine initialization.
|
|
||||||
"""
|
|
||||||
# Use the TableTestRunner to create a proper workflow instance
|
|
||||||
runner = TableTestRunner()
|
|
||||||
|
|
||||||
# Create a simple workflow
|
|
||||||
test_case = WorkflowTestCase(
|
|
||||||
fixture_path="simple_passthrough_workflow",
|
|
||||||
inputs={"query": "test"},
|
|
||||||
expected_outputs={"query": "test"},
|
|
||||||
description="Test ready_queue loading",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load the workflow fixture
|
|
||||||
workflow_runner = runner.workflow_runner
|
|
||||||
fixture_data = workflow_runner.load_fixture("simple_passthrough_workflow")
|
|
||||||
|
|
||||||
# Create graph and runtime state with pre-populated ready_queue
|
|
||||||
ready_queue_data = {
|
|
||||||
"type": "InMemoryReadyQueue",
|
|
||||||
"version": "1.0",
|
|
||||||
"items": ["node1", "node2", "node3"],
|
|
||||||
"maxsize": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# We need to create the graph first, then create a new GraphRuntimeState with ready_queue
|
|
||||||
graph, original_runtime_state = workflow_runner.create_graph_from_fixture(fixture_data, query="test")
|
|
||||||
|
|
||||||
# Create a new GraphRuntimeState with the ready_queue data
|
|
||||||
from core.workflow.entities import GraphRuntimeState
|
|
||||||
from core.workflow.graph_engine.ready_queue import ReadyQueueState
|
|
||||||
|
|
||||||
# Convert ready_queue_data to ReadyQueueState
|
|
||||||
ready_queue_state = ReadyQueueState(**ready_queue_data)
|
|
||||||
|
|
||||||
graph_runtime_state = GraphRuntimeState(
|
|
||||||
variable_pool=original_runtime_state.variable_pool,
|
|
||||||
start_at=original_runtime_state.start_at,
|
|
||||||
ready_queue=ready_queue_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update all nodes to use the new GraphRuntimeState
|
|
||||||
for node in graph.nodes.values():
|
|
||||||
node.graph_runtime_state = graph_runtime_state
|
|
||||||
|
|
||||||
# Create GraphEngine
|
|
||||||
command_channel = InMemoryChannel()
|
|
||||||
engine = GraphEngine(
|
|
||||||
tenant_id="test-tenant",
|
|
||||||
app_id="test-app",
|
|
||||||
workflow_id="test-workflow",
|
|
||||||
user_id="test-user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
graph=graph,
|
|
||||||
graph_config={},
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
command_channel=command_channel,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify that the ready_queue was loaded from GraphRuntimeState
|
|
||||||
assert engine._ready_queue.qsize() == 3
|
|
||||||
|
|
||||||
# Verify the initial state matches what was provided
|
|
||||||
initial_queue_state = engine.graph_runtime_state.ready_queue
|
|
||||||
assert initial_queue_state["type"] == "InMemoryReadyQueue"
|
|
||||||
assert initial_queue_state["version"] == "1.0"
|
|
||||||
assert len(initial_queue_state["items"]) == 3
|
|
||||||
assert initial_queue_state["items"] == ["node1", "node2", "node3"]
|
|
||||||
|
|||||||
Reference in New Issue
Block a user