mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 23:36:20 +08:00
649 lines
26 KiB
Python
649 lines
26 KiB
Python
import logging
|
|
from collections.abc import Generator, Mapping, Sequence
|
|
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
|
from datetime import UTC, datetime
|
|
from typing import TYPE_CHECKING, Any, NewType, cast
|
|
|
|
from typing_extensions import TypeIs
|
|
|
|
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
|
from dify_graph.enums import (
|
|
NodeExecutionType,
|
|
NodeType,
|
|
WorkflowNodeExecutionMetadataKey,
|
|
WorkflowNodeExecutionStatus,
|
|
)
|
|
from dify_graph.graph_events import (
|
|
GraphNodeEventBase,
|
|
GraphRunFailedEvent,
|
|
GraphRunPartialSucceededEvent,
|
|
GraphRunSucceededEvent,
|
|
)
|
|
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
|
from dify_graph.node_events import (
|
|
IterationFailedEvent,
|
|
IterationNextEvent,
|
|
IterationStartedEvent,
|
|
IterationSucceededEvent,
|
|
NodeEventBase,
|
|
NodeRunResult,
|
|
StreamCompletedEvent,
|
|
)
|
|
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
|
from dify_graph.nodes.base.node import Node
|
|
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
|
from dify_graph.runtime import VariablePool
|
|
from dify_graph.variables import IntegerVariable, NoneSegment
|
|
from dify_graph.variables.segments import ArrayAnySegment, ArraySegment
|
|
from dify_graph.variables.variables import Variable
|
|
from libs.datetime_utils import naive_utc_now
|
|
|
|
from .exc import (
|
|
InvalidIteratorValueError,
|
|
IterationGraphNotFoundError,
|
|
IterationIndexNotFoundError,
|
|
IterationNodeError,
|
|
IteratorVariableNotFoundError,
|
|
StartNodeIdNotFoundError,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from dify_graph.context import IExecutionContext
|
|
from dify_graph.graph_engine import GraphEngine
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
|
|
|
|
|
class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|
"""
|
|
Iteration Node.
|
|
"""
|
|
|
|
node_type = NodeType.ITERATION
|
|
execution_type = NodeExecutionType.CONTAINER
|
|
|
|
@classmethod
|
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
|
return {
|
|
"type": "iteration",
|
|
"config": {
|
|
"is_parallel": False,
|
|
"parallel_nums": 10,
|
|
"error_handle_mode": ErrorHandleMode.TERMINATED,
|
|
"flatten_output": True,
|
|
},
|
|
}
|
|
|
|
@classmethod
|
|
def version(cls) -> str:
|
|
return "1"
|
|
|
|
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore
|
|
variable = self._get_iterator_variable()
|
|
|
|
if self._is_empty_iteration(variable):
|
|
yield from self._handle_empty_iteration(variable)
|
|
return
|
|
|
|
iterator_list_value = self._validate_and_get_iterator_list(variable)
|
|
inputs = {"iterator_selector": iterator_list_value}
|
|
|
|
self._validate_start_node()
|
|
|
|
started_at = naive_utc_now()
|
|
iter_run_map: dict[str, float] = {}
|
|
outputs: list[object] = []
|
|
usage_accumulator = [LLMUsage.empty_usage()]
|
|
|
|
yield IterationStartedEvent(
|
|
start_at=started_at,
|
|
inputs=inputs,
|
|
metadata={"iteration_length": len(iterator_list_value)},
|
|
)
|
|
|
|
try:
|
|
yield from self._execute_iterations(
|
|
iterator_list_value=iterator_list_value,
|
|
outputs=outputs,
|
|
iter_run_map=iter_run_map,
|
|
usage_accumulator=usage_accumulator,
|
|
)
|
|
|
|
self._accumulate_usage(usage_accumulator[0])
|
|
yield from self._handle_iteration_success(
|
|
started_at=started_at,
|
|
inputs=inputs,
|
|
outputs=outputs,
|
|
iterator_list_value=iterator_list_value,
|
|
iter_run_map=iter_run_map,
|
|
usage=usage_accumulator[0],
|
|
)
|
|
except IterationNodeError as e:
|
|
self._accumulate_usage(usage_accumulator[0])
|
|
yield from self._handle_iteration_failure(
|
|
started_at=started_at,
|
|
inputs=inputs,
|
|
outputs=outputs,
|
|
iterator_list_value=iterator_list_value,
|
|
iter_run_map=iter_run_map,
|
|
usage=usage_accumulator[0],
|
|
error=e,
|
|
)
|
|
|
|
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
|
|
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
|
|
|
if not variable:
|
|
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
|
|
|
|
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
|
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
|
|
|
return variable
|
|
|
|
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
|
|
return isinstance(variable, NoneSegment) or len(variable.value) == 0
|
|
|
|
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
|
|
# Try our best to preserve the type information.
|
|
if isinstance(variable, ArraySegment):
|
|
output = variable.model_copy(update={"value": []})
|
|
else:
|
|
output = ArrayAnySegment(value=[])
|
|
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
|
# from graph definition?
|
|
outputs={"output": output},
|
|
)
|
|
)
|
|
|
|
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
|
|
iterator_list_value = variable.to_object()
|
|
|
|
if not isinstance(iterator_list_value, list):
|
|
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
|
|
|
return cast(list[object], iterator_list_value)
|
|
|
|
def _validate_start_node(self) -> None:
|
|
if not self.node_data.start_node_id:
|
|
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
|
|
|
def _execute_iterations(
|
|
self,
|
|
iterator_list_value: Sequence[object],
|
|
outputs: list[object],
|
|
iter_run_map: dict[str, float],
|
|
usage_accumulator: list[LLMUsage],
|
|
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
|
if self.node_data.is_parallel:
|
|
# Parallel mode execution
|
|
yield from self._execute_parallel_iterations(
|
|
iterator_list_value=iterator_list_value,
|
|
outputs=outputs,
|
|
iter_run_map=iter_run_map,
|
|
usage_accumulator=usage_accumulator,
|
|
)
|
|
else:
|
|
# Sequential mode execution
|
|
for index, item in enumerate(iterator_list_value):
|
|
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
|
yield IterationNextEvent(index=index)
|
|
|
|
graph_engine = self._create_graph_engine(index, item)
|
|
|
|
# Run the iteration
|
|
yield from self._run_single_iter(
|
|
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
|
outputs=outputs,
|
|
graph_engine=graph_engine,
|
|
)
|
|
|
|
# Sync conversation variables after each iteration completes
|
|
self._sync_conversation_variables_from_snapshot(
|
|
self._extract_conversation_variable_snapshot(
|
|
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
|
)
|
|
)
|
|
|
|
# Accumulate usage from this iteration
|
|
usage_accumulator[0] = self._merge_usage(
|
|
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
|
|
)
|
|
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
|
|
|
def _execute_parallel_iterations(
|
|
self,
|
|
iterator_list_value: Sequence[object],
|
|
outputs: list[object],
|
|
iter_run_map: dict[str, float],
|
|
usage_accumulator: list[LLMUsage],
|
|
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
|
# Initialize outputs list with None values to maintain order
|
|
outputs.extend([None] * len(iterator_list_value))
|
|
|
|
# Determine the number of parallel workers
|
|
max_workers = min(self.node_data.parallel_nums, len(iterator_list_value))
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
# Submit all iteration tasks
|
|
future_to_index: dict[
|
|
Future[
|
|
tuple[
|
|
datetime,
|
|
list[GraphNodeEventBase],
|
|
object | None,
|
|
dict[str, Variable],
|
|
LLMUsage,
|
|
]
|
|
],
|
|
int,
|
|
] = {}
|
|
for index, item in enumerate(iterator_list_value):
|
|
yield IterationNextEvent(index=index)
|
|
future = executor.submit(
|
|
self._execute_single_iteration_parallel,
|
|
index=index,
|
|
item=item,
|
|
execution_context=self._capture_execution_context(),
|
|
)
|
|
future_to_index[future] = index
|
|
|
|
# Process completed iterations as they finish
|
|
for future in as_completed(future_to_index):
|
|
index = future_to_index[future]
|
|
try:
|
|
result = future.result()
|
|
(
|
|
iter_start_at,
|
|
events,
|
|
output_value,
|
|
conversation_snapshot,
|
|
iteration_usage,
|
|
) = result
|
|
|
|
# Update outputs at the correct index
|
|
outputs[index] = output_value
|
|
|
|
# Yield all events from this iteration
|
|
yield from events
|
|
|
|
# Update tokens and timing
|
|
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
|
|
|
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
|
|
|
# Sync conversation variables after iteration completion
|
|
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
|
|
|
except Exception as e:
|
|
# Handle errors based on error_handle_mode
|
|
match self.node_data.error_handle_mode:
|
|
case ErrorHandleMode.TERMINATED:
|
|
# Cancel remaining futures and re-raise
|
|
for f in future_to_index:
|
|
if f != future:
|
|
f.cancel()
|
|
raise IterationNodeError(str(e))
|
|
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
|
outputs[index] = None
|
|
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
|
outputs[index] = None # Will be filtered later
|
|
|
|
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
|
|
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
|
outputs[:] = [output for output in outputs if output is not None]
|
|
|
|
def _execute_single_iteration_parallel(
|
|
self,
|
|
index: int,
|
|
item: object,
|
|
execution_context: "IExecutionContext",
|
|
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
|
"""Execute a single iteration in parallel mode and return results."""
|
|
with execution_context:
|
|
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
|
events: list[GraphNodeEventBase] = []
|
|
outputs_temp: list[object] = []
|
|
|
|
graph_engine = self._create_graph_engine(index, item)
|
|
|
|
# Collect events instead of yielding them directly
|
|
for event in self._run_single_iter(
|
|
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
|
outputs=outputs_temp,
|
|
graph_engine=graph_engine,
|
|
):
|
|
events.append(event)
|
|
|
|
# Get the output value from the temporary outputs list
|
|
output_value = outputs_temp[0] if outputs_temp else None
|
|
conversation_snapshot = self._extract_conversation_variable_snapshot(
|
|
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
|
)
|
|
|
|
return (
|
|
iter_start_at,
|
|
events,
|
|
output_value,
|
|
conversation_snapshot,
|
|
graph_engine.graph_runtime_state.llm_usage,
|
|
)
|
|
|
|
def _capture_execution_context(self) -> "IExecutionContext":
|
|
"""Capture current execution context for parallel iterations."""
|
|
from dify_graph.context import capture_current_context
|
|
|
|
return capture_current_context()
|
|
|
|
def _handle_iteration_success(
|
|
self,
|
|
started_at: datetime,
|
|
inputs: dict[str, Sequence[object]],
|
|
outputs: list[object],
|
|
iterator_list_value: Sequence[object],
|
|
iter_run_map: dict[str, float],
|
|
*,
|
|
usage: LLMUsage,
|
|
) -> Generator[NodeEventBase, None, None]:
|
|
# Flatten the list of lists if all outputs are lists
|
|
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
|
|
|
yield IterationSucceededEvent(
|
|
start_at=started_at,
|
|
inputs=inputs,
|
|
outputs={"output": flattened_outputs},
|
|
steps=len(iterator_list_value),
|
|
metadata={
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
|
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
|
},
|
|
)
|
|
|
|
# Yield final success event
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
outputs={"output": flattened_outputs},
|
|
metadata={
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
|
},
|
|
llm_usage=usage,
|
|
)
|
|
)
|
|
|
|
def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]:
|
|
"""
|
|
Flatten the outputs list if all elements are lists.
|
|
This maintains backward compatibility with version 1.8.1 behavior.
|
|
|
|
If flatten_output is False, returns outputs as-is (nested structure).
|
|
If flatten_output is True (default), flattens the list if all elements are lists.
|
|
"""
|
|
# If flatten_output is disabled, return outputs as-is
|
|
if not self.node_data.flatten_output:
|
|
return outputs
|
|
|
|
if not outputs:
|
|
return outputs
|
|
|
|
# Check if all non-None outputs are lists
|
|
non_none_outputs: list[object] = [output for output in outputs if output is not None]
|
|
if not non_none_outputs:
|
|
return outputs
|
|
|
|
if all(isinstance(output, list) for output in non_none_outputs):
|
|
# Flatten the list of lists
|
|
flattened: list[Any] = []
|
|
for output in outputs:
|
|
if isinstance(output, list):
|
|
flattened.extend(output)
|
|
elif output is not None:
|
|
# This shouldn't happen based on our check, but handle it gracefully
|
|
flattened.append(output)
|
|
return flattened
|
|
|
|
return outputs
|
|
|
|
def _handle_iteration_failure(
|
|
self,
|
|
started_at: datetime,
|
|
inputs: dict[str, Sequence[object]],
|
|
outputs: list[object],
|
|
iterator_list_value: Sequence[object],
|
|
iter_run_map: dict[str, float],
|
|
*,
|
|
usage: LLMUsage,
|
|
error: IterationNodeError,
|
|
) -> Generator[NodeEventBase, None, None]:
|
|
# Flatten the list of lists if all outputs are lists (even in failure case)
|
|
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
|
|
|
yield IterationFailedEvent(
|
|
start_at=started_at,
|
|
inputs=inputs,
|
|
outputs={"output": flattened_outputs},
|
|
steps=len(iterator_list_value),
|
|
metadata={
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
|
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
|
},
|
|
error=str(error),
|
|
)
|
|
yield StreamCompletedEvent(
|
|
node_run_result=NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
error=str(error),
|
|
metadata={
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
|
},
|
|
llm_usage=usage,
|
|
)
|
|
)
|
|
|
|
@classmethod
|
|
def _extract_variable_selector_to_variable_mapping(
|
|
cls,
|
|
*,
|
|
graph_config: Mapping[str, Any],
|
|
node_id: str,
|
|
node_data: Mapping[str, Any],
|
|
) -> Mapping[str, Sequence[str]]:
|
|
# Create typed NodeData from dict
|
|
typed_node_data = IterationNodeData.model_validate(node_data)
|
|
|
|
variable_mapping: dict[str, Sequence[str]] = {
|
|
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
|
}
|
|
iteration_node_ids = set()
|
|
|
|
# Find all nodes that belong to this loop
|
|
nodes = graph_config.get("nodes", [])
|
|
for node in nodes:
|
|
node_data = node.get("data", {})
|
|
if node_data.get("iteration_id") == node_id:
|
|
in_iteration_node_id = node.get("id")
|
|
if in_iteration_node_id:
|
|
iteration_node_ids.add(in_iteration_node_id)
|
|
|
|
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
|
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
|
for sub_node_id, sub_node_config in node_configs.items():
|
|
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
|
|
continue
|
|
|
|
# variable selector to variable mapping
|
|
try:
|
|
# Get node class
|
|
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
|
|
|
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
|
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
|
continue
|
|
node_version = sub_node_config.get("data", {}).get("version", "1")
|
|
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
|
|
|
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
|
graph_config=graph_config, config=sub_node_config
|
|
)
|
|
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
|
except NotImplementedError:
|
|
sub_node_variable_mapping = {}
|
|
|
|
# remove iteration variables
|
|
sub_node_variable_mapping = {
|
|
sub_node_id + "." + key: value
|
|
for key, value in sub_node_variable_mapping.items()
|
|
if value[0] != node_id
|
|
}
|
|
|
|
variable_mapping.update(sub_node_variable_mapping)
|
|
|
|
# remove variable out from iteration
|
|
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
|
|
|
|
return variable_mapping
|
|
|
|
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
|
|
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
|
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
|
|
|
|
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
|
|
parent_pool = self.graph_runtime_state.variable_pool
|
|
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
|
|
|
current_keys = set(parent_conversations.keys())
|
|
snapshot_keys = set(snapshot.keys())
|
|
|
|
for removed_key in current_keys - snapshot_keys:
|
|
parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key))
|
|
|
|
for name, variable in snapshot.items():
|
|
parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable)
|
|
|
|
def _append_iteration_info_to_event(
|
|
self,
|
|
event: GraphNodeEventBase,
|
|
iter_run_index: int,
|
|
):
|
|
event.in_iteration_id = self._node_id
|
|
iter_metadata = {
|
|
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id,
|
|
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
|
|
}
|
|
|
|
current_metadata = event.node_run_result.metadata
|
|
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
|
|
event.node_run_result.metadata = {**current_metadata, **iter_metadata}
|
|
|
|
def _run_single_iter(
|
|
self,
|
|
*,
|
|
variable_pool: VariablePool,
|
|
outputs: list[object],
|
|
graph_engine: "GraphEngine",
|
|
) -> Generator[GraphNodeEventBase, None, None]:
|
|
rst = graph_engine.run()
|
|
# get current iteration index
|
|
index_variable = variable_pool.get([self._node_id, "index"])
|
|
if not isinstance(index_variable, IntegerVariable):
|
|
raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found")
|
|
current_index = index_variable.value
|
|
for event in rst:
|
|
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START:
|
|
continue
|
|
|
|
if isinstance(event, GraphNodeEventBase):
|
|
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
|
yield event
|
|
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
|
result = variable_pool.get(self.node_data.output_selector)
|
|
if result is None:
|
|
outputs.append(None)
|
|
else:
|
|
outputs.append(result.to_object())
|
|
return
|
|
elif isinstance(event, GraphRunFailedEvent):
|
|
match self.node_data.error_handle_mode:
|
|
case ErrorHandleMode.TERMINATED:
|
|
raise IterationNodeError(event.error)
|
|
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
|
outputs.append(None)
|
|
return
|
|
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
|
return
|
|
|
|
def _create_graph_engine(self, index: int, item: object):
|
|
# Import dependencies
|
|
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
|
from core.workflow.node_factory import DifyNodeFactory
|
|
from dify_graph.entities import GraphInitParams
|
|
from dify_graph.graph import Graph
|
|
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
|
|
from dify_graph.graph_engine.command_channels import InMemoryChannel
|
|
from dify_graph.runtime import GraphRuntimeState
|
|
|
|
# Create GraphInitParams from node attributes
|
|
graph_init_params = GraphInitParams(
|
|
tenant_id=self.tenant_id,
|
|
app_id=self.app_id,
|
|
workflow_id=self.workflow_id,
|
|
graph_config=self.graph_config,
|
|
user_id=self.user_id,
|
|
user_from=self.user_from.value,
|
|
invoke_from=self.invoke_from.value,
|
|
call_depth=self.workflow_call_depth,
|
|
)
|
|
# Create a deep copy of the variable pool for each iteration
|
|
variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True)
|
|
|
|
# append iteration variable (item, index) to variable pool
|
|
variable_pool_copy.add([self._node_id, "index"], index)
|
|
variable_pool_copy.add([self._node_id, "item"], item)
|
|
|
|
# Create a new GraphRuntimeState for this iteration
|
|
graph_runtime_state_copy = GraphRuntimeState(
|
|
variable_pool=variable_pool_copy,
|
|
start_at=self.graph_runtime_state.start_at,
|
|
total_tokens=0,
|
|
node_run_steps=0,
|
|
)
|
|
|
|
# Create a new node factory with the new GraphRuntimeState
|
|
node_factory = DifyNodeFactory(
|
|
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
|
)
|
|
|
|
# Initialize the iteration graph with the new node factory
|
|
iteration_graph = Graph.init(
|
|
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
|
|
)
|
|
|
|
if not iteration_graph:
|
|
raise IterationGraphNotFoundError("iteration graph not found")
|
|
|
|
# Create a new GraphEngine for this iteration
|
|
graph_engine = GraphEngine(
|
|
workflow_id=self.workflow_id,
|
|
graph=iteration_graph,
|
|
graph_runtime_state=graph_runtime_state_copy,
|
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
|
config=GraphEngineConfig(),
|
|
)
|
|
graph_engine.layer(LLMQuotaLayer())
|
|
|
|
return graph_engine
|