mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 23:18:05 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -10,6 +10,8 @@ from typing_extensions import TypeIs
|
||||
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
@ -217,6 +219,13 @@ class IterationNode(Node):
|
||||
graph_engine=graph_engine,
|
||||
)
|
||||
|
||||
# Sync conversation variables after each iteration completes
|
||||
self._sync_conversation_variables_from_snapshot(
|
||||
self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
@ -235,7 +244,10 @@ class IterationNode(Node):
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
|
||||
future_to_index: dict[
|
||||
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
|
||||
int,
|
||||
] = {}
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
yield IterationNextEvent(index=index)
|
||||
future = executor.submit(
|
||||
@ -252,7 +264,7 @@ class IterationNode(Node):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
result = future.result()
|
||||
iter_start_at, events, output_value, tokens_used = result
|
||||
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
|
||||
|
||||
# Update outputs at the correct index
|
||||
outputs[index] = output_value
|
||||
@ -264,6 +276,9 @@ class IterationNode(Node):
|
||||
self.graph_runtime_state.total_tokens += tokens_used
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
# Sync conversation variables after iteration completion
|
||||
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors based on error_handle_mode
|
||||
match self._node_data.error_handle_mode:
|
||||
@ -288,7 +303,7 @@ class IterationNode(Node):
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@ -307,8 +322,17 @@ class IterationNode(Node):
|
||||
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
conversation_snapshot = self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||
return (
|
||||
iter_start_at,
|
||||
events,
|
||||
output_value,
|
||||
graph_engine.graph_runtime_state.total_tokens,
|
||||
conversation_snapshot,
|
||||
)
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
@ -430,6 +454,23 @@ class IterationNode(Node):
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
|
||||
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
|
||||
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
|
||||
parent_pool = self.graph_runtime_state.variable_pool
|
||||
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
|
||||
current_keys = set(parent_conversations.keys())
|
||||
snapshot_keys = set(snapshot.keys())
|
||||
|
||||
for removed_key in current_keys - snapshot_keys:
|
||||
parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key))
|
||||
|
||||
for name, variable in snapshot.items():
|
||||
parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable)
|
||||
|
||||
def _append_iteration_info_to_event(
|
||||
self,
|
||||
event: GraphNodeEventBase,
|
||||
|
||||
Reference in New Issue
Block a user