mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
refactor: rename model instance extraction method to improve clarity and update related logic in LLMQuotaLayer; enhance unit tests for model instance handling
This commit is contained in:
@ -10,6 +10,8 @@ from dify_graph.graph_events.node import NodeRunSucceededEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
|
||||
_FETCH_MODEL_CONFIG_PATH = "dify_graph.nodes.llm.llm_utils.fetch_model_config"
|
||||
|
||||
|
||||
def _build_succeeded_event() -> NodeRunSucceededEvent:
|
||||
return NodeRunSucceededEvent(
|
||||
@ -25,44 +27,52 @@ def _build_succeeded_event() -> NodeRunSucceededEvent:
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_successful_llm_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
def _make_llm_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock:
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.node_type = node_type
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_dify_context.return_value.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
node.node_data.model = MagicMock(name="model-config")
|
||||
return node
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_successful_llm_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = _make_llm_node()
|
||||
fake_instance = MagicMock(name="model-instance")
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
with (
|
||||
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
|
||||
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct,
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=node.model_instance,
|
||||
model_instance=fake_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_question_classifier_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node = _make_llm_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER)
|
||||
node.id = "question-classifier-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_dify_context.return_value.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
fake_instance = MagicMock(name="model-instance")
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
with (
|
||||
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
|
||||
patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct,
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=node.model_instance,
|
||||
model_instance=fake_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
@ -74,8 +84,6 @@ def test_non_llm_node_is_ignored() -> None:
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.START
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_dify_context.return_value.tenant_id = "tenant-id"
|
||||
node._model_instance = object()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
@ -86,19 +94,17 @@ def test_non_llm_node_is_ignored() -> None:
|
||||
|
||||
def test_quota_error_is_handled_in_layer() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_dify_context.return_value.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
node = _make_llm_node()
|
||||
fake_instance = MagicMock(name="model-instance")
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
autospec=True,
|
||||
side_effect=ValueError("quota exceeded"),
|
||||
with (
|
||||
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
|
||||
patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
autospec=True,
|
||||
side_effect=ValueError("quota exceeded"),
|
||||
),
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
@ -108,21 +114,19 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.require_dify_context.return_value.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
node = _make_llm_node()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
fake_instance = MagicMock(name="model-instance")
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("No credits remaining"),
|
||||
with (
|
||||
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
|
||||
patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("No credits remaining"),
|
||||
),
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
@ -138,17 +142,18 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.model_instance = object()
|
||||
node = _make_llm_node()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
fake_instance = MagicMock(name="model-instance")
|
||||
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
|
||||
with (
|
||||
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
|
||||
patch(
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
|
||||
),
|
||||
):
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
@ -164,16 +169,17 @@ def test_quota_precheck_passes_without_abort() -> None:
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.model_instance = object()
|
||||
node = _make_llm_node()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
fake_instance = MagicMock(name="model-instance")
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
|
||||
with (
|
||||
patch(_FETCH_MODEL_CONFIG_PATH, return_value=(fake_instance, MagicMock())),
|
||||
patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check,
|
||||
):
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert not stop_event.is_set()
|
||||
mock_check.assert_called_once_with(model_instance=node.model_instance)
|
||||
mock_check.assert_called_once_with(model_instance=fake_instance)
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user