refactor(api): continue decoupling dify_graph from API concerns (#33580)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
-LAN-
2026-03-25 20:32:24 +08:00
committed by GitHub
parent b7b9b003c9
commit 56593f20b0
487 changed files with 17999 additions and 9186 deletions

View File

@ -27,12 +27,12 @@ class _BuiltinDummyTool(BuiltinTool):
yield self.create_text_message("ok")
def _build_tool() -> _BuiltinDummyTool:
def _build_tool(user_id: str | None = None) -> _BuiltinDummyTool:
entity = ToolEntity(
identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"),
parameters=[],
)
runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER)
runtime = ToolRuntime(tenant_id="tenant-1", user_id=user_id, invoke_from=InvokeFrom.DEBUGGER)
return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime)
@ -45,7 +45,7 @@ def test_builtin_tool_fork_and_provider_type():
def test_invoke_model_calls_model_invocation_utils_invoke():
tool = _build_tool()
tool = _build_tool(user_id="runtime-user")
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke:
assert (
tool.invoke_model(
@ -55,19 +55,47 @@ def test_invoke_model_calls_model_invocation_utils_invoke():
)
== "result"
)
mock_invoke.assert_called_once()
mock_invoke.assert_called_once_with(
user_id="u1",
tenant_id="tenant-1",
tool_type=ToolProviderType.BUILT_IN,
tool_name="tool-a",
prompt_messages=[UserPromptMessage(content="hello")],
caller_user_id="runtime-user",
)
def test_get_max_tokens_returns_value():
tool = _build_tool()
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096):
tool = _build_tool(user_id="runtime-user")
with patch(
"core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096
) as mock_get:
assert tool.get_max_tokens() == 4096
mock_get.assert_called_once_with(tenant_id="tenant-1", user_id="runtime-user")
def test_get_prompt_tokens_returns_value():
tool = _build_tool()
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7):
tool = _build_tool(user_id="runtime-user")
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate:
assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7
mock_calculate.assert_called_once_with(
tenant_id="tenant-1",
prompt_messages=[UserPromptMessage(content="hello")],
user_id="runtime-user",
)
def test_get_prompt_tokens_falls_back_to_tenant_scope_when_runtime_user_id_missing():
tool = _build_tool()
with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate:
assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7
mock_calculate.assert_called_once_with(
tenant_id="tenant-1",
prompt_messages=[UserPromptMessage(content="hello")],
user_id=None,
)
def test_runtime_none_raises():

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import calendar
import math
from datetime import date
from types import SimpleNamespace
import pytest
@ -98,7 +100,13 @@ def test_timezone_conversion_tool():
def test_weekday_tool():
weekday_tool = _build_builtin_tool(WeekdayTool)
valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text
assert "January 1, 2024" in valid
expected_date = date(2024, 1, 1)
expected_message = (
f"{calendar.month_name[expected_date.month]} "
f"{expected_date.day}, {expected_date.year} "
f"is {calendar.day_name[expected_date.weekday()]}."
)
assert valid == expected_message
invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[
0
].message.text
@ -186,13 +194,19 @@ def test_asr_invalid_file():
def test_asr_valid_file_invocation(monkeypatch):
asr = _build_builtin_tool(ASRTool)
model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})()
model_instance = type("M", (), {"invoke_speech2text": lambda self, file: "transcript"})()
model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})()
monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes")
monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager)
captured_manager_kwargs = {}
monkeypatch.setattr(
"core.tools.builtin_tool.providers.audio.tools.asr.ModelManager.for_tenant",
lambda **kwargs: captured_manager_kwargs.update(kwargs) or model_manager,
)
audio_file = SimpleNamespace(type=FileType.AUDIO)
ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text
assert ok == "transcript"
assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"}
def test_asr_available_models_and_runtime_parameters(monkeypatch):
@ -208,6 +222,7 @@ def test_asr_available_models_and_runtime_parameters(monkeypatch):
def test_tts_invoke_returns_messages(monkeypatch):
tts = _build_builtin_tool(TTSTool)
captured_manager_kwargs = {}
voices_model_instance = type(
"TTSM",
(),
@ -217,11 +232,15 @@ def test_tts_invoke_returns_messages(monkeypatch):
},
)()
monkeypatch.setattr(
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager",
lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(),
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant",
lambda **kwargs: (
captured_manager_kwargs.update(kwargs)
or type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})()
),
)
messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"}))
assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB]
assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"}
def test_tts_get_available_models_requires_runtime():
@ -254,8 +273,8 @@ def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices):
},
)()
monkeypatch.setattr(
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager",
lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(),
"core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant",
lambda **_: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(),
)
with pytest.raises(ValueError, match="no voice available"):
list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"}))

View File

@ -6,7 +6,13 @@ from urllib.parse import parse_qs, urlparse
import pytest
from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature
from core.tools.signature import (
get_signed_file_url_for_plugin,
sign_tool_file,
sign_upload_file,
verify_plugin_file_signature,
verify_tool_file_signature,
)
def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None:
@ -117,3 +123,82 @@ def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatc
assert query["timestamp"][0]
assert query["nonce"][0]
assert query["sign"][0]
def test_get_signed_file_url_for_plugin_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x06" * 16)
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 60)
url = get_signed_file_url_for_plugin(
filename="report.pdf",
mimetype="application/pdf",
tenant_id="tenant-id",
user_id="user-id",
)
parsed = urlparse(url)
query = parse_qs(parsed.query)
assert parsed.netloc == "internal.example.com"
assert parsed.path == "/files/upload/for-plugin"
assert query["tenant_id"] == ["tenant-id"]
assert query["user_id"] == ["user-id"]
assert (
verify_plugin_file_signature(
filename="report.pdf",
mimetype="application/pdf",
tenant_id="tenant-id",
user_id="user-id",
timestamp=query["timestamp"][0],
nonce=query["nonce"][0],
sign=query["sign"][0],
)
is True
)
def test_verify_plugin_file_signature_rejects_invalid_signatures(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000)
monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x07" * 16)
monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com")
monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "")
monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 30)
url = get_signed_file_url_for_plugin(
filename="report.pdf",
mimetype="application/pdf",
tenant_id="tenant-id",
user_id="user-id",
)
query = parse_qs(urlparse(url).query)
assert (
verify_plugin_file_signature(
filename="report.pdf",
mimetype="application/pdf",
tenant_id="tenant-id",
user_id="user-id",
timestamp=query["timestamp"][0],
nonce=query["nonce"][0],
sign="bad-signature",
)
is False
)
monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000100)
assert (
verify_plugin_file_signature(
filename="report.pdf",
mimetype="application/pdf",
tenant_id="tenant-id",
user_id="user-id",
timestamp=query["timestamp"][0],
nonce=query["nonce"][0],
sign=query["sign"][0],
)
is False
)

View File

@ -14,6 +14,7 @@ import httpx
import pytest
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.file import FileTransferMethod
def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]:
@ -232,7 +233,14 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None:
def test_get_file_generator_returns_stream_when_found() -> None:
# Arrange
manager = ToolFileManager()
tool_file = SimpleNamespace(file_key="k2", mimetype="image/png")
tool_file = SimpleNamespace(
id="tool123",
file_key="k2",
mimetype="image/png",
original_url=None,
name="image.png",
size=12,
)
session = Mock()
session.query.return_value.where.return_value.first.return_value = tool_file
@ -240,10 +248,10 @@ def test_get_file_generator_returns_stream_when_found() -> None:
with patch("core.tools.tool_file_manager.storage") as storage:
stream = iter([b"a", b"b"])
storage.load_stream.return_value = stream
with (
_patch_session_factory(session),
patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"),
):
with _patch_session_factory(session):
result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123")
assert list(result_stream) == [b"a", b"b"]
assert result_file == "validated-file"
assert result_file is not None
assert result_file.related_id == "tool123"
assert result_file.mime_type == "image/png"
assert result_file.transfer_method == FileTransferMethod.TOOL_FILE

View File

@ -15,6 +15,7 @@ from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
@ -421,7 +422,7 @@ def test_get_agent_runtime_apply_runtime_parameters():
tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter])
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime):
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime) as mock_get_tool_runtime:
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}):
manager = Mock()
manager.decrypt_tool_parameters.return_value = {"query": "decrypted"}
@ -437,12 +438,23 @@ def test_get_agent_runtime_apply_runtime_parameters():
tenant_id="tenant-1",
app_id="app-1",
agent_tool=agent_tool,
user_id="user-1",
invoke_from=InvokeFrom.DEBUGGER,
variable_pool=None,
)
assert result is tool_runtime
assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted"
mock_get_tool_runtime.assert_called_once_with(
provider_type=ToolProviderType.API,
provider_id="api-1",
tool_name="search",
tenant_id="tenant-1",
user_id="user-1",
invoke_from=InvokeFrom.DEBUGGER,
tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=None,
)
def test_get_workflow_runtime_apply_runtime_parameters():
@ -463,7 +475,7 @@ def test_get_workflow_runtime_apply_runtime_parameters():
)
tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter])
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2):
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2) as mock_get_tool_runtime:
with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}):
manager = Mock()
manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"}
@ -473,12 +485,23 @@ def test_get_workflow_runtime_apply_runtime_parameters():
app_id="app-1",
node_id="node-1",
workflow_tool=workflow_tool,
user_id="user-1",
invoke_from=InvokeFrom.DEBUGGER,
variable_pool=None,
)
assert workflow_result is tool_runtime2
assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec"
mock_get_tool_runtime.assert_called_once_with(
provider_type=ToolProviderType.API,
provider_id="api-1",
tool_name="search",
tenant_id="tenant-1",
user_id="user-1",
invoke_from=InvokeFrom.DEBUGGER,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
credential_id=None,
)
def test_get_agent_runtime_raises_when_runtime_missing():
@ -520,17 +543,28 @@ def test_get_tool_runtime_from_plugin_only_uses_form_parameters():
tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={}))
tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param])
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity):
with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity) as mock_get_tool_runtime:
result = ToolManager.get_tool_runtime_from_plugin(
tool_type=ToolProviderType.API,
tenant_id="tenant-1",
provider="api-1",
tool_name="search",
tool_parameters={"q": "hello", "llm": "ignore"},
user_id="user-1",
)
assert result is tool_entity
assert tool_entity.runtime.runtime_parameters == {"q": "hello"}
mock_get_tool_runtime.assert_called_once_with(
provider_type=ToolProviderType.API,
provider_id="api-1",
tool_name="search",
tenant_id="tenant-1",
user_id="user-1",
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
credential_id=None,
)
def test_hardcoded_provider_icon_success():

View File

@ -84,3 +84,24 @@ def test_transform_tool_invoke_messages_mimetype_key_present_but_none():
# meta is preserved (still contains mime_type: None)
assert "mime_type" in (o.meta or {})
assert o.meta["mime_type"] is None
assert o.meta["tool_file_id"] == "fake-tool-file-id"
def test_transform_tool_invoke_messages_parses_existing_tool_file_link_meta():
msg = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text="/files/tools/existing-tool-file.png"),
meta={},
)
out = list(
mt.ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=_gen([msg]),
user_id="u1",
tenant_id="t1",
conversation_id="c1",
)
)
assert len(out) == 1
assert out[0].meta["tool_file_id"] == "existing-tool-file"

View File

@ -60,20 +60,23 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat
manager = Mock()
manager.get_default_model_instance.return_value = model_instance
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
if error_match:
with pytest.raises(InvokeModelError, match=error_match):
ModelInvocationUtils.get_max_llm_context_tokens("tenant")
ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1")
else:
assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected
assert ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") == expected
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="user-1")
def test_calculate_tokens_handles_missing_model():
manager = Mock()
manager.get_default_model_instance.return_value = None
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
with pytest.raises(InvokeModelError, match="Model not found"):
ModelInvocationUtils.calculate_tokens("tenant", [])
mock_factory.assert_called_once_with(tenant_id="tenant", user_id=None)
def test_invoke_success_and_error_mappings():
@ -98,7 +101,7 @@ def test_invoke_success_and_error_mappings():
db_mock = SimpleNamespace(session=Mock())
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
response = ModelInvocationUtils.invoke(
@ -107,11 +110,13 @@ def test_invoke_success_and_error_mappings():
tool_type="builtin",
tool_name="tool-a",
prompt_messages=[],
caller_user_id="caller-1",
)
assert response.message.content == "ok"
assert db_mock.session.add.call_count == 1
assert db_mock.session.commit.call_count == 2
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="caller-1")
@pytest.mark.parametrize(
@ -145,7 +150,7 @@ def test_invoke_error_mappings(exc, expected):
db_mock = SimpleNamespace(session=Mock())
with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager):
with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory:
with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke):
with patch("core.tools.utils.model_invocation_utils.db", db_mock):
with pytest.raises(InvokeModelError, match=expected):
@ -156,3 +161,4 @@ def test_invoke_error_mappings(exc, expected):
tool_name="tool-a",
prompt_messages=[],
)
mock_factory.assert_called_once_with(tenant_id="tenant", user_id="u1")

View File

@ -24,7 +24,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.errors import ToolInvokeError
from core.tools.workflow_as_tool.tool import WorkflowTool
from dify_graph.file import FILE_MODEL_IDENTITY
from dify_graph.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType
class StubScalars:
@ -439,6 +439,32 @@ def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool:
def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch):
"""Transform args into parameters and files payloads."""
tool = _setup_transform_args_tool(monkeypatch)
build_file_from_stored_mapping = MagicMock(
side_effect=[
SimpleNamespace(
transfer_method=FileTransferMethod.TOOL_FILE,
type=FileType.IMAGE,
reference="tool-1",
generate_url=lambda: None,
),
SimpleNamespace(
transfer_method=FileTransferMethod.LOCAL_FILE,
type=FileType.DOCUMENT,
reference="upload-1",
generate_url=lambda: None,
),
SimpleNamespace(
transfer_method=FileTransferMethod.REMOTE_URL,
type=FileType.DOCUMENT,
reference=None,
generate_url=lambda: "https://example.com/a.pdf",
),
]
)
monkeypatch.setattr(
"core.tools.workflow_as_tool.tool.build_file_from_stored_mapping",
build_file_from_stored_mapping,
)
params, files = tool._transform_args(
{
@ -470,6 +496,8 @@ def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch):
assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files)
assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files)
assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files)
assert build_file_from_stored_mapping.call_count == 3
assert all(call.kwargs["tenant_id"] == "test_tool" for call in build_file_from_stored_mapping.call_args_list)
def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch):