refactor: make code simpler

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-18 17:50:18 +08:00
parent 2175ae0e97
commit 5cad0caaae
6 changed files with 23 additions and 23 deletions

View File

@ -1,7 +1,6 @@
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from contextlib import AbstractContextManager
from typing import Any, cast
from configs import dify_config
@ -40,9 +39,6 @@ logger = logging.getLogger(__name__)
class _WorkflowChildEngineBuilder:
def __init__(self, execution_context: AbstractContextManager[object] | None = None) -> None:
self._execution_context = execution_context
@staticmethod
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
"""
@ -97,7 +93,6 @@ class _WorkflowChildEngineBuilder:
command_channel=command_channel,
config=config,
child_engine_builder=self,
execution_context=self._execution_context,
)
child_engine.layer(LLMQuotaLayer())
for layer in layers:
@ -149,7 +144,8 @@ class WorkflowEntry:
self.command_channel = command_channel
execution_context = capture_current_context()
self._child_engine_builder = _WorkflowChildEngineBuilder(execution_context=execution_context)
graph_runtime_state.execution_context = execution_context
self._child_engine_builder = _WorkflowChildEngineBuilder()
self.graph_engine = GraphEngine(
workflow_id=workflow_id,
graph=graph,
@ -162,7 +158,6 @@ class WorkflowEntry:
scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME,
),
child_engine_builder=self._child_engine_builder,
execution_context=execution_context,
)
# Add debug logging layer when in debug mode

View File

@ -10,7 +10,6 @@ from __future__ import annotations
import logging
import queue
from collections.abc import Generator, Mapping
from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, cast, final
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
@ -77,7 +76,6 @@ class GraphEngine:
command_channel: CommandChannel,
config: GraphEngineConfig = _DEFAULT_CONFIG,
child_engine_builder: ChildGraphEngineBuilderProtocol | None = None,
execution_context: AbstractContextManager[object] | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
@ -90,8 +88,6 @@ class GraphEngine:
self._child_engine_builder = child_engine_builder
if child_engine_builder is not None:
self._graph_runtime_state.bind_child_engine_builder(child_engine_builder)
if execution_context is not None:
self._graph_runtime_state.execution_context = execution_context
# Graph execution tracks the overall execution state
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from contextlib import AbstractContextManager, nullcontext
from contextlib import AbstractContextManager
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
@ -336,10 +336,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _capture_execution_context(self) -> AbstractContextManager[object]:
"""Return the application-supplied execution context for parallel iterations."""
execution_context = self.graph_runtime_state.execution_context
if execution_context is not None:
return execution_context
return nullcontext()
return self.graph_runtime_state.execution_context
def _handle_iteration_success(
self,

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import importlib
import json
from collections.abc import Mapping, Sequence
from contextlib import AbstractContextManager
from contextlib import AbstractContextManager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
@ -235,7 +235,7 @@ class GraphRuntimeState:
self._response_coordinator = response_coordinator
# Application code injects this when worker threads must restore request
# or framework-local state. It is intentionally excluded from snapshots.
self._execution_context = execution_context
self._execution_context = execution_context if execution_context is not None else nullcontext()
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
@ -335,12 +335,12 @@ class GraphRuntimeState:
return self._response_coordinator
@property
def execution_context(self) -> AbstractContextManager[object] | None:
def execution_context(self) -> AbstractContextManager[object]:
return self._execution_context
@execution_context.setter
def execution_context(self, value: AbstractContextManager[object] | None) -> None:
self._execution_context = value
self._execution_context = value if value is not None else nullcontext()
# ------------------------------------------------------------------
# Scalar state

View File

@ -23,6 +23,17 @@ class StubCoordinator:
class TestGraphRuntimeState:
def test_execution_context_defaults_to_empty_context(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
with state.execution_context:
assert state.execution_context is not None
state.execution_context = None
with state.execution_context:
assert state.execution_context is not None
def test_property_getters_and_setters(self):
# FIXME(-LAN-): Mock VariablePool if needed
variable_pool = VariablePool()

View File

@ -120,6 +120,7 @@ class TestWorkflowEntryInit:
def test_applies_debug_and_observability_layers(self):
graph_engine = MagicMock()
graph_runtime_state = SimpleNamespace(execution_context=None)
debug_layer = sentinel.debug_layer
execution_limits_layer = sentinel.execution_limits_layer
llm_quota_layer = sentinel.llm_quota_layer
@ -153,7 +154,7 @@ class TestWorkflowEntryInit:
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
variable_pool=sentinel.variable_pool,
graph_runtime_state=sentinel.graph_runtime_state,
graph_runtime_state=graph_runtime_state,
command_channel=None,
)
@ -161,12 +162,12 @@ class TestWorkflowEntryInit:
graph_engine_cls.assert_called_once_with(
workflow_id="workflow-id-123456",
graph=sentinel.graph,
graph_runtime_state=sentinel.graph_runtime_state,
graph_runtime_state=graph_runtime_state,
command_channel=sentinel.command_channel,
config=sentinel.graph_engine_config,
child_engine_builder=entry._child_engine_builder,
execution_context=sentinel.execution_context,
)
assert graph_runtime_state.execution_context is sentinel.execution_context
debug_logging_layer.assert_called_once_with(
level="DEBUG",
include_inputs=True,