Files
ragflow/common/ssrf_guard.py
Xing Hong fb95136f39 Fix: validate URL scheme and resolved IP before crawling to prevent SSRF (#14090)
### What problem does this PR solve?

The POST /upload_info?url=<url> endpoint accepted a user-supplied URL
and passed it directly to AsyncWebCrawler without any validation. There
were no restrictions on URL scheme, destination hostname, or resolved IP
address. This allowed any authenticated user to instruct the server to
make outbound HTTP requests to internal infrastructure — including RFC
1918 private networks, loopback addresses, and cloud metadata services
such as http://169.254.169.254 — effectively using the server as a proxy
for internal network reconnaissance or credential theft.

This PR adds an SSRF guard (_validate_url_for_crawl) that runs before
any crawl is initiated. It enforces an allowlist of safe schemes
(http/https), resolves the hostname at validation time, and rejects any
URL whose resolved IP falls within a private or reserved network range.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-04-25 14:30:15 +08:00

173 lines
6.3 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Shared SSRF-guard utilities.
Uses only the standard library so it can be imported from both ``api/`` and
``common/`` without pulling in any heavyweight dependencies.
"""
import ipaddress
import logging
import socket
import threading
from contextlib import contextmanager
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DNS pinning — closes the TOCTOU / rebinding window between SSRF validation
# and the actual TCP connection. The monkey-patch is a no-op for any host
# that has no active pin, so it cannot affect unrelated code.
# ---------------------------------------------------------------------------
_tl = threading.local()
_global_dns_pins: dict[str, str] = {}
_global_pin_lock = threading.Lock()
_orig_getaddrinfo = socket.getaddrinfo
def _getaddrinfo_with_pins(host, port, *args, **kwargs):
# Thread-local pins (synchronous callers: requests.get in the same thread)
local_pins: dict = getattr(_tl, "dns_pins", {})
if host in local_pins:
ip = local_pins[host]
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))]
# Process-global pins (async callers whose DNS resolves in executor threads)
with _global_pin_lock:
ip = _global_dns_pins.get(host)
if ip is not None:
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))]
return _orig_getaddrinfo(host, port, *args, **kwargs)
socket.getaddrinfo = _getaddrinfo_with_pins
@contextmanager
def pin_dns(hostname: str, ip: str):
"""Pin *hostname* → *ip* in the current thread for the duration of this context.
Use for synchronous ``requests.get()`` callers to prevent DNS rebinding
between SSRF validation and the actual TCP connection.
"""
pins = _tl.__dict__.setdefault("dns_pins", {})
pins[hostname] = ip
try:
yield
finally:
pins.pop(hostname, None)
@contextmanager
def pin_dns_global(hostname: str, ip: str):
"""Pin *hostname* → *ip* across all threads for the duration of this context.
Use for async callers (e.g. asyncio-based crawlers) where DNS resolution
may happen in thread-pool executor threads rather than the calling thread.
"""
with _global_pin_lock:
_global_dns_pins[hostname] = ip
try:
yield
finally:
with _global_pin_lock:
_global_dns_pins.pop(hostname, None)
_DEFAULT_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"})
def _effective_ip(
ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
) -> ipaddress.IPv4Address | ipaddress.IPv6Address:
"""Return the IPv4 equivalent for IPv4-mapped IPv6 addresses, unchanged otherwise.
Without this normalization ``::ffff:127.0.0.1`` would pass ``is_global``
as an IPv6Address in some Python versions, bypassing the loopback check.
"""
if isinstance(ip, ipaddress.IPv6Address):
mapped = ip.ipv4_mapped
if mapped is not None:
return mapped
return ip
def assert_url_is_safe(
url: str,
*,
allowed_schemes: frozenset[str] = _DEFAULT_ALLOWED_SCHEMES,
) -> tuple[str, str]:
"""Raise ``ValueError`` if *url* is not safe to fetch (SSRF guard).
Checks performed in order:
1. Scheme is in *allowed_schemes*.
2. Hostname is present.
3. **Every** address returned by ``getaddrinfo`` is globally routable
(``ip.is_global``). This is an allowlist approach: it catches private,
loopback, link-local, reserved, multicast, and all other
special-purpose ranges rather than individual deny-list flags.
IPv4-mapped IPv6 addresses (e.g. ``::ffff:127.0.0.1``) are normalised
to their IPv4 form via :func:`_effective_ip` before the check.
Returns ``(hostname, resolved_ip)`` — the first validated public IP string
— so the caller can **pin** that address in its HTTP client and prevent
DNS-rebinding attacks (the hostname is resolved exactly once).
"""
parsed = urlparse(url)
scheme = parsed.scheme
if scheme not in allowed_schemes:
logger.warning(
"SSRF guard blocked URL with disallowed scheme: scheme=%r url=%r",
scheme,
url,
)
raise ValueError(f"Disallowed URL scheme: {scheme!r}. Only {sorted(allowed_schemes)} are allowed.")
hostname = parsed.hostname
if not hostname:
logger.warning("SSRF guard blocked URL with missing host: url=%r", url)
raise ValueError("URL is missing a host.")
try:
addr_infos = socket.getaddrinfo(hostname, None)
except socket.gaierror as exc:
logger.warning("SSRF guard could not resolve hostname=%r reason=%s", hostname, exc)
raise ValueError(f"Could not resolve hostname {hostname!r}: {exc}") from exc
resolved_ip: str | None = None
for _family, _type, _proto, _canonname, sockaddr in addr_infos:
raw_ip = ipaddress.ip_address(sockaddr[0])
eff_ip = _effective_ip(raw_ip)
if not eff_ip.is_global:
logger.warning(
"SSRF guard blocked URL: hostname=%r resolved to non-public address=%s",
hostname,
raw_ip,
)
raise ValueError(f"URL resolves to a non-public address ({raw_ip}), which is not allowed.")
if resolved_ip is None:
resolved_ip = str(raw_ip)
if resolved_ip is None:
logger.warning("SSRF guard blocked URL: hostname=%r resolved to no addresses", hostname)
raise ValueError(f"Hostname {hostname!r} resolved to no addresses.")
return hostname, resolved_ip