refactor: file saver decouple db engine and ssrf proxy (#33076)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei
2026-03-09 16:09:44 +08:00
committed by GitHub
parent 6c19e75969
commit bbfa28e8a7
13 changed files with 40 additions and 60 deletions

View File

@ -44,7 +44,6 @@ forbidden_modules =
allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.file_saver -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
@ -114,7 +113,6 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.tool.tool_node -> models
dify_graph.nodes.agent.agent_node -> models.model
dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
@ -135,7 +133,6 @@ ignore_imports =
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.file_saver -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.nodes.agent.agent_node -> models

View File

@ -10,7 +10,6 @@ from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db as global_db
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -57,7 +56,7 @@ class ToolFileApi(Resource):
raise Forbidden("Invalid request.")
try:
tool_file_manager = ToolFileManager(engine=global_db.engine)
tool_file_manager = ToolFileManager()
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
file_id,
)

View File

@ -10,28 +10,18 @@ from typing import Union
from uuid import uuid4
import httpx
from sqlalchemy.orm import Session
from configs import dify_config
from core.db.session_factory import session_factory
from core.helper import ssrf_proxy
from extensions.ext_database import db as global_db
from extensions.ext_storage import storage
from models.model import MessageFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
from sqlalchemy.engine import Engine
class ToolFileManager:
_engine: Engine
def __init__(self, engine: Engine | None = None):
if engine is None:
engine = global_db.engine
self._engine = engine
@staticmethod
def sign_file(tool_file_id: str, extension: str) -> str:
"""
@ -89,7 +79,7 @@ class ToolFileManager:
filepath = f"tools/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
@ -132,7 +122,7 @@ class ToolFileManager:
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
@ -157,7 +147,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
@ -181,7 +171,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
message_file: MessageFile | None = (
session.query(MessageFile)
.where(
@ -225,7 +215,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(

View File

@ -250,6 +250,7 @@ class DifyNodeFactory(NodeFactory):
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
http_client=self._http_request_http_client,
)
if node_type == NodeType.DATASOURCE:
@ -292,6 +293,7 @@ class DifyNodeFactory(NodeFactory):
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
http_client=self._http_request_http_client,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:

View File

@ -14,7 +14,6 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
from dify_graph.variables import (
ArrayFileSegment,
@ -47,8 +46,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
@ -56,8 +53,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
rag_retrieval: RAGRetrievalProtocol,
*,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
id=id,
@ -69,14 +64,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
self._file_outputs = []
self._rag_retrieval = rag_retrieval
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
)
self._llm_file_saver = llm_file_saver
@classmethod
def version(cls):
return "1"

View File

@ -1,14 +1,11 @@
import mimetypes
import typing as tp
from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.file import File, FileTransferMethod, FileType
from extensions.ext_database import db as global_db
from dify_graph.nodes.protocols import HttpClientProtocol
class LLMFileSaver(tp.Protocol):
@ -59,30 +56,20 @@ class LLMFileSaver(tp.Protocol):
raise NotImplementedError()
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None:
def _factory():
return global_db.engine
engine_factory = _factory
self._engine_factory = engine_factory
def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol):
self._user_id = user_id
self._tenant_id = tenant_id
self._http_client = http_client
def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory())
return ToolFileManager()
def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response = self._http_client.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")

View File

@ -64,6 +64,7 @@ from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables import (
ArrayFileSegment,
@ -127,6 +128,7 @@ class LLMNode(Node[LLMNodeData]):
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_instance: ModelInstance,
http_client: HttpClientProtocol,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@ -149,6 +151,7 @@ class LLMNode(Node[LLMNodeData]):
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver

View File

@ -28,6 +28,7 @@ from dify_graph.nodes.llm import (
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
@ -68,6 +69,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
http_client: HttpClientProtocol,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@ -90,6 +92,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver

View File

@ -11,6 +11,7 @@ from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.node_events import StreamCompletedEvent
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from extensions.ext_database import db
@ -74,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode:
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
http_client=MagicMock(spec=HttpClientProtocol),
)
return node

View File

@ -22,6 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from dify_graph.nodes.llm import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.nodes.question_classifier import QuestionClassifierNode
from dify_graph.nodes.template_transform import TemplateTransformNode
from dify_graph.nodes.template_transform.template_renderer import (
@ -65,6 +66,8 @@ class MockNodeMixin:
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
# LLM-like nodes now require an http_client; provide a mock by default for tests.
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
# Ensure TemplateTransformNode receives a renderer now required by constructor
if isinstance(self, TemplateTransformNode):

View File

@ -112,7 +112,6 @@ class TestKnowledgeRetrievalNode:
# Assert
assert node.id == node_id
assert node._rag_retrieval == mock_rag_retrieval
assert node._llm_file_saver is not None
def test_run_with_no_query_or_attachment(
self,

View File

@ -1,10 +1,10 @@
import uuid
from typing import NamedTuple
from unittest import mock
from unittest.mock import MagicMock
import httpx
import pytest
from sqlalchemy import Engine
from core.helper import ssrf_proxy
from core.tools import signature
@ -44,7 +44,6 @@ class TestFileSaverImpl:
)
mock_tool_file.id = _gen_id()
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
mocked_engine = mock.MagicMock(spec=Engine)
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
@ -53,11 +52,12 @@ class TestFileSaverImpl:
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
mocked_sign_file.return_value = mock_signed_url
http_client = MagicMock()
storage_file_manager = FileSaverImpl(
user_id=user_id,
tenant_id=tenant_id,
engine_factory=mocked_engine,
http_client=http_client,
)
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
@ -87,16 +87,18 @@ class TestFileSaverImpl:
status_code=401,
request=mock_request,
)
http_client = MagicMock()
http_client.get.return_value = mock_response
file_saver = FileSaverImpl(
user_id=_gen_id(),
tenant_id=_gen_id(),
http_client=http_client,
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
with pytest.raises(httpx.HTTPStatusError) as exc:
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_get.assert_called_once_with(_TEST_URL)
http_client.get.assert_called_once_with(_TEST_URL)
assert exc.value.response.status_code == 401
def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch):
@ -112,8 +114,10 @@ class TestFileSaverImpl:
headers={"Content-Type": mime_type},
request=mock_request,
)
http_client = MagicMock()
http_client.get.return_value = mock_response
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client)
mock_tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,

View File

@ -111,6 +111,7 @@ def llm_node(
"id": "1",
"data": llm_node_data.model_dump(),
}
http_client = mock.MagicMock()
node = LLMNode(
id="1",
config=node_config,
@ -120,6 +121,7 @@ def llm_node(
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
http_client=http_client,
)
return node
@ -632,6 +634,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
"id": "1",
"data": llm_node_data.model_dump(),
}
http_client = mock.MagicMock()
node = LLMNode(
id="1",
config=node_config,
@ -641,6 +644,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
http_client=http_client,
)
return node, mock_file_saver