diff --git a/agent/component/invoke.py b/agent/component/invoke.py index 0dce464eb..4faaa7d01 100644 --- a/agent/component/invoke.py +++ b/agent/component/invoke.py @@ -179,10 +179,7 @@ class Invoke(ComponentBase, ABC): if not isinstance(headers, dict): raise ValueError("Invoke headers must be a JSON object.") - return { - key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value - for key, value in headers.items() - } + return {key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value for key, value in headers.items()} def _build_proxies(self) -> dict | None: if not re.sub(r"https?:?/?/?", "", self._param.proxy): @@ -215,7 +212,7 @@ class Invoke(ComponentBase, ABC): # HtmlParser keeps the Invoke output text-focused when the endpoint returns HTML. sections = HtmlParser()(None, response.content) return "\n".join(sections) - + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))) def _invoke(self, **kwargs): if self.check_if_canceled("Invoke processing"): diff --git a/agent/tools/crawler.py b/agent/tools/crawler.py index e4d049e1b..6558c524f 100644 --- a/agent/tools/crawler.py +++ b/agent/tools/crawler.py @@ -19,7 +19,6 @@ from crawl4ai import AsyncWebCrawler from agent.tools.base import ToolParamBase, ToolBase - class CrawlerParam(ToolParamBase): """ Define the Crawler component parameters. @@ -31,20 +30,26 @@ class CrawlerParam(ToolParamBase): self.extract_type = "markdown" def check(self): - self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content']) + self.check_valid_value(self.extract_type, "Type of content from the crawler", ["html", "markdown", "content"]) class Crawler(ToolBase, ABC): component_name = "Crawler" def _run(self, history, **kwargs): - from api.utils.web_utils import is_valid_url + from common.ssrf_guard import assert_url_is_safe, pin_dns_global + ans = self.get_input() ans = " - ".join(ans["content"]) if "content" in ans else "" - if not is_valid_url(ans): + try: + _ssrf_hostname, _ssrf_ip = assert_url_is_safe(ans) + except ValueError: return Crawler.be_output("URL not valid") try: - result = asyncio.run(self.get_web(ans)) + # pin_dns_global is used (not thread-local) because crawl4ai resolves + # DNS in asyncio executor threads that don't share thread-local state. + with pin_dns_global(_ssrf_hostname, _ssrf_ip): + result = asyncio.run(self.get_web(ans)) return Crawler.be_output(result) @@ -57,18 +62,15 @@ class Crawler(ToolBase, ABC): proxy = self._param.proxy if self._param.proxy else None async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler: - result = await crawler.arun( - url=url, - bypass_cache=True - ) + result = await crawler.arun(url=url, bypass_cache=True) if self.check_if_canceled("Crawler async operation"): return - if self._param.extract_type == 'html': + if self._param.extract_type == "html": return result.cleaned_html - elif self._param.extract_type == 'markdown': + elif self._param.extract_type == "markdown": return result.markdown - elif self._param.extract_type == 'content': + elif self._param.extract_type == "content": return result.extracted_content return result.markdown diff --git a/agent/tools/searxng.py b/agent/tools/searxng.py index fdc7bea52..ef03375b3 100644 --- a/agent/tools/searxng.py +++ b/agent/tools/searxng.py @@ -20,6 +20,7 @@ from abc import ABC import requests from agent.tools.base import ToolMeta, ToolParamBase, ToolBase from common.connection_utils import timeout +from common.ssrf_guard import assert_url_is_safe, pin_dns class SearXNGParam(ToolParamBase): @@ -36,15 +37,15 @@ class SearXNGParam(ToolParamBase): "type": "string", "description": "The search keywords to execute with SearXNG. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "{sys.query}", - "required": True + "required": True, }, "searxng_url": { "type": "string", "description": "The base URL of your SearXNG instance (e.g., http://localhost:4000). This is required to connect to your SearXNG server.", "required": False, - "default": "" - } - } + "default": "", + }, + }, } super().__init__() self.top_n = 10 @@ -61,17 +62,7 @@ class SearXNGParam(ToolParamBase): self.check_positive_integer(self.top_n, "Top N") def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - }, - "searxng_url": { - "name": "SearXNG URL", - "type": "line", - "placeholder": "http://localhost:4000" - } - } + return {"query": {"name": "Query", "type": "line"}, "searxng_url": {"name": "SearXNG URL", "type": "line", "placeholder": "http://localhost:4000"}} class SearXNG(ToolBase, ABC): @@ -94,26 +85,22 @@ class SearXNG(ToolBase, ABC): self.set_output("formalized_content", "") return "" + try: + _ssrf_hostname, _ssrf_ip = assert_url_is_safe(searxng_url) + except ValueError as e: + self.set_output("_ERROR", str(e)) + return f"SearXNG error: SSRF guard blocked {searxng_url!r}: {e}" + last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("SearXNG processing"): return try: - search_params = { - 'q': query, - 'format': 'json', - 'categories': 'general', - 'language': 'auto', - 'safesearch': 1, - 'pageno': 1 - } + search_params = {"q": query, "format": "json", "categories": "general", "language": "auto", "safesearch": 1, "pageno": 1} - response = requests.get( - f"{searxng_url}/search", - params=search_params, - timeout=10 - ) + with pin_dns(_ssrf_hostname, _ssrf_ip): + response = requests.get(f"{searxng_url}/search", params=search_params, timeout=10) response.raise_for_status() if self.check_if_canceled("SearXNG processing"): @@ -128,15 +115,12 @@ class SearXNG(ToolBase, ABC): if not isinstance(results, list): raise ValueError("Invalid results format from SearXNG") - results = results[:self._param.top_n] + results = results[: self._param.top_n] if self.check_if_canceled("SearXNG processing"): return - self._retrieve_chunks(results, - get_title=lambda r: r.get("title", ""), - get_url=lambda r: r.get("url", ""), - get_content=lambda r: r.get("content", "")) + self._retrieve_chunks(results, get_title=lambda r: r.get("title", ""), get_url=lambda r: r.get("url", ""), get_content=lambda r: r.get("content", "")) self.set_output("json", results) return self.output("formalized_content") diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 14f662368..15ec26dd4 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -43,6 +43,7 @@ from common import settings from common.constants import SANDBOX_ARTIFACT_BUCKET, ParserType, RetCode, TaskStatus from common.file_utils import get_project_base_directory from common.misc_utils import get_uuid, thread_pool_exec +from common.ssrf_guard import assert_url_is_safe from deepdoc.parser.html_parser import RAGFlowHtmlParser from rag.nlp import search @@ -333,6 +334,7 @@ async def run(): except Exception as e: return server_error_response(e) + @manager.route("/get/", methods=["GET"]) # noqa: F821 @login_required async def get(doc_id): @@ -581,6 +583,7 @@ async def upload_info(): try: if url and not file_objs: + assert_url_is_safe(url) return get_json_result(data=FileService.upload_info(current_user.id, None, url)) if len(file_objs) == 1: diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 11940b88c..079bf4390 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -23,6 +23,8 @@ from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Union +logger = logging.getLogger(__name__) + import xxhash from peewee import fn @@ -33,6 +35,7 @@ from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from common.misc_utils import get_uuid +from common.ssrf_guard import assert_url_is_safe from common.constants import TaskStatus, FileSource, ParserType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import TaskService @@ -624,6 +627,26 @@ class FileService(CommonService): return errors + _ALLOWED_SCHEMES = {"http", "https"} + + @staticmethod + def _validate_url_for_crawl(url: str) -> tuple[str, str]: + """Raise ValueError if the URL is not safe to crawl (SSRF guard). + + Delegates to :func:`common.ssrf_guard.assert_url_is_safe`, which + validates the scheme, hostname, and every DNS-resolved address, and + returns ``(hostname, resolved_ip)`` for DNS pinning. + + Only the scheme and host (and port when present) are forwarded to the + guard so that credentials or query parameters in *url* are never + written to the log. + """ + from urllib.parse import urlparse + parsed = urlparse(url) + port_suffix = f":{parsed.port}" if parsed.port else "" + redacted = f"{parsed.scheme}://{parsed.hostname}{port_suffix}" + return assert_url_is_safe(redacted, allowed_schemes=FileService._ALLOWED_SCHEMES) + @staticmethod def upload_info(user_id, file, url: str|None=None): def structured(filename, filetype, blob, content_type): @@ -646,6 +669,53 @@ class FileService(CommonService): } if url: + import requests as _requests + from urllib.parse import urljoin as _urljoin + + _MAX_CRAWL_REDIRECTS = 10 + + # Pre-resolve the full redirect chain so that AsyncWebCrawler never + # follows a server-sent redirect to an unvalidated (potentially + # internal) host. Each hop is SSRF-checked before being followed; + # the validated (hostname, ip) pairs are pinned via Chromium's + # --host-resolver-rules so the browser cannot re-resolve any of them + # through a fresh DNS query. + current_url = url + current_hostname, current_ip = FileService._validate_url_for_crawl(current_url) + # Accumulate MAP rules for every hostname we encounter in the chain. + host_pins: dict[str, str] = {current_hostname: current_ip} + + for _ in range(_MAX_CRAWL_REDIRECTS): + try: + _resp = _requests.get( + current_url, + timeout=10, + allow_redirects=False, + ) + except _requests.RequestException as _exc: + raise ValueError(f"Failed to fetch {current_url!r}: {_exc}") from _exc + + if _resp.status_code not in (301, 302, 303, 307, 308): + break + + _location = _resp.headers.get("Location") + if not _location: + break + + _next_url = _urljoin(current_url, _location) + _next_hostname, _next_ip = FileService._validate_url_for_crawl(_next_url) + host_pins[_next_hostname] = _next_ip + current_url = _next_url + else: + raise ValueError( + f"Exceeded {_MAX_CRAWL_REDIRECTS} redirects fetching {url!r}" + ) + + # Build a single MAP rule string covering every validated hostname + # in the redirect chain. Chromium uses the pinned IP for each, + # skipping DNS entirely and eliminating the rebinding window. + _map_rules = ",".join(f"MAP {h} {ip}" for h, ip in host_pins.items()) + from crawl4ai import ( AsyncWebCrawler, BrowserConfig, @@ -659,6 +729,7 @@ class FileService(CommonService): browser_config = BrowserConfig( headless=True, verbose=False, + extra_args=[f"--host-resolver-rules={_map_rules}"], ) async with AsyncWebCrawler(config=browser_config) as crawler: crawler_config = CrawlerRunConfig( @@ -668,8 +739,10 @@ class FileService(CommonService): pdf=True, screenshot=False ) + # Use the final resolved URL so the browser starts at the + # redirect destination rather than re-following the chain. result: CrawlResult = await crawler.arun( - url=url, + url=current_url, config=crawler_config ) return result @@ -679,7 +752,7 @@ class FileService(CommonService): filename += ".pdf" return structured(filename, "pdf", page.pdf, page.response_headers["content-type"]) - return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id) + return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"]) DocumentService.check_doc_health(user_id, file.filename) return structured(file.filename, filename_type(file.filename), file.read(), file.content_type) diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py index 4cb13ff7e..23d242186 100644 --- a/api/utils/web_utils.py +++ b/api/utils/web_utils.py @@ -15,11 +15,8 @@ # import base64 -import ipaddress import json import re -import socket -from urllib.parse import urlparse import aiosmtplib from email.mime.text import MIMEText from email.header import Header @@ -37,10 +34,10 @@ from webdriver_manager.chrome import ChromeDriverManager OTP_LENGTH = 4 -OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes -ATTEMPT_LIMIT = 5 # maximum attempts -ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes -RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute +OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes +ATTEMPT_LIMIT = 5 # maximum attempts +ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes +RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute CONTENT_TYPE_MAP = { @@ -188,29 +185,16 @@ def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_opt return base64.b64decode(result["data"]) -def is_private_ip(ip: str) -> bool: - try: - ip_obj = ipaddress.ip_address(ip) - return ip_obj.is_private - except ValueError: - return False - - def is_valid_url(url: str) -> bool: if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url): return False - parsed_url = urlparse(url) - hostname = parsed_url.hostname + from common.ssrf_guard import assert_url_is_safe - if not hostname: - return False try: - ip = socket.gethostbyname(hostname) - if is_private_ip(ip): - return False - except socket.gaierror: + assert_url_is_safe(url) + return True + except ValueError: return False - return True def safe_json_parse(data: str | dict) -> dict: diff --git a/common/data_source/rss_connector.py b/common/data_source/rss_connector.py index 85471407a..8000eaddf 100644 --- a/common/data_source/rss_connector.py +++ b/common/data_source/rss_connector.py @@ -1,11 +1,9 @@ import hashlib -import ipaddress -import socket from datetime import datetime, timezone from email.utils import parsedate_to_datetime from time import struct_time from typing import Any -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse import bs4 import feedparser @@ -14,28 +12,9 @@ import requests from common.data_source.config import INDEX_BATCH_SIZE, REQUEST_TIMEOUT_SECONDS, DocumentSource from common.data_source.interfaces import LoadConnector, PollConnector from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch +from common.ssrf_guard import assert_url_is_safe, pin_dns as _pin_dns - -def _is_private_ip(ip: str) -> bool: - try: - ip_obj = ipaddress.ip_address(ip) - return ip_obj.is_private or ip_obj.is_link_local or ip_obj.is_loopback - except ValueError: - return False - - -def _validate_url_no_ssrf(url: str) -> None: - parsed = urlparse(url) - hostname = parsed.hostname - if not hostname: - raise ValueError("URL must have a valid hostname") - - try: - ip = socket.gethostbyname(hostname) - if _is_private_ip(ip): - raise ValueError(f"URL resolves to private/internal IP address: {ip}") - except socket.gaierror as e: - raise ValueError(f"Failed to resolve hostname: {hostname}") from e +_MAX_REDIRECTS = 10 class RSSConnector(LoadConnector, PollConnector): @@ -87,7 +66,8 @@ class RSSConnector(LoadConnector, PollConnector): if batch: yield batch - def _validate_feed_url(self) -> None: + def _validate_feed_url(self) -> tuple[str, str]: + """Validate ``self.feed_url`` and return ``(hostname, resolved_ip)``.""" if not self.feed_url: raise ValueError("feed_url is required") @@ -95,7 +75,7 @@ class RSSConnector(LoadConnector, PollConnector): if parsed.scheme not in {"http", "https"} or not parsed.netloc: raise ValueError("feed_url must be a valid http or https URL") - _validate_url_no_ssrf(self.feed_url) + return assert_url_is_safe(self.feed_url) def _read_feed(self, require_entries: bool) -> Any: if self._cached_feed is not None: @@ -103,15 +83,38 @@ class RSSConnector(LoadConnector, PollConnector): raise ValueError("RSS feed contains no entries") return self._cached_feed - self._validate_feed_url() + # Validate once to get the pinned IP for the initial request. + current_hostname, current_ip = self._validate_feed_url() + current_url = self.feed_url + + # Follow redirects manually: each hop is validated and DNS-pinned + # *before* the connection is made, closing the TOCTOU rebinding window + # that existed when allow_redirects=True was used with post-hoc checks. + response: requests.Response | None = None + for _ in range(_MAX_REDIRECTS + 1): + with _pin_dns(current_hostname, current_ip): + response = requests.get( + current_url, + timeout=REQUEST_TIMEOUT_SECONDS, + allow_redirects=False, + ) + + if response.status_code not in (301, 302, 303, 307, 308): + break + + location = response.headers.get("Location") + if not location: + break # broken redirect; let raise_for_status() handle it + + redirect_url = urljoin(current_url, location) + # Validate redirect target before following it. + current_hostname, current_ip = assert_url_is_safe(redirect_url) + current_url = redirect_url + else: + raise ValueError(f"Exceeded {_MAX_REDIRECTS} redirects fetching {self.feed_url!r}") - response = requests.get(self.feed_url, timeout=REQUEST_TIMEOUT_SECONDS, allow_redirects=True) response.raise_for_status() - final_url = getattr(response, "url", self.feed_url) - if final_url != self.feed_url and urlparse(final_url).hostname: - _validate_url_no_ssrf(final_url) - feed = feedparser.parse(response.content) if getattr(feed, "bozo", False) and not feed.entries: error = getattr(feed, "bozo_exception", None) diff --git a/common/ssrf_guard.py b/common/ssrf_guard.py new file mode 100644 index 000000000..b60bcd4bc --- /dev/null +++ b/common/ssrf_guard.py @@ -0,0 +1,172 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Shared SSRF-guard utilities. + +Uses only the standard library so it can be imported from both ``api/`` and +``common/`` without pulling in any heavyweight dependencies. +""" + +import ipaddress +import logging +import socket +import threading +from contextlib import contextmanager +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# DNS pinning — closes the TOCTOU / rebinding window between SSRF validation +# and the actual TCP connection. The monkey-patch is a no-op for any host +# that has no active pin, so it cannot affect unrelated code. +# --------------------------------------------------------------------------- + +_tl = threading.local() +_global_dns_pins: dict[str, str] = {} +_global_pin_lock = threading.Lock() +_orig_getaddrinfo = socket.getaddrinfo + + +def _getaddrinfo_with_pins(host, port, *args, **kwargs): + # Thread-local pins (synchronous callers: requests.get in the same thread) + local_pins: dict = getattr(_tl, "dns_pins", {}) + if host in local_pins: + ip = local_pins[host] + family = socket.AF_INET6 if ":" in ip else socket.AF_INET + return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))] + # Process-global pins (async callers whose DNS resolves in executor threads) + with _global_pin_lock: + ip = _global_dns_pins.get(host) + if ip is not None: + family = socket.AF_INET6 if ":" in ip else socket.AF_INET + return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))] + return _orig_getaddrinfo(host, port, *args, **kwargs) + + +socket.getaddrinfo = _getaddrinfo_with_pins + + +@contextmanager +def pin_dns(hostname: str, ip: str): + """Pin *hostname* → *ip* in the current thread for the duration of this context. + + Use for synchronous ``requests.get()`` callers to prevent DNS rebinding + between SSRF validation and the actual TCP connection. + """ + pins = _tl.__dict__.setdefault("dns_pins", {}) + pins[hostname] = ip + try: + yield + finally: + pins.pop(hostname, None) + + +@contextmanager +def pin_dns_global(hostname: str, ip: str): + """Pin *hostname* → *ip* across all threads for the duration of this context. + + Use for async callers (e.g. asyncio-based crawlers) where DNS resolution + may happen in thread-pool executor threads rather than the calling thread. + """ + with _global_pin_lock: + _global_dns_pins[hostname] = ip + try: + yield + finally: + with _global_pin_lock: + _global_dns_pins.pop(hostname, None) + + +_DEFAULT_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"}) + + +def _effective_ip( + ip: ipaddress.IPv4Address | ipaddress.IPv6Address, +) -> ipaddress.IPv4Address | ipaddress.IPv6Address: + """Return the IPv4 equivalent for IPv4-mapped IPv6 addresses, unchanged otherwise. + + Without this normalization ``::ffff:127.0.0.1`` would pass ``is_global`` + as an IPv6Address in some Python versions, bypassing the loopback check. + """ + if isinstance(ip, ipaddress.IPv6Address): + mapped = ip.ipv4_mapped + if mapped is not None: + return mapped + return ip + + +def assert_url_is_safe( + url: str, + *, + allowed_schemes: frozenset[str] = _DEFAULT_ALLOWED_SCHEMES, +) -> tuple[str, str]: + """Raise ``ValueError`` if *url* is not safe to fetch (SSRF guard). + + Checks performed in order: + + 1. Scheme is in *allowed_schemes*. + 2. Hostname is present. + 3. **Every** address returned by ``getaddrinfo`` is globally routable + (``ip.is_global``). This is an allowlist approach: it catches private, + loopback, link-local, reserved, multicast, and all other + special-purpose ranges rather than individual deny-list flags. + IPv4-mapped IPv6 addresses (e.g. ``::ffff:127.0.0.1``) are normalised + to their IPv4 form via :func:`_effective_ip` before the check. + + Returns ``(hostname, resolved_ip)`` — the first validated public IP string + — so the caller can **pin** that address in its HTTP client and prevent + DNS-rebinding attacks (the hostname is resolved exactly once). + """ + parsed = urlparse(url) + scheme = parsed.scheme + if scheme not in allowed_schemes: + logger.warning( + "SSRF guard blocked URL with disallowed scheme: scheme=%r url=%r", + scheme, + url, + ) + raise ValueError(f"Disallowed URL scheme: {scheme!r}. Only {sorted(allowed_schemes)} are allowed.") + + hostname = parsed.hostname + if not hostname: + logger.warning("SSRF guard blocked URL with missing host: url=%r", url) + raise ValueError("URL is missing a host.") + + try: + addr_infos = socket.getaddrinfo(hostname, None) + except socket.gaierror as exc: + logger.warning("SSRF guard could not resolve hostname=%r reason=%s", hostname, exc) + raise ValueError(f"Could not resolve hostname {hostname!r}: {exc}") from exc + + resolved_ip: str | None = None + for _family, _type, _proto, _canonname, sockaddr in addr_infos: + raw_ip = ipaddress.ip_address(sockaddr[0]) + eff_ip = _effective_ip(raw_ip) + if not eff_ip.is_global: + logger.warning( + "SSRF guard blocked URL: hostname=%r resolved to non-public address=%s", + hostname, + raw_ip, + ) + raise ValueError(f"URL resolves to a non-public address ({raw_ip}), which is not allowed.") + if resolved_ip is None: + resolved_ip = str(raw_ip) + + if resolved_ip is None: + logger.warning("SSRF guard blocked URL: hostname=%r resolved to no addresses", hostname) + raise ValueError(f"Hostname {hostname!r} resolved to no addresses.") + + return hostname, resolved_ip diff --git a/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py b/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py index 0e5511039..36c736166 100644 --- a/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py +++ b/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py @@ -79,6 +79,7 @@ def _load_document_app_module(monkeypatch): @pytest.mark.p2 def test_upload_info_rejects_mixed_inputs(monkeypatch): module = _load_document_app_module(monkeypatch) + monkeypatch.setattr(module, "assert_url_is_safe", lambda url: ("example.com", "93.184.216.34")) files = _DummyFiles({"file": [_DummyFile("a.txt")]}) monkeypatch.setattr(module, "request", _DummyRequest(files=files, args={"url": "https://example.com/a.txt"})) @@ -100,6 +101,7 @@ def test_upload_info_requires_file_or_url(monkeypatch): @pytest.mark.p2 def test_upload_info_supports_url_single_and_multiple_files(monkeypatch): module = _load_document_app_module(monkeypatch) + monkeypatch.setattr(module, "assert_url_is_safe", lambda url: ("example.com", "93.184.216.34")) captured = [] def fake_upload_info(user_id, file_obj, url=None): diff --git a/test/unit_test/api/db/services/test_file_service_upload_document.py b/test/unit_test/api/db/services/test_file_service_upload_document.py index 12558cc8f..8962ae8a7 100644 --- a/test/unit_test/api/db/services/test_file_service_upload_document.py +++ b/test/unit_test/api/db/services/test_file_service_upload_document.py @@ -14,6 +14,7 @@ # limitations under the License. # import importlib.util +import socket import sys import types import warnings @@ -120,3 +121,158 @@ def test_upload_document_skips_cross_kb_document_id_collision(monkeypatch): assert len(err) == 1 assert err[0].startswith("collision.txt: ") assert "Existing document id collision with another knowledge base; skipping update." in err[0] + + +# --------------------------------------------------------------------------- +# Helpers shared by TestValidateUrlForCrawl +# --------------------------------------------------------------------------- + +def _addrinfo(ip_str: str) -> list: + """Build a minimal getaddrinfo-style result for a single address string.""" + family = socket.AF_INET6 if ":" in ip_str else socket.AF_INET + return [(family, socket.SOCK_STREAM, 6, "", (ip_str, 0))] + + +# --------------------------------------------------------------------------- +# _validate_url_for_crawl SSRF-guard tests +# --------------------------------------------------------------------------- + +@pytest.mark.p2 +class TestValidateUrlForCrawl: + """Focused regression suite for the SSRF guard on the URL-crawl path. + + All DNS lookups are monkeypatched so the tests are deterministic and + require no network access. + """ + + # -- scheme checks ------------------------------------------------------- + + def test_rejects_ftp_scheme(self): + with pytest.raises(ValueError, match="scheme"): + FileService._validate_url_for_crawl("ftp://example.com/file.txt") + + def test_rejects_file_scheme(self): + with pytest.raises(ValueError, match="scheme"): + FileService._validate_url_for_crawl("file:///etc/passwd") + + def test_rejects_javascript_scheme(self): + with pytest.raises(ValueError, match="scheme"): + FileService._validate_url_for_crawl("javascript:alert(1)") + + # -- host checks --------------------------------------------------------- + + def test_rejects_missing_host(self): + with pytest.raises(ValueError, match="host"): + FileService._validate_url_for_crawl("http:///path") + + def test_rejects_dns_resolution_failure(self, monkeypatch): + def _raise(h, p): + raise socket.gaierror("NXDOMAIN") + + monkeypatch.setattr(socket, "getaddrinfo", _raise) + with pytest.raises(ValueError, match="Could not resolve"): + FileService._validate_url_for_crawl("http://nxdomain.invalid/") + + # -- blocked address families -------------------------------------------- + + def test_rejects_loopback_ipv4(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("127.0.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://localhost/") + + def test_rejects_private_class_a(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("10.0.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://internal.example/") + + def test_rejects_private_class_b(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("172.16.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://internal.example/") + + def test_rejects_private_class_c(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("192.168.1.100")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://internal.example/") + + def test_rejects_link_local_ipv4(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("169.254.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://link-local.example/") + + def test_rejects_reserved_ipv4(self, monkeypatch): + # 240.0.0.0/4 is IANA reserved — not globally routable + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("240.0.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://reserved.example/") + + def test_rejects_ipv4_mapped_loopback(self, monkeypatch): + """::ffff:127.0.0.1 must not bypass the loopback check.""" + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("::ffff:127.0.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://mapped-loopback.example/") + + def test_rejects_ipv4_mapped_private(self, monkeypatch): + """::ffff:192.168.1.1 must not bypass the private-range check.""" + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("::ffff:192.168.1.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://mapped-private.example/") + + def test_rejects_when_any_record_is_private(self, monkeypatch): + """All DNS records must pass; one private record is enough to block.""" + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda h, p: _addrinfo("93.184.216.34") + _addrinfo("10.0.0.1"), + ) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://mixed.example/") + + # -- allowed cases ------------------------------------------------------- + + def test_allows_public_ipv4(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("93.184.216.34")) + hostname, resolved_ip = FileService._validate_url_for_crawl("https://example.com/doc.pdf") + assert hostname == "example.com" + assert resolved_ip == "93.184.216.34" + + def test_allows_public_ipv6(self, monkeypatch): + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda h, p: _addrinfo("2606:2800:220:1:248:1893:25c8:1946"), + ) + hostname, resolved_ip = FileService._validate_url_for_crawl("https://example.com/") + assert hostname == "example.com" + assert resolved_ip == "2606:2800:220:1:248:1893:25c8:1946" + + def test_allows_http_scheme(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("1.2.3.4")) + hostname, _ = FileService._validate_url_for_crawl("http://example.com/") + assert hostname == "example.com" + + # -- multi-record behaviour ---------------------------------------------- + + def test_returns_first_ip_for_multi_record_host(self, monkeypatch): + """The first public IP is returned as the DNS pin value.""" + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda h, p: _addrinfo("1.2.3.4") + _addrinfo("5.6.7.8"), + ) + _, resolved_ip = FileService._validate_url_for_crawl("http://multi.example/") + assert resolved_ip == "1.2.3.4" + + def test_allows_dual_stack_host(self, monkeypatch): + """A host with both public IPv4 and public IPv6 records is allowed.""" + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda h, p: ( + _addrinfo("93.184.216.34") + + _addrinfo("2606:2800:220:1:248:1893:25c8:1946") + ), + ) + hostname, resolved_ip = FileService._validate_url_for_crawl("https://example.com/") + assert hostname == "example.com" + assert resolved_ip == "93.184.216.34"