mirror of
https://github.com/langgenius/dify.git
synced 2026-05-27 04:16:16 +08:00
fix(api): centralize remote file retrieval
Introduce a unified remote file fetcher that resolves first-party signed file URLs through database records and storage before falling back to the SSRF-protected HTTP client. Route backend remote-file call sites through the new boundary, remove obsolete file signature verification helpers, and document when to use remote_fetcher versus ssrf_proxy.
This commit is contained in:
@ -13,7 +13,7 @@ from controllers.common.errors import (
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -36,9 +36,9 @@ class GetRemoteFileInfo(Resource):
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
resp = remote_fetcher.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp = remote_fetcher.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
@ -58,9 +58,9 @@ class RemoteFileUpload(Resource):
|
||||
|
||||
# Try to fetch remote file metadata/content first
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
resp = remote_fetcher.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp = remote_fetcher.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# Normalize into a user-friendly error message expected by tests
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
@ -74,7 +74,7 @@ class RemoteFileUpload(Resource):
|
||||
raise FileTooLargeError()
|
||||
|
||||
# Load content if needed
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
content = resp.content if resp.request.method == "GET" else remote_fetcher.get(url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
|
||||
@ -9,7 +9,7 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -60,10 +60,10 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
HTTPException: If the remote file cannot be accessed
|
||||
"""
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
resp = remote_fetcher.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# failed back to get method
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp = remote_fetcher.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
info = RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
@ -112,9 +112,9 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
url = str(payload.url)
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
resp = remote_fetcher.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp = remote_fetcher.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
@ -125,7 +125,7 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError
|
||||
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
content = resp.content if resp.request.method == "GET" else remote_fetcher.get(url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
|
||||
@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Literal, override
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.ssrf_proxy import graphon_ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from extensions.ext_storage import storage
|
||||
@ -46,7 +46,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
|
||||
@override
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
return remote_fetcher.graphon_remote_file_fetcher.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
@override
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
|
||||
@ -12,7 +12,7 @@ from uuid import uuid4
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
@ -44,26 +44,6 @@ class DatasourceFileManager:
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@staticmethod
|
||||
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
recalculated_sign = hmac.new(
|
||||
dify_config.SECRET_KEY.encode(),
|
||||
data_to_sign.encode(),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(
|
||||
*,
|
||||
@ -117,7 +97,7 @@ class DatasourceFileManager:
|
||||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response = remote_fetcher.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except httpx.TimeoutException:
|
||||
|
||||
5
api/core/file/__init__.py
Normal file
5
api/core/file/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""File retrieval helpers shared by backend file-oriented workflows."""
|
||||
|
||||
from . import remote_fetcher
|
||||
|
||||
__all__ = ["remote_fetcher"]
|
||||
363
api/core/file/remote_fetcher.py
Normal file
363
api/core/file/remote_fetcher.py
Normal file
@ -0,0 +1,363 @@
|
||||
"""Unified remote-file retrieval with Dify signed file URL resolution.
|
||||
|
||||
Use this module for backend workflows whose intent is to fetch remote file content
|
||||
or remote file metadata from a URL, even when the URL originally came from a user
|
||||
upload, a workflow variable, a tool/datasource file, or an app DSL. GET/HEAD
|
||||
requests can resolve Dify-signed file URLs locally through DB + storage before
|
||||
falling back to the SSRF-protected network client.
|
||||
|
||||
Use `core.helper.ssrf_proxy` directly only for generic outbound HTTP where the
|
||||
URL is not being treated as a remote file, such as HTTP Request nodes, external
|
||||
API integrations, auth discovery, or user-configured tool calls. Those calls must
|
||||
stay as real network requests and should not reinterpret Dify file URLs as stored
|
||||
files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper import ssrf_proxy
|
||||
from core.helper.ssrf_proxy import (
|
||||
SSRF_DEFAULT_MAX_RETRIES,
|
||||
_to_graphon_http_response,
|
||||
max_retries_exceeded_error,
|
||||
request_error,
|
||||
)
|
||||
from extensions.ext_storage import storage
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
_UPLOAD_FILE_PATH_PATTERN = re.compile(
|
||||
r"^/files/(?P<file_id>[a-fA-F0-9-]+)/(?P<preview_kind>file-preview|image-preview)$"
|
||||
)
|
||||
_TOOL_FILE_PATH_PATTERN = re.compile(r"^/files/tools/(?P<file_id>[a-fA-F0-9-]+)\.(?P<extension>[^/]+)$")
|
||||
_DATASOURCE_FILE_PATH_PATTERN = re.compile(
|
||||
r"^/files/datasources/(?P<file_id>[a-fA-F0-9-]+)\.(?P<extension>[^/]+)$"
|
||||
)
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _SignedFileUrl:
|
||||
file_id: str
|
||||
preview_kind: Literal["file-preview", "image-preview"]
|
||||
record_kind: Literal["upload", "tool", "datasource"]
|
||||
|
||||
|
||||
def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
"""Fetch remote file content or metadata.
|
||||
|
||||
GET and HEAD requests for Dify-owned signed file URLs are served from local
|
||||
storage. Every other request is delegated unchanged to the SSRF proxy.
|
||||
"""
|
||||
|
||||
normalized_method = method.upper()
|
||||
if normalized_method in {"GET", "HEAD"}:
|
||||
response = _resolve_dify_signed_file_url(normalized_method, url)
|
||||
if response is not None:
|
||||
return response
|
||||
return ssrf_proxy.make_request(method=method, url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
"""Fetch remote file content, resolving Dify-owned signed file URLs locally."""
|
||||
|
||||
response = _resolve_dify_signed_file_url("GET", url)
|
||||
if response is not None:
|
||||
return response
|
||||
return ssrf_proxy.get(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
"""Fetch remote file metadata, resolving Dify-owned signed file URLs locally."""
|
||||
|
||||
response = _resolve_dify_signed_file_url("HEAD", url)
|
||||
if response is not None:
|
||||
return response
|
||||
return ssrf_proxy.head(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return ssrf_proxy.post(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return ssrf_proxy.put(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return ssrf_proxy.delete(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
|
||||
return ssrf_proxy.patch(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
class GraphonRemoteFileFetcher:
|
||||
"""Graphon HTTP-client adapter backed by the unified remote-file fetcher."""
|
||||
|
||||
@property
|
||||
def max_retries_exceeded_error(self) -> type[Exception]:
|
||||
return max_retries_exceeded_error
|
||||
|
||||
@property
|
||||
def request_error(self) -> type[Exception]:
|
||||
return request_error
|
||||
|
||||
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
|
||||
return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
|
||||
return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
|
||||
return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
|
||||
return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
|
||||
return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any):
|
||||
return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
|
||||
def _resolve_dify_signed_file_url(method: Literal["GET", "HEAD"], url: str) -> httpx.Response | None:
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
if not _is_dify_file_origin(parsed_url):
|
||||
return None
|
||||
|
||||
signed_file_url = _parse_signed_file_path(parsed_url.path)
|
||||
if signed_file_url is None:
|
||||
return None
|
||||
|
||||
query = urllib.parse.parse_qs(parsed_url.query, keep_blank_values=True)
|
||||
timestamp = _single_query_value(query, "timestamp")
|
||||
nonce = _single_query_value(query, "nonce")
|
||||
sign = _single_query_value(query, "sign")
|
||||
if timestamp is None or nonce is None or sign is None:
|
||||
return None
|
||||
|
||||
if not _verify_signed_file_url(
|
||||
signed_file_url=signed_file_url,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
):
|
||||
return None
|
||||
|
||||
if signed_file_url.record_kind == "upload":
|
||||
return _build_upload_file_response(method=method, url=url, file_id=signed_file_url.file_id)
|
||||
if signed_file_url.record_kind == "tool":
|
||||
return _build_tool_file_response(method=method, url=url, file_id=signed_file_url.file_id)
|
||||
return _build_datasource_file_response(method=method, url=url, file_id=signed_file_url.file_id)
|
||||
|
||||
|
||||
def _parse_signed_file_path(path: str) -> _SignedFileUrl | None:
|
||||
upload_match = _UPLOAD_FILE_PATH_PATTERN.match(path)
|
||||
if upload_match:
|
||||
return _SignedFileUrl(
|
||||
file_id=upload_match.group("file_id"),
|
||||
preview_kind=upload_match.group("preview_kind"),
|
||||
record_kind="upload",
|
||||
)
|
||||
|
||||
tool_match = _TOOL_FILE_PATH_PATTERN.match(path)
|
||||
if tool_match:
|
||||
return _SignedFileUrl(
|
||||
file_id=tool_match.group("file_id"),
|
||||
preview_kind="file-preview",
|
||||
record_kind="tool",
|
||||
)
|
||||
|
||||
datasource_match = _DATASOURCE_FILE_PATH_PATTERN.match(path)
|
||||
if datasource_match:
|
||||
return _SignedFileUrl(
|
||||
file_id=datasource_match.group("file_id"),
|
||||
preview_kind="file-preview",
|
||||
record_kind="datasource",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_dify_file_origin(parsed_url: urllib.parse.ParseResult) -> bool:
|
||||
if parsed_url.scheme not in {"http", "https"} or not parsed_url.hostname:
|
||||
return False
|
||||
|
||||
url_origin = _origin_parts(parsed_url)
|
||||
if url_origin is None:
|
||||
return False
|
||||
|
||||
allowed_origins = {
|
||||
origin
|
||||
for configured_url in [dify_config.FILES_URL, dify_config.INTERNAL_FILES_URL]
|
||||
if configured_url and (origin := _origin_parts(urllib.parse.urlparse(configured_url))) is not None
|
||||
}
|
||||
return url_origin in allowed_origins
|
||||
|
||||
|
||||
def _origin_parts(parsed_url: urllib.parse.ParseResult) -> tuple[str, str, int] | None:
|
||||
if parsed_url.scheme not in {"http", "https"} or not parsed_url.hostname:
|
||||
return None
|
||||
return parsed_url.scheme, parsed_url.hostname.lower(), parsed_url.port or _default_port(parsed_url.scheme)
|
||||
|
||||
|
||||
def _default_port(scheme: str) -> int:
|
||||
return 443 if scheme == "https" else 80
|
||||
|
||||
|
||||
def _single_query_value(query: dict[str, list[str]], key: str) -> str | None:
|
||||
values = query.get(key)
|
||||
if not values or len(values) != 1:
|
||||
return None
|
||||
return values[0]
|
||||
|
||||
|
||||
def _verify_signed_file_url(
|
||||
*,
|
||||
signed_file_url: _SignedFileUrl,
|
||||
timestamp: str,
|
||||
nonce: str,
|
||||
sign: str,
|
||||
) -> bool:
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
signed_at = int(timestamp)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if current_time - signed_at > dify_config.FILES_ACCESS_TIMEOUT:
|
||||
return False
|
||||
|
||||
payload = f"{signed_file_url.preview_kind}|{signed_file_url.file_id}|{timestamp}|{nonce}"
|
||||
recalculated = hmac.new(dify_config.SECRET_KEY.encode(), payload.encode(), hashlib.sha256).digest()
|
||||
expected = base64.urlsafe_b64encode(recalculated).decode()
|
||||
return hmac.compare_digest(sign, expected)
|
||||
|
||||
|
||||
def _build_upload_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = _file_access_controller.get_upload_file(session=session, file_id=file_id)
|
||||
if upload_file is None:
|
||||
return _build_response(method=method, url=url, status_code=404)
|
||||
|
||||
content = b"" if method == "HEAD" else storage.load_once(upload_file.key)
|
||||
return _build_response(
|
||||
method=method,
|
||||
url=url,
|
||||
status_code=200,
|
||||
content=content,
|
||||
content_length=upload_file.size,
|
||||
content_type=upload_file.mime_type,
|
||||
filename=upload_file.name,
|
||||
)
|
||||
|
||||
|
||||
def _build_tool_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file = _file_access_controller.get_tool_file(session=session, file_id=file_id)
|
||||
if tool_file is None:
|
||||
return _build_response(method=method, url=url, status_code=404)
|
||||
|
||||
content = b"" if method == "HEAD" else storage.load_once(tool_file.file_key)
|
||||
return _build_response(
|
||||
method=method,
|
||||
url=url,
|
||||
status_code=200,
|
||||
content=content,
|
||||
content_length=tool_file.size,
|
||||
content_type=tool_file.mimetype,
|
||||
filename=tool_file.name,
|
||||
)
|
||||
|
||||
|
||||
def _build_datasource_file_response(*, method: Literal["GET", "HEAD"], url: str, file_id: str) -> httpx.Response:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = _file_access_controller.get_upload_file(session=session, file_id=file_id)
|
||||
if upload_file is not None:
|
||||
return _build_upload_file_record_response(method=method, url=url, upload_file=upload_file)
|
||||
|
||||
tool_file = _file_access_controller.get_tool_file(session=session, file_id=file_id)
|
||||
if tool_file is not None:
|
||||
return _build_tool_file_record_response(method=method, url=url, tool_file=tool_file)
|
||||
|
||||
return _build_response(method=method, url=url, status_code=404)
|
||||
|
||||
|
||||
def _build_upload_file_record_response(
|
||||
*,
|
||||
method: Literal["GET", "HEAD"],
|
||||
url: str,
|
||||
upload_file: UploadFile,
|
||||
) -> httpx.Response:
|
||||
content = b"" if method == "HEAD" else storage.load_once(upload_file.key)
|
||||
return _build_response(
|
||||
method=method,
|
||||
url=url,
|
||||
status_code=200,
|
||||
content=content,
|
||||
content_length=upload_file.size,
|
||||
content_type=upload_file.mime_type,
|
||||
filename=upload_file.name,
|
||||
)
|
||||
|
||||
|
||||
def _build_tool_file_record_response(
|
||||
*,
|
||||
method: Literal["GET", "HEAD"],
|
||||
url: str,
|
||||
tool_file: ToolFile,
|
||||
) -> httpx.Response:
|
||||
content = b"" if method == "HEAD" else storage.load_once(tool_file.file_key)
|
||||
return _build_response(
|
||||
method=method,
|
||||
url=url,
|
||||
status_code=200,
|
||||
content=content,
|
||||
content_length=tool_file.size,
|
||||
content_type=tool_file.mimetype,
|
||||
filename=tool_file.name,
|
||||
)
|
||||
|
||||
|
||||
def _build_response(
|
||||
*,
|
||||
method: Literal["GET", "HEAD"],
|
||||
url: str,
|
||||
status_code: int,
|
||||
content: bytes = b"",
|
||||
content_length: int | None = None,
|
||||
content_type: str | None = None,
|
||||
filename: str | None = None,
|
||||
) -> httpx.Response:
|
||||
headers: dict[str, str] = {}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
if content_length is not None and content_length >= 0:
|
||||
headers["Content-Length"] = str(content_length)
|
||||
if filename:
|
||||
headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{urllib.parse.quote(filename)}"
|
||||
return httpx.Response(
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
content=content,
|
||||
request=httpx.Request(method, url),
|
||||
)
|
||||
|
||||
|
||||
graphon_remote_file_fetcher = GraphonRemoteFileFetcher()
|
||||
@ -1,8 +1,8 @@
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
|
||||
|
||||
def download_with_size_limit(url, max_download_size: int, **kwargs):
|
||||
response = ssrf_proxy.get(url, follow_redirects=True, **kwargs)
|
||||
response = remote_fetcher.get(url, follow_redirects=True, **kwargs)
|
||||
if response.status_code == 404:
|
||||
raise ValueError("file not found")
|
||||
|
||||
|
||||
@ -1,5 +1,13 @@
|
||||
"""
|
||||
Proxy requests to avoid SSRF
|
||||
"""SSRF-protected HTTP client for generic outbound requests.
|
||||
|
||||
Use this module when the URL represents a normal external HTTP interaction that
|
||||
must go through network/proxy policy exactly as requested, such as HTTP Request
|
||||
nodes, provider/API integrations, auth discovery, or custom tool calls.
|
||||
|
||||
Do not use this directly for "remote file" retrieval. File downloads, probes,
|
||||
and metadata checks should use `core.file.remote_fetcher` instead so Dify-signed
|
||||
file URLs can be resolved through DB + storage before falling back to this SSRF
|
||||
client.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Union
|
||||
from urllib.parse import unquote
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
@ -55,7 +55,7 @@ class ExtractProcessor:
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT})
|
||||
response = remote_fetcher.get(url, headers={"User-Agent": USER_AGENT})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Word (.docx) document extractor used for RAG ingestion.
|
||||
|
||||
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
|
||||
Supports local file paths and remote URLs downloaded through the unified remote-file fetcher.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
@ -17,7 +17,7 @@ from docx.oxml.ns import qn
|
||||
from docx.text.run import Run
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
@ -51,7 +51,7 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
# If the file is a web path, download it to a temporary file, and use that
|
||||
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
|
||||
response = ssrf_proxy.get(self.file_path)
|
||||
response = remote_fetcher.get(self.file_path)
|
||||
|
||||
if response.status_code != 200:
|
||||
response.close()
|
||||
@ -120,7 +120,7 @@ class WordExtractor(BaseExtractor):
|
||||
if not self._is_valid_url(url):
|
||||
continue
|
||||
try:
|
||||
response = ssrf_proxy.get(url)
|
||||
response = remote_fetcher.get(url)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download image from URL: %s: %s", url, str(e))
|
||||
continue
|
||||
|
||||
@ -15,7 +15,7 @@ from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -243,7 +243,7 @@ class BaseIndexProcessor(ABC):
|
||||
|
||||
try:
|
||||
# Download with timeout
|
||||
response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT)
|
||||
response = remote_fetcher.get(image_url, timeout=DOWNLOAD_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
# Check Content-Length header if available
|
||||
|
||||
@ -13,7 +13,7 @@ from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type
|
||||
@ -60,26 +60,6 @@ class ToolFileManager:
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@staticmethod
|
||||
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
recalculated_sign = hmac.new(
|
||||
dify_config.SECRET_KEY.encode(),
|
||||
data_to_sign.encode(),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
def create_file_by_raw(
|
||||
self,
|
||||
*,
|
||||
@ -129,7 +109,7 @@ class ToolFileManager:
|
||||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response = remote_fetcher.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except httpx.TimeoutException:
|
||||
|
||||
@ -9,7 +9,7 @@ import charset_normalizer
|
||||
import cloudscraper
|
||||
from readabilipy import simple_json_from_html_string
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from core.rag.extractor import extract_processor
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
|
||||
@ -38,7 +38,7 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
||||
|
||||
main_content_type = None
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
|
||||
response = remote_fetcher.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if response.status_code == 200:
|
||||
# check content-type
|
||||
@ -60,10 +60,10 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
|
||||
|
||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
response = remote_fetcher.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
scraper.perform_request = remote_fetcher.make_request
|
||||
response = scraper.get(url, headers=headers, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
|
||||
@ -11,6 +11,7 @@ from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm.model_access import build_dify_model_access, fetch_model_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.file import remote_fetcher
|
||||
from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
@ -307,6 +308,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
self._http_request_http_client = graphon_ssrf_proxy
|
||||
self._remote_file_http_client = remote_fetcher.graphon_remote_file_fetcher
|
||||
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
|
||||
self._dify_context,
|
||||
conversation_id_getter=self._conversation_id,
|
||||
@ -318,7 +320,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
self._llm_file_saver = build_dify_llm_file_saver(
|
||||
run_context=self._dify_context,
|
||||
http_client=self._http_request_http_client,
|
||||
http_client=self._remote_file_http_client,
|
||||
conversation_id_getter=self._conversation_id,
|
||||
)
|
||||
self._human_input_runtime = DifyHumanInputNodeRuntime(
|
||||
@ -416,7 +418,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
),
|
||||
BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: {
|
||||
"unstructured_api_config": self._document_extractor_unstructured_api_config,
|
||||
"http_client": self._http_request_http_client,
|
||||
"http_client": self._remote_file_http_client,
|
||||
},
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
@ -528,7 +530,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
node_init_kwargs["template_renderer"] = self._jinja2_template_renderer
|
||||
if include_http_client:
|
||||
node_init_kwargs["http_client"] = self._http_request_http_client
|
||||
node_init_kwargs["http_client"] = self._remote_file_http_client
|
||||
if include_llm_file_saver:
|
||||
node_init_kwargs["llm_file_saver"] = self._llm_file_saver
|
||||
if include_prompt_message_serializer:
|
||||
|
||||
@ -16,7 +16,7 @@ import uuid
|
||||
import httpx
|
||||
from werkzeug.http import parse_options_header
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
|
||||
|
||||
def extract_filename(url_or_path: str, content_disposition: str | None) -> str | None:
|
||||
@ -81,7 +81,7 @@ def get_remote_file_info(url: str) -> tuple[str, str, int]:
|
||||
filename = os.path.basename(url_path)
|
||||
mime_type = _guess_mime_type(filename)
|
||||
|
||||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||
resp = remote_fetcher.head(url, follow_redirects=True)
|
||||
if resp.status_code == httpx.codes.OK:
|
||||
content_disposition = resp.headers.get("Content-Disposition")
|
||||
extracted_filename = extract_filename(url_path, content_disposition)
|
||||
|
||||
@ -17,7 +17,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.dsl_version import CURRENT_APP_DSL_VERSION
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.trigger.constants import (
|
||||
TRIGGER_PLUGIN_NODE_TYPE,
|
||||
@ -126,7 +126,7 @@ class AppDslService:
|
||||
):
|
||||
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
||||
yaml_url = yaml_url.replace("/blob/", "/")
|
||||
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
||||
response = remote_fetcher.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
||||
response.raise_for_status()
|
||||
content = response.content.decode()
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
@ -125,7 +125,7 @@ class RagPipelineDslService:
|
||||
):
|
||||
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
||||
yaml_url = yaml_url.replace("/blob/", "/")
|
||||
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
||||
response = remote_fetcher.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
||||
response.raise_for_status()
|
||||
content = response.content.decode()
|
||||
|
||||
|
||||
@ -316,7 +316,7 @@ class TestAppDslService:
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.ssrf_proxy,
|
||||
app_dsl_service.remote_fetcher,
|
||||
"get",
|
||||
lambda _url, **_kw: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
@ -336,7 +336,7 @@ class TestAppDslService:
|
||||
response = MagicMock()
|
||||
response.content = b""
|
||||
response.raise_for_status.return_value = None
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response)
|
||||
monkeypatch.setattr(app_dsl_service.remote_fetcher, "get", lambda _url, **_kw: response)
|
||||
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.import_app(
|
||||
@ -353,7 +353,7 @@ class TestAppDslService:
|
||||
response = MagicMock()
|
||||
response.content = b"x" * (DSL_MAX_SIZE + 1)
|
||||
response.raise_for_status.return_value = None
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response)
|
||||
monkeypatch.setattr(app_dsl_service.remote_fetcher, "get", lambda _url, **_kw: response)
|
||||
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.import_app(
|
||||
@ -379,7 +379,7 @@ class TestAppDslService:
|
||||
response.raise_for_status.return_value = None
|
||||
return response
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(app_dsl_service.remote_fetcher, "get", fake_get)
|
||||
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.import_app(
|
||||
@ -409,7 +409,7 @@ class TestAppDslService:
|
||||
response.raise_for_status.return_value = None
|
||||
return response
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(app_dsl_service.remote_fetcher, "get", fake_get)
|
||||
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.import_app(
|
||||
|
||||
@ -100,8 +100,8 @@ def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest
|
||||
)
|
||||
head_mock = MagicMock(return_value=head_resp)
|
||||
get_mock = MagicMock()
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
monkeypatch.setattr(remote_files_module.remote_fetcher, "head", head_mock)
|
||||
monkeypatch.setattr(remote_files_module.remote_fetcher, "get", get_mock)
|
||||
|
||||
with app.test_request_context(method="GET"):
|
||||
payload = handler(api, url=encoded_url)
|
||||
@ -123,8 +123,8 @@ def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch:
|
||||
method="HEAD",
|
||||
)
|
||||
head_mock = MagicMock(return_value=head_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
monkeypatch.setattr(remote_files_module.remote_fetcher, "head", head_mock)
|
||||
monkeypatch.setattr(remote_files_module.remote_fetcher, "get", MagicMock())
|
||||
|
||||
with app.test_request_context(f"/remote-files/{target_url}?{query}", method="GET"):
|
||||
payload = handler(api, url=target_url)
|
||||
@ -139,9 +139,13 @@ def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, mo
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
encoded_url = urllib.parse.quote(decoded_url, safe="")
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503)))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.remote_fetcher,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=503)),
|
||||
)
|
||||
get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET"))
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
monkeypatch.setattr(remote_files_module.remote_fetcher, "get", get_mock)
|
||||
|
||||
with app.test_request_context(method="GET"):
|
||||
payload = handler(api, url=encoded_url)
|
||||
@ -155,10 +159,14 @@ def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatc
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/report.txt"
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404)))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.remote_fetcher,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=404)),
|
||||
)
|
||||
get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content")
|
||||
get_mock = MagicMock(return_value=get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
monkeypatch.setattr(remote_files_module.remote_fetcher, "get", get_mock)
|
||||
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
@ -196,13 +204,13 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
url = "https://example.com/photo.jpg"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
remote_files_module.remote_fetcher,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")),
|
||||
)
|
||||
extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content")
|
||||
get_mock = MagicMock(return_value=extra_get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
monkeypatch.setattr(remote_files_module.remote_fetcher, "get", get_mock)
|
||||
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
@ -230,9 +238,13 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500)))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
remote_files_module.remote_fetcher,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=500)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.remote_fetcher,
|
||||
"get",
|
||||
MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")),
|
||||
)
|
||||
@ -249,7 +261,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
|
||||
|
||||
request = httpx.Request("HEAD", url)
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
remote_files_module.remote_fetcher,
|
||||
"head",
|
||||
MagicMock(side_effect=httpx.RequestError("network down", request=request)),
|
||||
)
|
||||
@ -265,11 +277,11 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk
|
||||
url = "https://example.com/large.bin"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
remote_files_module.remote_fetcher,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
monkeypatch.setattr(remote_files_module.remote_fetcher, "get", MagicMock())
|
||||
|
||||
_, current_user = _mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
|
||||
|
||||
@ -284,11 +296,11 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp
|
||||
url = "https://example.com/large.bin"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
remote_files_module.remote_fetcher,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
monkeypatch.setattr(remote_files_module.remote_fetcher, "get", MagicMock())
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
|
||||
|
||||
@ -303,11 +315,11 @@ def test_remote_file_upload_translates_service_unsupported_type_error(app, monke
|
||||
url = "https://example.com/file.exe"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
remote_files_module.remote_fetcher,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
monkeypatch.setattr(remote_files_module.remote_fetcher, "get", MagicMock())
|
||||
file_service_cls, current_user = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()
|
||||
|
||||
|
||||
@ -351,7 +351,11 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M
|
||||
|
||||
assert runtime.multimodal_send_format == "url"
|
||||
|
||||
with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get:
|
||||
with patch.object(
|
||||
file_runtime.remote_fetcher.graphon_remote_file_fetcher,
|
||||
"get",
|
||||
return_value="response",
|
||||
) as mock_get:
|
||||
assert runtime.http_get("http://example", follow_redirects=False) == "response"
|
||||
mock_get.assert_called_once_with("http://example", follow_redirects=False)
|
||||
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
@ -34,34 +31,6 @@ class TestDatasourceFileManager:
|
||||
assert f"nonce={mock_urandom.return_value.hex()}" in signed_url
|
||||
assert "sign=" in signed_url
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.time.time")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_verify_file(self, mock_config, mock_time):
|
||||
# Setup
|
||||
mock_config.SECRET_KEY = "test_secret"
|
||||
mock_config.FILES_ACCESS_TIMEOUT = 300
|
||||
mock_time.return_value = 1700000000
|
||||
|
||||
datasource_file_id = "file_id_123"
|
||||
timestamp = "1699999800" # 200 seconds ago
|
||||
nonce = "some_nonce"
|
||||
|
||||
# Manually calculate sign
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = b"test_secret"
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
# Execute & Verify Success
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True
|
||||
|
||||
# Verify Failure - Wrong Sign
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, "wrong_sign") is False
|
||||
|
||||
# Verify Failure - Timeout
|
||||
mock_time.return_value = 1700000500 # 700 seconds after timestamp (300 is timeout)
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is False
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@ -170,7 +139,7 @@ class TestDatasourceFileManager:
|
||||
assert upload_file.name == "unique_hex.pdf"
|
||||
assert upload_file.extension == ".pdf"
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
@patch("core.datasource.datasource_file_manager.remote_fetcher")
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@ -190,7 +159,7 @@ class TestDatasourceFileManager:
|
||||
# Verify
|
||||
assert tool_file.mimetype == "image/png" # Guessed from .png in URL
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
@patch("core.datasource.datasource_file_manager.remote_fetcher")
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@ -212,7 +181,7 @@ class TestDatasourceFileManager:
|
||||
# Verify
|
||||
assert tool_file.mimetype == "application/octet-stream"
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
@patch("core.datasource.datasource_file_manager.remote_fetcher")
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@ -235,7 +204,7 @@ class TestDatasourceFileManager:
|
||||
assert tool_file.file_key == "tools/tenant_456/unique_hex.jpg"
|
||||
mock_storage.save.assert_called_once()
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
@patch("core.datasource.datasource_file_manager.remote_fetcher")
|
||||
def test_create_file_by_url_timeout(self, mock_ssrf):
|
||||
# Setup
|
||||
mock_ssrf.get.side_effect = httpx.TimeoutException("Timeout")
|
||||
|
||||
258
api/tests/unit_tests/core/file/test_remote_fetcher.py
Normal file
258
api/tests/unit_tests/core/file/test_remote_fetcher.py
Normal file
@ -0,0 +1,258 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import urllib.parse
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
|
||||
from core.file import remote_fetcher
|
||||
|
||||
UPLOAD_FILE_ID = "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
TOOL_FILE_ID = "2602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
DATASOURCE_FILE_ID = "3602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
|
||||
|
||||
def _signed_url(*, base_url: str, path: str, payload: str, secret: str = "test-secret") -> str:
|
||||
timestamp = "1700000000"
|
||||
nonce = "nonce"
|
||||
signature = hmac.new(
|
||||
secret.encode(),
|
||||
f"{payload}|{timestamp}|{nonce}".encode(),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
query = urllib.parse.urlencode(
|
||||
{
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
"sign": base64.urlsafe_b64encode(signature).decode(),
|
||||
}
|
||||
)
|
||||
return f"{base_url}{path}?{query}"
|
||||
|
||||
|
||||
def _patch_file_fetcher_config(monkeypatch):
|
||||
monkeypatch.setattr(remote_fetcher.dify_config, "FILES_URL", "http://localhost:5001")
|
||||
monkeypatch.setattr(remote_fetcher.dify_config, "INTERNAL_FILES_URL", "http://api:5001")
|
||||
monkeypatch.setattr(remote_fetcher.dify_config, "SECRET_KEY", "test-secret")
|
||||
monkeypatch.setattr(remote_fetcher.dify_config, "FILES_ACCESS_TIMEOUT", 3600)
|
||||
monkeypatch.setattr(remote_fetcher.time, "time", lambda: 1700000100)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch):
|
||||
session = MagicMock()
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
session_cm.__exit__.return_value = False
|
||||
monkeypatch.setattr(remote_fetcher.session_factory, "create_session", MagicMock(return_value=session_cm))
|
||||
return session
|
||||
|
||||
|
||||
def test_get_signed_upload_file_url_reads_storage_without_ssrf(monkeypatch):
|
||||
_patch_file_fetcher_config(monkeypatch)
|
||||
session = _patch_session(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id=UPLOAD_FILE_ID,
|
||||
key="upload_files/tenant/hello.txt",
|
||||
name="hello.txt",
|
||||
mime_type="text/plain",
|
||||
size=5,
|
||||
extension="txt",
|
||||
)
|
||||
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", MagicMock(return_value=upload_file))
|
||||
monkeypatch.setattr(remote_fetcher.storage, "load_once", MagicMock(return_value=b"hello"))
|
||||
ssrf_get = MagicMock()
|
||||
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get)
|
||||
url = _signed_url(
|
||||
base_url="http://localhost:5001",
|
||||
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
|
||||
payload=f"file-preview|{UPLOAD_FILE_ID}",
|
||||
)
|
||||
|
||||
response = remote_fetcher.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"hello"
|
||||
assert response.headers["Content-Type"] == "text/plain"
|
||||
assert response.headers["Content-Length"] == "5"
|
||||
assert response.request.method == "GET"
|
||||
remote_fetcher._file_access_controller.get_upload_file.assert_called_once_with(
|
||||
session=session,
|
||||
file_id=UPLOAD_FILE_ID,
|
||||
)
|
||||
remote_fetcher.storage.load_once.assert_called_once_with("upload_files/tenant/hello.txt")
|
||||
ssrf_get.assert_not_called()
|
||||
|
||||
|
||||
def test_head_signed_upload_file_url_returns_metadata_without_storage_content(monkeypatch):
|
||||
_patch_file_fetcher_config(monkeypatch)
|
||||
session = _patch_session(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id=UPLOAD_FILE_ID,
|
||||
key="upload_files/tenant/hello.txt",
|
||||
name="hello.txt",
|
||||
mime_type="text/plain",
|
||||
size=5,
|
||||
extension="txt",
|
||||
)
|
||||
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", MagicMock(return_value=upload_file))
|
||||
load_once = MagicMock(return_value=b"hello")
|
||||
monkeypatch.setattr(remote_fetcher.storage, "load_once", load_once)
|
||||
ssrf_head = MagicMock()
|
||||
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "head", ssrf_head)
|
||||
url = _signed_url(
|
||||
base_url="http://localhost:5001",
|
||||
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
|
||||
payload=f"file-preview|{UPLOAD_FILE_ID}",
|
||||
)
|
||||
|
||||
response = remote_fetcher.head(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b""
|
||||
assert response.headers["Content-Type"] == "text/plain"
|
||||
assert response.headers["Content-Length"] == "5"
|
||||
assert response.request.method == "HEAD"
|
||||
remote_fetcher._file_access_controller.get_upload_file.assert_called_once_with(
|
||||
session=session,
|
||||
file_id=UPLOAD_FILE_ID,
|
||||
)
|
||||
load_once.assert_not_called()
|
||||
ssrf_head.assert_not_called()
|
||||
|
||||
|
||||
def test_invalid_signature_delegates_to_ssrf_proxy(monkeypatch):
|
||||
_patch_file_fetcher_config(monkeypatch)
|
||||
proxy_response = httpx.Response(403, request=httpx.Request("GET", "http://localhost:5001/bad"))
|
||||
ssrf_get = MagicMock(return_value=proxy_response)
|
||||
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get)
|
||||
url = (
|
||||
f"http://localhost:5001/files/{UPLOAD_FILE_ID}/file-preview"
|
||||
"?timestamp=1700000000&nonce=nonce&sign=bad"
|
||||
)
|
||||
|
||||
response = remote_fetcher.get(url, timeout=3)
|
||||
|
||||
assert response is proxy_response
|
||||
ssrf_get.assert_called_once_with(url=url, max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES, timeout=3)
|
||||
|
||||
|
||||
def test_host_mismatch_delegates_to_ssrf_proxy(monkeypatch):
|
||||
_patch_file_fetcher_config(monkeypatch)
|
||||
url = _signed_url(
|
||||
base_url="http://example.com",
|
||||
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
|
||||
payload=f"file-preview|{UPLOAD_FILE_ID}",
|
||||
)
|
||||
proxy_response = httpx.Response(200, request=httpx.Request("GET", url), content=b"remote")
|
||||
ssrf_get = MagicMock(return_value=proxy_response)
|
||||
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get)
|
||||
|
||||
response = remote_fetcher.get(url)
|
||||
|
||||
assert response is proxy_response
|
||||
ssrf_get.assert_called_once_with(url=url, max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES)
|
||||
|
||||
|
||||
def test_unsupported_dify_path_delegates_to_ssrf_proxy(monkeypatch):
|
||||
_patch_file_fetcher_config(monkeypatch)
|
||||
url = _signed_url(
|
||||
base_url="http://localhost:5001",
|
||||
path=f"/files/{UPLOAD_FILE_ID}/not-preview",
|
||||
payload=f"file-preview|{UPLOAD_FILE_ID}",
|
||||
)
|
||||
proxy_response = httpx.Response(404, request=httpx.Request("HEAD", url))
|
||||
ssrf_head = MagicMock(return_value=proxy_response)
|
||||
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "head", ssrf_head)
|
||||
|
||||
response = remote_fetcher.head(url, follow_redirects=True)
|
||||
|
||||
assert response is proxy_response
|
||||
ssrf_head.assert_called_once_with(
|
||||
url=url,
|
||||
max_retries=remote_fetcher.SSRF_DEFAULT_MAX_RETRIES,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
|
||||
def test_signed_upload_file_url_returns_404_when_record_missing(monkeypatch):
|
||||
_patch_file_fetcher_config(monkeypatch)
|
||||
_patch_session(monkeypatch)
|
||||
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", MagicMock(return_value=None))
|
||||
ssrf_get = MagicMock()
|
||||
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get)
|
||||
url = _signed_url(
|
||||
base_url="http://localhost:5001",
|
||||
path=f"/files/{UPLOAD_FILE_ID}/file-preview",
|
||||
payload=f"file-preview|{UPLOAD_FILE_ID}",
|
||||
)
|
||||
|
||||
response = remote_fetcher.get(url)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.content == b""
|
||||
ssrf_get.assert_not_called()
|
||||
|
||||
|
||||
def test_get_signed_tool_file_url_reads_storage_without_ssrf(monkeypatch):
|
||||
_patch_file_fetcher_config(monkeypatch)
|
||||
session = _patch_session(monkeypatch)
|
||||
tool_file = SimpleNamespace(
|
||||
id=TOOL_FILE_ID,
|
||||
file_key="tools/tenant/result.txt",
|
||||
name="result.txt",
|
||||
mimetype="text/plain",
|
||||
size=6,
|
||||
)
|
||||
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", MagicMock(return_value=tool_file))
|
||||
monkeypatch.setattr(remote_fetcher.storage, "load_once", MagicMock(return_value=b"result"))
|
||||
ssrf_get = MagicMock()
|
||||
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get)
|
||||
url = _signed_url(
|
||||
base_url="http://localhost:5001",
|
||||
path=f"/files/tools/{TOOL_FILE_ID}.txt",
|
||||
payload=f"file-preview|{TOOL_FILE_ID}",
|
||||
)
|
||||
|
||||
response = remote_fetcher.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"result"
|
||||
assert response.headers["Content-Type"] == "text/plain"
|
||||
remote_fetcher._file_access_controller.get_tool_file.assert_called_once_with(
|
||||
session=session,
|
||||
file_id=TOOL_FILE_ID,
|
||||
)
|
||||
remote_fetcher.storage.load_once.assert_called_once_with("tools/tenant/result.txt")
|
||||
ssrf_get.assert_not_called()
|
||||
|
||||
|
||||
def test_get_signed_datasource_file_url_reads_upload_storage_without_ssrf(monkeypatch):
|
||||
_patch_file_fetcher_config(monkeypatch)
|
||||
_patch_session(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id=DATASOURCE_FILE_ID,
|
||||
key="datasources/tenant/data.txt",
|
||||
name="data.txt",
|
||||
mime_type="text/plain",
|
||||
size=4,
|
||||
extension="txt",
|
||||
)
|
||||
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_upload_file", MagicMock(return_value=upload_file))
|
||||
monkeypatch.setattr(remote_fetcher._file_access_controller, "get_tool_file", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(remote_fetcher.storage, "load_once", MagicMock(return_value=b"data"))
|
||||
ssrf_get = MagicMock()
|
||||
monkeypatch.setattr(remote_fetcher.ssrf_proxy, "get", ssrf_get)
|
||||
url = _signed_url(
|
||||
base_url="http://localhost:5001",
|
||||
path=f"/files/datasources/{DATASOURCE_FILE_ID}.txt",
|
||||
payload=f"file-preview|{DATASOURCE_FILE_ID}",
|
||||
)
|
||||
|
||||
response = remote_fetcher.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"data"
|
||||
remote_fetcher.storage.load_once.assert_called_once_with("datasources/tenant/data.txt")
|
||||
ssrf_get.assert_not_called()
|
||||
@ -17,7 +17,7 @@ class _StubResponse:
|
||||
|
||||
def test_download_with_size_limit_returns_content(mocker: MockerFixture) -> None:
|
||||
response = _StubResponse(status_code=200, chunks=[b"ab", b"cd", b"ef"])
|
||||
mock_get = mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response)
|
||||
mock_get = mocker.patch("core.helper.download.remote_fetcher.get", return_value=response)
|
||||
|
||||
content = download_with_size_limit("https://example.com/a.txt", max_download_size=6, timeout=10)
|
||||
|
||||
@ -26,7 +26,7 @@ def test_download_with_size_limit_returns_content(mocker: MockerFixture) -> None
|
||||
|
||||
|
||||
def test_download_with_size_limit_raises_for_404(mocker: MockerFixture) -> None:
|
||||
mocker.patch("core.helper.download.ssrf_proxy.get", return_value=_StubResponse(status_code=404, chunks=[]))
|
||||
mocker.patch("core.helper.download.remote_fetcher.get", return_value=_StubResponse(status_code=404, chunks=[]))
|
||||
|
||||
with pytest.raises(ValueError, match="file not found"):
|
||||
download_with_size_limit("https://example.com/missing.txt", max_download_size=10)
|
||||
@ -36,7 +36,7 @@ def test_download_with_size_limit_raises_when_size_exceeds_limit(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
response = _StubResponse(status_code=200, chunks=[b"abc", b"de"])
|
||||
mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response)
|
||||
mocker.patch("core.helper.download.remote_fetcher.get", return_value=response)
|
||||
|
||||
with pytest.raises(ValueError, match="Max file size reached"):
|
||||
download_with_size_limit("https://example.com/large.bin", max_download_size=4)
|
||||
@ -46,7 +46,7 @@ def test_download_with_size_limit_accepts_content_equal_to_limit(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
response = _StubResponse(status_code=200, chunks=[b"ab", b"cd"])
|
||||
mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response)
|
||||
mocker.patch("core.helper.download.remote_fetcher.get", return_value=response)
|
||||
|
||||
content = download_with_size_limit("https://example.com/exact.bin", max_download_size=4)
|
||||
|
||||
|
||||
@ -97,7 +97,7 @@ class TestExtractProcessorLoaders:
|
||||
self, monkeypatch: pytest.MonkeyPatch, url, headers, expected_suffix
|
||||
):
|
||||
response = SimpleNamespace(headers=headers, content=b"body")
|
||||
monkeypatch.setattr(processor_module.ssrf_proxy, "get", lambda *args, **kwargs: response)
|
||||
monkeypatch.setattr(processor_module.remote_fetcher, "get", lambda *args, **kwargs: response)
|
||||
monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
captured = {}
|
||||
|
||||
@ -200,7 +200,7 @@ class TestBaseIndexProcessor:
|
||||
mock_db.engine = Mock()
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response),
|
||||
patch("core.rag.index_processor.index_processor_base.remote_fetcher.get", return_value=response),
|
||||
patch("core.rag.index_processor.index_processor_base.db", mock_db),
|
||||
patch("services.file_service.FileService") as mock_file_service,
|
||||
):
|
||||
@ -215,7 +215,7 @@ class TestBaseIndexProcessor:
|
||||
too_large.headers = {"Content-Length": str(3 * 1024 * 1024), "content-type": "image/png"}
|
||||
too_large.raise_for_status.return_value = None
|
||||
|
||||
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=too_large):
|
||||
with patch("core.rag.index_processor.index_processor_base.remote_fetcher.get", return_value=too_large):
|
||||
assert processor._download_image("https://example.com/too-large.png", current_user=Mock()) is None
|
||||
|
||||
empty = Mock()
|
||||
@ -223,7 +223,7 @@ class TestBaseIndexProcessor:
|
||||
empty.raise_for_status.return_value = None
|
||||
empty.iter_bytes.return_value = []
|
||||
|
||||
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=empty):
|
||||
with patch("core.rag.index_processor.index_processor_base.remote_fetcher.get", return_value=empty):
|
||||
assert processor._download_image("https://example.com/empty.png", current_user=Mock()) is None
|
||||
|
||||
def test_download_image_limits_stream_size(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
@ -232,7 +232,7 @@ class TestBaseIndexProcessor:
|
||||
response.raise_for_status.return_value = None
|
||||
response.iter_bytes.return_value = [b"a" * (3 * 1024 * 1024)]
|
||||
|
||||
with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response):
|
||||
with patch("core.rag.index_processor.index_processor_base.remote_fetcher.get", return_value=response):
|
||||
assert processor._download_image("https://example.com/big-stream.png", current_user=Mock()) is None
|
||||
|
||||
def test_download_image_handles_timeout_request_and_unexpected_errors(
|
||||
@ -241,19 +241,19 @@ class TestBaseIndexProcessor:
|
||||
request = httpx.Request("GET", "https://example.com/image.png")
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
|
||||
"core.rag.index_processor.index_processor_base.remote_fetcher.get",
|
||||
side_effect=httpx.TimeoutException("timeout"),
|
||||
):
|
||||
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
|
||||
"core.rag.index_processor.index_processor_base.remote_fetcher.get",
|
||||
side_effect=httpx.RequestError("bad request", request=request),
|
||||
):
|
||||
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
|
||||
|
||||
with patch(
|
||||
"core.rag.index_processor.index_processor_base.ssrf_proxy.get",
|
||||
"core.rag.index_processor.index_processor_base.remote_fetcher.get",
|
||||
side_effect=RuntimeError("unexpected"),
|
||||
):
|
||||
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
"""Unit tests for `ToolFileManager` behavior.
|
||||
|
||||
Covers signing/verification, file persistence flows, and retrieval APIs with
|
||||
mocked storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to
|
||||
avoid real IO.
|
||||
Covers signing, file persistence flows, and retrieval APIs with mocked
|
||||
storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to avoid real
|
||||
IO.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -17,18 +17,6 @@ from core.tools.tool_file_manager import ToolFileManager
|
||||
from graphon.file import FileTransferMethod
|
||||
|
||||
|
||||
def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]:
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000000)
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.os.urandom", lambda _: b"\x01" * 16)
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.SECRET_KEY", "secret")
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_URL", "https://files.example.com")
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.INTERNAL_FILES_URL", "https://internal.example.com")
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 100)
|
||||
|
||||
url = ToolFileManager.sign_file("tf-1", ".png")
|
||||
return dict(part.split("=", 1) for part in url.split("?", 1)[1].split("&"))
|
||||
|
||||
|
||||
def _patch_session_factory(session: Mock):
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
@ -36,27 +24,10 @@ def _patch_session_factory(session: Mock):
|
||||
return patch("core.tools.tool_file_manager.session_factory.create_session", return_value=session_cm)
|
||||
|
||||
|
||||
def test_tool_file_manager_sign_verify_valid(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = _setup_tool_file_signing(monkeypatch)
|
||||
def test_tool_file_manager_sign_file_builds_url() -> None:
|
||||
url = ToolFileManager.sign_file("tf-1", ".png")
|
||||
assert "/files/tools/tf-1.png" in url
|
||||
|
||||
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is True
|
||||
|
||||
|
||||
def test_tool_file_manager_sign_verify_bad_signature(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = _setup_tool_file_signing(monkeypatch)
|
||||
|
||||
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], "bad") is False
|
||||
|
||||
|
||||
def test_tool_file_manager_sign_verify_expired_timestamp(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = _setup_tool_file_signing(monkeypatch)
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 0)
|
||||
monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000100)
|
||||
|
||||
assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is False
|
||||
|
||||
|
||||
def test_create_file_by_raw_stores_file_and_persists_record() -> None:
|
||||
manager = ToolFileManager()
|
||||
@ -106,7 +77,7 @@ def test_create_file_by_url_downloads_and_persists_record() -> None:
|
||||
patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory),
|
||||
patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="def")),
|
||||
_patch_session_factory(session),
|
||||
patch("core.tools.tool_file_manager.ssrf_proxy.get", return_value=response),
|
||||
patch("core.tools.tool_file_manager.remote_fetcher.get", return_value=response),
|
||||
):
|
||||
file_model = manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1")
|
||||
|
||||
@ -120,7 +91,7 @@ def test_create_file_by_url_downloads_and_persists_record() -> None:
|
||||
def test_create_file_by_url_raises_on_timeout() -> None:
|
||||
manager = ToolFileManager()
|
||||
|
||||
with patch("core.tools.tool_file_manager.ssrf_proxy.get", side_effect=httpx.TimeoutException("timeout")):
|
||||
with patch("core.tools.tool_file_manager.remote_fetcher.get", side_effect=httpx.TimeoutException("timeout")):
|
||||
with pytest.raises(ValueError, match="timeout when downloading file"):
|
||||
manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1")
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_
|
||||
headers={"Content-Type": "image/png"}, # not supported
|
||||
)
|
||||
|
||||
monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(stub_support_types.remote_fetcher, "head", fake_head)
|
||||
|
||||
result = get_url("https://x.test/file.png")
|
||||
assert result == "Unsupported content-type [image/png] of URL."
|
||||
@ -82,7 +82,7 @@ def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch: pytes
|
||||
assert return_text is True
|
||||
return "PDF extracted text"
|
||||
|
||||
monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(stub_support_types.remote_fetcher, "head", fake_head)
|
||||
monkeypatch.setattr(stub_support_types.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url))
|
||||
|
||||
result = get_url("https://x.test/doc.pdf")
|
||||
@ -103,8 +103,8 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.Monk
|
||||
# chardet.detect returns utf-8
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(mod.remote_fetcher, "head", fake_head)
|
||||
monkeypatch.setattr(mod.remote_fetcher, "get", fake_get)
|
||||
|
||||
mock_best = SimpleNamespace(encoding="utf-8")
|
||||
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
|
||||
@ -137,8 +137,8 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(mod.remote_fetcher, "head", fake_head)
|
||||
monkeypatch.setattr(mod.remote_fetcher, "get", fake_get)
|
||||
mock_best = SimpleNamespace(encoding="utf-8")
|
||||
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
|
||||
monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes)
|
||||
@ -150,7 +150,7 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.
|
||||
|
||||
|
||||
def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed."""
|
||||
"""HEAD 403 → use cloudscraper.get via remote_fetcher.make_request, then proceed."""
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(status_code=403, headers={})
|
||||
@ -167,7 +167,7 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.remote_fetcher, "head", fake_head)
|
||||
monkeypatch.setattr(mod.cloudscraper, "create_scraper", lambda: FakeScraper())
|
||||
mock_best = SimpleNamespace(encoding="utf-8")
|
||||
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
|
||||
@ -192,7 +192,7 @@ def test_get_url_head_non_200_returns_status(monkeypatch: pytest.MonkeyPatch, st
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.remote_fetcher, "head", fake_head)
|
||||
|
||||
out = get_url("https://x.test/fail")
|
||||
assert out == "URL returned status code 500."
|
||||
@ -214,7 +214,7 @@ def test_get_url_content_disposition_filename_detection(monkeypatch: pytest.Monk
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.remote_fetcher, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url))
|
||||
|
||||
out = get_url("https://x.test/fname")
|
||||
@ -241,8 +241,8 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.Mo
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(mod.remote_fetcher, "head", fake_head)
|
||||
monkeypatch.setattr(mod.remote_fetcher, "get", fake_get)
|
||||
|
||||
mock_best = SimpleNamespace(encoding="utf-8")
|
||||
mock_from_bytes = SimpleNamespace(best=lambda: mock_best)
|
||||
|
||||
@ -452,6 +452,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
factory._jinja2_template_renderer = sentinel.jinja2_template_renderer
|
||||
factory._template_transform_max_output_length = 2048
|
||||
factory._http_request_http_client = sentinel.http_client
|
||||
factory._remote_file_http_client = sentinel.remote_file_http_client
|
||||
factory._bound_tool_file_manager_factory = MagicMock(return_value=sentinel.tool_file_manager)
|
||||
factory._file_reference_factory = sentinel.file_reference_factory
|
||||
factory._prompt_message_serializer = sentinel.prompt_message_serializer
|
||||
@ -596,7 +597,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
factory._bound_tool_file_manager_factory.assert_called_once_with()
|
||||
elif constructor_name == "DocumentExtractorNode":
|
||||
assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config
|
||||
assert kwargs["http_client"] is sentinel.http_client
|
||||
assert kwargs["http_client"] is sentinel.remote_file_http_client
|
||||
|
||||
def test_build_llm_compatible_node_init_kwargs_preserves_structured_output_switch(self, factory):
|
||||
node_data = LLMNodeData.model_validate(
|
||||
@ -732,7 +733,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
BuiltinNodeTypes.LLM,
|
||||
"LLMNode",
|
||||
{
|
||||
"http_client": sentinel.http_client,
|
||||
"http_client": sentinel.remote_file_http_client,
|
||||
"llm_file_saver": sentinel.llm_file_saver,
|
||||
"prompt_message_serializer": sentinel.prompt_message_serializer,
|
||||
"retriever_attachment_loader": sentinel.retriever_attachment_loader,
|
||||
@ -743,7 +744,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
"QuestionClassifierNode",
|
||||
{
|
||||
"http_client": sentinel.http_client,
|
||||
"http_client": sentinel.remote_file_http_client,
|
||||
"llm_file_saver": sentinel.llm_file_saver,
|
||||
"prompt_message_serializer": sentinel.prompt_message_serializer,
|
||||
"template_renderer": sentinel.jinja2_template_renderer,
|
||||
|
||||
@ -22,7 +22,7 @@ from graphon.variables.variables import StringVariable
|
||||
def _mock_ssrf_head(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Avoid any real network requests during tests.
|
||||
|
||||
factories.file_factory.remote.get_remote_file_info() uses ssrf_proxy.head
|
||||
factories.file_factory.remote.get_remote_file_info() uses remote_fetcher.head
|
||||
to inspect
|
||||
remote files. We stub it to return a minimal response object with
|
||||
headers so filename/mime/size can be derived deterministically.
|
||||
@ -46,7 +46,7 @@ def _mock_ssrf_head(monkeypatch: pytest.MonkeyPatch):
|
||||
}
|
||||
return SimpleNamespace(status_code=200, headers=headers)
|
||||
|
||||
monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head)
|
||||
monkeypatch.setattr("factories.file_factory.remote.remote_fetcher.head", fake_head)
|
||||
|
||||
|
||||
class TestWorkflowEntry:
|
||||
|
||||
@ -99,7 +99,7 @@ def mock_http_head():
|
||||
},
|
||||
)
|
||||
|
||||
with patch("factories.file_factory.remote.ssrf_proxy.head", autospec=True) as mock_head:
|
||||
with patch("factories.file_factory.remote.remote_fetcher.head", autospec=True) as mock_head:
|
||||
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
|
||||
yield mock_head
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ def _mock_head(monkeypatch: pytest.MonkeyPatch, headers: dict[str, str], status_
|
||||
def _fake_head(url: str, follow_redirects: bool = True):
|
||||
return _FakeResponse(status_code=status_code, headers=headers)
|
||||
|
||||
monkeypatch.setattr("factories.file_factory.remote.ssrf_proxy.head", _fake_head)
|
||||
monkeypatch.setattr("factories.file_factory.remote.remote_fetcher.head", _fake_head)
|
||||
|
||||
|
||||
class TestGetRemoteFileInfo:
|
||||
|
||||
@ -206,7 +206,10 @@ def test_export_rag_pipeline_dsl_raises_when_dataset_missing() -> None:
|
||||
|
||||
|
||||
def test_import_rag_pipeline_url_fetch_error(mocker) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.ssrf_proxy.get", side_effect=Exception("fetch failed"))
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.remote_fetcher.get",
|
||||
side_effect=Exception("fetch failed"),
|
||||
)
|
||||
service = RagPipelineDslService(session=Mock())
|
||||
account = Mock(current_tenant_id="t1")
|
||||
|
||||
@ -813,7 +816,7 @@ def test_import_rag_pipeline_yaml_url_handles_empty_content_after_github_rewrite
|
||||
response = Mock()
|
||||
response.raise_for_status.return_value = None
|
||||
response.content = b""
|
||||
get_mock = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.ssrf_proxy.get", return_value=response)
|
||||
get_mock = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.remote_fetcher.get", return_value=response)
|
||||
service = RagPipelineDslService(session=Mock())
|
||||
account = Mock(current_tenant_id="t1")
|
||||
|
||||
@ -880,7 +883,7 @@ def test_import_rag_pipeline_url_size_exceeds_limit(mocker) -> None:
|
||||
response = Mock()
|
||||
response.raise_for_status.return_value = None
|
||||
response.content = b"x" * (10 * 1024 * 1024 + 1)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.ssrf_proxy.get", return_value=response)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.remote_fetcher.get", return_value=response)
|
||||
service = RagPipelineDslService(session=Mock())
|
||||
account = Mock(current_tenant_id="t1")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user