refactor: rename model instance extraction method to improve clarity and update related logic in LLMQuotaLayer; enhance unit tests for model instance handling

This commit is contained in:
Novice
2026-03-24 15:33:48 +08:00
parent e6c4bf7320
commit 7e65659239
3 changed files with 86 additions and 68 deletions

View File

@ -10,6 +10,8 @@ from dify_graph.graph_events.node import NodeRunSucceededEvent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import NodeRunResult
_FETCH_MODEL_CONFIG_PATH = "dify_graph.nodes.llm.llm_utils.fetch_model_config"
def _build_succeeded_event() -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
@ -25,44 +27,52 @@ def _build_succeeded_event() -> NodeRunSucceededEvent:
)
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
def _make_llm_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock:
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.node_type = node_type
node.tenant_id = "tenant-id"
node.require_dify_context.return_value.tenant_id = "tenant-id"
node.model_instance = object()
node.node_data.model = MagicMock(name="model-config")
return node
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = _make_llm_node()
fake_instance = MagicMock(name="model-instance")
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct,
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
model_instance=fake_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_deduct_quota_called_for_question_classifier_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node = _make_llm_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER)
node.id = "question-classifier-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.require_dify_context.return_value.tenant_id = "tenant-id"
node.model_instance = object()
fake_instance = MagicMock(name="model-instance")
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct,
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
model_instance=fake_instance,
usage=result_event.node_run_result.llm_usage,
)
@ -74,8 +84,6 @@ def test_non_llm_node_is_ignored() -> None:
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.START
node.tenant_id = "tenant-id"
node.require_dify_context.return_value.tenant_id = "tenant-id"
node._model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
@ -86,19 +94,17 @@ def test_non_llm_node_is_ignored() -> None:
def test_quota_error_is_handled_in_layer() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_dify_context.return_value.tenant_id = "tenant-id"
node.model_instance = object()
node = _make_llm_node()
fake_instance = MagicMock(name="model-instance")
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
@ -108,21 +114,19 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = BuiltinNodeTypes.LLM
node.tenant_id = "tenant-id"
node.require_dify_context.return_value.tenant_id = "tenant-id"
node.model_instance = object()
node = _make_llm_node()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
fake_instance = MagicMock(name="model-instance")
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
@ -138,17 +142,18 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance = object()
node = _make_llm_node()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
fake_instance = MagicMock(name="model-instance")
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
),
):
layer.on_node_run_start(node)
@ -164,16 +169,17 @@ def test_quota_precheck_passes_without_abort() -> None:
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = BuiltinNodeTypes.LLM
node.model_instance = object()
node = _make_llm_node()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
fake_instance = MagicMock(name="model-instance")
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
with (
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check,
):
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=node.model_instance)
mock_check.assert_called_once_with(model_instance=fake_instance)
layer.command_channel.send_command.assert_not_called()