From ef2b5d6107eca6dcb5b609867b2a652bc6ccc46c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 1 Mar 2026 23:25:36 +0800 Subject: [PATCH] refactor(api): move llm quota deduction to app graph layer (#32786) --- api/.importlinter | 10 +- api/core/app/llm/__init__.py | 4 + api/core/app/llm/quota.py | 93 ++++++++++ api/core/app/workflow/layers/__init__.py | 2 + api/core/app/workflow/layers/llm_quota.py | 128 +++++++++++++ api/core/plugin/backwards_invocation/model.py | 14 +- .../processor/paragraph_index_processor.py | 4 +- .../router/multi_dataset_react_route.py | 4 +- .../nodes/iteration/iteration_node.py | 2 + api/core/workflow/nodes/llm/llm_utils.py | 73 +------- api/core/workflow/nodes/llm/node.py | 6 +- api/core/workflow/nodes/loop/loop_node.py | 2 + .../parameter_extractor_node.py | 7 +- .../question_classifier_node.py | 4 + api/core/workflow/workflow_entry.py | 2 + .../graph_engine/layers/test_llm_quota.py | 174 ++++++++++++++++++ 16 files changed, 434 insertions(+), 95 deletions(-) create mode 100644 api/core/app/llm/quota.py create mode 100644 api/core/app/workflow/layers/llm_quota.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py diff --git a/api/.importlinter b/api/.importlinter index 49cf70d61a..3b1f58d886 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -29,6 +29,8 @@ ignore_imports = core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory + core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota + core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine core.workflow.nodes.iteration.iteration_node -> core.workflow.graph @@ -107,14 +109,12 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> core.tools.tool_manager core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory + core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory - core.workflow.nodes.llm.llm_utils -> configs core.workflow.nodes.llm.llm_utils -> core.model_manager core.workflow.nodes.llm.protocols -> core.model_manager core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model core.workflow.nodes.llm.llm_utils -> models.model - core.workflow.nodes.llm.llm_utils -> models.provider - core.workflow.nodes.llm.llm_utils -> services.credit_pool_service core.workflow.nodes.llm.node -> core.tools.signature core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler core.workflow.nodes.tool.tool_node -> core.tools.tool_engine @@ -135,8 +135,8 @@ ignore_imports = core.workflow.nodes.start.start_node -> core.app.app_config.entities core.workflow.workflow_entry -> core.app.apps.exc core.workflow.workflow_entry -> core.app.entities.app_invoke_entities + core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota core.workflow.workflow_entry -> core.app.workflow.node_factory - core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer @@ -180,7 +180,7 @@ ignore_imports = core.workflow.workflow_entry -> extensions.otel.runtime core.workflow.nodes.agent.agent_node -> models core.workflow.nodes.base.node -> models.enums - core.workflow.nodes.llm.llm_utils -> models.provider_ids + core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota core.workflow.nodes.llm.node -> models.model core.workflow.workflow_entry -> models.enums core.workflow.nodes.agent.agent_node -> services diff --git a/api/core/app/llm/__init__.py b/api/core/app/llm/__init__.py index 5ac76c8086..f069bede74 100644 --- a/api/core/app/llm/__init__.py +++ b/api/core/app/llm/__init__.py @@ -1 +1,5 @@ """LLM-related application services.""" + +from .quota import deduct_llm_quota, ensure_llm_quota_available + +__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"] diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py new file mode 100644 index 0000000000..1c66c8c1ff --- /dev/null +++ b/api/core/app/llm/quota.py @@ -0,0 +1,93 @@ +from sqlalchemy import update +from sqlalchemy.orm import Session + +from configs import dify_config +from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit +from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMUsage +from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now +from models.provider import Provider, ProviderType +from models.provider_ids import ModelProviderID + + +def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None: + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + provider_model = provider_configuration.get_provider_model( + model_type=model_instance.model_type_instance.model_type, + model=model_instance.model_name, + ) + if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.") + + +def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = dify_config.get_model_credits(model_instance.model_name) + else: + used_quota = 1 + + if used_quota is not None and system_configuration.current_quota_type is not None: + if system_configuration.current_quota_type == ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + ) + elif system_configuration.current_quota_type == ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", + ) + else: + with Session(db.engine) as session: + stmt = ( + update(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used, + ) + .values( + quota_used=Provider.quota_used + used_quota, + last_used=naive_utc_now(), + ) + ) + session.execute(stmt) + session.commit() diff --git a/api/core/app/workflow/layers/__init__.py b/api/core/app/workflow/layers/__init__.py index 945f75303c..7d5841275d 100644 --- a/api/core/app/workflow/layers/__init__.py +++ b/api/core/app/workflow/layers/__init__.py @@ -1,9 +1,11 @@ """Workflow-level GraphEngine layers that depend on outer infrastructure.""" +from .llm_quota import LLMQuotaLayer from .observability import ObservabilityLayer from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer __all__ = [ + "LLMQuotaLayer", "ObservabilityLayer", "PersistenceWorkflowInfo", "WorkflowPersistenceLayer", diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py new file mode 100644 index 0000000000..45fb84c81f --- /dev/null +++ b/api/core/app/workflow/layers/llm_quota.py @@ -0,0 +1,128 @@ +""" +LLM quota deduction layer for GraphEngine. + +This layer centralizes model-quota deduction outside node implementations. +""" + +import logging +from typing import TYPE_CHECKING, cast, final + +from typing_extensions import override + +from core.app.llm import deduct_llm_quota, ensure_llm_quota_available +from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance +from core.workflow.enums import NodeType +from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase +from core.workflow.graph_events.node import NodeRunSucceededEvent +from core.workflow.nodes.base.node import Node + +if TYPE_CHECKING: + from core.workflow.nodes.llm.node import LLMNode + from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode + from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode + +logger = logging.getLogger(__name__) + + +@final +class LLMQuotaLayer(GraphEngineLayer): + """Graph layer that applies LLM quota deduction after node execution.""" + + def __init__(self) -> None: + super().__init__() + self._abort_sent = False + + @override + def on_graph_start(self) -> None: + self._abort_sent = False + + @override + def on_event(self, event: GraphEngineEvent) -> None: + _ = event + + @override + def on_graph_end(self, error: Exception | None) -> None: + _ = error + + @override + def on_node_run_start(self, node: Node) -> None: + if self._abort_sent: + return + + model_instance = self._extract_model_instance(node) + if model_instance is None: + return + + try: + ensure_llm_quota_available(model_instance=model_instance) + except QuotaExceededError as exc: + self._set_stop_event(node) + self._send_abort_command(reason=str(exc)) + logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc) + + @override + def on_node_run_end( + self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None + ) -> None: + if error is not None or not isinstance(result_event, NodeRunSucceededEvent): + return + + model_instance = self._extract_model_instance(node) + if model_instance is None: + return + + try: + deduct_llm_quota( + tenant_id=node.tenant_id, + model_instance=model_instance, + usage=result_event.node_run_result.llm_usage, + ) + except QuotaExceededError as exc: + self._set_stop_event(node) + self._send_abort_command(reason=str(exc)) + logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc) + except Exception: + logger.exception("LLM quota deduction failed, node_id=%s", node.id) + + @staticmethod + def _set_stop_event(node: Node) -> None: + stop_event = getattr(node.graph_runtime_state, "stop_event", None) + if stop_event is not None: + stop_event.set() + + def _send_abort_command(self, *, reason: str) -> None: + if not self.command_channel or self._abort_sent: + return + + try: + self.command_channel.send_command( + AbortCommand( + command_type=CommandType.ABORT, + reason=reason, + ) + ) + self._abort_sent = True + except Exception: + logger.exception("Failed to send quota abort command") + + @staticmethod + def _extract_model_instance(node: Node) -> ModelInstance | None: + try: + match node.node_type: + case NodeType.LLM: + return cast("LLMNode", node).model_instance + case NodeType.PARAMETER_EXTRACTOR: + return cast("ParameterExtractorNode", node).model_instance + case NodeType.QUESTION_CLASSIFIER: + return cast("QuestionClassifierNode", node).model_instance + case _: + return None + except AttributeError: + logger.warning( + "LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s", + node.id, + ) + return None diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 6cdc047a64..4ecc22834d 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -2,6 +2,7 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator +from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import ( @@ -29,7 +30,6 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from core.workflow.nodes.llm import llm_utils from models.account import Tenant @@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): def handle() -> Generator[LLMResultChunk, None, None]: for chunk in response: if chunk.delta.usage: - llm_utils.deduct_llm_quota( - tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage - ) + deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage) chunk.prompt_messages = [] yield chunk return handle() else: if response.usage: - llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) + deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]: yield LLMResultChunk( @@ -126,16 +124,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: for chunk in response: if chunk.delta.usage: - llm_utils.deduct_llm_quota( - tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage - ) + deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage) chunk.prompt_messages = [] yield chunk return handle() else: if response.usage: - llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) + deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) def handle_non_streaming( response: LLMResultWithStructuredOutput, diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index cfeee4afc7..df5c89a522 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -8,6 +8,7 @@ from typing import Any, cast logger = logging.getLogger(__name__) +from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance @@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from core.workflow.file import File, FileTransferMethod, FileType, file_manager -from core.workflow.nodes.llm import llm_utils from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs import helper @@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Deduct quota for summary generation (same as workflow nodes) try: - llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) except Exception as e: # Log but don't fail summary generation if quota deduction fails logger.warning("Failed to deduct quota for summary generation: %s", str(e)) diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 8f3bec2704..fa2007122d 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -2,6 +2,7 @@ from collections.abc import Generator, Sequence from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.app.llm import deduct_llm_quota from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool @@ -9,7 +10,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from core.workflow.nodes.llm import llm_utils PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -162,7 +162,7 @@ class ReactMultiDatasetRouter: text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota - llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) return text, usage diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 5e7aa2a751..54b0561dd8 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -588,6 +588,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): def _create_graph_engine(self, index: int, item: object): # Import dependencies + from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph @@ -642,5 +643,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs config=GraphEngineConfig(), ) + graph_engine.layer(LLMQuotaLayer()) return graph_engine diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index f753e19897..b751640e1b 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -1,14 +1,11 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import select, update +from sqlalchemy import select from sqlalchemy.orm import Session -from configs import dify_config -from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig @@ -17,10 +14,7 @@ from core.workflow.file.models import File from core.workflow.runtime import VariablePool from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment from extensions.ext_database import db -from libs.datetime_utils import naive_utc_now from models.model import Conversation -from models.provider import Provider, ProviderType -from models.provider_ids import ModelProviderID from .exc import InvalidVariableTypeError @@ -68,68 +62,3 @@ def fetch_memory( memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) return memory - - -def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage): - provider_model_bundle = model_instance.provider_model_bundle - provider_configuration = provider_model_bundle.configuration - - if provider_configuration.using_provider_type != ProviderType.SYSTEM: - return - - system_configuration = provider_configuration.system_configuration - - quota_unit = None - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.quota_type == system_configuration.current_quota_type: - quota_unit = quota_configuration.quota_unit - - if quota_configuration.quota_limit == -1: - return - - break - - used_quota = None - if quota_unit: - if quota_unit == QuotaUnit.TOKENS: - used_quota = usage.total_tokens - elif quota_unit == QuotaUnit.CREDITS: - used_quota = dify_config.get_model_credits(model_instance.model_name) - else: - used_quota = 1 - - if used_quota is not None and system_configuration.current_quota_type is not None: - if system_configuration.current_quota_type == ProviderQuotaType.TRIAL: - from services.credit_pool_service import CreditPoolService - - CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=used_quota, - ) - elif system_configuration.current_quota_type == ProviderQuotaType.PAID: - from services.credit_pool_service import CreditPoolService - - CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=used_quota, - pool_type="paid", - ) - else: - with Session(db.engine) as session: - stmt = ( - update(Provider) - .where( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used, - ) - .values( - quota_used=Provider.quota_used + used_quota, - last_used=naive_utc_now(), - ) - ) - session.execute(stmt) - session.commit() diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 057a144e89..4378201eee 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -278,8 +278,6 @@ class LLMNode(Node[LLMNodeData]): else None ) - # deduct quota - llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) break elif isinstance(event, LLMStructuredOutput): structured_output = event @@ -1234,6 +1232,10 @@ class LLMNode(Node[LLMNodeData]): def retry(self) -> bool: return self.node_data.retry_config.retry_enabled + @property + def model_instance(self) -> ModelInstance: + return self._model_instance + def _combine_message_content_with_role( *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index c546df1fba..40ec0cf8b1 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -413,6 +413,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): def _create_graph_engine(self, start_at: datetime, root_node_id: str): # Import dependencies + from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph @@ -454,5 +455,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs config=GraphEngineConfig(), ) + graph_engine.layer(LLMQuotaLayer()) return graph_engine diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 66ef17e585..af3a4cdad3 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -308,9 +308,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): usage = invoke_result.usage tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None - # deduct quota - llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - return text, usage, tool_call def _generate_function_call_prompt( @@ -828,6 +825,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): return rest_tokens + @property + def model_instance(self) -> ModelInstance: + return self._model_instance + @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 464d9b6b9c..5d5edcc0f7 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -240,6 +240,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): llm_usage=usage, ) + @property + def model_instance(self) -> ModelInstance: + return self._model_instance + @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index a724fbcab7..2ea4266b16 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -6,6 +6,7 @@ from typing import Any, cast from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID @@ -106,6 +107,7 @@ class WorkflowEntry: max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME ) self.graph_engine.layer(limits_layer) + self.graph_engine.layer(LLMQuotaLayer()) # Add observability layer when OTel is enabled if dify_config.ENABLE_OTEL or is_instrument_flag_enabled(): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py new file mode 100644 index 0000000000..9a491d24e1 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -0,0 +1,174 @@ +import threading +from datetime import datetime +from unittest.mock import MagicMock, patch + +from core.app.workflow.layers.llm_quota import LLMQuotaLayer +from core.errors.error import QuotaExceededError +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph_engine.entities.commands import CommandType +from core.workflow.graph_events.node import NodeRunSucceededEvent +from core.workflow.node_events import NodeRunResult + + +def _build_succeeded_event() -> NodeRunSucceededEvent: + return NodeRunSucceededEvent( + id="execution-id", + node_id="llm-node-id", + node_type=NodeType.LLM, + start_at=datetime.now(), + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"question": "hello"}, + llm_usage=LLMUsage.empty_usage(), + ), + ) + + +def test_deduct_quota_called_for_successful_llm_node() -> None: + layer = LLMQuotaLayer() + node = MagicMock() + node.id = "llm-node-id" + node.execution_id = "execution-id" + node.node_type = NodeType.LLM + node.tenant_id = "tenant-id" + node.model_instance = object() + + result_event = _build_succeeded_event() + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + layer.on_node_run_end(node=node, error=None, result_event=result_event) + + mock_deduct.assert_called_once_with( + tenant_id="tenant-id", + model_instance=node.model_instance, + usage=result_event.node_run_result.llm_usage, + ) + + +def test_deduct_quota_called_for_question_classifier_node() -> None: + layer = LLMQuotaLayer() + node = MagicMock() + node.id = "question-classifier-node-id" + node.execution_id = "execution-id" + node.node_type = NodeType.QUESTION_CLASSIFIER + node.tenant_id = "tenant-id" + node.model_instance = object() + + result_event = _build_succeeded_event() + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + layer.on_node_run_end(node=node, error=None, result_event=result_event) + + mock_deduct.assert_called_once_with( + tenant_id="tenant-id", + model_instance=node.model_instance, + usage=result_event.node_run_result.llm_usage, + ) + + +def test_non_llm_node_is_ignored() -> None: + layer = LLMQuotaLayer() + node = MagicMock() + node.id = "start-node-id" + node.execution_id = "execution-id" + node.node_type = NodeType.START + node.tenant_id = "tenant-id" + node._model_instance = object() + + result_event = _build_succeeded_event() + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + layer.on_node_run_end(node=node, error=None, result_event=result_event) + + mock_deduct.assert_not_called() + + +def test_quota_error_is_handled_in_layer() -> None: + layer = LLMQuotaLayer() + node = MagicMock() + node.id = "llm-node-id" + node.execution_id = "execution-id" + node.node_type = NodeType.LLM + node.tenant_id = "tenant-id" + node.model_instance = object() + + result_event = _build_succeeded_event() + with patch( + "core.app.workflow.layers.llm_quota.deduct_llm_quota", + autospec=True, + side_effect=ValueError("quota exceeded"), + ): + layer.on_node_run_end(node=node, error=None, result_event=result_event) + + +def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: + layer = LLMQuotaLayer() + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = MagicMock() + node.id = "llm-node-id" + node.execution_id = "execution-id" + node.node_type = NodeType.LLM + node.tenant_id = "tenant-id" + node.model_instance = object() + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.stop_event = stop_event + + result_event = _build_succeeded_event() + with patch( + "core.app.workflow.layers.llm_quota.deduct_llm_quota", + autospec=True, + side_effect=QuotaExceededError("No credits remaining"), + ): + layer.on_node_run_end(node=node, error=None, result_event=result_event) + + assert stop_event.is_set() + layer.command_channel.send_command.assert_called_once() + abort_command = layer.command_channel.send_command.call_args.args[0] + assert abort_command.command_type == CommandType.ABORT + assert abort_command.reason == "No credits remaining" + + +def test_quota_precheck_failure_aborts_workflow_immediately() -> None: + layer = LLMQuotaLayer() + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = MagicMock() + node.id = "llm-node-id" + node.node_type = NodeType.LLM + node.model_instance = object() + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.stop_event = stop_event + + with patch( + "core.app.workflow.layers.llm_quota.ensure_llm_quota_available", + autospec=True, + side_effect=QuotaExceededError("Model provider openai quota exceeded."), + ): + layer.on_node_run_start(node) + + assert stop_event.is_set() + layer.command_channel.send_command.assert_called_once() + abort_command = layer.command_channel.send_command.call_args.args[0] + assert abort_command.command_type == CommandType.ABORT + assert abort_command.reason == "Model provider openai quota exceeded." + + +def test_quota_precheck_passes_without_abort() -> None: + layer = LLMQuotaLayer() + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = MagicMock() + node.id = "llm-node-id" + node.node_type = NodeType.LLM + node.model_instance = object() + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.stop_event = stop_event + + with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check: + layer.on_node_run_start(node) + + assert not stop_event.is_set() + mock_check.assert_called_once_with(model_instance=node.model_instance) + layer.command_channel.send_command.assert_not_called()