mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 18:27:27 +08:00
refactor: make code simpler
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
@ -40,9 +39,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
def __init__(self, execution_context: AbstractContextManager[object] | None = None) -> None:
|
||||
self._execution_context = execution_context
|
||||
|
||||
@staticmethod
|
||||
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
|
||||
"""
|
||||
@ -97,7 +93,6 @@ class _WorkflowChildEngineBuilder:
|
||||
command_channel=command_channel,
|
||||
config=config,
|
||||
child_engine_builder=self,
|
||||
execution_context=self._execution_context,
|
||||
)
|
||||
child_engine.layer(LLMQuotaLayer())
|
||||
for layer in layers:
|
||||
@ -149,7 +144,8 @@ class WorkflowEntry:
|
||||
|
||||
self.command_channel = command_channel
|
||||
execution_context = capture_current_context()
|
||||
self._child_engine_builder = _WorkflowChildEngineBuilder(execution_context=execution_context)
|
||||
graph_runtime_state.execution_context = execution_context
|
||||
self._child_engine_builder = _WorkflowChildEngineBuilder()
|
||||
self.graph_engine = GraphEngine(
|
||||
workflow_id=workflow_id,
|
||||
graph=graph,
|
||||
@ -162,7 +158,6 @@ class WorkflowEntry:
|
||||
scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME,
|
||||
),
|
||||
child_engine_builder=self._child_engine_builder,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
# Add debug logging layer when in debug mode
|
||||
|
||||
@ -10,7 +10,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator, Mapping
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
@ -77,7 +76,6 @@ class GraphEngine:
|
||||
command_channel: CommandChannel,
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
child_engine_builder: ChildGraphEngineBuilderProtocol | None = None,
|
||||
execution_context: AbstractContextManager[object] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
@ -90,8 +88,6 @@ class GraphEngine:
|
||||
self._child_engine_builder = child_engine_builder
|
||||
if child_engine_builder is not None:
|
||||
self._graph_runtime_state.bind_child_engine_builder(child_engine_builder)
|
||||
if execution_context is not None:
|
||||
self._graph_runtime_state.execution_context = execution_context
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from contextlib import AbstractContextManager
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
@ -336,10 +336,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
def _capture_execution_context(self) -> AbstractContextManager[object]:
|
||||
"""Return the application-supplied execution context for parallel iterations."""
|
||||
execution_context = self.graph_runtime_state.execution_context
|
||||
if execution_context is not None:
|
||||
return execution_context
|
||||
return nullcontext()
|
||||
return self.graph_runtime_state.execution_context
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import importlib
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from contextlib import AbstractContextManager
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
|
||||
@ -235,7 +235,7 @@ class GraphRuntimeState:
|
||||
self._response_coordinator = response_coordinator
|
||||
# Application code injects this when worker threads must restore request
|
||||
# or framework-local state. It is intentionally excluded from snapshots.
|
||||
self._execution_context = execution_context
|
||||
self._execution_context = execution_context if execution_context is not None else nullcontext()
|
||||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
@ -335,12 +335,12 @@ class GraphRuntimeState:
|
||||
return self._response_coordinator
|
||||
|
||||
@property
|
||||
def execution_context(self) -> AbstractContextManager[object] | None:
|
||||
def execution_context(self) -> AbstractContextManager[object]:
|
||||
return self._execution_context
|
||||
|
||||
@execution_context.setter
|
||||
def execution_context(self, value: AbstractContextManager[object] | None) -> None:
|
||||
self._execution_context = value
|
||||
self._execution_context = value if value is not None else nullcontext()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scalar state
|
||||
|
||||
@ -23,6 +23,17 @@ class StubCoordinator:
|
||||
|
||||
|
||||
class TestGraphRuntimeState:
|
||||
def test_execution_context_defaults_to_empty_context(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
|
||||
with state.execution_context:
|
||||
assert state.execution_context is not None
|
||||
|
||||
state.execution_context = None
|
||||
|
||||
with state.execution_context:
|
||||
assert state.execution_context is not None
|
||||
|
||||
def test_property_getters_and_setters(self):
|
||||
# FIXME(-LAN-): Mock VariablePool if needed
|
||||
variable_pool = VariablePool()
|
||||
|
||||
@ -120,6 +120,7 @@ class TestWorkflowEntryInit:
|
||||
|
||||
def test_applies_debug_and_observability_layers(self):
|
||||
graph_engine = MagicMock()
|
||||
graph_runtime_state = SimpleNamespace(execution_context=None)
|
||||
debug_layer = sentinel.debug_layer
|
||||
execution_limits_layer = sentinel.execution_limits_layer
|
||||
llm_quota_layer = sentinel.llm_quota_layer
|
||||
@ -153,7 +154,7 @@ class TestWorkflowEntryInit:
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
variable_pool=sentinel.variable_pool,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=None,
|
||||
)
|
||||
|
||||
@ -161,12 +162,12 @@ class TestWorkflowEntryInit:
|
||||
graph_engine_cls.assert_called_once_with(
|
||||
workflow_id="workflow-id-123456",
|
||||
graph=sentinel.graph,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=sentinel.command_channel,
|
||||
config=sentinel.graph_engine_config,
|
||||
child_engine_builder=entry._child_engine_builder,
|
||||
execution_context=sentinel.execution_context,
|
||||
)
|
||||
assert graph_runtime_state.execution_context is sentinel.execution_context
|
||||
debug_logging_layer.assert_called_once_with(
|
||||
level="DEBUG",
|
||||
include_inputs=True,
|
||||
|
||||
Reference in New Issue
Block a user