From bbfa28e8a756ee696d9e47937384964f67783fd8 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Mon, 9 Mar 2026 16:09:44 +0800 Subject: [PATCH] refactor: file saver decouple db engine and ssrf proxy (#33076) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/.importlinter | 3 --- api/controllers/files/tool_files.py | 3 +-- api/core/tools/tool_file_manager.py | 22 +++++------------- api/core/workflow/node_factory.py | 2 ++ .../knowledge_retrieval_node.py | 13 ----------- api/dify_graph/nodes/llm/file_saver.py | 23 ++++--------------- api/dify_graph/nodes/llm/node.py | 3 +++ .../question_classifier_node.py | 3 +++ .../workflow/nodes/test_llm.py | 2 ++ .../workflow/graph_engine/test_mock_nodes.py | 3 +++ .../test_knowledge_retrieval_node.py | 1 - .../workflow/nodes/llm/test_file_saver.py | 18 +++++++++------ .../core/workflow/nodes/llm/test_node.py | 4 ++++ 13 files changed, 40 insertions(+), 60 deletions(-) diff --git a/api/.importlinter b/api/.importlinter index e4536b1f10..57773f57d6 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -44,7 +44,6 @@ forbidden_modules = allow_indirect_imports = True ignore_imports = dify_graph.nodes.agent.agent_node -> extensions.ext_database - dify_graph.nodes.llm.file_saver -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database dify_graph.nodes.tool.tool_node -> extensions.ext_database dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis @@ -114,7 +113,6 @@ ignore_imports = dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer dify_graph.nodes.tool.tool_node -> models dify_graph.nodes.agent.agent_node -> models.model - dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy dify_graph.nodes.llm.node -> core.helper.code_executor dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output @@ -135,7 +133,6 @@ ignore_imports = dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager dify_graph.nodes.tool.tool_node -> core.tools.errors dify_graph.nodes.agent.agent_node -> extensions.ext_database - dify_graph.nodes.llm.file_saver -> extensions.ext_database dify_graph.nodes.llm.node -> extensions.ext_database dify_graph.nodes.tool.tool_node -> extensions.ext_database dify_graph.nodes.agent.agent_node -> models diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index f6032a8e49..9e3fb3a90b 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -10,7 +10,6 @@ from controllers.common.file_response import enforce_download_for_html from controllers.files import files_ns from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager -from extensions.ext_database import db as global_db DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -57,7 +56,7 @@ class ToolFileApi(Resource): raise Forbidden("Invalid request.") try: - tool_file_manager = ToolFileManager(engine=global_db.engine) + tool_file_manager = ToolFileManager() stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id( file_id, ) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 83e4e53418..d16b919561 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -10,28 +10,18 @@ from typing import Union from uuid import uuid4 import httpx -from sqlalchemy.orm import Session from configs import dify_config +from core.db.session_factory import session_factory from core.helper import ssrf_proxy -from extensions.ext_database import db as global_db from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile logger = logging.getLogger(__name__) -from sqlalchemy.engine import Engine - class ToolFileManager: - _engine: Engine - - def __init__(self, engine: Engine | None = None): - if engine is None: - engine = global_db.engine - self._engine = engine - @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -89,7 +79,7 @@ class ToolFileManager: filepath = f"tools/{tenant_id}/{unique_filename}" storage.save(filepath, file_binary) - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, @@ -132,7 +122,7 @@ class ToolFileManager: filename = f"{unique_name}{extension}" filepath = f"tools/{tenant_id}/{filename}" storage.save(filepath, blob) - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, @@ -157,7 +147,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file: ToolFile | None = ( session.query(ToolFile) .where( @@ -181,7 +171,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: message_file: MessageFile | None = ( session.query(MessageFile) .where( @@ -225,7 +215,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file: ToolFile | None = ( session.query(ToolFile) .where( diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 4cbee08a65..c1475f2f18 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -250,6 +250,7 @@ class DifyNodeFactory(NodeFactory): model_factory=self._llm_model_factory, model_instance=model_instance, memory=memory, + http_client=self._http_request_http_client, ) if node_type == NodeType.DATASOURCE: @@ -292,6 +293,7 @@ class DifyNodeFactory(NodeFactory): model_factory=self._llm_model_factory, model_instance=model_instance, memory=memory, + http_client=self._http_request_http_client, ) if node_type == NodeType.PARAMETER_EXTRACTOR: diff --git a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 803c9ccaf9..c67e14ce17 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -14,7 +14,6 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base import LLMUsageTrackingMixin from dify_graph.nodes.base.node import Node -from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source from dify_graph.variables import ( ArrayFileSegment, @@ -47,8 +46,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD # Output variable for file _file_outputs: list["File"] - _llm_file_saver: LLMFileSaver - def __init__( self, id: str, @@ -56,8 +53,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", rag_retrieval: RAGRetrievalProtocol, - *, - llm_file_saver: LLMFileSaver | None = None, ): super().__init__( id=id, @@ -69,14 +64,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD self._file_outputs = [] self._rag_retrieval = rag_retrieval - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - ) - self._llm_file_saver = llm_file_saver - @classmethod def version(cls): return "1" diff --git a/api/dify_graph/nodes/llm/file_saver.py b/api/dify_graph/nodes/llm/file_saver.py index b4f64f4093..50e52a3b6f 100644 --- a/api/dify_graph/nodes/llm/file_saver.py +++ b/api/dify_graph/nodes/llm/file_saver.py @@ -1,14 +1,11 @@ import mimetypes import typing as tp -from sqlalchemy import Engine - from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from core.helper import ssrf_proxy from core.tools.signature import sign_tool_file from core.tools.tool_file_manager import ToolFileManager from dify_graph.file import File, FileTransferMethod, FileType -from extensions.ext_database import db as global_db +from dify_graph.nodes.protocols import HttpClientProtocol class LLMFileSaver(tp.Protocol): @@ -59,30 +56,20 @@ class LLMFileSaver(tp.Protocol): raise NotImplementedError() -EngineFactory: tp.TypeAlias = tp.Callable[[], Engine] - - class FileSaverImpl(LLMFileSaver): - _engine_factory: EngineFactory _tenant_id: str _user_id: str - def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None): - if engine_factory is None: - - def _factory(): - return global_db.engine - - engine_factory = _factory - self._engine_factory = engine_factory + def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol): self._user_id = user_id self._tenant_id = tenant_id + self._http_client = http_client def _get_tool_file_manager(self): - return ToolFileManager(engine=self._engine_factory()) + return ToolFileManager() def save_remote_url(self, url: str, file_type: FileType) -> File: - http_response = ssrf_proxy.get(url) + http_response = self._http_client.get(url) http_response.raise_for_status() data = http_response.content mime_type_from_header = http_response.headers.get("Content-Type") diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index c7697a0972..5e59c96cd6 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -64,6 +64,7 @@ from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.runtime import VariablePool from dify_graph.variables import ( ArrayFileSegment, @@ -127,6 +128,7 @@ class LLMNode(Node[LLMNodeData]): credentials_provider: CredentialsProvider, model_factory: ModelFactory, model_instance: ModelInstance, + http_client: HttpClientProtocol, memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): @@ -149,6 +151,7 @@ class LLMNode(Node[LLMNodeData]): llm_file_saver = FileSaverImpl( user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, + http_client=http_client, ) self._llm_file_saver = llm_file_saver diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 97535d832d..443d216186 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -28,6 +28,7 @@ from dify_graph.nodes.llm import ( ) from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.protocols import HttpClientProtocol from libs.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData @@ -68,6 +69,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): credentials_provider: "CredentialsProvider", model_factory: "ModelFactory", model_instance: ModelInstance, + http_client: HttpClientProtocol, memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): @@ -90,6 +92,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): llm_file_saver = FileSaverImpl( user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, + http_client=http_client, ) self._llm_file_saver = llm_file_saver diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index b4779ebcdd..2aca9f5157 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -11,6 +11,7 @@ from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.node_events import StreamCompletedEvent from dify_graph.nodes.llm.node import LLMNode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from extensions.ext_database import db @@ -74,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode: credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), + http_client=MagicMock(spec=HttpClientProtocol), ) return node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 3f458e9de9..43fadadbc2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -22,6 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode from dify_graph.nodes.llm import LLMNode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.parameter_extractor import ParameterExtractorNode +from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.nodes.question_classifier import QuestionClassifierNode from dify_graph.nodes.template_transform import TemplateTransformNode from dify_graph.nodes.template_transform.template_renderer import ( @@ -65,6 +66,8 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + # LLM-like nodes now require an http_client; provide a mock by default for tests. + kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) # Ensure TemplateTransformNode receives a renderer now required by constructor if isinstance(self, TemplateTransformNode): diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 9dacc5a39b..e929d652fd 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -112,7 +112,6 @@ class TestKnowledgeRetrievalNode: # Assert assert node.id == node_id assert node._rag_retrieval == mock_rag_retrieval - assert node._llm_file_saver is not None def test_run_with_no_query_or_attachment( self, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index a3afd1ed5c..b0f0fd428b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -1,10 +1,10 @@ import uuid from typing import NamedTuple from unittest import mock +from unittest.mock import MagicMock import httpx import pytest -from sqlalchemy import Engine from core.helper import ssrf_proxy from core.tools import signature @@ -44,7 +44,6 @@ class TestFileSaverImpl: ) mock_tool_file.id = _gen_id() mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - mocked_engine = mock.MagicMock(spec=Engine) mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) @@ -53,11 +52,12 @@ class TestFileSaverImpl: # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) mocked_sign_file.return_value = mock_signed_url + http_client = MagicMock() storage_file_manager = FileSaverImpl( user_id=user_id, tenant_id=tenant_id, - engine_factory=mocked_engine, + http_client=http_client, ) file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) @@ -87,16 +87,18 @@ class TestFileSaverImpl: status_code=401, request=mock_request, ) + http_client = MagicMock() + http_client.get.return_value = mock_response + file_saver = FileSaverImpl( user_id=_gen_id(), tenant_id=_gen_id(), + http_client=http_client, ) - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) with pytest.raises(httpx.HTTPStatusError) as exc: file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - mock_get.assert_called_once_with(_TEST_URL) + http_client.get.assert_called_once_with(_TEST_URL) assert exc.value.response.status_code == 401 def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): @@ -112,8 +114,10 @@ class TestFileSaverImpl: headers={"Content-Type": mime_type}, request=mock_request, ) + http_client = MagicMock() + http_client.get.return_value = mock_response - file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id) + file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client) mock_tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 90308facc3..d56035b6bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -111,6 +111,7 @@ def llm_node( "id": "1", "data": llm_node_data.model_dump(), } + http_client = mock.MagicMock() node = LLMNode( id="1", config=node_config, @@ -120,6 +121,7 @@ def llm_node( model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + http_client=http_client, ) return node @@ -632,6 +634,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat "id": "1", "data": llm_node_data.model_dump(), } + http_client = mock.MagicMock() node = LLMNode( id="1", config=node_config, @@ -641,6 +644,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + http_client=http_client, ) return node, mock_file_saver