mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
Fix: surface workflow container LLM usage (#27021)
This commit is contained in:
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
@ -34,6 +35,7 @@ from core.workflow.node_events import (
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
|
||||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||
|
||||
|
||||
class IterationNode(Node):
|
||||
class IterationNode(LLMUsageTrackingMixin, Node):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
@ -118,6 +120,7 @@ class IterationNode(Node):
|
||||
started_at = naive_utc_now()
|
||||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[object] = []
|
||||
usage_accumulator = [LLMUsage.empty_usage()]
|
||||
|
||||
yield IterationStartedEvent(
|
||||
start_at=started_at,
|
||||
@ -130,22 +133,27 @@ class IterationNode(Node):
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
|
||||
self._accumulate_usage(usage_accumulator[0])
|
||||
yield from self._handle_iteration_success(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
usage=usage_accumulator[0],
|
||||
)
|
||||
except IterationNodeError as e:
|
||||
self._accumulate_usage(usage_accumulator[0])
|
||||
yield from self._handle_iteration_failure(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
usage=usage_accumulator[0],
|
||||
error=e,
|
||||
)
|
||||
|
||||
@ -196,6 +204,7 @@ class IterationNode(Node):
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
usage_accumulator: list[LLMUsage],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
if self._node_data.is_parallel:
|
||||
# Parallel mode execution
|
||||
@ -203,6 +212,7 @@ class IterationNode(Node):
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
else:
|
||||
# Sequential mode execution
|
||||
@ -228,6 +238,9 @@ class IterationNode(Node):
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
usage_accumulator[0] = self._merge_usage(
|
||||
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
|
||||
)
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
def _execute_parallel_iterations(
|
||||
@ -235,6 +248,7 @@ class IterationNode(Node):
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
usage_accumulator: list[LLMUsage],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
# Initialize outputs list with None values to maintain order
|
||||
outputs.extend([None] * len(iterator_list_value))
|
||||
@ -245,7 +259,16 @@ class IterationNode(Node):
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
future_to_index: dict[
|
||||
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
|
||||
Future[
|
||||
tuple[
|
||||
datetime,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
int,
|
||||
dict[str, VariableUnion],
|
||||
LLMUsage,
|
||||
]
|
||||
],
|
||||
int,
|
||||
] = {}
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
@ -264,7 +287,14 @@ class IterationNode(Node):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
result = future.result()
|
||||
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
|
||||
(
|
||||
iter_start_at,
|
||||
events,
|
||||
output_value,
|
||||
tokens_used,
|
||||
conversation_snapshot,
|
||||
iteration_usage,
|
||||
) = result
|
||||
|
||||
# Update outputs at the correct index
|
||||
outputs[index] = output_value
|
||||
@ -276,6 +306,8 @@ class IterationNode(Node):
|
||||
self.graph_runtime_state.total_tokens += tokens_used
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||
|
||||
# Sync conversation variables after iteration completion
|
||||
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
||||
|
||||
@ -303,7 +335,7 @@ class IterationNode(Node):
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@ -332,6 +364,7 @@ class IterationNode(Node):
|
||||
output_value,
|
||||
graph_engine.graph_runtime_state.total_tokens,
|
||||
conversation_snapshot,
|
||||
graph_engine.graph_runtime_state.llm_usage,
|
||||
)
|
||||
|
||||
def _handle_iteration_success(
|
||||
@ -341,6 +374,8 @@ class IterationNode(Node):
|
||||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
*,
|
||||
usage: LLMUsage,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists
|
||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||
@ -351,7 +386,9 @@ class IterationNode(Node):
|
||||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
)
|
||||
@ -362,8 +399,11 @@ class IterationNode(Node):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": flattened_outputs},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
@ -400,6 +440,8 @@ class IterationNode(Node):
|
||||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
*,
|
||||
usage: LLMUsage,
|
||||
error: IterationNodeError,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists (even in failure case)
|
||||
@ -411,7 +453,9 @@ class IterationNode(Node):
|
||||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
error=str(error),
|
||||
@ -420,6 +464,12 @@ class IterationNode(Node):
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(error),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user