mirror of
https://github.com/langgenius/dify.git
synced 2026-03-19 05:37:42 +08:00
refactor: file saver decouple db engine and ssrf proxy (#33076)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -44,7 +44,6 @@ forbidden_modules =
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.file_saver -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
@ -114,7 +113,6 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.tool.tool_node -> models
|
||||
dify_graph.nodes.agent.agent_node -> models.model
|
||||
dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
@ -135,7 +133,6 @@ ignore_imports =
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.file_saver -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
|
||||
@ -10,7 +10,6 @@ from controllers.common.file_response import enforce_download_for_html
|
||||
from controllers.files import files_ns
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@ -57,7 +56,7 @@ class ToolFileApi(Resource):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
tool_file_manager = ToolFileManager(engine=global_db.engine)
|
||||
tool_file_manager = ToolFileManager()
|
||||
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
|
||||
file_id,
|
||||
)
|
||||
|
||||
@ -10,28 +10,18 @@ from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db as global_db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
class ToolFileManager:
|
||||
_engine: Engine
|
||||
|
||||
def __init__(self, engine: Engine | None = None):
|
||||
if engine is None:
|
||||
engine = global_db.engine
|
||||
self._engine = engine
|
||||
|
||||
@staticmethod
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
@ -89,7 +79,7 @@ class ToolFileManager:
|
||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@ -132,7 +122,7 @@ class ToolFileManager:
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@ -157,7 +147,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
@ -181,7 +171,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.where(
|
||||
@ -225,7 +215,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
|
||||
@ -250,6 +250,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
@ -292,6 +293,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
|
||||
@ -14,7 +14,6 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
@ -47,8 +46,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
@ -56,8 +53,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -69,14 +64,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
self._file_outputs = []
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
@ -1,14 +1,11 @@
|
||||
import mimetypes
|
||||
import typing as tp
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db as global_db
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
@ -59,30 +56,20 @@ class LLMFileSaver(tp.Protocol):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
||||
|
||||
class FileSaverImpl(LLMFileSaver):
|
||||
_engine_factory: EngineFactory
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
|
||||
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||
if engine_factory is None:
|
||||
|
||||
def _factory():
|
||||
return global_db.engine
|
||||
|
||||
engine_factory = _factory
|
||||
self._engine_factory = engine_factory
|
||||
def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol):
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
self._http_client = http_client
|
||||
|
||||
def _get_tool_file_manager(self):
|
||||
return ToolFileManager(engine=self._engine_factory())
|
||||
return ToolFileManager()
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
http_response = ssrf_proxy.get(url)
|
||||
http_response = self._http_client.get(url)
|
||||
http_response.raise_for_status()
|
||||
data = http_response.content
|
||||
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||
|
||||
@ -64,6 +64,7 @@ from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
@ -127,6 +128,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@ -149,6 +151,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ from dify_graph.nodes.llm import (
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
@ -68,6 +69,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@ -90,6 +92,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
@ -74,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
@ -22,6 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from dify_graph.nodes.llm import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNode
|
||||
from dify_graph.nodes.template_transform import TemplateTransformNode
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
@ -65,6 +66,8 @@ class MockNodeMixin:
|
||||
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
|
||||
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
|
||||
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
|
||||
# LLM-like nodes now require an http_client; provide a mock by default for tests.
|
||||
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
|
||||
|
||||
# Ensure TemplateTransformNode receives a renderer now required by constructor
|
||||
if isinstance(self, TemplateTransformNode):
|
||||
|
||||
@ -112,7 +112,6 @@ class TestKnowledgeRetrievalNode:
|
||||
# Assert
|
||||
assert node.id == node_id
|
||||
assert node._rag_retrieval == mock_rag_retrieval
|
||||
assert node._llm_file_saver is not None
|
||||
|
||||
def test_run_with_no_query_or_attachment(
|
||||
self,
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import uuid
|
||||
from typing import NamedTuple
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools import signature
|
||||
@ -44,7 +44,6 @@ class TestFileSaverImpl:
|
||||
)
|
||||
mock_tool_file.id = _gen_id()
|
||||
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||
mocked_engine = mock.MagicMock(spec=Engine)
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||
@ -53,11 +52,12 @@ class TestFileSaverImpl:
|
||||
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
|
||||
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
|
||||
mocked_sign_file.return_value = mock_signed_url
|
||||
http_client = MagicMock()
|
||||
|
||||
storage_file_manager = FileSaverImpl(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
engine_factory=mocked_engine,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||
@ -87,16 +87,18 @@ class TestFileSaverImpl:
|
||||
status_code=401,
|
||||
request=mock_request,
|
||||
)
|
||||
http_client = MagicMock()
|
||||
http_client.get.return_value = mock_response
|
||||
|
||||
file_saver = FileSaverImpl(
|
||||
user_id=_gen_id(),
|
||||
tenant_id=_gen_id(),
|
||||
http_client=http_client,
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc:
|
||||
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
http_client.get.assert_called_once_with(_TEST_URL)
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -112,8 +114,10 @@ class TestFileSaverImpl:
|
||||
headers={"Content-Type": mime_type},
|
||||
request=mock_request,
|
||||
)
|
||||
http_client = MagicMock()
|
||||
http_client.get.return_value = mock_response
|
||||
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client)
|
||||
mock_tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -111,6 +111,7 @@ def llm_node(
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
}
|
||||
http_client = mock.MagicMock()
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
@ -120,6 +121,7 @@ def llm_node(
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node
|
||||
|
||||
@ -632,6 +634,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
}
|
||||
http_client = mock.MagicMock()
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
@ -641,6 +644,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
|
||||
Reference in New Issue
Block a user