mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
feat(graph-engine): make layer runtime state non-null and bound early (#30552)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
|||||||
"""
|
"""
|
||||||
if isinstance(session_factory, Engine):
|
if isinstance(session_factory, Engine):
|
||||||
session_factory = sessionmaker(session_factory)
|
session_factory = sessionmaker(session_factory)
|
||||||
|
super().__init__()
|
||||||
self._session_maker = session_factory
|
self._session_maker = session_factory
|
||||||
self._state_owner_user_id = state_owner_user_id
|
self._state_owner_user_id = state_owner_user_id
|
||||||
self._generate_entity = generate_entity
|
self._generate_entity = generate_entity
|
||||||
@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
|||||||
if not isinstance(event, GraphRunPausedEvent):
|
if not isinstance(event, GraphRunPausedEvent):
|
||||||
return
|
return
|
||||||
|
|
||||||
assert self.graph_runtime_state is not None
|
|
||||||
|
|
||||||
entity_wrapper: _GenerateEntityUnion
|
entity_wrapper: _GenerateEntityUnion
|
||||||
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
|
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
|
||||||
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
|
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
|
||||||
|
|||||||
@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer):
|
|||||||
trigger_log_id: str,
|
trigger_log_id: str,
|
||||||
session_maker: sessionmaker[Session],
|
session_maker: sessionmaker[Session],
|
||||||
):
|
):
|
||||||
|
super().__init__()
|
||||||
self.trigger_log_id = trigger_log_id
|
self.trigger_log_id = trigger_log_id
|
||||||
self.start_time = start_time
|
self.start_time = start_time
|
||||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||||
@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer):
|
|||||||
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
|
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
|
||||||
|
|
||||||
# Extract relevant data from result
|
# Extract relevant data from result
|
||||||
if not self.graph_runtime_state:
|
|
||||||
logger.exception("Graph runtime state is not set")
|
|
||||||
return
|
|
||||||
|
|
||||||
outputs = self.graph_runtime_state.outputs
|
outputs = self.graph_runtime_state.outputs
|
||||||
|
|
||||||
# BASICLY, workflow_execution_id is the same as workflow_run_id
|
# BASICLY, workflow_execution_id is the same as workflow_run_id
|
||||||
|
|||||||
@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
|
|||||||
engine.layer(ExecutionLimitsLayer(max_nodes=100))
|
engine.layer(ExecutionLimitsLayer(max_nodes=100))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
|
||||||
|
can assume `graph_runtime_state` is available.
|
||||||
|
|
||||||
### Event-Driven Architecture
|
### Event-Driven Architecture
|
||||||
|
|
||||||
All node executions emit events for monitoring and integration:
|
All node executions emit events for monitoring and integration:
|
||||||
|
|||||||
@ -212,9 +212,16 @@ class GraphEngine:
|
|||||||
if id(node.graph_runtime_state) != expected_state_id:
|
if id(node.graph_runtime_state) != expected_state_id:
|
||||||
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
||||||
|
|
||||||
|
def _bind_layer_context(
|
||||||
|
self,
|
||||||
|
layer: GraphEngineLayer,
|
||||||
|
) -> None:
|
||||||
|
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
||||||
|
|
||||||
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
|
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
|
||||||
"""Add a layer for extending functionality."""
|
"""Add a layer for extending functionality."""
|
||||||
self._layers.append(layer)
|
self._layers.append(layer)
|
||||||
|
self._bind_layer_context(layer)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||||
@ -301,14 +308,7 @@ class GraphEngine:
|
|||||||
def _initialize_layers(self) -> None:
|
def _initialize_layers(self) -> None:
|
||||||
"""Initialize layers with context."""
|
"""Initialize layers with context."""
|
||||||
self._event_manager.set_layers(self._layers)
|
self._event_manager.set_layers(self._layers)
|
||||||
# Create a read-only wrapper for the runtime state
|
|
||||||
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
|
|
||||||
for layer in self._layers:
|
for layer in self._layers:
|
||||||
try:
|
|
||||||
layer.initialize(read_only_state, self._command_channel)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
layer.on_graph_start()
|
layer.on_graph_start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
|
|||||||
|
|
||||||
Abstract base class for layers.
|
Abstract base class for layers.
|
||||||
|
|
||||||
- `initialize()` - Receive runtime context
|
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
|
||||||
- `on_graph_start()` - Execution start hook
|
- `on_graph_start()` - Execution start hook
|
||||||
- `on_event()` - Process all events
|
- `on_event()` - Process all events
|
||||||
- `on_graph_end()` - Execution end hook
|
- `on_graph_end()` - Execution end hook
|
||||||
@ -34,6 +34,9 @@ engine.layer(debug_layer)
|
|||||||
engine.run()
|
engine.run()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`engine.layer()` binds the read-only runtime state before execution, so
|
||||||
|
`graph_runtime_state` is always available inside layer hooks.
|
||||||
|
|
||||||
## Custom Layers
|
## Custom Layers
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|||||||
@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node
|
|||||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
|
class GraphEngineLayerNotInitializedError(Exception):
|
||||||
|
"""Raised when a layer's runtime state is accessed before initialization."""
|
||||||
|
|
||||||
|
def __init__(self, layer_name: str | None = None) -> None:
|
||||||
|
name = layer_name or "GraphEngineLayer"
|
||||||
|
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
|
||||||
|
|
||||||
|
|
||||||
class GraphEngineLayer(ABC):
|
class GraphEngineLayer(ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for GraphEngine layers.
|
Abstract base class for GraphEngine layers.
|
||||||
@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||||
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||||
self.command_channel: CommandChannel | None = None
|
self.command_channel: CommandChannel | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
|
||||||
|
if self._graph_runtime_state is None:
|
||||||
|
raise GraphEngineLayerNotInitializedError(type(self).__name__)
|
||||||
|
return self._graph_runtime_state
|
||||||
|
|
||||||
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
|
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the layer with engine dependencies.
|
Initialize the layer with engine dependencies.
|
||||||
|
|
||||||
Called by GraphEngine before execution starts to inject the read-only runtime state
|
Called by GraphEngine to inject the read-only runtime state and command channel.
|
||||||
and command channel. This allows layers to observe engine context and send
|
This is invoked when the layer is registered with a `GraphEngine` instance.
|
||||||
commands, but prevents direct state modification.
|
Implementations should be idempotent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_runtime_state: Read-only view of the runtime state
|
graph_runtime_state: Read-only view of the runtime state
|
||||||
command_channel: Channel for sending commands to the engine
|
command_channel: Channel for sending commands to the engine
|
||||||
"""
|
"""
|
||||||
self.graph_runtime_state = graph_runtime_state
|
self._graph_runtime_state = graph_runtime_state
|
||||||
self.command_channel = command_channel
|
self.command_channel = command_channel
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
|
|||||||
self.logger.info("=" * 80)
|
self.logger.info("=" * 80)
|
||||||
self.logger.info("🚀 GRAPH EXECUTION STARTED")
|
self.logger.info("🚀 GRAPH EXECUTION STARTED")
|
||||||
self.logger.info("=" * 80)
|
self.logger.info("=" * 80)
|
||||||
|
# Log initial state
|
||||||
if self.graph_runtime_state:
|
self.logger.info("Initial State:")
|
||||||
# Log initial state
|
|
||||||
self.logger.info("Initial State:")
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_event(self, event: GraphEngineEvent) -> None:
|
def on_event(self, event: GraphEngineEvent) -> None:
|
||||||
@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
|
|||||||
self.logger.info(" Node retries: %s", self.retry_count)
|
self.logger.info(" Node retries: %s", self.retry_count)
|
||||||
|
|
||||||
# Log final state if available
|
# Log final state if available
|
||||||
if self.graph_runtime_state and self.include_outputs:
|
if self.include_outputs and self.graph_runtime_state.outputs:
|
||||||
if self.graph_runtime_state.outputs:
|
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
|
||||||
|
|
||||||
self.logger.info("=" * 80)
|
self.logger.info("=" * 80)
|
||||||
|
|||||||
@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
|||||||
if update_finished:
|
if update_finished:
|
||||||
execution.finished_at = naive_utc_now()
|
execution.finished_at = naive_utc_now()
|
||||||
runtime_state = self.graph_runtime_state
|
runtime_state = self.graph_runtime_state
|
||||||
if runtime_state is None:
|
|
||||||
return
|
|
||||||
execution.total_tokens = runtime_state.total_tokens
|
execution.total_tokens = runtime_state.total_tokens
|
||||||
execution.total_steps = runtime_state.node_run_steps
|
execution.total_steps = runtime_state.node_run_steps
|
||||||
execution.outputs = execution.outputs or runtime_state.outputs
|
execution.outputs = execution.outputs or runtime_state.outputs
|
||||||
@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
|||||||
|
|
||||||
def _system_variables(self) -> Mapping[str, Any]:
|
def _system_variables(self) -> Mapping[str, Any]:
|
||||||
runtime_state = self.graph_runtime_state
|
runtime_state = self.graph_runtime_state
|
||||||
if runtime_state is None:
|
|
||||||
return {}
|
|
||||||
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
|||||||
from core.workflow.entities.pause_reason import SchedulingPause
|
from core.workflow.entities.pause_reason import SchedulingPause
|
||||||
from core.workflow.enums import WorkflowExecutionStatus
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||||
|
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
|
||||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||||
@ -569,10 +570,10 @@ class TestPauseStatePersistenceLayerTestContainers:
|
|||||||
"""Test that layer requires proper initialization before handling events."""
|
"""Test that layer requires proper initialization before handling events."""
|
||||||
# Arrange
|
# Arrange
|
||||||
layer = self._create_pause_state_persistence_layer()
|
layer = self._create_pause_state_persistence_layer()
|
||||||
# Don't initialize - graph_runtime_state should not be set
|
# Don't initialize - graph_runtime_state should be uninitialized
|
||||||
|
|
||||||
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])
|
||||||
|
|
||||||
# Act & Assert - Should raise AttributeError
|
# Act & Assert - Should raise GraphEngineLayerNotInitializedError
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||||
layer.on_event(event)
|
layer.on_event(event)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from core.app.layers.pause_state_persist_layer import (
|
|||||||
from core.variables.segments import Segment
|
from core.variables.segments import Segment
|
||||||
from core.workflow.entities.pause_reason import SchedulingPause
|
from core.workflow.entities.pause_reason import SchedulingPause
|
||||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||||
|
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
|
||||||
from core.workflow.graph_events.graph import (
|
from core.workflow.graph_events.graph import (
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
GraphRunPausedEvent,
|
GraphRunPausedEvent,
|
||||||
@ -209,8 +210,9 @@ class TestPauseStatePersistenceLayer:
|
|||||||
|
|
||||||
assert layer._session_maker is session_factory
|
assert layer._session_maker is session_factory
|
||||||
assert layer._state_owner_user_id == state_owner_user_id
|
assert layer._state_owner_user_id == state_owner_user_id
|
||||||
assert not hasattr(layer, "graph_runtime_state")
|
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||||
assert not hasattr(layer, "command_channel")
|
_ = layer.graph_runtime_state
|
||||||
|
assert layer.command_channel is None
|
||||||
|
|
||||||
def test_initialize_sets_dependencies(self):
|
def test_initialize_sets_dependencies(self):
|
||||||
session_factory = Mock(name="session_factory")
|
session_factory = Mock(name="session_factory")
|
||||||
@ -295,7 +297,7 @@ class TestPauseStatePersistenceLayer:
|
|||||||
mock_factory.assert_not_called()
|
mock_factory.assert_not_called()
|
||||||
mock_repo.create_workflow_pause.assert_not_called()
|
mock_repo.create_workflow_pause.assert_not_called()
|
||||||
|
|
||||||
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
|
def test_on_event_raises_when_graph_runtime_state_is_uninitialized(self):
|
||||||
session_factory = Mock(name="session_factory")
|
session_factory = Mock(name="session_factory")
|
||||||
layer = PauseStatePersistenceLayer(
|
layer = PauseStatePersistenceLayer(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
@ -305,7 +307,7 @@ class TestPauseStatePersistenceLayer:
|
|||||||
|
|
||||||
event = TestDataFactory.create_graph_run_paused_event()
|
event = TestDataFactory.create_graph_run_paused_event()
|
||||||
|
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||||
layer.on_event(event)
|
layer.on_event(event)
|
||||||
|
|
||||||
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
|||||||
@ -0,0 +1,56 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.workflow.graph_engine import GraphEngine
|
||||||
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
|
from core.workflow.graph_engine.layers.base import (
|
||||||
|
GraphEngineLayer,
|
||||||
|
GraphEngineLayerNotInitializedError,
|
||||||
|
)
|
||||||
|
from core.workflow.graph_events import GraphEngineEvent
|
||||||
|
|
||||||
|
from ..test_table_runner import WorkflowRunner
|
||||||
|
|
||||||
|
|
||||||
|
class LayerForTest(GraphEngineLayer):
|
||||||
|
def on_graph_start(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_event(self, event: GraphEngineEvent) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_graph_end(self, error: Exception | None) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_runtime_state_raises_when_uninitialized() -> None:
|
||||||
|
layer = LayerForTest()
|
||||||
|
|
||||||
|
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||||
|
_ = layer.graph_runtime_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_runtime_state_available_after_engine_layer() -> None:
|
||||||
|
runner = WorkflowRunner()
|
||||||
|
fixture_data = runner.load_fixture("simple_passthrough_workflow")
|
||||||
|
graph, graph_runtime_state = runner.create_graph_from_fixture(
|
||||||
|
fixture_data,
|
||||||
|
inputs={"query": "test layer state"},
|
||||||
|
)
|
||||||
|
engine = GraphEngine(
|
||||||
|
workflow_id="test_workflow",
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
command_channel=InMemoryChannel(),
|
||||||
|
)
|
||||||
|
|
||||||
|
layer = LayerForTest()
|
||||||
|
engine.layer(layer)
|
||||||
|
|
||||||
|
outputs = layer.graph_runtime_state.outputs
|
||||||
|
ready_queue_size = layer.graph_runtime_state.ready_queue_size
|
||||||
|
|
||||||
|
assert outputs == {}
|
||||||
|
assert isinstance(ready_queue_size, int)
|
||||||
|
assert ready_queue_size >= 0
|
||||||
Reference in New Issue
Block a user