Files
ragflow/common/data_source/rss_connector.py
Xing Hong fb95136f39 Fix: validate URL scheme and resolved IP before crawling to prevent SSRF (#14090)
### What problem does this PR solve?

The POST /upload_info?url=<url> endpoint accepted a user-supplied URL
and passed it directly to AsyncWebCrawler without any validation. There
were no restrictions on URL scheme, destination hostname, or resolved IP
address. This allowed any authenticated user to instruct the server to
make outbound HTTP requests to internal infrastructure — including RFC
1918 private networks, loopback addresses, and cloud metadata services
such as http://169.254.169.254 — effectively using the server as a proxy
for internal network reconnaissance or credential theft.

This PR adds an SSRF guard (_validate_url_for_crawl) that runs before
any crawl is initiated. It enforces an allowlist of safe schemes
(http/https), resolves the hostname at validation time, and rejects any
URL whose resolved IP falls within a private or reserved network range.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-04-25 14:30:15 +08:00

212 lines
7.9 KiB
Python

import hashlib
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from time import struct_time
from typing import Any
from urllib.parse import urljoin, urlparse
import bs4
import feedparser
import requests
from common.data_source.config import INDEX_BATCH_SIZE, REQUEST_TIMEOUT_SECONDS, DocumentSource
from common.data_source.interfaces import LoadConnector, PollConnector
from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
from common.ssrf_guard import assert_url_is_safe, pin_dns as _pin_dns
_MAX_REDIRECTS = 10
class RSSConnector(LoadConnector, PollConnector):
def __init__(self, feed_url: str, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.feed_url = feed_url.strip()
self.batch_size = batch_size
self.credentials: dict[str, Any] = {}
self._cached_feed: Any | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.credentials = credentials or {}
return None
def validate_connector_settings(self) -> None:
self._validate_feed_url()
if self.batch_size < 1:
raise ValueError("batch_size must be greater than 0")
self._read_feed(require_entries=True)
def load_from_state(self) -> GenerateDocumentsOutput:
yield from self._load_entries()
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
yield from self._load_entries(start=start, end=end)
def _load_entries(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
feed = self._read_feed(require_entries=False)
batch: list[Document] = []
for entry in feed.entries:
updated_at = self._resolve_entry_time(entry)
ts = updated_at.timestamp()
if start is not None and ts <= start:
continue
if end is not None and ts > end:
continue
batch.append(self._build_document(entry, updated_at))
if len(batch) >= self.batch_size:
yield batch
batch = []
if batch:
yield batch
def _validate_feed_url(self) -> tuple[str, str]:
"""Validate ``self.feed_url`` and return ``(hostname, resolved_ip)``."""
if not self.feed_url:
raise ValueError("feed_url is required")
parsed = urlparse(self.feed_url)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise ValueError("feed_url must be a valid http or https URL")
return assert_url_is_safe(self.feed_url)
def _read_feed(self, require_entries: bool) -> Any:
if self._cached_feed is not None:
if require_entries and not self._cached_feed.entries:
raise ValueError("RSS feed contains no entries")
return self._cached_feed
# Validate once to get the pinned IP for the initial request.
current_hostname, current_ip = self._validate_feed_url()
current_url = self.feed_url
# Follow redirects manually: each hop is validated and DNS-pinned
# *before* the connection is made, closing the TOCTOU rebinding window
# that existed when allow_redirects=True was used with post-hoc checks.
response: requests.Response | None = None
for _ in range(_MAX_REDIRECTS + 1):
with _pin_dns(current_hostname, current_ip):
response = requests.get(
current_url,
timeout=REQUEST_TIMEOUT_SECONDS,
allow_redirects=False,
)
if response.status_code not in (301, 302, 303, 307, 308):
break
location = response.headers.get("Location")
if not location:
break # broken redirect; let raise_for_status() handle it
redirect_url = urljoin(current_url, location)
# Validate redirect target before following it.
current_hostname, current_ip = assert_url_is_safe(redirect_url)
current_url = redirect_url
else:
raise ValueError(f"Exceeded {_MAX_REDIRECTS} redirects fetching {self.feed_url!r}")
response.raise_for_status()
feed = feedparser.parse(response.content)
if getattr(feed, "bozo", False) and not feed.entries:
error = getattr(feed, "bozo_exception", None)
if error:
raise ValueError(f"Failed to parse RSS feed: {error}") from error
raise ValueError("Failed to parse RSS feed")
if require_entries and not feed.entries:
raise ValueError("RSS feed contains no entries")
self._cached_feed = feed
return feed
def _build_document(self, entry: Any, updated_at: datetime) -> Document:
link = (entry.get("link") or "").strip()
title = (entry.get("title") or "").strip()
stable_key = (entry.get("id") or link or title or self.feed_url).strip()
semantic_identifier = title or link or stable_key
content = self._build_content(entry, semantic_identifier)
blob = content.encode("utf-8")
metadata: dict[str, Any] = {"feed_url": self.feed_url}
if link:
metadata["link"] = link
if entry.get("author"):
metadata["author"] = entry.get("author")
categories = []
for tag in entry.get("tags", []):
if not isinstance(tag, dict):
continue
term = tag.get("term")
if isinstance(term, str) and term:
categories.append(term)
if categories:
metadata["categories"] = categories
return Document(
id=f"rss:{hashlib.md5(stable_key.encode('utf-8')).hexdigest()}",
source=DocumentSource.RSS,
semantic_identifier=semantic_identifier,
extension=".txt",
blob=blob,
doc_updated_at=updated_at,
size_bytes=len(blob),
metadata=metadata,
)
def _build_content(self, entry: Any, semantic_identifier: str) -> str:
parts = [semantic_identifier]
content_blocks = entry.get("content") or []
for block in content_blocks:
value = block.get("value") if isinstance(block, dict) else None
normalized = self._normalize_text(value)
if normalized:
parts.append(normalized)
if len(parts) == 1:
fallback = entry.get("summary") or entry.get("description") or ""
normalized = self._normalize_text(fallback)
if normalized:
parts.append(normalized)
return "\n\n".join(part for part in parts if part).strip()
def _resolve_entry_time(self, entry: Any) -> datetime:
for field in ("updated_parsed", "published_parsed"):
value = entry.get(field)
if value:
return self._struct_time_to_utc(value)
for field in ("updated", "published"):
value = entry.get(field)
if isinstance(value, str) and value.strip():
try:
parsed = parsedate_to_datetime(value)
except (TypeError, ValueError, IndexError):
continue
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)
return datetime.now(timezone.utc)
@staticmethod
def _normalize_text(value: Any) -> str:
if not isinstance(value, str):
return ""
return bs4.BeautifulSoup(value, "html.parser").get_text("\n", strip=True)
@staticmethod
def _struct_time_to_utc(value: struct_time | tuple[Any, ...]) -> datetime:
dt = datetime(*value[:6], tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)