mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
fix: workflow token usage (#26723)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
This commit is contained in:
@ -7,6 +7,7 @@ from collections.abc import Mapping
|
|||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
from typing import TYPE_CHECKING, final
|
from typing import TYPE_CHECKING, final
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.workflow.entities import GraphRuntimeState
|
from core.workflow.entities import GraphRuntimeState
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
@ -125,6 +126,7 @@ class EventHandler:
|
|||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
is_initial_attempt = node_execution.retry_count == 0
|
is_initial_attempt = node_execution.retry_count == 0
|
||||||
node_execution.mark_started(event.id)
|
node_execution.mark_started(event.id)
|
||||||
|
self._graph_runtime_state.increment_node_run_steps()
|
||||||
|
|
||||||
# Track in response coordinator for stream ordering
|
# Track in response coordinator for stream ordering
|
||||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||||
@ -163,6 +165,8 @@ class EventHandler:
|
|||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
node_execution.mark_taken()
|
node_execution.mark_taken()
|
||||||
|
|
||||||
|
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||||
|
|
||||||
# Store outputs in variable pool
|
# Store outputs in variable pool
|
||||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||||
|
|
||||||
@ -212,6 +216,8 @@ class EventHandler:
|
|||||||
node_execution.mark_failed(event.error)
|
node_execution.mark_failed(event.error)
|
||||||
self._graph_execution.record_node_failure()
|
self._graph_execution.record_node_failure()
|
||||||
|
|
||||||
|
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||||
|
|
||||||
result = self._error_handler.handle_node_failure(event)
|
result = self._error_handler.handle_node_failure(event)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
@ -235,6 +241,8 @@ class EventHandler:
|
|||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
node_execution.mark_taken()
|
node_execution.mark_taken()
|
||||||
|
|
||||||
|
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||||
|
|
||||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||||
|
|
||||||
@ -286,6 +294,19 @@ class EventHandler:
|
|||||||
self._state_manager.enqueue_node(event.node_id)
|
self._state_manager.enqueue_node(event.node_id)
|
||||||
self._state_manager.start_execution(event.node_id)
|
self._state_manager.start_execution(event.node_id)
|
||||||
|
|
||||||
|
def _accumulate_node_usage(self, usage: LLMUsage) -> None:
|
||||||
|
"""Accumulate token usage into the shared runtime state."""
|
||||||
|
if usage.total_tokens <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._graph_runtime_state.add_tokens(usage.total_tokens)
|
||||||
|
|
||||||
|
current_usage = self._graph_runtime_state.llm_usage
|
||||||
|
if current_usage.total_tokens == 0:
|
||||||
|
self._graph_runtime_state.llm_usage = usage
|
||||||
|
else:
|
||||||
|
self._graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||||
|
|
||||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||||
"""
|
"""
|
||||||
Store node outputs in the variable pool.
|
Store node outputs in the variable pool.
|
||||||
|
|||||||
Reference in New Issue
Block a user