mirror of
https://github.com/langgenius/dify.git
synced 2026-03-27 17:19:55 +08:00
refactor(api): continue decoupling dify_graph from API concerns (#33580)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@ -2,7 +2,7 @@ import contextlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
@ -14,6 +14,7 @@ from dify_graph.enums import (
|
||||
)
|
||||
from dify_graph.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
@ -31,14 +32,13 @@ from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||
from dify_graph.utils.condition.processor import ConditionProcessor
|
||||
from dify_graph.variables import Segment, SegmentType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from dify_graph.variables import Segment, SegmentType, TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_DEFAULT_CHILD_ABORT_REASON = "child graph aborted"
|
||||
|
||||
|
||||
class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
@ -91,7 +91,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||
inputs[loop_variable.label] = processed_segment.value
|
||||
|
||||
start_at = naive_utc_now()
|
||||
start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
condition_processor = ConditionProcessor()
|
||||
|
||||
loop_duration_map: dict[str, float] = {}
|
||||
@ -124,10 +124,13 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
self._clear_loop_subgraph_variables(loop_node_ids)
|
||||
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
|
||||
|
||||
loop_start_time = naive_utc_now()
|
||||
reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i)
|
||||
loop_start_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
try:
|
||||
reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i)
|
||||
finally:
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
# Track loop duration
|
||||
loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds()
|
||||
loop_duration_map[str(i)] = (datetime.now(UTC).replace(tzinfo=None) - loop_start_time).total_seconds()
|
||||
|
||||
# Accumulate outputs from the sub-graph's response nodes
|
||||
for key, value in graph_engine.graph_runtime_state.outputs.items():
|
||||
@ -142,9 +145,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
# For other outputs, just update
|
||||
self.graph_runtime_state.set_output(key, value)
|
||||
|
||||
# Accumulate usage from the sub-graph execution
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
|
||||
# Collect loop variable values after iteration
|
||||
single_loop_variable = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
@ -256,6 +256,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
yield event
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END:
|
||||
reach_break_node = True
|
||||
if isinstance(event, GraphRunAbortedEvent):
|
||||
raise RuntimeError(event.reason or _DEFAULT_CHILD_ABORT_REASON)
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
raise Exception(event.error)
|
||||
|
||||
@ -410,7 +412,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams for child graph execution.
|
||||
graph_init_params = GraphInitParams(
|
||||
@ -420,16 +421,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
|
||||
# Create a new GraphRuntimeState for this iteration
|
||||
graph_runtime_state_copy = GraphRuntimeState(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
start_at=start_at.timestamp(),
|
||||
)
|
||||
|
||||
return self.graph_runtime_state.create_child_engine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
graph_config=self.graph_config,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user