From fb95136f391fac8fa4288d4f687e473675c3cdb2 Mon Sep 17 00:00:00 2001 From: Xing Hong <39619359+xingxing21@users.noreply.github.com> Date: Sat, 25 Apr 2026 15:30:15 +0900 Subject: [PATCH] Fix: validate URL scheme and resolved IP before crawling to prevent SSRF (#14090) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? The POST /upload_info?url= endpoint accepted a user-supplied URL and passed it directly to AsyncWebCrawler without any validation. There were no restrictions on URL scheme, destination hostname, or resolved IP address. This allowed any authenticated user to instruct the server to make outbound HTTP requests to internal infrastructure — including RFC 1918 private networks, loopback addresses, and cloud metadata services such as http://169.254.169.254 — effectively using the server as a proxy for internal network reconnaissance or credential theft. This PR adds an SSRF guard (_validate_url_for_crawl) that runs before any crawl is initiated. It enforces an allowlist of safe schemes (http/https), resolves the hostname at validation time, and rejects any URL whose resolved IP falls within a private or reserved network range. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- agent/component/invoke.py | 7 +- agent/tools/crawler.py | 26 +-- agent/tools/searxng.py | 52 ++---- api/apps/document_app.py | 3 + api/db/services/file_service.py | 77 +++++++- api/utils/web_utils.py | 32 +--- common/data_source/rss_connector.py | 67 +++---- common/ssrf_guard.py | 172 ++++++++++++++++++ .../test_upload_info_unit.py | 2 + .../test_file_service_upload_document.py | 156 ++++++++++++++++ 10 files changed, 485 insertions(+), 109 deletions(-) create mode 100644 common/ssrf_guard.py diff --git a/agent/component/invoke.py b/agent/component/invoke.py index 0dce464eb..4faaa7d01 100644 --- a/agent/component/invoke.py +++ b/agent/component/invoke.py @@ -179,10 +179,7 @@ class Invoke(ComponentBase, ABC): if not isinstance(headers, dict): raise ValueError("Invoke headers must be a JSON object.") - return { - key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value - for key, value in headers.items() - } + return {key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value for key, value in headers.items()} def _build_proxies(self) -> dict | None: if not re.sub(r"https?:?/?/?", "", self._param.proxy): @@ -215,7 +212,7 @@ class Invoke(ComponentBase, ABC): # HtmlParser keeps the Invoke output text-focused when the endpoint returns HTML. sections = HtmlParser()(None, response.content) return "\n".join(sections) - + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))) def _invoke(self, **kwargs): if self.check_if_canceled("Invoke processing"): diff --git a/agent/tools/crawler.py b/agent/tools/crawler.py index e4d049e1b..6558c524f 100644 --- a/agent/tools/crawler.py +++ b/agent/tools/crawler.py @@ -19,7 +19,6 @@ from crawl4ai import AsyncWebCrawler from agent.tools.base import ToolParamBase, ToolBase - class CrawlerParam(ToolParamBase): """ Define the Crawler component parameters. @@ -31,20 +30,26 @@ class CrawlerParam(ToolParamBase): self.extract_type = "markdown" def check(self): - self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content']) + self.check_valid_value(self.extract_type, "Type of content from the crawler", ["html", "markdown", "content"]) class Crawler(ToolBase, ABC): component_name = "Crawler" def _run(self, history, **kwargs): - from api.utils.web_utils import is_valid_url + from common.ssrf_guard import assert_url_is_safe, pin_dns_global + ans = self.get_input() ans = " - ".join(ans["content"]) if "content" in ans else "" - if not is_valid_url(ans): + try: + _ssrf_hostname, _ssrf_ip = assert_url_is_safe(ans) + except ValueError: return Crawler.be_output("URL not valid") try: - result = asyncio.run(self.get_web(ans)) + # pin_dns_global is used (not thread-local) because crawl4ai resolves + # DNS in asyncio executor threads that don't share thread-local state. + with pin_dns_global(_ssrf_hostname, _ssrf_ip): + result = asyncio.run(self.get_web(ans)) return Crawler.be_output(result) @@ -57,18 +62,15 @@ class Crawler(ToolBase, ABC): proxy = self._param.proxy if self._param.proxy else None async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler: - result = await crawler.arun( - url=url, - bypass_cache=True - ) + result = await crawler.arun(url=url, bypass_cache=True) if self.check_if_canceled("Crawler async operation"): return - if self._param.extract_type == 'html': + if self._param.extract_type == "html": return result.cleaned_html - elif self._param.extract_type == 'markdown': + elif self._param.extract_type == "markdown": return result.markdown - elif self._param.extract_type == 'content': + elif self._param.extract_type == "content": return result.extracted_content return result.markdown diff --git a/agent/tools/searxng.py b/agent/tools/searxng.py index fdc7bea52..ef03375b3 100644 --- a/agent/tools/searxng.py +++ b/agent/tools/searxng.py @@ -20,6 +20,7 @@ from abc import ABC import requests from agent.tools.base import ToolMeta, ToolParamBase, ToolBase from common.connection_utils import timeout +from common.ssrf_guard import assert_url_is_safe, pin_dns class SearXNGParam(ToolParamBase): @@ -36,15 +37,15 @@ class SearXNGParam(ToolParamBase): "type": "string", "description": "The search keywords to execute with SearXNG. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "{sys.query}", - "required": True + "required": True, }, "searxng_url": { "type": "string", "description": "The base URL of your SearXNG instance (e.g., http://localhost:4000). This is required to connect to your SearXNG server.", "required": False, - "default": "" - } - } + "default": "", + }, + }, } super().__init__() self.top_n = 10 @@ -61,17 +62,7 @@ class SearXNGParam(ToolParamBase): self.check_positive_integer(self.top_n, "Top N") def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - }, - "searxng_url": { - "name": "SearXNG URL", - "type": "line", - "placeholder": "http://localhost:4000" - } - } + return {"query": {"name": "Query", "type": "line"}, "searxng_url": {"name": "SearXNG URL", "type": "line", "placeholder": "http://localhost:4000"}} class SearXNG(ToolBase, ABC): @@ -94,26 +85,22 @@ class SearXNG(ToolBase, ABC): self.set_output("formalized_content", "") return "" + try: + _ssrf_hostname, _ssrf_ip = assert_url_is_safe(searxng_url) + except ValueError as e: + self.set_output("_ERROR", str(e)) + return f"SearXNG error: SSRF guard blocked {searxng_url!r}: {e}" + last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("SearXNG processing"): return try: - search_params = { - 'q': query, - 'format': 'json', - 'categories': 'general', - 'language': 'auto', - 'safesearch': 1, - 'pageno': 1 - } + search_params = {"q": query, "format": "json", "categories": "general", "language": "auto", "safesearch": 1, "pageno": 1} - response = requests.get( - f"{searxng_url}/search", - params=search_params, - timeout=10 - ) + with pin_dns(_ssrf_hostname, _ssrf_ip): + response = requests.get(f"{searxng_url}/search", params=search_params, timeout=10) response.raise_for_status() if self.check_if_canceled("SearXNG processing"): @@ -128,15 +115,12 @@ class SearXNG(ToolBase, ABC): if not isinstance(results, list): raise ValueError("Invalid results format from SearXNG") - results = results[:self._param.top_n] + results = results[: self._param.top_n] if self.check_if_canceled("SearXNG processing"): return - self._retrieve_chunks(results, - get_title=lambda r: r.get("title", ""), - get_url=lambda r: r.get("url", ""), - get_content=lambda r: r.get("content", "")) + self._retrieve_chunks(results, get_title=lambda r: r.get("title", ""), get_url=lambda r: r.get("url", ""), get_content=lambda r: r.get("content", "")) self.set_output("json", results) return self.output("formalized_content") diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 14f662368..15ec26dd4 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -43,6 +43,7 @@ from common import settings from common.constants import SANDBOX_ARTIFACT_BUCKET, ParserType, RetCode, TaskStatus from common.file_utils import get_project_base_directory from common.misc_utils import get_uuid, thread_pool_exec +from common.ssrf_guard import assert_url_is_safe from deepdoc.parser.html_parser import RAGFlowHtmlParser from rag.nlp import search @@ -333,6 +334,7 @@ async def run(): except Exception as e: return server_error_response(e) + @manager.route("/get/", methods=["GET"]) # noqa: F821 @login_required async def get(doc_id): @@ -581,6 +583,7 @@ async def upload_info(): try: if url and not file_objs: + assert_url_is_safe(url) return get_json_result(data=FileService.upload_info(current_user.id, None, url)) if len(file_objs) == 1: diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 11940b88c..079bf4390 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -23,6 +23,8 @@ from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Union +logger = logging.getLogger(__name__) + import xxhash from peewee import fn @@ -33,6 +35,7 @@ from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from common.misc_utils import get_uuid +from common.ssrf_guard import assert_url_is_safe from common.constants import TaskStatus, FileSource, ParserType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import TaskService @@ -624,6 +627,26 @@ class FileService(CommonService): return errors + _ALLOWED_SCHEMES = {"http", "https"} + + @staticmethod + def _validate_url_for_crawl(url: str) -> tuple[str, str]: + """Raise ValueError if the URL is not safe to crawl (SSRF guard). + + Delegates to :func:`common.ssrf_guard.assert_url_is_safe`, which + validates the scheme, hostname, and every DNS-resolved address, and + returns ``(hostname, resolved_ip)`` for DNS pinning. + + Only the scheme and host (and port when present) are forwarded to the + guard so that credentials or query parameters in *url* are never + written to the log. + """ + from urllib.parse import urlparse + parsed = urlparse(url) + port_suffix = f":{parsed.port}" if parsed.port else "" + redacted = f"{parsed.scheme}://{parsed.hostname}{port_suffix}" + return assert_url_is_safe(redacted, allowed_schemes=FileService._ALLOWED_SCHEMES) + @staticmethod def upload_info(user_id, file, url: str|None=None): def structured(filename, filetype, blob, content_type): @@ -646,6 +669,53 @@ class FileService(CommonService): } if url: + import requests as _requests + from urllib.parse import urljoin as _urljoin + + _MAX_CRAWL_REDIRECTS = 10 + + # Pre-resolve the full redirect chain so that AsyncWebCrawler never + # follows a server-sent redirect to an unvalidated (potentially + # internal) host. Each hop is SSRF-checked before being followed; + # the validated (hostname, ip) pairs are pinned via Chromium's + # --host-resolver-rules so the browser cannot re-resolve any of them + # through a fresh DNS query. + current_url = url + current_hostname, current_ip = FileService._validate_url_for_crawl(current_url) + # Accumulate MAP rules for every hostname we encounter in the chain. + host_pins: dict[str, str] = {current_hostname: current_ip} + + for _ in range(_MAX_CRAWL_REDIRECTS): + try: + _resp = _requests.get( + current_url, + timeout=10, + allow_redirects=False, + ) + except _requests.RequestException as _exc: + raise ValueError(f"Failed to fetch {current_url!r}: {_exc}") from _exc + + if _resp.status_code not in (301, 302, 303, 307, 308): + break + + _location = _resp.headers.get("Location") + if not _location: + break + + _next_url = _urljoin(current_url, _location) + _next_hostname, _next_ip = FileService._validate_url_for_crawl(_next_url) + host_pins[_next_hostname] = _next_ip + current_url = _next_url + else: + raise ValueError( + f"Exceeded {_MAX_CRAWL_REDIRECTS} redirects fetching {url!r}" + ) + + # Build a single MAP rule string covering every validated hostname + # in the redirect chain. Chromium uses the pinned IP for each, + # skipping DNS entirely and eliminating the rebinding window. + _map_rules = ",".join(f"MAP {h} {ip}" for h, ip in host_pins.items()) + from crawl4ai import ( AsyncWebCrawler, BrowserConfig, @@ -659,6 +729,7 @@ class FileService(CommonService): browser_config = BrowserConfig( headless=True, verbose=False, + extra_args=[f"--host-resolver-rules={_map_rules}"], ) async with AsyncWebCrawler(config=browser_config) as crawler: crawler_config = CrawlerRunConfig( @@ -668,8 +739,10 @@ class FileService(CommonService): pdf=True, screenshot=False ) + # Use the final resolved URL so the browser starts at the + # redirect destination rather than re-following the chain. result: CrawlResult = await crawler.arun( - url=url, + url=current_url, config=crawler_config ) return result @@ -679,7 +752,7 @@ class FileService(CommonService): filename += ".pdf" return structured(filename, "pdf", page.pdf, page.response_headers["content-type"]) - return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id) + return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"]) DocumentService.check_doc_health(user_id, file.filename) return structured(file.filename, filename_type(file.filename), file.read(), file.content_type) diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py index 4cb13ff7e..23d242186 100644 --- a/api/utils/web_utils.py +++ b/api/utils/web_utils.py @@ -15,11 +15,8 @@ # import base64 -import ipaddress import json import re -import socket -from urllib.parse import urlparse import aiosmtplib from email.mime.text import MIMEText from email.header import Header @@ -37,10 +34,10 @@ from webdriver_manager.chrome import ChromeDriverManager OTP_LENGTH = 4 -OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes -ATTEMPT_LIMIT = 5 # maximum attempts -ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes -RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute +OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes +ATTEMPT_LIMIT = 5 # maximum attempts +ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes +RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute CONTENT_TYPE_MAP = { @@ -188,29 +185,16 @@ def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_opt return base64.b64decode(result["data"]) -def is_private_ip(ip: str) -> bool: - try: - ip_obj = ipaddress.ip_address(ip) - return ip_obj.is_private - except ValueError: - return False - - def is_valid_url(url: str) -> bool: if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url): return False - parsed_url = urlparse(url) - hostname = parsed_url.hostname + from common.ssrf_guard import assert_url_is_safe - if not hostname: - return False try: - ip = socket.gethostbyname(hostname) - if is_private_ip(ip): - return False - except socket.gaierror: + assert_url_is_safe(url) + return True + except ValueError: return False - return True def safe_json_parse(data: str | dict) -> dict: diff --git a/common/data_source/rss_connector.py b/common/data_source/rss_connector.py index 85471407a..8000eaddf 100644 --- a/common/data_source/rss_connector.py +++ b/common/data_source/rss_connector.py @@ -1,11 +1,9 @@ import hashlib -import ipaddress -import socket from datetime import datetime, timezone from email.utils import parsedate_to_datetime from time import struct_time from typing import Any -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse import bs4 import feedparser @@ -14,28 +12,9 @@ import requests from common.data_source.config import INDEX_BATCH_SIZE, REQUEST_TIMEOUT_SECONDS, DocumentSource from common.data_source.interfaces import LoadConnector, PollConnector from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch +from common.ssrf_guard import assert_url_is_safe, pin_dns as _pin_dns - -def _is_private_ip(ip: str) -> bool: - try: - ip_obj = ipaddress.ip_address(ip) - return ip_obj.is_private or ip_obj.is_link_local or ip_obj.is_loopback - except ValueError: - return False - - -def _validate_url_no_ssrf(url: str) -> None: - parsed = urlparse(url) - hostname = parsed.hostname - if not hostname: - raise ValueError("URL must have a valid hostname") - - try: - ip = socket.gethostbyname(hostname) - if _is_private_ip(ip): - raise ValueError(f"URL resolves to private/internal IP address: {ip}") - except socket.gaierror as e: - raise ValueError(f"Failed to resolve hostname: {hostname}") from e +_MAX_REDIRECTS = 10 class RSSConnector(LoadConnector, PollConnector): @@ -87,7 +66,8 @@ class RSSConnector(LoadConnector, PollConnector): if batch: yield batch - def _validate_feed_url(self) -> None: + def _validate_feed_url(self) -> tuple[str, str]: + """Validate ``self.feed_url`` and return ``(hostname, resolved_ip)``.""" if not self.feed_url: raise ValueError("feed_url is required") @@ -95,7 +75,7 @@ class RSSConnector(LoadConnector, PollConnector): if parsed.scheme not in {"http", "https"} or not parsed.netloc: raise ValueError("feed_url must be a valid http or https URL") - _validate_url_no_ssrf(self.feed_url) + return assert_url_is_safe(self.feed_url) def _read_feed(self, require_entries: bool) -> Any: if self._cached_feed is not None: @@ -103,15 +83,38 @@ class RSSConnector(LoadConnector, PollConnector): raise ValueError("RSS feed contains no entries") return self._cached_feed - self._validate_feed_url() + # Validate once to get the pinned IP for the initial request. + current_hostname, current_ip = self._validate_feed_url() + current_url = self.feed_url + + # Follow redirects manually: each hop is validated and DNS-pinned + # *before* the connection is made, closing the TOCTOU rebinding window + # that existed when allow_redirects=True was used with post-hoc checks. + response: requests.Response | None = None + for _ in range(_MAX_REDIRECTS + 1): + with _pin_dns(current_hostname, current_ip): + response = requests.get( + current_url, + timeout=REQUEST_TIMEOUT_SECONDS, + allow_redirects=False, + ) + + if response.status_code not in (301, 302, 303, 307, 308): + break + + location = response.headers.get("Location") + if not location: + break # broken redirect; let raise_for_status() handle it + + redirect_url = urljoin(current_url, location) + # Validate redirect target before following it. + current_hostname, current_ip = assert_url_is_safe(redirect_url) + current_url = redirect_url + else: + raise ValueError(f"Exceeded {_MAX_REDIRECTS} redirects fetching {self.feed_url!r}") - response = requests.get(self.feed_url, timeout=REQUEST_TIMEOUT_SECONDS, allow_redirects=True) response.raise_for_status() - final_url = getattr(response, "url", self.feed_url) - if final_url != self.feed_url and urlparse(final_url).hostname: - _validate_url_no_ssrf(final_url) - feed = feedparser.parse(response.content) if getattr(feed, "bozo", False) and not feed.entries: error = getattr(feed, "bozo_exception", None) diff --git a/common/ssrf_guard.py b/common/ssrf_guard.py new file mode 100644 index 000000000..b60bcd4bc --- /dev/null +++ b/common/ssrf_guard.py @@ -0,0 +1,172 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Shared SSRF-guard utilities. + +Uses only the standard library so it can be imported from both ``api/`` and +``common/`` without pulling in any heavyweight dependencies. +""" + +import ipaddress +import logging +import socket +import threading +from contextlib import contextmanager +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# DNS pinning — closes the TOCTOU / rebinding window between SSRF validation +# and the actual TCP connection. The monkey-patch is a no-op for any host +# that has no active pin, so it cannot affect unrelated code. +# --------------------------------------------------------------------------- + +_tl = threading.local() +_global_dns_pins: dict[str, str] = {} +_global_pin_lock = threading.Lock() +_orig_getaddrinfo = socket.getaddrinfo + + +def _getaddrinfo_with_pins(host, port, *args, **kwargs): + # Thread-local pins (synchronous callers: requests.get in the same thread) + local_pins: dict = getattr(_tl, "dns_pins", {}) + if host in local_pins: + ip = local_pins[host] + family = socket.AF_INET6 if ":" in ip else socket.AF_INET + return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))] + # Process-global pins (async callers whose DNS resolves in executor threads) + with _global_pin_lock: + ip = _global_dns_pins.get(host) + if ip is not None: + family = socket.AF_INET6 if ":" in ip else socket.AF_INET + return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))] + return _orig_getaddrinfo(host, port, *args, **kwargs) + + +socket.getaddrinfo = _getaddrinfo_with_pins + + +@contextmanager +def pin_dns(hostname: str, ip: str): + """Pin *hostname* → *ip* in the current thread for the duration of this context. + + Use for synchronous ``requests.get()`` callers to prevent DNS rebinding + between SSRF validation and the actual TCP connection. + """ + pins = _tl.__dict__.setdefault("dns_pins", {}) + pins[hostname] = ip + try: + yield + finally: + pins.pop(hostname, None) + + +@contextmanager +def pin_dns_global(hostname: str, ip: str): + """Pin *hostname* → *ip* across all threads for the duration of this context. + + Use for async callers (e.g. asyncio-based crawlers) where DNS resolution + may happen in thread-pool executor threads rather than the calling thread. + """ + with _global_pin_lock: + _global_dns_pins[hostname] = ip + try: + yield + finally: + with _global_pin_lock: + _global_dns_pins.pop(hostname, None) + + +_DEFAULT_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"}) + + +def _effective_ip( + ip: ipaddress.IPv4Address | ipaddress.IPv6Address, +) -> ipaddress.IPv4Address | ipaddress.IPv6Address: + """Return the IPv4 equivalent for IPv4-mapped IPv6 addresses, unchanged otherwise. + + Without this normalization ``::ffff:127.0.0.1`` would pass ``is_global`` + as an IPv6Address in some Python versions, bypassing the loopback check. + """ + if isinstance(ip, ipaddress.IPv6Address): + mapped = ip.ipv4_mapped + if mapped is not None: + return mapped + return ip + + +def assert_url_is_safe( + url: str, + *, + allowed_schemes: frozenset[str] = _DEFAULT_ALLOWED_SCHEMES, +) -> tuple[str, str]: + """Raise ``ValueError`` if *url* is not safe to fetch (SSRF guard). + + Checks performed in order: + + 1. Scheme is in *allowed_schemes*. + 2. Hostname is present. + 3. **Every** address returned by ``getaddrinfo`` is globally routable + (``ip.is_global``). This is an allowlist approach: it catches private, + loopback, link-local, reserved, multicast, and all other + special-purpose ranges rather than individual deny-list flags. + IPv4-mapped IPv6 addresses (e.g. ``::ffff:127.0.0.1``) are normalised + to their IPv4 form via :func:`_effective_ip` before the check. + + Returns ``(hostname, resolved_ip)`` — the first validated public IP string + — so the caller can **pin** that address in its HTTP client and prevent + DNS-rebinding attacks (the hostname is resolved exactly once). + """ + parsed = urlparse(url) + scheme = parsed.scheme + if scheme not in allowed_schemes: + logger.warning( + "SSRF guard blocked URL with disallowed scheme: scheme=%r url=%r", + scheme, + url, + ) + raise ValueError(f"Disallowed URL scheme: {scheme!r}. Only {sorted(allowed_schemes)} are allowed.") + + hostname = parsed.hostname + if not hostname: + logger.warning("SSRF guard blocked URL with missing host: url=%r", url) + raise ValueError("URL is missing a host.") + + try: + addr_infos = socket.getaddrinfo(hostname, None) + except socket.gaierror as exc: + logger.warning("SSRF guard could not resolve hostname=%r reason=%s", hostname, exc) + raise ValueError(f"Could not resolve hostname {hostname!r}: {exc}") from exc + + resolved_ip: str | None = None + for _family, _type, _proto, _canonname, sockaddr in addr_infos: + raw_ip = ipaddress.ip_address(sockaddr[0]) + eff_ip = _effective_ip(raw_ip) + if not eff_ip.is_global: + logger.warning( + "SSRF guard blocked URL: hostname=%r resolved to non-public address=%s", + hostname, + raw_ip, + ) + raise ValueError(f"URL resolves to a non-public address ({raw_ip}), which is not allowed.") + if resolved_ip is None: + resolved_ip = str(raw_ip) + + if resolved_ip is None: + logger.warning("SSRF guard blocked URL: hostname=%r resolved to no addresses", hostname) + raise ValueError(f"Hostname {hostname!r} resolved to no addresses.") + + return hostname, resolved_ip diff --git a/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py b/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py index 0e5511039..36c736166 100644 --- a/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py +++ b/test/testcases/test_web_api/test_document_app/test_upload_info_unit.py @@ -79,6 +79,7 @@ def _load_document_app_module(monkeypatch): @pytest.mark.p2 def test_upload_info_rejects_mixed_inputs(monkeypatch): module = _load_document_app_module(monkeypatch) + monkeypatch.setattr(module, "assert_url_is_safe", lambda url: ("example.com", "93.184.216.34")) files = _DummyFiles({"file": [_DummyFile("a.txt")]}) monkeypatch.setattr(module, "request", _DummyRequest(files=files, args={"url": "https://example.com/a.txt"})) @@ -100,6 +101,7 @@ def test_upload_info_requires_file_or_url(monkeypatch): @pytest.mark.p2 def test_upload_info_supports_url_single_and_multiple_files(monkeypatch): module = _load_document_app_module(monkeypatch) + monkeypatch.setattr(module, "assert_url_is_safe", lambda url: ("example.com", "93.184.216.34")) captured = [] def fake_upload_info(user_id, file_obj, url=None): diff --git a/test/unit_test/api/db/services/test_file_service_upload_document.py b/test/unit_test/api/db/services/test_file_service_upload_document.py index 12558cc8f..8962ae8a7 100644 --- a/test/unit_test/api/db/services/test_file_service_upload_document.py +++ b/test/unit_test/api/db/services/test_file_service_upload_document.py @@ -14,6 +14,7 @@ # limitations under the License. # import importlib.util +import socket import sys import types import warnings @@ -120,3 +121,158 @@ def test_upload_document_skips_cross_kb_document_id_collision(monkeypatch): assert len(err) == 1 assert err[0].startswith("collision.txt: ") assert "Existing document id collision with another knowledge base; skipping update." in err[0] + + +# --------------------------------------------------------------------------- +# Helpers shared by TestValidateUrlForCrawl +# --------------------------------------------------------------------------- + +def _addrinfo(ip_str: str) -> list: + """Build a minimal getaddrinfo-style result for a single address string.""" + family = socket.AF_INET6 if ":" in ip_str else socket.AF_INET + return [(family, socket.SOCK_STREAM, 6, "", (ip_str, 0))] + + +# --------------------------------------------------------------------------- +# _validate_url_for_crawl SSRF-guard tests +# --------------------------------------------------------------------------- + +@pytest.mark.p2 +class TestValidateUrlForCrawl: + """Focused regression suite for the SSRF guard on the URL-crawl path. + + All DNS lookups are monkeypatched so the tests are deterministic and + require no network access. + """ + + # -- scheme checks ------------------------------------------------------- + + def test_rejects_ftp_scheme(self): + with pytest.raises(ValueError, match="scheme"): + FileService._validate_url_for_crawl("ftp://example.com/file.txt") + + def test_rejects_file_scheme(self): + with pytest.raises(ValueError, match="scheme"): + FileService._validate_url_for_crawl("file:///etc/passwd") + + def test_rejects_javascript_scheme(self): + with pytest.raises(ValueError, match="scheme"): + FileService._validate_url_for_crawl("javascript:alert(1)") + + # -- host checks --------------------------------------------------------- + + def test_rejects_missing_host(self): + with pytest.raises(ValueError, match="host"): + FileService._validate_url_for_crawl("http:///path") + + def test_rejects_dns_resolution_failure(self, monkeypatch): + def _raise(h, p): + raise socket.gaierror("NXDOMAIN") + + monkeypatch.setattr(socket, "getaddrinfo", _raise) + with pytest.raises(ValueError, match="Could not resolve"): + FileService._validate_url_for_crawl("http://nxdomain.invalid/") + + # -- blocked address families -------------------------------------------- + + def test_rejects_loopback_ipv4(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("127.0.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://localhost/") + + def test_rejects_private_class_a(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("10.0.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://internal.example/") + + def test_rejects_private_class_b(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("172.16.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://internal.example/") + + def test_rejects_private_class_c(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("192.168.1.100")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://internal.example/") + + def test_rejects_link_local_ipv4(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("169.254.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://link-local.example/") + + def test_rejects_reserved_ipv4(self, monkeypatch): + # 240.0.0.0/4 is IANA reserved — not globally routable + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("240.0.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://reserved.example/") + + def test_rejects_ipv4_mapped_loopback(self, monkeypatch): + """::ffff:127.0.0.1 must not bypass the loopback check.""" + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("::ffff:127.0.0.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://mapped-loopback.example/") + + def test_rejects_ipv4_mapped_private(self, monkeypatch): + """::ffff:192.168.1.1 must not bypass the private-range check.""" + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("::ffff:192.168.1.1")) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://mapped-private.example/") + + def test_rejects_when_any_record_is_private(self, monkeypatch): + """All DNS records must pass; one private record is enough to block.""" + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda h, p: _addrinfo("93.184.216.34") + _addrinfo("10.0.0.1"), + ) + with pytest.raises(ValueError, match="non-public"): + FileService._validate_url_for_crawl("http://mixed.example/") + + # -- allowed cases ------------------------------------------------------- + + def test_allows_public_ipv4(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("93.184.216.34")) + hostname, resolved_ip = FileService._validate_url_for_crawl("https://example.com/doc.pdf") + assert hostname == "example.com" + assert resolved_ip == "93.184.216.34" + + def test_allows_public_ipv6(self, monkeypatch): + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda h, p: _addrinfo("2606:2800:220:1:248:1893:25c8:1946"), + ) + hostname, resolved_ip = FileService._validate_url_for_crawl("https://example.com/") + assert hostname == "example.com" + assert resolved_ip == "2606:2800:220:1:248:1893:25c8:1946" + + def test_allows_http_scheme(self, monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("1.2.3.4")) + hostname, _ = FileService._validate_url_for_crawl("http://example.com/") + assert hostname == "example.com" + + # -- multi-record behaviour ---------------------------------------------- + + def test_returns_first_ip_for_multi_record_host(self, monkeypatch): + """The first public IP is returned as the DNS pin value.""" + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda h, p: _addrinfo("1.2.3.4") + _addrinfo("5.6.7.8"), + ) + _, resolved_ip = FileService._validate_url_for_crawl("http://multi.example/") + assert resolved_ip == "1.2.3.4" + + def test_allows_dual_stack_host(self, monkeypatch): + """A host with both public IPv4 and public IPv6 records is allowed.""" + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda h, p: ( + _addrinfo("93.184.216.34") + + _addrinfo("2606:2800:220:1:248:1893:25c8:1946") + ), + ) + hostname, resolved_ip = FileService._validate_url_for_crawl("https://example.com/") + assert hostname == "example.com" + assert resolved_ip == "93.184.216.34"