From 85e144cf648f7523030d50efdcb46dbbcd2e67bc Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 19 May 2026 19:41:46 +0800 Subject: [PATCH] fix(api): centralize remote file retrieval Introduce a unified remote file fetcher that resolves first-party signed file URLs through database records and storage before falling back to the SSRF-protected HTTP client. Route backend remote-file call sites through the new boundary, remove obsolete file signature verification helpers, and document when to use remote_fetcher versus ssrf_proxy. --- api/controllers/console/remote_files.py | 12 +- api/controllers/web/remote_files.py | 12 +- api/core/app/workflow/file_runtime.py | 4 +- .../datasource/datasource_file_manager.py | 24 +- api/core/file/__init__.py | 5 + api/core/file/remote_fetcher.py | 363 ++++++++++++++++++ api/core/helper/download.py | 4 +- api/core/helper/ssrf_proxy.py | 12 +- api/core/rag/extractor/extract_processor.py | 4 +- api/core/rag/extractor/word_extractor.py | 8 +- .../index_processor/index_processor_base.py | 4 +- api/core/tools/tool_file_manager.py | 24 +- api/core/tools/utils/web_reader_tool.py | 8 +- api/core/workflow/node_factory.py | 8 +- api/factories/file_factory/remote.py | 4 +- api/services/app_dsl_service.py | 4 +- .../rag_pipeline/rag_pipeline_dsl_service.py | 4 +- .../services/test_app_dsl_service.py | 10 +- .../controllers/console/test_remote_files.py | 50 ++- .../core/app/workflow/test_file_runtime.py | 6 +- .../test_datasource_file_manager.py | 39 +- .../core/file/test_remote_fetcher.py | 258 +++++++++++++ .../unit_tests/core/helper/test_download.py | 8 +- .../rag/extractor/test_extract_processor.py | 2 +- .../rag/indexing/test_index_processor_base.py | 14 +- .../core/tools/test_tool_file_manager.py | 41 +- .../core/tools/utils/test_web_reader_tool.py | 24 +- .../core/workflow/test_node_factory.py | 7 +- .../core/workflow/test_workflow_entry.py | 4 +- .../factories/test_build_from_mapping.py | 2 +- .../unit_tests/factories/test_file_factory.py | 2 +- .../test_rag_pipeline_dsl_service.py | 9 +- 32 files changed, 768 insertions(+), 212 deletions(-) create mode 100644 api/core/file/__init__.py create mode 100644 api/core/file/remote_fetcher.py create mode 100644 api/tests/unit_tests/core/file/test_remote_fetcher.py diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 9f7fe6379c..e4c8643a6a 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -13,7 +13,7 @@ from controllers.common.errors import ( from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import with_current_user -from core.helper import ssrf_proxy +from core.file import remote_fetcher from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from graphon.file import helpers as file_helpers @@ -36,9 +36,9 @@ class GetRemoteFileInfo(Resource): @login_required def get(self, url: str): decoded_url = helpers.decode_remote_url(url, request.query_string) - resp = ssrf_proxy.head(decoded_url) + resp = remote_fetcher.head(decoded_url) if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(decoded_url, timeout=3) + resp = remote_fetcher.get(decoded_url, timeout=3) resp.raise_for_status() return RemoteFileInfo( file_type=resp.headers.get("Content-Type", "application/octet-stream"), @@ -58,9 +58,9 @@ class RemoteFileUpload(Resource): # Try to fetch remote file metadata/content first try: - resp = ssrf_proxy.head(url=url) + resp = remote_fetcher.head(url=url) if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + resp = remote_fetcher.get(url=url, timeout=3, follow_redirects=True) if resp.status_code != httpx.codes.OK: # Normalize into a user-friendly error message expected by tests raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") @@ -74,7 +74,7 @@ class RemoteFileUpload(Resource): raise FileTooLargeError() # Load content if needed - content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content + content = resp.content if resp.request.method == "GET" else remote_fetcher.get(url).content try: upload_file = FileService(db.engine).upload_file( diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index c18c05d3e9..c0e48bc138 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -9,7 +9,7 @@ from controllers.common.errors import ( RemoteFileUploadError, UnsupportedFileTypeError, ) -from core.helper import ssrf_proxy +from core.file import remote_fetcher from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from graphon.file import helpers as file_helpers @@ -60,10 +60,10 @@ class RemoteFileInfoApi(WebApiResource): HTTPException: If the remote file cannot be accessed """ decoded_url = helpers.decode_remote_url(url, request.query_string) - resp = ssrf_proxy.head(decoded_url) + resp = remote_fetcher.head(decoded_url) if resp.status_code != httpx.codes.OK: # failed back to get method - resp = ssrf_proxy.get(decoded_url, timeout=3) + resp = remote_fetcher.get(decoded_url, timeout=3) resp.raise_for_status() info = RemoteFileInfo( file_type=resp.headers.get("Content-Type", "application/octet-stream"), @@ -112,9 +112,9 @@ class RemoteFileUploadApi(WebApiResource): url = str(payload.url) try: - resp = ssrf_proxy.head(url=url) + resp = remote_fetcher.head(url=url) if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + resp = remote_fetcher.get(url=url, timeout=3, follow_redirects=True) if resp.status_code != httpx.codes.OK: raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") except httpx.RequestError as e: @@ -125,7 +125,7 @@ class RemoteFileUploadApi(WebApiResource): if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): raise FileTooLargeError - content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content + content = resp.content if resp.request.method == "GET" else remote_fetcher.get(url).content try: upload_file = FileService(db.engine).upload_file( diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 90fdf41022..dec92c03b6 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Literal, override from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol from core.db.session_factory import session_factory -from core.helper.ssrf_proxy import graphon_ssrf_proxy +from core.file import remote_fetcher from core.tools.signature import sign_tool_file from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage @@ -46,7 +46,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): @override def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects) + return remote_fetcher.graphon_remote_file_fetcher.get(url, follow_redirects=follow_redirects) @override def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 79b84a28be..8efa24c1ea 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -12,7 +12,7 @@ from uuid import uuid4 import httpx from configs import dify_config -from core.helper import ssrf_proxy +from core.file import remote_fetcher from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType @@ -44,26 +44,6 @@ class DatasourceFileManager: return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" - @staticmethod - def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - """ - verify signature - """ - data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" - recalculated_sign = hmac.new( - dify_config.SECRET_KEY.encode(), - data_to_sign.encode(), - hashlib.sha256, - ).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - # verify signature - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT - @staticmethod def create_file_by_raw( *, @@ -117,7 +97,7 @@ class DatasourceFileManager: ) -> ToolFile: # try to download image try: - response = ssrf_proxy.get(file_url) + response = remote_fetcher.get(file_url) response.raise_for_status() blob = response.content except httpx.TimeoutException: diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py new file mode 100644 index 0000000000..4f48fffbf5 --- /dev/null +++ b/api/core/file/__init__.py @@ -0,0 +1,5 @@ +"""File retrieval helpers shared by backend file-oriented workflows.""" + +from . import remote_fetcher + +__all__ = ["remote_fetcher"] diff --git a/api/core/file/remote_fetcher.py b/api/core/file/remote_fetcher.py new file mode 100644 index 0000000000..b5c067c5ba --- /dev/null +++ b/api/core/file/remote_fetcher.py @@ -0,0 +1,363 @@ +"""Unified remote-file retrieval with Dify signed file URL resolution. + +Use this module for backend workflows whose intent is to fetch remote file content +or remote file metadata from a URL, even when the URL originally came from a user +upload, a workflow variable, a tool/datasource file, or an app DSL. GET/HEAD +requests can resolve Dify-signed file URLs locally through DB + storage before +falling back to the SSRF-protected network client. + +Use `core.helper.ssrf_proxy` directly only for generic outbound HTTP where the +URL is not being treated as a remote file, such as HTTP Request nodes, external +API integrations, auth discovery, or user-configured tool calls. Those calls must +stay as real network requests and should not reinterpret Dify file URLs as stored +files. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import re +import time +import urllib.parse +from dataclasses import dataclass +from typing import Any, Literal + +import httpx + +from configs import dify_config +from core.app.file_access import DatabaseFileAccessController +from core.db.session_factory import session_factory +from core.helper import ssrf_proxy +from core.helper.ssrf_proxy import ( + SSRF_DEFAULT_MAX_RETRIES, + _to_graphon_http_response, + max_retries_exceeded_error, + request_error, +) +from extensions.ext_storage import storage +from models import ToolFile, UploadFile + +_UPLOAD_FILE_PATH_PATTERN = re.compile( + r"^/files/(?P[a-fA-F0-9-]+)/(?Pfile-preview|image-preview)$" +) +_TOOL_FILE_PATH_PATTERN = re.compile(r"^/files/tools/(?P[a-fA-F0-9-]+)\.(?P[^/]+)$") +_DATASOURCE_FILE_PATH_PATTERN = re.compile( + r"^/files/datasources/(?P[a-fA-F0-9-]+)\.(?P[^/]+)$" +) + +_file_access_controller = DatabaseFileAccessController() + + +@dataclass(frozen=True) +class _SignedFileUrl: + file_id: str + preview_kind: Literal["file-preview", "image-preview"] + record_kind: Literal["upload", "tool", "datasource"] + + +def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + """Fetch remote file content or metadata. + + GET and HEAD requests for Dify-owned signed file URLs are served from local + storage. Every other request is delegated unchanged to the SSRF proxy. + """ + + normalized_method = method.upper() + if normalized_method in {"GET", "HEAD"}: + response = _resolve_dify_signed_file_url(normalized_method, url) + if response is not None: + return response + return ssrf_proxy.make_request(method=method, url=url, max_retries=max_retries, **kwargs) + + +def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + """Fetch remote file content, resolving Dify-owned signed file URLs locally.""" + + response = _resolve_dify_signed_file_url("GET", url) + if response is not None: + return response + return ssrf_proxy.get(url=url, max_retries=max_retries, **kwargs) + + +def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + """Fetch remote file metadata, resolving Dify-owned signed file URLs locally.""" + + response = _resolve_dify_signed_file_url("HEAD", url) + if response is not None: + return response + return ssrf_proxy.head(url=url, max_retries=max_retries, **kwargs) + + +def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return ssrf_proxy.post(url=url, max_retries=max_retries, **kwargs) + + +def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return ssrf_proxy.put(url=url, max_retries=max_retries, **kwargs) + + +def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return ssrf_proxy.delete(url=url, max_retries=max_retries, **kwargs) + + +def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return ssrf_proxy.patch(url=url, max_retries=max_retries, **kwargs) + + +class GraphonRemoteFileFetcher: + """Graphon HTTP-client adapter backed by the unified remote-file fetcher.""" + + @property + def max_retries_exceeded_error(self) -> type[Exception]: + return max_retries_exceeded_error + + @property + def request_error(self) -> type[Exception]: + return request_error + + def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any): + return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs)) + + def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any): + return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs)) + + def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any): + return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs)) + + def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any): + return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs)) + + def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any): + return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs)) + + def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any): + return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs)) + + +def _resolve_dify_signed_file_url(method: Literal["GET", "HEAD"], url: str) -> httpx.Response | None: + parsed_url = urllib.parse.urlparse(url) + if not _is_dify_file_origin(parsed_url): + return None + + signed_file_url = _parse_signed_file_path(parsed_url.path) + if signed_file_url is None: + return None + + query = urllib.parse.parse_qs(parsed_url.query, keep_blank_values=True) + timestamp = _single_query_value(query, "timestamp") + nonce = _single_query_value(query, "nonce") + sign = _single_query_value(query, "sign") + if timestamp is None or nonce is None or sign is None: + return None + + if not _verify_signed_file_url( + signed_file_url=signed_file_url, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ): + return None + + if signed_file_url.record_kind == "upload": + return _build_upload_file_response(method=method, url=url, file_id=signed_file_url.file_id) + if signed_file_url.record_kind == "tool": + return _build_tool_file_response(method=method, url=url, file_id=signed_file_url.file_id) + return _build_datasource_file_response(method=method, url=url, file_id=signed_file_url.file_id) + + +def _parse_signed_file_path(path: str) -> _SignedFileUrl | None: + upload_match = _UPLOAD_FILE_PATH_PATTERN.match(path) + if upload_match: + return _SignedFileUrl( + file_id=upload_match.group("file_id"), + preview_kind=upload_match.group("preview_kind"), + record_kind="upload", + ) + + tool_match = _TOOL_FILE_PATH_PATTERN.match(path) + if tool_match: + return _SignedFileUrl( + file_id=tool_match.group("file_id"), + preview_kind="file-preview", + record_kind="tool", + ) + + datasource_match = _DATASOURCE_FILE_PATH_PATTERN.match(path) + if datasource_match: + return _SignedFileUrl( + file_id=datasource_match.group("file_id"), + preview_kind="file-preview", + record_kind="datasource", + ) + + return None + + +def _is_dify_file_origin(parsed_url: urllib.parse.ParseResult) -> bool: + if parsed_url.scheme not in {"http", "https"} or not parsed_url.hostname: + return False + + url_origin = _origin_parts(parsed_url) + if url_origin is None: + return False + + allowed_origins = { + origin + for configured_url in [dify_config.FILES_URL, dify_config.INTERNAL_FILES_URL] + if configured_url and (origin := _origin_parts(urllib.parse.urlparse(configured_url))) is not None + } + return url_origin in allowed_origins + + +def _origin_parts(parsed_url: urllib.parse.ParseResult) -> tuple[str, str, int] | None: + if parsed_url.scheme not in {"http", "https"} or not parsed_url.hostname: + return None + return parsed_url.scheme, parsed_url.hostname.lower(), parsed_url.port or _default_port(parsed_url.scheme) + + +def _default_port(scheme: str) -> int: + return 443 if scheme == "https" else 80 + + +def _single_query_value(query: dict[str, list[str]], key: str) -> str | None: + values = query.get(key) + if not values or len(values) != 1: + return None + return values[0] + + +def _verify_signed_file_url( + *, + signed_file_url: _SignedFileUrl, + timestamp: str, + nonce: str, + sign: str, +) -> bool: + try: + current_time = int(time.time()) + signed_at = int(timestamp) + except ValueError: + return False + + if current_time - signed_at > dify_config.FILES_ACCESS_TIMEOUT: + return False + + payload = f"{signed_file_url.preview_kind}|{signed_file_url.file_id}|{timestamp}|{nonce}" + recalculated = hmac.new(dify_config.SECRET_KEY.encode(), payload.encode(), hashlib.sha256).digest() + expected = base64.urlsafe_b64encode(recalculated).decode() + return hmac.compare_digest(sign, expected) + + +def _build_upload_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response: + with session_factory.create_session() as session: + upload_file = _file_access_controller.get_upload_file(session=session, file_id=file_id) + if upload_file is None: + return _build_response(method=method, url=url, status_code=404) + + content = b"" if method == "HEAD" else storage.load_once(upload_file.key) + return _build_response( + method=method, + url=url, + status_code=200, + content=content, + content_length=upload_file.size, + content_type=upload_file.mime_type, + filename=upload_file.name, + ) + + +def _build_tool_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response: + with session_factory.create_session() as session: + tool_file = _file_access_controller.get_tool_file(session=session, file_id=file_id) + if tool_file is None: + return _build_response(method=method, url=url, status_code=404) + + content = b"" if method == "HEAD" else storage.load_once(tool_file.file_key) + return _build_response( + method=method, + url=url, + status_code=200, + content=content, + content_length=tool_file.size, + content_type=tool_file.mimetype, + filename=tool_file.name, + ) + + +def _build_datasource_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response: + with session_factory.create_session() as session: + upload_file = _file_access_controller.get_upload_file(session=session, file_id=file_id) + if upload_file is not None: + return _build_upload_file_record_response(method=method, url=url, upload_file=upload_file) + + tool_file = _file_access_controller.get_tool_file(session=session, file_id=file_id) + if tool_file is not None: + return _build_tool_file_record_response(method=method, url=url, tool_file=tool_file) + + return _build_response(method=method, url=url, status_code=404) + + +def _build_upload_file_record_response( + *, + method: Literal["GET", "HEAD"], + url: str, + upload_file: UploadFile, +) -> httpx.Response: + content = b"" if method == "HEAD" else storage.load_once(upload_file.key) + return _build_response( + method=method, + url=url, + status_code=200, + content=content, + content_length=upload_file.size, + content_type=upload_file.mime_type, + filename=upload_file.name, + ) + + +def _build_tool_file_record_response( + *, + method: Literal["GET", "HEAD"], + url: str, + tool_file: ToolFile, +) -> httpx.Response: + content = b"" if method == "HEAD" else storage.load_once(tool_file.file_key) + return _build_response( + method=method, + url=url, + status_code=200, + content=content, + content_length=tool_file.size, + content_type=tool_file.mimetype, + filename=tool_file.name, + ) + + +def _build_response( + *, + method: Literal["GET", "HEAD"], + url: str, + status_code: int, + content: bytes = b"", + content_length: int | None = None, + content_type: str | None = None, + filename: str | None = None, +) -> httpx.Response: + headers: dict[str, str] = {} + if content_type: + headers["Content-Type"] = content_type + if content_length is not None and content_length >= 0: + headers["Content-Length"] = str(content_length) + if filename: + headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{urllib.parse.quote(filename)}" + return httpx.Response( + status_code=status_code, + headers=headers, + content=content, + request=httpx.Request(method, url), + ) + + +graphon_remote_file_fetcher = GraphonRemoteFileFetcher() diff --git a/api/core/helper/download.py b/api/core/helper/download.py index 96400e8ba5..2919c9eeae 100644 --- a/api/core/helper/download.py +++ b/api/core/helper/download.py @@ -1,8 +1,8 @@ -from core.helper import ssrf_proxy +from core.file import remote_fetcher def download_with_size_limit(url, max_download_size: int, **kwargs): - response = ssrf_proxy.get(url, follow_redirects=True, **kwargs) + response = remote_fetcher.get(url, follow_redirects=True, **kwargs) if response.status_code == 404: raise ValueError("file not found") diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index b2493934bf..f2a7cac1bd 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -1,5 +1,13 @@ -""" -Proxy requests to avoid SSRF +"""SSRF-protected HTTP client for generic outbound requests. + +Use this module when the URL represents a normal external HTTP interaction that +must go through network/proxy policy exactly as requested, such as HTTP Request +nodes, provider/API integrations, auth discovery, or custom tool calls. + +Do not use this directly for "remote file" retrieval. File downloads, probes, +and metadata checks should use `core.file.remote_fetcher` instead so Dify-signed +file URLs can be resolved through DB + storage before falling back to this SSRF +client. """ import logging diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index e49e814149..21d9d7f2e3 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -5,7 +5,7 @@ from typing import Union from urllib.parse import unquote from configs import dify_config -from core.helper import ssrf_proxy +from core.file import remote_fetcher from core.rag.extractor.csv_extractor import CSVExtractor from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -55,7 +55,7 @@ class ExtractProcessor: @classmethod def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: - response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT}) + response = remote_fetcher.get(url, headers={"User-Agent": USER_AGENT}) with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(url).suffix diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 60f8906181..beb628b90f 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,6 +1,6 @@ """Word (.docx) document extractor used for RAG ingestion. -Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). +Supports local file paths and remote URLs downloaded through the unified remote-file fetcher. """ import inspect @@ -17,7 +17,7 @@ from docx.oxml.ns import qn from docx.text.run import Run from configs import dify_config -from core.helper import ssrf_proxy +from core.file import remote_fetcher from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db @@ -51,7 +51,7 @@ class WordExtractor(BaseExtractor): # If the file is a web path, download it to a temporary file, and use that if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): - response = ssrf_proxy.get(self.file_path) + response = remote_fetcher.get(self.file_path) if response.status_code != 200: response.close() @@ -120,7 +120,7 @@ class WordExtractor(BaseExtractor): if not self._is_valid_url(url): continue try: - response = ssrf_proxy.get(url) + response = remote_fetcher.get(url) except Exception as e: logger.warning("Failed to download image from URL: %s: %s", url, str(e)) continue diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index a3b6e0dbd2..f18f0d609b 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -15,7 +15,7 @@ from sqlalchemy import select from configs import dify_config from core.entities.knowledge_entities import PreviewDetail -from core.helper import ssrf_proxy +from core.file import remote_fetcher from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.doc_type import DocType @@ -243,7 +243,7 @@ class BaseIndexProcessor(ABC): try: # Download with timeout - response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT) + response = remote_fetcher.get(image_url, timeout=DOWNLOAD_TIMEOUT) response.raise_for_status() # Check Content-Length header if available diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index f2552e7cbd..8d50b933d1 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -13,7 +13,7 @@ from sqlalchemy import select from configs import dify_config from core.db.session_factory import session_factory -from core.helper import ssrf_proxy +from core.file import remote_fetcher from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type @@ -60,26 +60,6 @@ class ToolFileManager: return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" - @staticmethod - def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - """ - verify signature - """ - data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" - recalculated_sign = hmac.new( - dify_config.SECRET_KEY.encode(), - data_to_sign.encode(), - hashlib.sha256, - ).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - # verify signature - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT - def create_file_by_raw( self, *, @@ -129,7 +109,7 @@ class ToolFileManager: ) -> ToolFile: # try to download image try: - response = ssrf_proxy.get(file_url) + response = remote_fetcher.get(file_url) response.raise_for_status() blob = response.content except httpx.TimeoutException: diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 94a2c0427b..658169a372 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -9,7 +9,7 @@ import charset_normalizer import cloudscraper from readabilipy import simple_json_from_html_string -from core.helper import ssrf_proxy +from core.file import remote_fetcher from core.rag.extractor import extract_processor from core.rag.extractor.extract_processor import ExtractProcessor @@ -38,7 +38,7 @@ def get_url(url: str, user_agent: str | None = None) -> str: main_content_type = None supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] - response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) + response = remote_fetcher.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) if response.status_code == 200: # check content-type @@ -60,10 +60,10 @@ def get_url(url: str, user_agent: str | None = None) -> str: if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) - response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) + response = remote_fetcher.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: scraper = cloudscraper.create_scraper() - scraper.perform_request = ssrf_proxy.make_request + scraper.perform_request = remote_fetcher.make_request response = scraper.get(url, headers=headers, timeout=(120, 300)) if response.status_code != 200: diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 4a27b2c623..c9cd737a8d 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -11,6 +11,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm.model_access import build_dify_model_access, fetch_model_config from core.db.session_factory import session_factory +from core.file import remote_fetcher from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, @@ -307,6 +308,7 @@ class DifyNodeFactory(NodeFactory): self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH self._http_request_http_client = graphon_ssrf_proxy + self._remote_file_http_client = remote_fetcher.graphon_remote_file_fetcher self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( self._dify_context, conversation_id_getter=self._conversation_id, @@ -318,7 +320,7 @@ class DifyNodeFactory(NodeFactory): ) self._llm_file_saver = build_dify_llm_file_saver( run_context=self._dify_context, - http_client=self._http_request_http_client, + http_client=self._remote_file_http_client, conversation_id_getter=self._conversation_id, ) self._human_input_runtime = DifyHumanInputNodeRuntime( @@ -416,7 +418,7 @@ class DifyNodeFactory(NodeFactory): ), BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: { "unstructured_api_config": self._document_extractor_unstructured_api_config, - "http_client": self._http_request_http_client, + "http_client": self._remote_file_http_client, }, BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, @@ -528,7 +530,7 @@ class DifyNodeFactory(NodeFactory): if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER: node_init_kwargs["template_renderer"] = self._jinja2_template_renderer if include_http_client: - node_init_kwargs["http_client"] = self._http_request_http_client + node_init_kwargs["http_client"] = self._remote_file_http_client if include_llm_file_saver: node_init_kwargs["llm_file_saver"] = self._llm_file_saver if include_prompt_message_serializer: diff --git a/api/factories/file_factory/remote.py b/api/factories/file_factory/remote.py index 9b8f94b1f3..7109180718 100644 --- a/api/factories/file_factory/remote.py +++ b/api/factories/file_factory/remote.py @@ -16,7 +16,7 @@ import uuid import httpx from werkzeug.http import parse_options_header -from core.helper import ssrf_proxy +from core.file import remote_fetcher def extract_filename(url_or_path: str, content_disposition: str | None) -> str | None: @@ -81,7 +81,7 @@ def get_remote_file_info(url: str) -> tuple[str, str, int]: filename = os.path.basename(url_path) mime_type = _guess_mime_type(filename) - resp = ssrf_proxy.head(url, follow_redirects=True) + resp = remote_fetcher.head(url, follow_redirects=True) if resp.status_code == httpx.codes.OK: content_disposition = resp.headers.get("Content-Disposition") extracted_filename = extract_filename(url_path, content_disposition) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 7ba2b64c74..82e14357a1 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -17,7 +17,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants.dsl_version import CURRENT_APP_DSL_VERSION -from core.helper import ssrf_proxy +from core.file import remote_fetcher from core.plugin.entities.plugin import PluginDependency from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, @@ -126,7 +126,7 @@ class AppDslService: ): yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") yaml_url = yaml_url.replace("/blob/", "/") - response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) + response = remote_fetcher.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) response.raise_for_status() content = response.content.decode() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 99fd3f5628..16f16a77fa 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -17,7 +17,7 @@ from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session -from core.helper import ssrf_proxy +from core.file import remote_fetcher from core.helper.name_generator import generate_incremental_name from core.plugin.entities.plugin import PluginDependency from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -125,7 +125,7 @@ class RagPipelineDslService: ): yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") yaml_url = yaml_url.replace("/blob/", "/") - response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) + response = remote_fetcher.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) response.raise_for_status() content = response.content.decode() diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index ca3ae6d0cf..900bce5da5 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -316,7 +316,7 @@ class TestAppDslService: self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch ): monkeypatch.setattr( - app_dsl_service.ssrf_proxy, + app_dsl_service.remote_fetcher, "get", lambda _url, **_kw: (_ for _ in ()).throw(RuntimeError("boom")), ) @@ -336,7 +336,7 @@ class TestAppDslService: response = MagicMock() response.content = b"" response.raise_for_status.return_value = None - monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response) + monkeypatch.setattr(app_dsl_service.remote_fetcher, "get", lambda _url, **_kw: response) service = AppDslService(db_session_with_containers) result = service.import_app( @@ -353,7 +353,7 @@ class TestAppDslService: response = MagicMock() response.content = b"x" * (DSL_MAX_SIZE + 1) response.raise_for_status.return_value = None - monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response) + monkeypatch.setattr(app_dsl_service.remote_fetcher, "get", lambda _url, **_kw: response) service = AppDslService(db_session_with_containers) result = service.import_app( @@ -379,7 +379,7 @@ class TestAppDslService: response.raise_for_status.return_value = None return response - monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + monkeypatch.setattr(app_dsl_service.remote_fetcher, "get", fake_get) service = AppDslService(db_session_with_containers) result = service.import_app( @@ -409,7 +409,7 @@ class TestAppDslService: response.raise_for_status.return_value = None return response - monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + monkeypatch.setattr(app_dsl_service.remote_fetcher, "get", fake_get) service = AppDslService(db_session_with_containers) result = service.import_app( diff --git a/api/tests/unit_tests/controllers/console/test_remote_files.py b/api/tests/unit_tests/controllers/console/test_remote_files.py index 7c7abdcf2d..bb696c040f 100644 --- a/api/tests/unit_tests/controllers/console/test_remote_files.py +++ b/api/tests/unit_tests/controllers/console/test_remote_files.py @@ -100,8 +100,8 @@ def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest ) head_mock = MagicMock(return_value=head_resp) get_mock = MagicMock() - monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock) - monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + monkeypatch.setattr(remote_files_module.remote_fetcher, "head", head_mock) + monkeypatch.setattr(remote_files_module.remote_fetcher, "get", get_mock) with app.test_request_context(method="GET"): payload = handler(api, url=encoded_url) @@ -123,8 +123,8 @@ def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: method="HEAD", ) head_mock = MagicMock(return_value=head_resp) - monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock) - monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + monkeypatch.setattr(remote_files_module.remote_fetcher, "head", head_mock) + monkeypatch.setattr(remote_files_module.remote_fetcher, "get", MagicMock()) with app.test_request_context(f"/remote-files/{target_url}?{query}", method="GET"): payload = handler(api, url=target_url) @@ -139,9 +139,13 @@ def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, mo decoded_url = "https://example.com/test.txt" encoded_url = urllib.parse.quote(decoded_url, safe="") - monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503))) + monkeypatch.setattr( + remote_files_module.remote_fetcher, + "head", + MagicMock(return_value=_FakeResponse(status_code=503)), + ) get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET")) - monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + monkeypatch.setattr(remote_files_module.remote_fetcher, "get", get_mock) with app.test_request_context(method="GET"): payload = handler(api, url=encoded_url) @@ -155,10 +159,14 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc handler = _unwrap(api.post) url = "https://example.com/report.txt" - monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404))) + monkeypatch.setattr( + remote_files_module.remote_fetcher, + "head", + MagicMock(return_value=_FakeResponse(status_code=404)), + ) get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content") get_mock = MagicMock(return_value=get_resp) - monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + monkeypatch.setattr(remote_files_module.remote_fetcher, "get", get_mock) file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) upload_file = SimpleNamespace( @@ -196,13 +204,13 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds( url = "https://example.com/photo.jpg" monkeypatch.setattr( - remote_files_module.ssrf_proxy, + remote_files_module.remote_fetcher, "head", MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")), ) extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content") get_mock = MagicMock(return_value=extra_get_resp) - monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + monkeypatch.setattr(remote_files_module.remote_fetcher, "get", get_mock) file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) upload_file = SimpleNamespace( @@ -230,9 +238,13 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat handler = _unwrap(api.post) url = "https://example.com/fail.txt" - monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500))) monkeypatch.setattr( - remote_files_module.ssrf_proxy, + remote_files_module.remote_fetcher, + "head", + MagicMock(return_value=_FakeResponse(status_code=500)), + ) + monkeypatch.setattr( + remote_files_module.remote_fetcher, "get", MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")), ) @@ -249,7 +261,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte request = httpx.Request("HEAD", url) monkeypatch.setattr( - remote_files_module.ssrf_proxy, + remote_files_module.remote_fetcher, "head", MagicMock(side_effect=httpx.RequestError("network down", request=request)), ) @@ -265,11 +277,11 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk url = "https://example.com/large.bin" monkeypatch.setattr( - remote_files_module.ssrf_proxy, + remote_files_module.remote_fetcher, "head", MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), ) - monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + monkeypatch.setattr(remote_files_module.remote_fetcher, "get", MagicMock()) _, current_user = _mock_upload_dependencies(monkeypatch, file_size_within_limit=False) @@ -284,11 +296,11 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp url = "https://example.com/large.bin" monkeypatch.setattr( - remote_files_module.ssrf_proxy, + remote_files_module.remote_fetcher, "head", MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), ) - monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + monkeypatch.setattr(remote_files_module.remote_fetcher, "get", MagicMock()) file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded") @@ -303,11 +315,11 @@ def test_remote_file_upload_translates_service_unsupported_type_error(app, monke url = "https://example.com/file.exe" monkeypatch.setattr( - remote_files_module.ssrf_proxy, + remote_files_module.remote_fetcher, "head", MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), ) - monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + monkeypatch.setattr(remote_files_module.remote_fetcher, "get", MagicMock()) file_service_cls, current_user = _mock_upload_dependencies(monkeypatch) file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError() diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index 701863b927..0025c21f43 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -351,7 +351,11 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M assert runtime.multimodal_send_format == "url" - with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get: + with patch.object( + file_runtime.remote_fetcher.graphon_remote_file_fetcher, + "get", + return_value="response", + ) as mock_get: assert runtime.http_get("http://example", follow_redirects=False) == "response" mock_get.assert_called_once_with("http://example", follow_redirects=False) diff --git a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py index cee7d46083..4eaff201f0 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py @@ -1,6 +1,3 @@ -import base64 -import hashlib -import hmac from unittest.mock import MagicMock, patch import httpx @@ -34,34 +31,6 @@ class TestDatasourceFileManager: assert f"nonce={mock_urandom.return_value.hex()}" in signed_url assert "sign=" in signed_url - @patch("core.datasource.datasource_file_manager.time.time") - @patch("core.datasource.datasource_file_manager.dify_config") - def test_verify_file(self, mock_config, mock_time): - # Setup - mock_config.SECRET_KEY = "test_secret" - mock_config.FILES_ACCESS_TIMEOUT = 300 - mock_time.return_value = 1700000000 - - datasource_file_id = "file_id_123" - timestamp = "1699999800" # 200 seconds ago - nonce = "some_nonce" - - # Manually calculate sign - data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" - secret_key = b"test_secret" - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - # Execute & Verify Success - assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True - - # Verify Failure - Wrong Sign - assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, "wrong_sign") is False - - # Verify Failure - Timeout - mock_time.return_value = 1700000500 # 700 seconds after timestamp (300 is timeout) - assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is False - @patch("core.datasource.datasource_file_manager.db") @patch("core.datasource.datasource_file_manager.storage") @patch("core.datasource.datasource_file_manager.uuid4") @@ -170,7 +139,7 @@ class TestDatasourceFileManager: assert upload_file.name == "unique_hex.pdf" assert upload_file.extension == ".pdf" - @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.remote_fetcher") @patch("core.datasource.datasource_file_manager.db") @patch("core.datasource.datasource_file_manager.storage") @patch("core.datasource.datasource_file_manager.uuid4") @@ -190,7 +159,7 @@ class TestDatasourceFileManager: # Verify assert tool_file.mimetype == "image/png" # Guessed from .png in URL - @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.remote_fetcher") @patch("core.datasource.datasource_file_manager.db") @patch("core.datasource.datasource_file_manager.storage") @patch("core.datasource.datasource_file_manager.uuid4") @@ -212,7 +181,7 @@ class TestDatasourceFileManager: # Verify assert tool_file.mimetype == "application/octet-stream" - @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.remote_fetcher") @patch("core.datasource.datasource_file_manager.db") @patch("core.datasource.datasource_file_manager.storage") @patch("core.datasource.datasource_file_manager.uuid4") @@ -235,7 +204,7 @@ class TestDatasourceFileManager: assert tool_file.file_key == "tools/tenant_456/unique_hex.jpg" mock_storage.save.assert_called_once() - @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.remote_fetcher") def test_create_file_by_url_timeout(self, mock_ssrf): # Setup mock_ssrf.get.side_effect = httpx.TimeoutException("Timeout") diff --git a/api/tests/unit_tests/core/file/test_remote_fetcher.py b/api/tests/unit_tests/core/file/test_remote_fetcher.py new file mode 100644 index 0000000000..c5a45ceec6 --- /dev/null +++ b/api/tests/unit_tests/core/file/test_remote_fetcher.py @@ -0,0 +1,258 @@ +import base64 +import hashlib +import hmac +import urllib.parse +from types import SimpleNamespace +from unittest.mock import MagicMock + +import httpx + +from core.file import remote_fetcher + +UPLOAD_FILE_ID = "1602650a-4fe4-423c-85a2-af76c083e3c4" +TOOL_FILE_ID = "2602650a-4fe4-423c-85a2-af76c083e3c4" +DATASOURCE_FILE_ID = "3602650a-4fe4-423c-85a2-af76c083e3c4" + + +def _signed_url(*, base_url: str, path: str, payload: str, secret: str = "test-secret") -> str: + timestamp = "1700000000" + nonce = "nonce" + signature = hmac.new( + secret.encode(), + f"{payload}|{timestamp}|{nonce}".encode(), + hashlib.sha256, + ).digest() + query = urllib.parse.urlencode( + { + "timestamp": timestamp, + "nonce": nonce, + "sign": base64.urlsafe_b64encode(signature).decode(), + } + ) + return f"{base_url}{path}?{query}" + + +def _patch_file_fetcher_config(monkeypatch): + monkeypatch.setattr(remote_fetcher.dify_config, "FILES_URL", "http://localhost:5001") + monkeypatch.setattr(remote_fetcher.dify_config, "INTERNAL_FILES_URL", "http://api:5001") + monkeypatch.setattr(remote_fetcher.dify_config, "SECRET_KEY", "test-secret") + monkeypatch.setattr(remote_fetcher.dify_config, "FILES_ACCESS_TIMEOUT", 3600) + monkeypatch.setattr(remote_fetcher.time, "time", lambda: 1700000100) + + +def _patch_session(monkeypatch): + session = MagicMock() + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = False + monkeypatch.setattr(remote_fetcher.session_factory, "create_session", MagicMock(return_value=session_cm)) + return session + + +def test_get_signed_upload_file_url_reads_storage_without_ssrf(monkeypatch): + _patch_file_fetcher_config(monkeypatch) + session = _patch_session(monkeypatch) + upload_file = SimpleNamespace( + id=UPLOAD_FILE_ID, + key="upload_files/tenant/hello.txt", + name="hello.txt", + mime_type="text/plain", + size=5, + extension="txt", + ) + monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", MagicMock(return_value=upload_file)) + monkeypatch.setattr(remote_fetcher.storage, "load_once", MagicMock(return_value=b"hello")) + ssrf_get = MagicMock() + monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get) + url = _signed_url( + base_url="http://localhost:5001", + path=f"/files/{UPLOAD_FILE_ID}/file-preview", + payload=f"file-preview|{UPLOAD_FILE_ID}", + ) + + response = remote_fetcher.get(url) + + assert response.status_code == 200 + assert response.content == b"hello" + assert response.headers["Content-Type"] == "text/plain" + assert response.headers["Content-Length"] == "5" + assert response.request.method == "GET" + remote_fetcher._file_access_controller.get_upload_file.assert_called_once_with( + session=session, + file_id=UPLOAD_FILE_ID, + ) + remote_fetcher.storage.load_once.assert_called_once_with("upload_files/tenant/hello.txt") + ssrf_get.assert_not_called() + + +def test_head_signed_upload_file_url_returns_metadata_without_storage_content(monkeypatch): + _patch_file_fetcher_config(monkeypatch) + session = _patch_session(monkeypatch) + upload_file = SimpleNamespace( + id=UPLOAD_FILE_ID, + key="upload_files/tenant/hello.txt", + name="hello.txt", + mime_type="text/plain", + size=5, + extension="txt", + ) + monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", MagicMock(return_value=upload_file)) + load_once = MagicMock(return_value=b"hello") + monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once) + ssrf_head = MagicMock() + monkeypatch.setattr(remote_fetcher.ssrf_proxy, "head", ssrf_head) + url = _signed_url( + base_url="http://localhost:5001", + path=f"/files/{UPLOAD_FILE_ID}/file-preview", + payload=f"file-preview|{UPLOAD_FILE_ID}", + ) + + response = remote_fetcher.head(url) + + assert response.status_code == 200 + assert response.content == b"" + assert response.headers["Content-Type"] == "text/plain" + assert response.headers["Content-Length"] == "5" + assert response.request.method == "HEAD" + remote_fetcher._file_access_controller.get_upload_file.assert_called_once_with( + session=session, + file_id=UPLOAD_FILE_ID, + ) + load_once.assert_not_called() + ssrf_head.assert_not_called() + + +def test_invalid_signature_delegates_to_ssrf_proxy(monkeypatch): + _patch_file_fetcher_config(monkeypatch) + proxy_response = httpx.Response(403, request=httpx.Request("GET", "http://localhost:5001/bad")) + ssrf_get = MagicMock(return_value=proxy_response) + monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get) + url = ( + f"http://localhost:5001/files/{UPLOAD_FILE_ID}/file-preview" + "?timestamp=1700000000&nonce=nonce&sign=bad" + ) + + response = remote_fetcher.get(url, timeout=3) + + assert response is proxy_response + ssrf_get.assert_called_once_with(url=url, max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES, timeout=3) + + +def test_host_mismatch_delegates_to_ssrf_proxy(monkeypatch): + _patch_file_fetcher_config(monkeypatch) + url = _signed_url( + base_url="http://example.com", + path=f"/files/{UPLOAD_FILE_ID}/file-preview", + payload=f"file-preview|{UPLOAD_FILE_ID}", + ) + proxy_response = httpx.Response(200, request=httpx.Request("GET", url), content=b"remote") + ssrf_get = MagicMock(return_value=proxy_response) + monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get) + + response = remote_fetcher.get(url) + + assert response is proxy_response + ssrf_get.assert_called_once_with(url=url, max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES) + + +def test_unsupported_dify_path_delegates_to_ssrf_proxy(monkeypatch): + _patch_file_fetcher_config(monkeypatch) + url = _signed_url( + base_url="http://localhost:5001", + path=f"/files/{UPLOAD_FILE_ID}/not-preview", + payload=f"file-preview|{UPLOAD_FILE_ID}", + ) + proxy_response = httpx.Response(404, request=httpx.Request("HEAD", url)) + ssrf_head = MagicMock(return_value=proxy_response) + monkeypatch.setattr(remote_fetcher.ssrf_proxy, "head", ssrf_head) + + response = remote_fetcher.head(url, follow_redirects=True) + + assert response is proxy_response + ssrf_head.assert_called_once_with( + url=url, + max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES, + follow_redirects=True, + ) + + +def test_signed_upload_file_url_returns_404_when_record_missing(monkeypatch): + _patch_file_fetcher_config(monkeypatch) + _patch_session(monkeypatch) + monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", MagicMock(return_value=None)) + ssrf_get = MagicMock() + monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get) + url = _signed_url( + base_url="http://localhost:5001", + path=f"/files/{UPLOAD_FILE_ID}/file-preview", + payload=f"file-preview|{UPLOAD_FILE_ID}", + ) + + response = remote_fetcher.get(url) + + assert response.status_code == 404 + assert response.content == b"" + ssrf_get.assert_not_called() + + +def test_get_signed_tool_file_url_reads_storage_without_ssrf(monkeypatch): + _patch_file_fetcher_config(monkeypatch) + session = _patch_session(monkeypatch) + tool_file = SimpleNamespace( + id=TOOL_FILE_ID, + file_key="tools/tenant/result.txt", + name="result.txt", + mimetype="text/plain", + size=6, + ) + monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", MagicMock(return_value=tool_file)) + monkeypatch.setattr(remote_fetcher.storage, "load_once", MagicMock(return_value=b"result")) + ssrf_get = MagicMock() + monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get) + url = _signed_url( + base_url="http://localhost:5001", + path=f"/files/tools/{TOOL_FILE_ID}.txt", + payload=f"file-preview|{TOOL_FILE_ID}", + ) + + response = remote_fetcher.get(url) + + assert response.status_code == 200 + assert response.content == b"result" + assert response.headers["Content-Type"] == "text/plain" + remote_fetcher._file_access_controller.get_tool_file.assert_called_once_with( + session=session, + file_id=TOOL_FILE_ID, + ) + remote_fetcher.storage.load_once.assert_called_once_with("tools/tenant/result.txt") + ssrf_get.assert_not_called() + + +def test_get_signed_datasource_file_url_reads_upload_storage_without_ssrf(monkeypatch): + _patch_file_fetcher_config(monkeypatch) + _patch_session(monkeypatch) + upload_file = SimpleNamespace( + id=DATASOURCE_FILE_ID, + key="datasources/tenant/data.txt", + name="data.txt", + mime_type="text/plain", + size=4, + extension="txt", + ) + monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", MagicMock(return_value=upload_file)) + monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", MagicMock(return_value=None)) + monkeypatch.setattr(remote_fetcher.storage, "load_once", MagicMock(return_value=b"data")) + ssrf_get = MagicMock() + monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get) + url = _signed_url( + base_url="http://localhost:5001", + path=f"/files/datasources/{DATASOURCE_FILE_ID}.txt", + payload=f"file-preview|{DATASOURCE_FILE_ID}", + ) + + response = remote_fetcher.get(url) + + assert response.status_code == 200 + assert response.content == b"data" + remote_fetcher.storage.load_once.assert_called_once_with("datasources/tenant/data.txt") + ssrf_get.assert_not_called() diff --git a/api/tests/unit_tests/core/helper/test_download.py b/api/tests/unit_tests/core/helper/test_download.py index 0755c25826..6845887ab9 100644 --- a/api/tests/unit_tests/core/helper/test_download.py +++ b/api/tests/unit_tests/core/helper/test_download.py @@ -17,7 +17,7 @@ class _StubResponse: def test_download_with_size_limit_returns_content(mocker: MockerFixture) -> None: response = _StubResponse(status_code=200, chunks=[b"ab", b"cd", b"ef"]) - mock_get = mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response) + mock_get = mocker.patch("core.helper.download.remote_fetcher.get", return_value=response) content = download_with_size_limit("https://example.com/a.txt", max_download_size=6, timeout=10) @@ -26,7 +26,7 @@ def test_download_with_size_limit_returns_content(mocker: MockerFixture) -> None def test_download_with_size_limit_raises_for_404(mocker: MockerFixture) -> None: - mocker.patch("core.helper.download.ssrf_proxy.get", return_value=_StubResponse(status_code=404, chunks=[])) + mocker.patch("core.helper.download.remote_fetcher.get", return_value=_StubResponse(status_code=404, chunks=[])) with pytest.raises(ValueError, match="file not found"): download_with_size_limit("https://example.com/missing.txt", max_download_size=10) @@ -36,7 +36,7 @@ def test_download_with_size_limit_raises_when_size_exceeds_limit( mocker: MockerFixture, ) -> None: response = _StubResponse(status_code=200, chunks=[b"abc", b"de"]) - mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response) + mocker.patch("core.helper.download.remote_fetcher.get", return_value=response) with pytest.raises(ValueError, match="Max file size reached"): download_with_size_limit("https://example.com/large.bin", max_download_size=4) @@ -46,7 +46,7 @@ def test_download_with_size_limit_accepts_content_equal_to_limit( mocker: MockerFixture, ) -> None: response = _StubResponse(status_code=200, chunks=[b"ab", b"cd"]) - mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response) + mocker.patch("core.helper.download.remote_fetcher.get", return_value=response) content = download_with_size_limit("https://example.com/exact.bin", max_download_size=4) diff --git a/api/tests/unit_tests/core/rag/extractor/test_extract_processor.py b/api/tests/unit_tests/core/rag/extractor/test_extract_processor.py index b4b08f57ec..3fd417a2aa 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_extract_processor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_extract_processor.py @@ -97,7 +97,7 @@ class TestExtractProcessorLoaders: self, monkeypatch: pytest.MonkeyPatch, url, headers, expected_suffix ): response = SimpleNamespace(headers=headers, content=b"body") - monkeypatch.setattr(processor_module.ssrf_proxy, "get", lambda *args, **kwargs: response) + monkeypatch.setattr(processor_module.remote_fetcher, "get", lambda *args, **kwargs: response) monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs)) captured = {} diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py index 12c5238f5e..715b83584e 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py @@ -200,7 +200,7 @@ class TestBaseIndexProcessor: mock_db.engine = Mock() with ( - patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response), + patch("core.rag.index_processor.index_processor_base.remote_fetcher.get", return_value=response), patch("core.rag.index_processor.index_processor_base.db", mock_db), patch("services.file_service.FileService") as mock_file_service, ): @@ -215,7 +215,7 @@ class TestBaseIndexProcessor: too_large.headers = {"Content-Length": str(3 * 1024 * 1024), "content-type": "image/png"} too_large.raise_for_status.return_value = None - with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=too_large): + with patch("core.rag.index_processor.index_processor_base.remote_fetcher.get", return_value=too_large): assert processor._download_image("https://example.com/too-large.png", current_user=Mock()) is None empty = Mock() @@ -223,7 +223,7 @@ class TestBaseIndexProcessor: empty.raise_for_status.return_value = None empty.iter_bytes.return_value = [] - with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=empty): + with patch("core.rag.index_processor.index_processor_base.remote_fetcher.get", return_value=empty): assert processor._download_image("https://example.com/empty.png", current_user=Mock()) is None def test_download_image_limits_stream_size(self, processor: _ForwardingBaseIndexProcessor) -> None: @@ -232,7 +232,7 @@ class TestBaseIndexProcessor: response.raise_for_status.return_value = None response.iter_bytes.return_value = [b"a" * (3 * 1024 * 1024)] - with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response): + with patch("core.rag.index_processor.index_processor_base.remote_fetcher.get", return_value=response): assert processor._download_image("https://example.com/big-stream.png", current_user=Mock()) is None def test_download_image_handles_timeout_request_and_unexpected_errors( @@ -241,19 +241,19 @@ class TestBaseIndexProcessor: request = httpx.Request("GET", "https://example.com/image.png") with patch( - "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + "core.rag.index_processor.index_processor_base.remote_fetcher.get", side_effect=httpx.TimeoutException("timeout"), ): assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None with patch( - "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + "core.rag.index_processor.index_processor_base.remote_fetcher.get", side_effect=httpx.RequestError("bad request", request=request), ): assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None with patch( - "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + "core.rag.index_processor.index_processor_base.remote_fetcher.get", side_effect=RuntimeError("unexpected"), ): assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index ccffdf16d1..0b3123c72f 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -1,8 +1,8 @@ """Unit tests for `ToolFileManager` behavior. -Covers signing/verification, file persistence flows, and retrieval APIs with -mocked storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to -avoid real IO. +Covers signing, file persistence flows, and retrieval APIs with mocked +storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to avoid real +IO. """ from __future__ import annotations @@ -17,18 +17,6 @@ from core.tools.tool_file_manager import ToolFileManager from graphon.file import FileTransferMethod -def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: - monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000000) - monkeypatch.setattr("core.tools.tool_file_manager.os.urandom", lambda _: b"\x01" * 16) - monkeypatch.setattr("core.tools.tool_file_manager.dify_config.SECRET_KEY", "secret") - monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_URL", "https://files.example.com") - monkeypatch.setattr("core.tools.tool_file_manager.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") - monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 100) - - url = ToolFileManager.sign_file("tf-1", ".png") - return dict(part.split("=", 1) for part in url.split("?", 1)[1].split("&")) - - def _patch_session_factory(session: Mock): session_cm = MagicMock() session_cm.__enter__.return_value = session @@ -36,27 +24,10 @@ def _patch_session_factory(session: Mock): return patch("core.tools.tool_file_manager.session_factory.create_session", return_value=session_cm) -def test_tool_file_manager_sign_verify_valid(monkeypatch: pytest.MonkeyPatch) -> None: - query = _setup_tool_file_signing(monkeypatch) +def test_tool_file_manager_sign_file_builds_url() -> None: url = ToolFileManager.sign_file("tf-1", ".png") assert "/files/tools/tf-1.png" in url - assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is True - - -def test_tool_file_manager_sign_verify_bad_signature(monkeypatch: pytest.MonkeyPatch) -> None: - query = _setup_tool_file_signing(monkeypatch) - - assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], "bad") is False - - -def test_tool_file_manager_sign_verify_expired_timestamp(monkeypatch: pytest.MonkeyPatch) -> None: - query = _setup_tool_file_signing(monkeypatch) - monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 0) - monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000100) - - assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is False - def test_create_file_by_raw_stores_file_and_persists_record() -> None: manager = ToolFileManager() @@ -106,7 +77,7 @@ def test_create_file_by_url_downloads_and_persists_record() -> None: patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory), patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="def")), _patch_session_factory(session), - patch("core.tools.tool_file_manager.ssrf_proxy.get", return_value=response), + patch("core.tools.tool_file_manager.remote_fetcher.get", return_value=response), ): file_model = manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1") @@ -120,7 +91,7 @@ def test_create_file_by_url_downloads_and_persists_record() -> None: def test_create_file_by_url_raises_on_timeout() -> None: manager = ToolFileManager() - with patch("core.tools.tool_file_manager.ssrf_proxy.get", side_effect=httpx.TimeoutException("timeout")): + with patch("core.tools.tool_file_manager.remote_fetcher.get", side_effect=httpx.TimeoutException("timeout")): with pytest.raises(ValueError, match="timeout when downloading file"): manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1") diff --git a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py index 1361e16b06..c30cf05f43 100644 --- a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py +++ b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py @@ -58,7 +58,7 @@ def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_ headers={"Content-Type": "image/png"}, # not supported ) - monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(stub_support_types.remote_fetcher, "head", fake_head) result = get_url("https://x.test/file.png") assert result == "Unsupported content-type [image/png] of URL." @@ -82,7 +82,7 @@ def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch: pytes assert return_text is True return "PDF extracted text" - monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(stub_support_types.remote_fetcher, "head", fake_head) monkeypatch.setattr(stub_support_types.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url)) result = get_url("https://x.test/doc.pdf") @@ -103,8 +103,8 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.Monk # chardet.detect returns utf-8 import core.tools.utils.web_reader_tool as mod - monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) - monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get) + monkeypatch.setattr(mod.remote_fetcher, "head", fake_head) + monkeypatch.setattr(mod.remote_fetcher, "get", fake_get) mock_best = SimpleNamespace(encoding="utf-8") mock_from_bytes = SimpleNamespace(best=lambda: mock_best) @@ -137,8 +137,8 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest. import core.tools.utils.web_reader_tool as mod - monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) - monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get) + monkeypatch.setattr(mod.remote_fetcher, "head", fake_head) + monkeypatch.setattr(mod.remote_fetcher, "get", fake_get) mock_best = SimpleNamespace(encoding="utf-8") mock_from_bytes = SimpleNamespace(best=lambda: mock_best) monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes) @@ -150,7 +150,7 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest. def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub_support_types): - """HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed.""" + """HEAD 403 → use cloudscraper.get via remote_fetcher.make_request, then proceed.""" def fake_head(url, headers=None, follow_redirects=True, timeout=None): return FakeResponse(status_code=403, headers={}) @@ -167,7 +167,7 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub import core.tools.utils.web_reader_tool as mod - monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(mod.remote_fetcher, "head", fake_head) monkeypatch.setattr(mod.cloudscraper, "create_scraper", lambda: FakeScraper()) mock_best = SimpleNamespace(encoding="utf-8") mock_from_bytes = SimpleNamespace(best=lambda: mock_best) @@ -192,7 +192,7 @@ def test_get_url_head_non_200_returns_status(monkeypatch: pytest.MonkeyPatch, st import core.tools.utils.web_reader_tool as mod - monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(mod.remote_fetcher, "head", fake_head) out = get_url("https://x.test/fail") assert out == "URL returned status code 500." @@ -214,7 +214,7 @@ def test_get_url_content_disposition_filename_detection(monkeypatch: pytest.Monk import core.tools.utils.web_reader_tool as mod - monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(mod.remote_fetcher, "head", fake_head) monkeypatch.setattr(mod.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url)) out = get_url("https://x.test/fname") @@ -241,8 +241,8 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.Mo import core.tools.utils.web_reader_tool as mod - monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) - monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get) + monkeypatch.setattr(mod.remote_fetcher, "head", fake_head) + monkeypatch.setattr(mod.remote_fetcher, "get", fake_get) mock_best = SimpleNamespace(encoding="utf-8") mock_from_bytes = SimpleNamespace(best=lambda: mock_best) diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 62e1a50291..f35624aed1 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -452,6 +452,7 @@ class TestDifyNodeFactoryCreateNode: factory._jinja2_template_renderer = sentinel.jinja2_template_renderer factory._template_transform_max_output_length = 2048 factory._http_request_http_client = sentinel.http_client + factory._remote_file_http_client = sentinel.remote_file_http_client factory._bound_tool_file_manager_factory = MagicMock(return_value=sentinel.tool_file_manager) factory._file_reference_factory = sentinel.file_reference_factory factory._prompt_message_serializer = sentinel.prompt_message_serializer @@ -596,7 +597,7 @@ class TestDifyNodeFactoryCreateNode: factory._bound_tool_file_manager_factory.assert_called_once_with() elif constructor_name == "DocumentExtractorNode": assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config - assert kwargs["http_client"] is sentinel.http_client + assert kwargs["http_client"] is sentinel.remote_file_http_client def test_build_llm_compatible_node_init_kwargs_preserves_structured_output_switch(self, factory): node_data = LLMNodeData.model_validate( @@ -732,7 +733,7 @@ class TestDifyNodeFactoryCreateNode: BuiltinNodeTypes.LLM, "LLMNode", { - "http_client": sentinel.http_client, + "http_client": sentinel.remote_file_http_client, "llm_file_saver": sentinel.llm_file_saver, "prompt_message_serializer": sentinel.prompt_message_serializer, "retriever_attachment_loader": sentinel.retriever_attachment_loader, @@ -743,7 +744,7 @@ class TestDifyNodeFactoryCreateNode: BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", { - "http_client": sentinel.http_client, + "http_client": sentinel.remote_file_http_client, "llm_file_saver": sentinel.llm_file_saver, "prompt_message_serializer": sentinel.prompt_message_serializer, "template_renderer": sentinel.jinja2_template_renderer, diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 661882f013..99ad7a6622 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -22,7 +22,7 @@ from graphon.variables.variables import StringVariable def _mock_ssrf_head(monkeypatch: pytest.MonkeyPatch): """Avoid any real network requests during tests. - factories.file_factory.remote.get_remote_file_info() uses ssrf_proxy.head + factories.file_factory.remote.get_remote_file_info() uses remote_fetcher.head to inspect remote files. We stub it to return a minimal response object with headers so filename/mime/size can be derived deterministically. @@ -46,7 +46,7 @@ def _mock_ssrf_head(monkeypatch: pytest.MonkeyPatch): } return SimpleNamespace(status_code=200, headers=headers) - monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head) + monkeypatch.setattr("factories.file_factory.remote.remote_fetcher.head", fake_head) class TestWorkflowEntry: diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index ffb151fbf4..1ae3cea0f8 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -99,7 +99,7 @@ def mock_http_head(): }, ) - with patch("factories.file_factory.remote.ssrf_proxy.head", autospec=True) as mock_head: + with patch("factories.file_factory.remote.remote_fetcher.head", autospec=True) as mock_head: mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg") yield mock_head diff --git a/api/tests/unit_tests/factories/test_file_factory.py b/api/tests/unit_tests/factories/test_file_factory.py index 293be925ae..e838f9164f 100644 --- a/api/tests/unit_tests/factories/test_file_factory.py +++ b/api/tests/unit_tests/factories/test_file_factory.py @@ -18,7 +18,7 @@ def _mock_head(monkeypatch: pytest.MonkeyPatch, headers: dict[str, str], status_ def _fake_head(url: str, follow_redirects: bool = True): return _FakeResponse(status_code=status_code, headers=headers) - monkeypatch.setattr("factories.file_factory.remote.ssrf_proxy.head", _fake_head) + monkeypatch.setattr("factories.file_factory.remote.remote_fetcher.head", _fake_head) class TestGetRemoteFileInfo: diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py index e72ebb4907..4d05a783e4 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py @@ -206,7 +206,10 @@ def test_export_rag_pipeline_dsl_raises_when_dataset_missing() -> None: def test_import_rag_pipeline_url_fetch_error(mocker) -> None: - mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.ssrf_proxy.get", side_effect=Exception("fetch failed")) + mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.remote_fetcher.get", + side_effect=Exception("fetch failed"), + ) service = RagPipelineDslService(session=Mock()) account = Mock(current_tenant_id="t1") @@ -813,7 +816,7 @@ def test_import_rag_pipeline_yaml_url_handles_empty_content_after_github_rewrite response = Mock() response.raise_for_status.return_value = None response.content = b"" - get_mock = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.ssrf_proxy.get", return_value=response) + get_mock = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.remote_fetcher.get", return_value=response) service = RagPipelineDslService(session=Mock()) account = Mock(current_tenant_id="t1") @@ -880,7 +883,7 @@ def test_import_rag_pipeline_url_size_exceeds_limit(mocker) -> None: response = Mock() response.raise_for_status.return_value = None response.content = b"x" * (10 * 1024 * 1024 + 1) - mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.ssrf_proxy.get", return_value=response) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.remote_fetcher.get", return_value=response) service = RagPipelineDslService(session=Mock()) account = Mock(current_tenant_id="t1")