mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
refactor(api): continue decoupling dify_graph from API concerns (#33580)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@ -27,12 +27,12 @@ class _BuiltinDummyTool(BuiltinTool):
|
||||
yield self.create_text_message("ok")
|
||||
|
||||
|
||||
def _build_tool() -> _BuiltinDummyTool:
|
||||
def _build_tool(user_id: str | None = None) -> _BuiltinDummyTool:
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"),
|
||||
parameters=[],
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
|
||||
runtime = ToolRuntime(tenant_id="tenant-1", user_id=user_id, invoke_from=InvokeFrom.DEBUGGER)
|
||||
return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime)
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ def test_builtin_tool_fork_and_provider_type():
|
||||
|
||||
|
||||
def test_invoke_model_calls_model_invocation_utils_invoke():
|
||||
tool = _build_tool()
|
||||
tool = _build_tool(user_id="runtime-user")
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke:
|
||||
assert (
|
||||
tool.invoke_model(
|
||||
@ -55,19 +55,47 @@ def test_invoke_model_calls_model_invocation_utils_invoke():
|
||||
)
|
||||
== "result"
|
||||
)
|
||||
mock_invoke.assert_called_once()
|
||||
mock_invoke.assert_called_once_with(
|
||||
user_id="u1",
|
||||
tenant_id="tenant-1",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
tool_name="tool-a",
|
||||
prompt_messages=[UserPromptMessage(content="hello")],
|
||||
caller_user_id="runtime-user",
|
||||
)
|
||||
|
||||
|
||||
def test_get_max_tokens_returns_value():
|
||||
tool = _build_tool()
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096):
|
||||
tool = _build_tool(user_id="runtime-user")
|
||||
with patch(
|
||||
"core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096
|
||||
) as mock_get:
|
||||
assert tool.get_max_tokens() == 4096
|
||||
mock_get.assert_called_once_with(tenant_id="tenant-1", user_id="runtime-user")
|
||||
|
||||
|
||||
def test_get_prompt_tokens_returns_value():
|
||||
tool = _build_tool()
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7):
|
||||
tool = _build_tool(user_id="runtime-user")
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate:
|
||||
assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7
|
||||
mock_calculate.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
prompt_messages=[UserPromptMessage(content="hello")],
|
||||
user_id="runtime-user",
|
||||
)
|
||||
|
||||
|
||||
def test_get_prompt_tokens_falls_back_to_tenant_scope_when_runtime_user_id_missing():
|
||||
tool = _build_tool()
|
||||
|
||||
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate:
|
||||
assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7
|
||||
|
||||
mock_calculate.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
prompt_messages=[UserPromptMessage(content="hello")],
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_none_raises():
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import calendar
|
||||
import math
|
||||
from datetime import date
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
@ -98,7 +100,13 @@ def test_timezone_conversion_tool():
|
||||
def test_weekday_tool():
|
||||
weekday_tool = _build_builtin_tool(WeekdayTool)
|
||||
valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text
|
||||
assert "January 1, 2024" in valid
|
||||
expected_date = date(2024, 1, 1)
|
||||
expected_message = (
|
||||
f"{calendar.month_name[expected_date.month]} "
|
||||
f"{expected_date.day}, {expected_date.year} "
|
||||
f"is {calendar.day_name[expected_date.weekday()]}."
|
||||
)
|
||||
assert valid == expected_message
|
||||
invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[
|
||||
0
|
||||
].message.text
|
||||
@ -186,13 +194,19 @@ def test_asr_invalid_file():
|
||||
|
||||
def test_asr_valid_file_invocation(monkeypatch):
|
||||
asr = _build_builtin_tool(ASRTool)
|
||||
model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})()
|
||||
model_instance = type("M", (), {"invoke_speech2text": lambda self, file: "transcript"})()
|
||||
model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})()
|
||||
monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes")
|
||||
monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager)
|
||||
captured_manager_kwargs = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers.audio.tools.asr.ModelManager.for_tenant",
|
||||
lambda **kwargs: captured_manager_kwargs.update(kwargs) or model_manager,
|
||||
)
|
||||
audio_file = SimpleNamespace(type=FileType.AUDIO)
|
||||
ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text
|
||||
assert ok == "transcript"
|
||||
assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"}
|
||||
|
||||
|
||||
def test_asr_available_models_and_runtime_parameters(monkeypatch):
|
||||
@ -208,6 +222,7 @@ def test_asr_available_models_and_runtime_parameters(monkeypatch):
|
||||
|
||||
def test_tts_invoke_returns_messages(monkeypatch):
|
||||
tts = _build_builtin_tool(TTSTool)
|
||||
captured_manager_kwargs = {}
|
||||
voices_model_instance = type(
|
||||
"TTSM",
|
||||
(),
|
||||
@ -217,11 +232,15 @@ def test_tts_invoke_returns_messages(monkeypatch):
|
||||
},
|
||||
)()
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager",
|
||||
lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(),
|
||||
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant",
|
||||
lambda **kwargs: (
|
||||
captured_manager_kwargs.update(kwargs)
|
||||
or type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})()
|
||||
),
|
||||
)
|
||||
messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"}))
|
||||
assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB]
|
||||
assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"}
|
||||
|
||||
|
||||
def test_tts_get_available_models_requires_runtime():
|
||||
@ -254,8 +273,8 @@ def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices):
|
||||
},
|
||||
)()
|
||||
monkeypatch.setattr(
|
||||
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager",
|
||||
lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(),
|
||||
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant",
|
||||
lambda **_: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(),
|
||||
)
|
||||
with pytest.raises(ValueError, match="no voice available"):
|
||||
list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"}))
|
||||
|
||||
@ -6,7 +6,13 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature
|
||||
from core.tools.signature import (
|
||||
get_signed_file_url_for_plugin,
|
||||
sign_tool_file,
|
||||
sign_upload_file,
|
||||
verify_plugin_file_signature,
|
||||
verify_tool_file_signature,
|
||||
)
|
||||
|
||||
|
||||
def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -117,3 +123,82 @@ def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatc
|
||||
assert query["timestamp"][0]
|
||||
assert query["nonce"][0]
|
||||
assert query["sign"][0]
|
||||
|
||||
|
||||
def test_get_signed_file_url_for_plugin_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x06" * 16)
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 60)
|
||||
|
||||
url = get_signed_file_url_for_plugin(
|
||||
filename="report.pdf",
|
||||
mimetype="application/pdf",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
query = parse_qs(parsed.query)
|
||||
|
||||
assert parsed.netloc == "internal.example.com"
|
||||
assert parsed.path == "/files/upload/for-plugin"
|
||||
assert query["tenant_id"] == ["tenant-id"]
|
||||
assert query["user_id"] == ["user-id"]
|
||||
assert (
|
||||
verify_plugin_file_signature(
|
||||
filename="report.pdf",
|
||||
mimetype="application/pdf",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
timestamp=query["timestamp"][0],
|
||||
nonce=query["nonce"][0],
|
||||
sign=query["sign"][0],
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_verify_plugin_file_signature_rejects_invalid_signatures(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x07" * 16)
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "")
|
||||
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 30)
|
||||
|
||||
url = get_signed_file_url_for_plugin(
|
||||
filename="report.pdf",
|
||||
mimetype="application/pdf",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
)
|
||||
query = parse_qs(urlparse(url).query)
|
||||
|
||||
assert (
|
||||
verify_plugin_file_signature(
|
||||
filename="report.pdf",
|
||||
mimetype="application/pdf",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
timestamp=query["timestamp"][0],
|
||||
nonce=query["nonce"][0],
|
||||
sign="bad-signature",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000100)
|
||||
assert (
|
||||
verify_plugin_file_signature(
|
||||
filename="report.pdf",
|
||||
mimetype="application/pdf",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
timestamp=query["timestamp"][0],
|
||||
nonce=query["nonce"][0],
|
||||
sign=query["sign"][0],
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
@ -14,6 +14,7 @@ import httpx
|
||||
import pytest
|
||||
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.file import FileTransferMethod
|
||||
|
||||
|
||||
def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]:
|
||||
@ -232,7 +233,14 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
|
||||
def test_get_file_generator_returns_stream_when_found() -> None:
|
||||
# Arrange
|
||||
manager = ToolFileManager()
|
||||
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
|
||||
tool_file = SimpleNamespace(
|
||||
id="tool123",
|
||||
file_key="k2",
|
||||
mimetype="image/png",
|
||||
original_url=None,
|
||||
name="image.png",
|
||||
size=12,
|
||||
)
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value.first.return_value = tool_file
|
||||
|
||||
@ -240,10 +248,10 @@ def test_get_file_generator_returns_stream_when_found() -> None:
|
||||
with patch("core.tools.tool_file_manager.storage") as storage:
|
||||
stream = iter([b"a", b"b"])
|
||||
storage.load_stream.return_value = stream
|
||||
with (
|
||||
_patch_session_factory(session),
|
||||
patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"),
|
||||
):
|
||||
with _patch_session_factory(session):
|
||||
result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123")
|
||||
assert list(result_stream) == [b"a", b"b"]
|
||||
assert result_file == "validated-file"
|
||||
assert result_file is not None
|
||||
assert result_file.related_id == "tool123"
|
||||
assert result_file.mime_type == "image/png"
|
||||
assert result_file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||
|
||||
@ -15,6 +15,7 @@ from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
@ -421,7 +422,7 @@ def test_get_agent_runtime_apply_runtime_parameters():
|
||||
tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
|
||||
tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter])
|
||||
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime):
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime) as mock_get_tool_runtime:
|
||||
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}):
|
||||
manager = Mock()
|
||||
manager.decrypt_tool_parameters.return_value = {"query": "decrypted"}
|
||||
@ -437,12 +438,23 @@ def test_get_agent_runtime_apply_runtime_parameters():
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
agent_tool=agent_tool,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
variable_pool=None,
|
||||
)
|
||||
|
||||
assert result is tool_runtime
|
||||
assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted"
|
||||
mock_get_tool_runtime.assert_called_once_with(
|
||||
provider_type=ToolProviderType.API,
|
||||
provider_id="api-1",
|
||||
tool_name="search",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credential_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_workflow_runtime_apply_runtime_parameters():
|
||||
@ -463,7 +475,7 @@ def test_get_workflow_runtime_apply_runtime_parameters():
|
||||
)
|
||||
tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
|
||||
tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter])
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2):
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2) as mock_get_tool_runtime:
|
||||
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}):
|
||||
manager = Mock()
|
||||
manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"}
|
||||
@ -473,12 +485,23 @@ def test_get_workflow_runtime_apply_runtime_parameters():
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
workflow_tool=workflow_tool,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
variable_pool=None,
|
||||
)
|
||||
|
||||
assert workflow_result is tool_runtime2
|
||||
assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec"
|
||||
mock_get_tool_runtime.assert_called_once_with(
|
||||
provider_type=ToolProviderType.API,
|
||||
provider_id="api-1",
|
||||
tool_name="search",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
credential_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agent_runtime_raises_when_runtime_missing():
|
||||
@ -520,17 +543,28 @@ def test_get_tool_runtime_from_plugin_only_uses_form_parameters():
|
||||
tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
|
||||
tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param])
|
||||
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity):
|
||||
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity) as mock_get_tool_runtime:
|
||||
result = ToolManager.get_tool_runtime_from_plugin(
|
||||
tool_type=ToolProviderType.API,
|
||||
tenant_id="tenant-1",
|
||||
provider="api-1",
|
||||
tool_name="search",
|
||||
tool_parameters={"q": "hello", "llm": "ignore"},
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result is tool_entity
|
||||
assert tool_entity.runtime.runtime_parameters == {"q": "hello"}
|
||||
mock_get_tool_runtime.assert_called_once_with(
|
||||
provider_type=ToolProviderType.API,
|
||||
provider_id="api-1",
|
||||
tool_name="search",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
||||
credential_id=None,
|
||||
)
|
||||
|
||||
|
||||
def test_hardcoded_provider_icon_success():
|
||||
|
||||
@ -84,3 +84,24 @@ def test_transform_tool_invoke_messages_mimetype_key_present_but_none():
|
||||
# meta is preserved (still contains mime_type: None)
|
||||
assert "mime_type" in (o.meta or {})
|
||||
assert o.meta["mime_type"] is None
|
||||
assert o.meta["tool_file_id"] == "fake-tool-file-id"
|
||||
|
||||
|
||||
def test_transform_tool_invoke_messages_parses_existing_tool_file_link_meta():
|
||||
msg = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text="/files/tools/existing-tool-file.png"),
|
||||
meta={},
|
||||
)
|
||||
|
||||
out = list(
|
||||
mt.ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=_gen([msg]),
|
||||
user_id="u1",
|
||||
tenant_id="t1",
|
||||
conversation_id="c1",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(out) == 1
|
||||
assert out[0].meta["tool_file_id"] == "existing-tool-file"
|
||||
|
||||
@ -60,20 +60,23 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat
|
||||
manager = Mock()
|
||||
manager.get_default_model_instance.return_value = model_instance
|
||||
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
|
||||
if error_match:
|
||||
with pytest.raises(InvokeModelError, match=error_match):
|
||||
ModelInvocationUtils.get_max_llm_context_tokens("tenant")
|
||||
ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1")
|
||||
else:
|
||||
assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected
|
||||
assert ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") == expected
|
||||
|
||||
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="user-1")
|
||||
|
||||
|
||||
def test_calculate_tokens_handles_missing_model():
|
||||
manager = Mock()
|
||||
manager.get_default_model_instance.return_value = None
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
|
||||
with pytest.raises(InvokeModelError, match="Model not found"):
|
||||
ModelInvocationUtils.calculate_tokens("tenant", [])
|
||||
mock_factory.assert_called_once_with(tenant_id="tenant", user_id=None)
|
||||
|
||||
|
||||
def test_invoke_success_and_error_mappings():
|
||||
@ -98,7 +101,7 @@ def test_invoke_success_and_error_mappings():
|
||||
|
||||
db_mock = SimpleNamespace(session=Mock())
|
||||
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
|
||||
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
|
||||
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
|
||||
response = ModelInvocationUtils.invoke(
|
||||
@ -107,11 +110,13 @@ def test_invoke_success_and_error_mappings():
|
||||
tool_type="builtin",
|
||||
tool_name="tool-a",
|
||||
prompt_messages=[],
|
||||
caller_user_id="caller-1",
|
||||
)
|
||||
|
||||
assert response.message.content == "ok"
|
||||
assert db_mock.session.add.call_count == 1
|
||||
assert db_mock.session.commit.call_count == 2
|
||||
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="caller-1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -145,7 +150,7 @@ def test_invoke_error_mappings(exc, expected):
|
||||
|
||||
db_mock = SimpleNamespace(session=Mock())
|
||||
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
|
||||
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
|
||||
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
|
||||
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
|
||||
with pytest.raises(InvokeModelError, match=expected):
|
||||
@ -156,3 +161,4 @@ def test_invoke_error_mappings(exc, expected):
|
||||
tool_name="tool-a",
|
||||
prompt_messages=[],
|
||||
)
|
||||
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="u1")
|
||||
|
||||
@ -24,7 +24,7 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType
|
||||
|
||||
|
||||
class StubScalars:
|
||||
@ -439,6 +439,32 @@ def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool:
|
||||
def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Transform args into parameters and files payloads."""
|
||||
tool = _setup_transform_args_tool(monkeypatch)
|
||||
build_file_from_stored_mapping = MagicMock(
|
||||
side_effect=[
|
||||
SimpleNamespace(
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
type=FileType.IMAGE,
|
||||
reference="tool-1",
|
||||
generate_url=lambda: None,
|
||||
),
|
||||
SimpleNamespace(
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
type=FileType.DOCUMENT,
|
||||
reference="upload-1",
|
||||
generate_url=lambda: None,
|
||||
),
|
||||
SimpleNamespace(
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
type=FileType.DOCUMENT,
|
||||
reference=None,
|
||||
generate_url=lambda: "https://example.com/a.pdf",
|
||||
),
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.workflow_as_tool.tool.build_file_from_stored_mapping",
|
||||
build_file_from_stored_mapping,
|
||||
)
|
||||
|
||||
params, files = tool._transform_args(
|
||||
{
|
||||
@ -470,6 +496,8 @@ def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch):
|
||||
assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files)
|
||||
assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files)
|
||||
assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files)
|
||||
assert build_file_from_stored_mapping.call_count == 3
|
||||
assert all(call.kwargs["tenant_id"] == "test_tool" for call in build_file_from_stored_mapping.call_args_list)
|
||||
|
||||
|
||||
def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
Reference in New Issue
Block a user