mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-20 00:07:00 +08:00
### What problem does this PR solve? Partially addresses #14362. This PR enables syncing deleted files for RSS data sources. Previously, RSS incremental sync only returned feed entries whose timestamps were inside the poll window. If an entry was removed from the RSS feed, RAGFlow had no full current RSS snapshot to pass into the shared stale-document cleanup path, so the deleted remote entry could remain in the knowledge base. This PR: - adds `retrieve_all_slim_docs_perm_sync()` to `RSSConnector` - reuses the same `rss:<md5(stable_key)>` document ID derivation used by normal RSS ingest - returns `(document_generator, file_list)` for incremental RSS sync when `sync_deleted_files` is enabled - captures the poll end timestamp before snapshot/poll so cleanup does not race against the same sync window - adds start/end logs around RSS slim snapshot collection - exposes the deleted-file sync toggle for RSS in the data source UI Per maintainer request on related datasource PRs, this PR contains no test-case changes. Local verification was run with an external script. Validation: - `uv run ruff check common/data_source/rss_connector.py rag/svr/sync_data_source.py` - `uv run pytest test/unit_test/rag/test_sync_data_source.py -q` - `./node_modules/.bin/eslint src/pages/user-setting/data-source/constant/index.tsx` - `git diff --check` - `uv run python /tmp/verify_rss_deleted_sync.py --repo /root/74/ragflow` ### Type of change - [x] New Feature (non-breaking change which adds functionality)
246 lines
8.8 KiB
Python
246 lines
8.8 KiB
Python
import hashlib
|
|
from datetime import datetime, timezone
|
|
from email.utils import parsedate_to_datetime
|
|
from time import struct_time
|
|
from typing import Any
|
|
from urllib.parse import urljoin, urlparse
|
|
|
|
import bs4
|
|
import feedparser
|
|
import requests
|
|
|
|
from common.data_source.config import INDEX_BATCH_SIZE, REQUEST_TIMEOUT_SECONDS, DocumentSource
|
|
from common.data_source.interfaces import LoadConnector, PollConnector, SlimConnectorWithPermSync
|
|
from common.data_source.models import (
|
|
Document,
|
|
GenerateDocumentsOutput,
|
|
GenerateSlimDocumentOutput,
|
|
SecondsSinceUnixEpoch,
|
|
SlimDocument,
|
|
)
|
|
from common.ssrf_guard import assert_url_is_safe, pin_dns as _pin_dns
|
|
|
|
_MAX_REDIRECTS = 10
|
|
|
|
|
|
class RSSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
|
def __init__(self, feed_url: str, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
|
self.feed_url = feed_url.strip()
|
|
self.batch_size = batch_size
|
|
self.credentials: dict[str, Any] = {}
|
|
self._cached_feed: Any | None = None
|
|
|
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
|
self.credentials = credentials or {}
|
|
return None
|
|
|
|
def validate_connector_settings(self) -> None:
|
|
self._validate_feed_url()
|
|
if self.batch_size < 1:
|
|
raise ValueError("batch_size must be greater than 0")
|
|
self._read_feed(require_entries=True)
|
|
|
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
|
yield from self._load_entries()
|
|
|
|
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
|
|
yield from self._load_entries(start=start, end=end)
|
|
|
|
def retrieve_all_slim_docs_perm_sync(
|
|
self,
|
|
callback: Any = None,
|
|
) -> GenerateSlimDocumentOutput:
|
|
del callback
|
|
|
|
feed = self._read_feed(require_entries=False)
|
|
batch: list[SlimDocument] = []
|
|
|
|
for entry in feed.entries:
|
|
batch.append(SlimDocument(id=self._build_document_id(entry)))
|
|
|
|
if len(batch) >= self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
def _load_entries(
|
|
self,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> GenerateDocumentsOutput:
|
|
feed = self._read_feed(require_entries=False)
|
|
batch: list[Document] = []
|
|
|
|
for entry in feed.entries:
|
|
updated_at = self._resolve_entry_time(entry)
|
|
ts = updated_at.timestamp()
|
|
|
|
if start is not None and ts <= start:
|
|
continue
|
|
if end is not None and ts > end:
|
|
continue
|
|
|
|
batch.append(self._build_document(entry, updated_at))
|
|
|
|
if len(batch) >= self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
def _validate_feed_url(self) -> tuple[str, str]:
|
|
"""Validate ``self.feed_url`` and return ``(hostname, resolved_ip)``."""
|
|
if not self.feed_url:
|
|
raise ValueError("feed_url is required")
|
|
|
|
parsed = urlparse(self.feed_url)
|
|
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
|
raise ValueError("feed_url must be a valid http or https URL")
|
|
|
|
return assert_url_is_safe(self.feed_url)
|
|
|
|
def _read_feed(self, require_entries: bool) -> Any:
|
|
if self._cached_feed is not None:
|
|
if require_entries and not self._cached_feed.entries:
|
|
raise ValueError("RSS feed contains no entries")
|
|
return self._cached_feed
|
|
|
|
# Validate once to get the pinned IP for the initial request.
|
|
current_hostname, current_ip = self._validate_feed_url()
|
|
current_url = self.feed_url
|
|
|
|
# Follow redirects manually: each hop is validated and DNS-pinned
|
|
# *before* the connection is made, closing the TOCTOU rebinding window
|
|
# that existed when allow_redirects=True was used with post-hoc checks.
|
|
response: requests.Response | None = None
|
|
for _ in range(_MAX_REDIRECTS + 1):
|
|
with _pin_dns(current_hostname, current_ip):
|
|
response = requests.get(
|
|
current_url,
|
|
timeout=REQUEST_TIMEOUT_SECONDS,
|
|
allow_redirects=False,
|
|
)
|
|
|
|
if response.status_code not in (301, 302, 303, 307, 308):
|
|
break
|
|
|
|
location = response.headers.get("Location")
|
|
if not location:
|
|
break # broken redirect; let raise_for_status() handle it
|
|
|
|
redirect_url = urljoin(current_url, location)
|
|
# Validate redirect target before following it.
|
|
current_hostname, current_ip = assert_url_is_safe(redirect_url)
|
|
current_url = redirect_url
|
|
else:
|
|
raise ValueError(f"Exceeded {_MAX_REDIRECTS} redirects fetching {self.feed_url!r}")
|
|
|
|
response.raise_for_status()
|
|
|
|
feed = feedparser.parse(response.content)
|
|
if getattr(feed, "bozo", False) and not feed.entries:
|
|
error = getattr(feed, "bozo_exception", None)
|
|
if error:
|
|
raise ValueError(f"Failed to parse RSS feed: {error}") from error
|
|
raise ValueError("Failed to parse RSS feed")
|
|
if require_entries and not feed.entries:
|
|
raise ValueError("RSS feed contains no entries")
|
|
|
|
self._cached_feed = feed
|
|
return feed
|
|
|
|
def _build_document(self, entry: Any, updated_at: datetime) -> Document:
|
|
link = (entry.get("link") or "").strip()
|
|
title = (entry.get("title") or "").strip()
|
|
stable_key = self._resolve_stable_key(entry)
|
|
semantic_identifier = title or link or stable_key
|
|
content = self._build_content(entry, semantic_identifier)
|
|
blob = content.encode("utf-8")
|
|
|
|
metadata: dict[str, Any] = {"feed_url": self.feed_url}
|
|
if link:
|
|
metadata["link"] = link
|
|
if entry.get("author"):
|
|
metadata["author"] = entry.get("author")
|
|
|
|
categories = []
|
|
for tag in entry.get("tags", []):
|
|
if not isinstance(tag, dict):
|
|
continue
|
|
term = tag.get("term")
|
|
if isinstance(term, str) and term:
|
|
categories.append(term)
|
|
if categories:
|
|
metadata["categories"] = categories
|
|
|
|
return Document(
|
|
id=self._build_document_id(entry),
|
|
source=DocumentSource.RSS,
|
|
semantic_identifier=semantic_identifier,
|
|
extension=".txt",
|
|
blob=blob,
|
|
doc_updated_at=updated_at,
|
|
size_bytes=len(blob),
|
|
metadata=metadata,
|
|
)
|
|
|
|
def _build_content(self, entry: Any, semantic_identifier: str) -> str:
|
|
parts = [semantic_identifier]
|
|
content_blocks = entry.get("content") or []
|
|
|
|
for block in content_blocks:
|
|
value = block.get("value") if isinstance(block, dict) else None
|
|
normalized = self._normalize_text(value)
|
|
if normalized:
|
|
parts.append(normalized)
|
|
|
|
if len(parts) == 1:
|
|
fallback = entry.get("summary") or entry.get("description") or ""
|
|
normalized = self._normalize_text(fallback)
|
|
if normalized:
|
|
parts.append(normalized)
|
|
|
|
return "\n\n".join(part for part in parts if part).strip()
|
|
|
|
def _build_document_id(self, entry: Any) -> str:
|
|
stable_key = self._resolve_stable_key(entry)
|
|
return f"rss:{hashlib.md5(stable_key.encode('utf-8')).hexdigest()}"
|
|
|
|
def _resolve_stable_key(self, entry: Any) -> str:
|
|
link = (entry.get("link") or "").strip()
|
|
title = (entry.get("title") or "").strip()
|
|
return (entry.get("id") or link or title or self.feed_url).strip()
|
|
|
|
def _resolve_entry_time(self, entry: Any) -> datetime:
|
|
for field in ("updated_parsed", "published_parsed"):
|
|
value = entry.get(field)
|
|
if value:
|
|
return self._struct_time_to_utc(value)
|
|
|
|
for field in ("updated", "published"):
|
|
value = entry.get(field)
|
|
if isinstance(value, str) and value.strip():
|
|
try:
|
|
parsed = parsedate_to_datetime(value)
|
|
except (TypeError, ValueError, IndexError):
|
|
continue
|
|
if parsed.tzinfo is None:
|
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
|
return parsed.astimezone(timezone.utc)
|
|
|
|
return datetime.now(timezone.utc)
|
|
|
|
@staticmethod
|
|
def _normalize_text(value: Any) -> str:
|
|
if not isinstance(value, str):
|
|
return ""
|
|
return bs4.BeautifulSoup(value, "html.parser").get_text("\n", strip=True)
|
|
|
|
@staticmethod
|
|
def _struct_time_to_utc(value: struct_time | tuple[Any, ...]) -> datetime:
|
|
dt = datetime(*value[:6], tzinfo=timezone.utc)
|
|
return dt.astimezone(timezone.utc)
|