mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-21 00:36:43 +08:00
# feat: Add Generic REST API Connector
## What problem does this PR solve?
RAGFlow supports many specific data source connectors (MySQL, Slack,
Google Drive, etc.), but there was no way to connect an arbitrary REST
API as a data source. Users with custom or third-party APIs had to write
a new connector class for each one.
This PR adds a **generic, configuration-driven REST API connector** that
lets users connect any REST API as a data source entirely through the UI
— no code changes needed per API.
---
## Features
### Core Connector (`common/data_source/rest_api_connector.py`)
- Implements `LoadConnector` and `PollConnector` interfaces for full and
incremental sync
- **Configurable authentication:** None, API Key (custom header), Bearer
Token, Basic Auth
- **Pluggable pagination:** Page-based, Offset-based, Cursor-based, or
None
- Smart page-size inference from user's query parameters to avoid
duplicate/conflicting params
- Configurable request delay between pages to prevent API rate limiting
- Auto-detection of the items array in JSON responses (`items`,
`results`, `data`, `records`, or first list found)
- **Advanced field mapping** with dot-notation (`country.name`), array
wildcards (`newsType[*].name`), type hints, and default values
- Optional content template rendering (`"Title: {title}\nBody: {body}"`)
- HTML stripping for content fields
- Stable document IDs via `hash128` from a configurable ID field or
auto-generated from item content
- Pydantic configuration schema with automatic coercion of UI string
inputs to dicts/lists
### Backend Registration (`rag/svr/sync_data_source.py`,
`common/constants.py`, `common/data_source/config.py`)
- `REST_API` sync class wired into RAGFlow's `func_factory`
- Full sync (`load_from_state`) and incremental polling (`poll_source`)
support
- Credentials and config passed from task to connector following
existing patterns (MySQL, SeaFile, etc.)
### Test Connection Endpoint (`api/apps/connector_app.py`)
- `POST /v1/connector/<id>/test` validates config schema,
authentication, and API connectivity without triggering a sync
- Clear error messages for auth failures vs. config issues
### Frontend UI (`web/src/pages/user-setting/data-source/constant/`)
- **Postman-style configuration:** Base URL, Query Parameters (key=value
per line), Auth, Content Fields, Metadata Fields, Pagination Type
- Auth-type-aware form: fields for API key header/value, Bearer token,
or Basic username/password appear only when relevant
- **Advanced Settings** toggle for: Custom Headers, Max Pages, Request
Delay, Poll Timestamp Field, Request Body (POST)
- Connector icon (SVG) and i18n strings (English)
- **"Test Connection"** button to validate before syncing
---
## Controls & Safety
- Configurable max pages safety cap (default: 1000, adjustable in UI)
- Configurable request delay between pages (default: 0.5s, adjustable in
UI)
- Auth errors (401/403) fail immediately without retries; transient
errors retry with exponential backoff
- Diagnostic logging: auth setup confirmation, request details on
failure, content field extraction status
---
## Type of change
- [x] New Feature (non-breaking change which adds functionality)
##Visual Screenshots of Features
<img width="482" height="510" alt="Screenshot 2026-03-11 at 5 19 52 PM"
src="https://github.com/user-attachments/assets/dcb7ab4a-1622-44f3-bb02-d6f0527314c4"
/>
(Connector can be configured within the external data sources tab)
Configuration Parameters:
<img width="661" height="682" alt="Screenshot 2026-03-11 at 5 20 46 PM"
src="https://github.com/user-attachments/assets/5e154e71-4ab5-4872-bfb2-04f02b73c18a"
/>
<img width="661" height="682" alt="Screenshot 2026-03-11 at 5 20 54 PM"
src="https://github.com/user-attachments/assets/00cb14b7-0bcf-4b94-9d71-34e93369ecb2"
/>
Connection can be tested before attaching to dataset:
<img width="981" height="681" alt="Screenshot 2026-03-11 at 5 21 40 PM"
src="https://github.com/user-attachments/assets/aaa6eeeb-89a7-4349-bc34-2423bf8be9ee"
/>
Ingestion tested with API connector (works perfectly fine):
<img width="1062" height="705" alt="Screenshot 2026-03-11 at 5 22 30 PM"
src="https://github.com/user-attachments/assets/afcd0d58-cadd-4152-badc-d2f14d96fbec"
/>
Search & Retrieval works as well with metadata flow:
<img width="1062" height="705" alt="Screenshot 2026-03-11 at 5 23 05 PM"
src="https://github.com/user-attachments/assets/d41ee935-dcf7-4456-b317-22a76ca032c0"
/>
---------
Co-authored-by: Ahmad Intisar <ahmadintisar@Ahmads-MacBook-M4-Pro.local>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1013 lines
39 KiB
Python
1013 lines
39 KiB
Python
"""Generic, configuration-driven REST API data source connector.
|
|
|
|
Connect any REST API as a RAGFlow data source without code changes.
|
|
All behaviour — URL, auth, pagination, field mapping — is controlled
|
|
via the ``RestAPIConnectorConfig`` schema exposed by the UI.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional
|
|
from urllib.parse import parse_qs, urlparse, urlunparse
|
|
|
|
import ipaddress
|
|
import socket
|
|
import requests
|
|
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, ValidationError, field_validator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from api.utils.common import hash128
|
|
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
|
|
from common.data_source.exceptions import (
|
|
ConnectorMissingCredentialError,
|
|
ConnectorValidationError,
|
|
)
|
|
from common.data_source.interfaces import (
|
|
LoadConnector,
|
|
PollConnector,
|
|
SecondsSinceUnixEpoch,
|
|
)
|
|
from common.data_source.models import Document
|
|
from common.data_source.utils import rl_requests, retry_builder
|
|
|
|
try:
|
|
from jsonpath import jsonpath as _jsonpath # type: ignore[import]
|
|
except Exception: # pragma: no cover
|
|
_jsonpath = None
|
|
|
|
_FIELD_SEGMENT_RE = re.compile(r'^(?P<key>[^\[\]]+)(\[(?P<index>\d+|\*)\])?$')
|
|
_DEFAULT_MAX_PAGES = 1000
|
|
|
|
|
|
class AuthType:
|
|
NONE = "none"
|
|
API_KEY_HEADER = "api_key_header"
|
|
BEARER = "bearer"
|
|
BASIC = "basic"
|
|
|
|
|
|
class PaginationType:
|
|
NONE = "none"
|
|
PAGE = "page"
|
|
OFFSET = "offset"
|
|
CURSOR = "cursor"
|
|
|
|
|
|
def _text_to_dict(v: Any) -> Dict[str, str]:
|
|
"""Parse a dict, JSON string, or ``key=value`` text (one per line) into a dict.
|
|
|
|
This is module-level because Pydantic ``@field_validator`` classmethods
|
|
on ``RestAPIConnectorConfig`` need to call it before any instance exists.
|
|
"""
|
|
if v is None or v == "":
|
|
return {}
|
|
if isinstance(v, dict):
|
|
return {str(k): str(vv) for k, vv in v.items()}
|
|
if isinstance(v, str):
|
|
try:
|
|
parsed = json.loads(v)
|
|
if isinstance(parsed, dict):
|
|
return {str(k): str(vv) for k, vv in parsed.items()}
|
|
except Exception:
|
|
pass
|
|
result: Dict[str, str] = {}
|
|
for line in v.splitlines():
|
|
line = line.strip()
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
if "=" in line:
|
|
k, _, val = line.partition("=")
|
|
result[k.strip()] = val.strip()
|
|
return result
|
|
return {}
|
|
|
|
|
|
class RestAPIConnectorConfig(BaseModel):
|
|
"""Validated schema for the REST API connector configuration."""
|
|
|
|
model_config = ConfigDict(extra="ignore")
|
|
|
|
url: HttpUrl
|
|
method: str = "GET"
|
|
headers: Dict[str, str] = Field(default_factory=dict)
|
|
query_params: Dict[str, str] = Field(default_factory=dict)
|
|
|
|
auth_type: str = AuthType.NONE
|
|
auth_config: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
items_path: Optional[str] = None
|
|
id_field: Optional[str] = None
|
|
content_fields: List[str] = Field(default_factory=list)
|
|
metadata_fields: List[str] = Field(default_factory=list)
|
|
|
|
pagination_type: str = PaginationType.NONE
|
|
pagination_config: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
poll_timestamp_field: Optional[str] = None
|
|
request_body: Optional[Dict[str, Any]] = None
|
|
|
|
field_type_hints: Dict[str, str] = Field(default_factory=dict)
|
|
field_default_values: Dict[str, Any] = Field(default_factory=dict)
|
|
content_template: Optional[str] = None
|
|
|
|
batch_size: int = INDEX_BATCH_SIZE
|
|
max_pages: int = _DEFAULT_MAX_PAGES
|
|
request_delay: float = 0.5
|
|
|
|
@field_validator("headers", mode="before")
|
|
@classmethod
|
|
def _coerce_headers(cls, v: Any) -> Dict[str, str]:
|
|
return _text_to_dict(v)
|
|
|
|
@field_validator("query_params", mode="before")
|
|
@classmethod
|
|
def _coerce_query_params(cls, v: Any) -> Dict[str, str]:
|
|
return _text_to_dict(v)
|
|
|
|
@field_validator("content_fields", "metadata_fields", mode="before")
|
|
@classmethod
|
|
def _coerce_field_list(cls, v: Any) -> List[str]:
|
|
if v is None or v == "":
|
|
return []
|
|
if isinstance(v, str):
|
|
return [p.strip() for p in v.split(",") if p.strip()]
|
|
if isinstance(v, list):
|
|
return [str(p).strip() for p in v if str(p).strip()]
|
|
return []
|
|
|
|
def normalized_method(self) -> str:
|
|
m = (self.method or "GET").upper()
|
|
if m not in {"GET", "POST"}:
|
|
raise ConnectorValidationError(f"Unsupported HTTP method '{m}'.")
|
|
return m
|
|
|
|
def normalized_auth_type(self) -> str:
|
|
if self.auth_type not in {AuthType.NONE, AuthType.API_KEY_HEADER, AuthType.BEARER, AuthType.BASIC}:
|
|
raise ConnectorValidationError(f"Unsupported auth_type '{self.auth_type}'.")
|
|
return self.auth_type
|
|
|
|
def normalized_pagination_type(self) -> str:
|
|
if self.pagination_type not in {PaginationType.NONE, PaginationType.PAGE, PaginationType.OFFSET, PaginationType.CURSOR}:
|
|
raise ConnectorValidationError(f"Unsupported pagination_type '{self.pagination_type}'.")
|
|
return self.pagination_type
|
|
|
|
def ensure_required_fields(self) -> None:
|
|
if not self.content_fields:
|
|
raise ConnectorValidationError("At least one content field must be configured (content_fields).")
|
|
|
|
|
|
class RestAPIConnector(LoadConnector, PollConnector):
|
|
"""Configuration-driven REST API connector.
|
|
|
|
Implements ``LoadConnector`` and ``PollConnector`` to fetch documents
|
|
from any REST API using user-provided configuration (URL, auth,
|
|
pagination, field mapping).
|
|
"""
|
|
|
|
@staticmethod
|
|
def _validate_url_for_ssrf(url: str) -> None:
|
|
"""Validate that the URL does not point to localhost or private/internal networks.
|
|
|
|
Raises:
|
|
ConnectorValidationError: If the URL is considered unsafe.
|
|
"""
|
|
parsed = urlparse(str(url))
|
|
|
|
if parsed.scheme not in ("http", "https"):
|
|
msg = f"Unsupported URL scheme for REST API connector: {parsed.scheme!r}. Only http/https are allowed."
|
|
logger.warning(msg)
|
|
raise ConnectorValidationError(msg)
|
|
|
|
hostname = parsed.hostname
|
|
if not hostname:
|
|
msg = "REST API connector URL must include a hostname."
|
|
logger.warning(msg)
|
|
raise ConnectorValidationError(msg)
|
|
|
|
# Quick checks for obvious localhost-style hostnames.
|
|
lower_host = hostname.lower()
|
|
if lower_host in ("localhost",):
|
|
msg = f"REST API connector URL hostname {hostname!r} is not allowed (localhost is blocked)."
|
|
logger.warning(msg)
|
|
raise ConnectorValidationError(msg)
|
|
|
|
try:
|
|
addrinfo_list = socket.getaddrinfo(hostname, None)
|
|
except OSError as exc:
|
|
# If resolution fails, log and let higher-level validation (if any) decide.
|
|
# We do not treat this as an SSRF condition by itself.
|
|
logger.info("DNS resolution failed for REST API connector URL %r: %s", url, exc)
|
|
return
|
|
|
|
for family, _, _, _, sockaddr in addrinfo_list:
|
|
ip_str = sockaddr[0]
|
|
try:
|
|
ip_obj = ipaddress.ip_address(ip_str)
|
|
except ValueError:
|
|
# Not an IP address we understand; skip.
|
|
logger.debug("Skipping non-IP address resolved from %r: %r", hostname, ip_str)
|
|
continue
|
|
|
|
if (
|
|
ip_obj.is_loopback
|
|
or ip_obj.is_private
|
|
or ip_obj.is_link_local
|
|
or ip_obj.is_reserved
|
|
or ip_obj.is_multicast
|
|
):
|
|
msg = (
|
|
f"REST API connector URL {url!r} resolves to disallowed address {ip_str} "
|
|
"(localhost, private, link-local, reserved, or multicast addresses are blocked)."
|
|
)
|
|
logger.warning(msg)
|
|
raise ConnectorValidationError(msg)
|
|
|
|
logger.debug("REST API connector URL %r passed SSRF safety validation.", url)
|
|
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
method: str = "GET",
|
|
headers: Optional[Dict[str, str]] = None,
|
|
query_params: Optional[Dict[str, str]] = None,
|
|
auth_type: str = AuthType.NONE,
|
|
auth_config: Optional[Dict[str, Any]] = None,
|
|
items_path: Optional[str] = None,
|
|
id_field: Optional[str] = None,
|
|
content_fields: Optional[List[str]] = None,
|
|
metadata_fields: Optional[List[str]] = None,
|
|
pagination_type: str = PaginationType.NONE,
|
|
pagination_config: Optional[Dict[str, Any]] = None,
|
|
poll_timestamp_field: Optional[str] = None,
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
max_pages: int = _DEFAULT_MAX_PAGES,
|
|
request_delay: float = 0.5,
|
|
request_body: Optional[Dict[str, Any]] = None,
|
|
field_type_hints: Optional[Dict[str, str]] = None,
|
|
field_default_values: Optional[Dict[str, Any]] = None,
|
|
content_template: Optional[str] = None,
|
|
) -> None:
|
|
# Validate URL against SSRF-style targets (localhost, private/internal ranges, etc.)
|
|
self._validate_url_for_ssrf(url)
|
|
|
|
parsed = urlparse(str(url))
|
|
self._base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, "", "", ""))
|
|
self._url_params: Dict[str, str] = {}
|
|
if parsed.query:
|
|
for k, v_list in parse_qs(parsed.query, keep_blank_values=True).items():
|
|
self._url_params[k] = v_list[-1]
|
|
|
|
self._explicit_query_params: Dict[str, str] = (
|
|
_text_to_dict(query_params) if isinstance(query_params, str) else (query_params or {})
|
|
)
|
|
self.url = self._base_url
|
|
self.method = (method or "GET").upper()
|
|
self._base_headers: Dict[str, str] = (
|
|
_text_to_dict(headers) if isinstance(headers, str) else (headers or {})
|
|
)
|
|
self.auth_type = auth_type or AuthType.NONE
|
|
self.auth_config: Dict[str, Any] = auth_config or {}
|
|
self.items_path = items_path
|
|
self.id_field = id_field
|
|
self.content_fields: List[str] = content_fields or []
|
|
self.metadata_fields: List[str] = metadata_fields or []
|
|
self.pagination_type = pagination_type or PaginationType.NONE
|
|
self.pagination_config: Dict[str, Any] = pagination_config or {}
|
|
self._static_request_body: Dict[str, Any] = (
|
|
request_body if request_body is not None
|
|
else self.pagination_config.get("request_body") or {}
|
|
)
|
|
self.poll_timestamp_field = poll_timestamp_field
|
|
self.batch_size = batch_size
|
|
self.max_pages = max_pages
|
|
self.request_delay = max(request_delay, 0.0)
|
|
self.field_type_hints: Dict[str, str] = field_type_hints or {}
|
|
self.field_default_values: Dict[str, Any] = field_default_values or {}
|
|
self.content_template = content_template
|
|
|
|
self._credentials: Dict[str, Any] = {}
|
|
self._auth_headers: Dict[str, str] = {}
|
|
self._basic_auth: Optional[requests.auth.HTTPBasicAuth] = None
|
|
|
|
# -- Credentials --------------------------------------------------------
|
|
|
|
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
|
|
"""Apply authentication credentials (no network call).
|
|
|
|
Use ``validate_config()`` to perform a live connectivity check.
|
|
"""
|
|
self._credentials = credentials or {}
|
|
self._build_auth()
|
|
return None
|
|
|
|
def _build_auth(self) -> None:
|
|
"""Derive auth headers / basic-auth object from credentials."""
|
|
self._auth_headers = {}
|
|
self._basic_auth = None
|
|
|
|
if self.auth_type == AuthType.NONE:
|
|
logging.info("REST API auth_type=none, no authentication configured.")
|
|
return
|
|
|
|
if self.auth_type == AuthType.API_KEY_HEADER:
|
|
header_name = self.auth_config.get("header_name")
|
|
api_key = (
|
|
self._credentials.get("api_key")
|
|
or self.auth_config.get("api_key_value")
|
|
or self.auth_config.get("api_key")
|
|
)
|
|
if not header_name or not api_key:
|
|
logging.warning(
|
|
"REST API auth setup failed: header_name=%s, api_key present=%s, "
|
|
"credentials keys=%s, auth_config keys=%s",
|
|
header_name, bool(api_key),
|
|
list(self._credentials.keys()), list(self.auth_config.keys()),
|
|
)
|
|
raise ConnectorMissingCredentialError(
|
|
"REST API (api_key_header) requires 'header_name' in auth_config and 'api_key' in credentials"
|
|
)
|
|
self._auth_headers[header_name] = str(api_key)
|
|
logging.info("REST API auth configured: header '%s' set.", header_name)
|
|
return
|
|
|
|
if self.auth_type == AuthType.BEARER:
|
|
token = self._credentials.get("token") or self.auth_config.get("token")
|
|
if not token:
|
|
raise ConnectorMissingCredentialError("REST API (bearer) requires 'token' in credentials")
|
|
self._auth_headers["Authorization"] = f"Bearer {token}"
|
|
logging.info("REST API auth configured: Bearer token set.")
|
|
return
|
|
|
|
if self.auth_type == AuthType.BASIC:
|
|
username = self._credentials.get("username") or self.auth_config.get("username")
|
|
password = self._credentials.get("password") or self.auth_config.get("password")
|
|
if not username or password is None:
|
|
raise ConnectorMissingCredentialError("REST API (basic) requires 'username' and 'password'")
|
|
self._basic_auth = requests.auth.HTTPBasicAuth(str(username), str(password))
|
|
logging.info("REST API auth configured: Basic auth for user '%s'.", username)
|
|
return
|
|
|
|
raise ConnectorValidationError(f"Unsupported auth_type: {self.auth_type}")
|
|
|
|
# -- Config validation (test connection) --------------------------------
|
|
|
|
@classmethod
|
|
def parse_storage_config(cls, raw: Dict[str, Any]) -> RestAPIConnectorConfig:
|
|
"""Parse connector config as stored on the connector row (no network I/O).
|
|
|
|
``credentials`` live under ``raw`` but are excluded from the schema and
|
|
must be applied via ``load_credentials`` separately.
|
|
"""
|
|
body = {k: v for k, v in raw.items() if k != "credentials"}
|
|
try:
|
|
cfg = RestAPIConnectorConfig(**body)
|
|
except ValidationError as exc:
|
|
raise ConnectorValidationError(f"Invalid REST API config: {exc}") from exc
|
|
cfg.normalized_method()
|
|
cfg.normalized_auth_type()
|
|
cfg.normalized_pagination_type()
|
|
cfg.ensure_required_fields()
|
|
return cfg
|
|
|
|
@classmethod
|
|
def from_parsed_config(
|
|
cls,
|
|
cfg: RestAPIConnectorConfig,
|
|
*,
|
|
max_pages: Optional[int] = None,
|
|
) -> RestAPIConnector:
|
|
"""Build a connector from validated config (``__init__`` runs SSRF validation)."""
|
|
return cls(
|
|
url=str(cfg.url),
|
|
method=cfg.normalized_method(),
|
|
headers=cfg.headers,
|
|
query_params=cfg.query_params,
|
|
auth_type=cfg.normalized_auth_type(),
|
|
auth_config=cfg.auth_config,
|
|
items_path=cfg.items_path,
|
|
id_field=cfg.id_field,
|
|
content_fields=cfg.content_fields,
|
|
metadata_fields=cfg.metadata_fields,
|
|
pagination_type=cfg.normalized_pagination_type(),
|
|
pagination_config=cfg.pagination_config,
|
|
poll_timestamp_field=cfg.poll_timestamp_field,
|
|
batch_size=cfg.batch_size,
|
|
max_pages=max_pages if max_pages is not None else cfg.max_pages,
|
|
request_delay=cfg.request_delay,
|
|
request_body=cfg.request_body,
|
|
field_type_hints=cfg.field_type_hints,
|
|
field_default_values=cfg.field_default_values,
|
|
content_template=cfg.content_template,
|
|
)
|
|
|
|
@classmethod
|
|
def validate_config(
|
|
cls,
|
|
config: Dict[str, Any],
|
|
credentials: Optional[Dict[str, Any]] = None,
|
|
) -> RestAPIConnectorConfig:
|
|
"""Validate config schema and optionally perform a live API call.
|
|
|
|
Args:
|
|
config: Raw config dict from the UI / database.
|
|
credentials: Optional credentials dict; when provided a live
|
|
connectivity check is performed.
|
|
|
|
Returns:
|
|
The validated ``RestAPIConnectorConfig`` instance.
|
|
|
|
Raises:
|
|
ConnectorValidationError: On schema or connectivity failure.
|
|
"""
|
|
cfg = cls.parse_storage_config(config)
|
|
validation_cap = min(cfg.max_pages, 10)
|
|
connector = cls.from_parsed_config(cfg, max_pages=validation_cap)
|
|
|
|
if credentials is None and cfg.auth_type != AuthType.NONE:
|
|
return cfg
|
|
|
|
if credentials is not None:
|
|
connector.load_credentials(credentials)
|
|
else:
|
|
connector._credentials = {}
|
|
connector._build_auth()
|
|
|
|
try:
|
|
logging.info("Validating REST API connector by fetching first page")
|
|
_ = next(connector._page_iter_for_validation())
|
|
except StopIteration:
|
|
pass
|
|
|
|
return cfg
|
|
|
|
# -- LoadConnector / PollConnector interface -----------------------------
|
|
|
|
def load_from_state(self) -> Generator[List[Document], None, None]:
|
|
"""Full fetch with pagination."""
|
|
return self._yield_documents(time_window=None)
|
|
|
|
def poll_source(
|
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
) -> Generator[List[Document], None, None]:
|
|
"""Incremental fetch; filters by ``poll_timestamp_field`` if configured."""
|
|
if not self.poll_timestamp_field:
|
|
logging.warning(
|
|
"poll_source called without poll_timestamp_field; "
|
|
"falling back to full fetch with in-memory filtering."
|
|
)
|
|
return self._yield_documents(
|
|
time_window=(
|
|
datetime.fromtimestamp(start, tz=timezone.utc),
|
|
datetime.fromtimestamp(end, tz=timezone.utc),
|
|
)
|
|
)
|
|
|
|
# -- Document generation ------------------------------------------------
|
|
|
|
def _yield_documents(
|
|
self,
|
|
time_window: tuple[datetime, datetime] | None,
|
|
) -> Generator[List[Document], None, None]:
|
|
batch: List[Document] = []
|
|
for item in self._iter_items():
|
|
try:
|
|
doc = self._item_to_document(item)
|
|
except Exception as exc:
|
|
logging.warning("Failed to convert REST API item to Document: %s", exc)
|
|
continue
|
|
|
|
if time_window is not None and not self._doc_in_time_window(doc, *time_window):
|
|
continue
|
|
|
|
batch.append(doc)
|
|
if len(batch) >= self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
# -- Pagination & page fetching -----------------------------------------
|
|
|
|
def _iter_items(self) -> Iterable[Mapping[str, Any]]:
|
|
"""Iterate over raw items across all pages."""
|
|
page_count = 0
|
|
|
|
page = int(self.pagination_config.get("start_page", 1))
|
|
per_page = self._resolve_page_size()
|
|
|
|
offset = int(self.pagination_config.get("start_offset", 0))
|
|
limit = int(self.pagination_config.get("limit", per_page))
|
|
if limit <= 0:
|
|
limit = per_page
|
|
|
|
cursor: Optional[str] = self.pagination_config.get("initial_cursor")
|
|
|
|
while True:
|
|
if page_count >= self.max_pages:
|
|
logging.warning("REST API connector reached max_pages=%d, stopping.", self.max_pages)
|
|
break
|
|
|
|
params: Dict[str, Any] = {}
|
|
if self.pagination_type == PaginationType.PAGE:
|
|
self._apply_page_pagination(params, page, per_page)
|
|
elif self.pagination_type == PaginationType.OFFSET:
|
|
self._apply_offset_pagination(params, offset, limit)
|
|
elif self.pagination_type == PaginationType.CURSOR and cursor is not None:
|
|
self._apply_cursor_pagination(params, cursor)
|
|
|
|
if page_count > 0 and self.request_delay > 0:
|
|
time.sleep(self.request_delay)
|
|
|
|
try:
|
|
response_json = self._fetch_page(params)
|
|
except (ConnectorValidationError, ConnectorMissingCredentialError):
|
|
raise
|
|
except Exception as exc:
|
|
raise ConnectorValidationError(f"REST API page fetch failed: {exc}") from exc
|
|
|
|
items = self._extract_items(response_json)
|
|
if not items:
|
|
break
|
|
|
|
for item in items:
|
|
if isinstance(item, Mapping):
|
|
yield item
|
|
|
|
page_count += 1
|
|
|
|
if self.pagination_type == PaginationType.NONE:
|
|
break
|
|
elif self.pagination_type == PaginationType.PAGE:
|
|
if len(items) < per_page:
|
|
break
|
|
page += 1
|
|
elif self.pagination_type == PaginationType.OFFSET:
|
|
if len(items) < limit:
|
|
break
|
|
offset += limit
|
|
elif self.pagination_type == PaginationType.CURSOR:
|
|
next_cursor = self._extract_next_cursor(response_json)
|
|
if not next_cursor:
|
|
break
|
|
cursor = next_cursor
|
|
|
|
def _page_iter_for_validation(self) -> Iterable[Mapping[str, Any]]:
|
|
"""Single-page iterator used for connectivity checks."""
|
|
params: Dict[str, Any] = {}
|
|
if self.pagination_type == PaginationType.PAGE:
|
|
page = int(self.pagination_config.get("start_page", 1))
|
|
per_page = self._resolve_page_size()
|
|
self._apply_page_pagination(params, page, per_page)
|
|
elif self.pagination_type == PaginationType.OFFSET:
|
|
per_page = self._resolve_page_size()
|
|
offset = int(self.pagination_config.get("start_offset", 0))
|
|
limit = int(self.pagination_config.get("limit", per_page))
|
|
if limit <= 0:
|
|
limit = per_page
|
|
self._apply_offset_pagination(params, offset, limit)
|
|
elif self.pagination_type == PaginationType.CURSOR:
|
|
cursor = self.pagination_config.get("initial_cursor")
|
|
if cursor is not None:
|
|
self._apply_cursor_pagination(params, cursor)
|
|
|
|
response_json = self._fetch_page(params=params)
|
|
for item in self._extract_items(response_json):
|
|
yield item
|
|
|
|
@retry_builder(
|
|
tries=5, delay=1, max_delay=30, backoff=2,
|
|
exceptions=(requests.ConnectionError, requests.Timeout, requests.HTTPError),
|
|
)
|
|
def _fetch_page(self, params: Dict[str, Any]) -> Any:
|
|
"""Fetch a single page with retry and exponential backoff."""
|
|
headers = {**self._base_headers, **self._auth_headers}
|
|
|
|
merged: Dict[str, Any] = {**self._url_params}
|
|
merged.update(self._explicit_query_params)
|
|
merged.update(params)
|
|
|
|
url, query_params = self._build_url_with_templates(merged)
|
|
|
|
sensitive = {"authorization", "apikey", "api-key", "x-api-key"}
|
|
logging.debug(
|
|
"REST API request: %s %s | params=%s | headers=%s",
|
|
self.method, url,
|
|
{k: ("***" if k.lower() in sensitive else v) for k, v in query_params.items()},
|
|
{k: ("***" if k.lower() in sensitive else v) for k, v in headers.items()},
|
|
)
|
|
|
|
if self.method == "GET":
|
|
resp = rl_requests.get(url, headers=headers, params=query_params, auth=self._basic_auth, timeout=60)
|
|
elif self.method == "POST":
|
|
resp = rl_requests.post(
|
|
url, headers=headers, params=query_params,
|
|
json=self._static_request_body or {}, auth=self._basic_auth, timeout=60,
|
|
)
|
|
else:
|
|
raise ConnectorValidationError(f"Unsupported HTTP method: {self.method}")
|
|
|
|
try:
|
|
resp.raise_for_status()
|
|
except requests.HTTPError as exc:
|
|
status = exc.response.status_code if exc.response is not None else None
|
|
if status in (401, 403):
|
|
sensitive = {"authorization", "apikey", "api-key", "x-api-key"}
|
|
logging.warning(
|
|
"REST API %d for %s %s | auth_type=%s | "
|
|
"request header keys=%s | auth_header keys=%s",
|
|
status, self.method, resp.url,
|
|
self.auth_type,
|
|
[k for k in headers],
|
|
[k for k in self._auth_headers],
|
|
)
|
|
raise ConnectorMissingCredentialError(
|
|
f"REST API authentication failed with status {status}"
|
|
) from exc
|
|
if status is not None and 400 <= status < 500 and status != 429:
|
|
logging.warning(
|
|
"REST API client error %d for %s %s; not retrying.",
|
|
status,
|
|
self.method,
|
|
resp.url,
|
|
)
|
|
raise ConnectorValidationError(
|
|
f"REST API request failed with non-retriable client error status {status}"
|
|
) from exc
|
|
raise
|
|
|
|
try:
|
|
return resp.json()
|
|
except ValueError as exc:
|
|
raise ConnectorValidationError("REST API response is not valid JSON") from exc
|
|
|
|
def _build_url_with_templates(self, params: Dict[str, Any]) -> tuple[str, Dict[str, Any]]:
|
|
"""Substitute ``{key}`` placeholders in the URL; return remaining query params."""
|
|
url = self.url
|
|
query_params = dict(params)
|
|
used_keys: List[str] = []
|
|
for key, value in list(query_params.items()):
|
|
placeholder = "{" + key + "}"
|
|
if placeholder in url:
|
|
url = url.replace(placeholder, str(value))
|
|
used_keys.append(key)
|
|
for key in used_keys:
|
|
query_params.pop(key, None)
|
|
return url, query_params
|
|
|
|
# -- Pagination helpers -------------------------------------------------
|
|
|
|
def _resolve_page_size(self) -> int:
|
|
"""Determine per-page size from config, query params, or batch_size fallback.
|
|
|
|
Priority: explicit ``page_size`` in pagination_config > value already
|
|
present in user query params for the same param name > batch_size.
|
|
"""
|
|
explicit = self.pagination_config.get("page_size")
|
|
if explicit is not None:
|
|
val = int(explicit)
|
|
if val > 0:
|
|
return val
|
|
|
|
size_param = self.pagination_config.get("page_size_param") or self.pagination_config.get("limit_param")
|
|
if size_param:
|
|
for source in (self._explicit_query_params, self._url_params):
|
|
if size_param in source:
|
|
try:
|
|
val = int(source[size_param])
|
|
if val > 0:
|
|
return val
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
return self.batch_size
|
|
|
|
def _apply_page_pagination(self, params: Dict[str, Any], page: int, per_page: int) -> None:
|
|
params[self.pagination_config.get("page_param", "page")] = page
|
|
size_param = self.pagination_config.get("page_size_param")
|
|
if size_param:
|
|
params[size_param] = per_page
|
|
|
|
def _apply_offset_pagination(self, params: Dict[str, Any], offset: int, limit: int) -> None:
|
|
params[self.pagination_config.get("offset_param", "offset")] = offset
|
|
limit_param = self.pagination_config.get("limit_param")
|
|
if limit_param:
|
|
params[limit_param] = limit
|
|
|
|
def _apply_cursor_pagination(self, params: Dict[str, Any], cursor: str) -> None:
|
|
params[self.pagination_config.get("cursor_param", "cursor")] = cursor
|
|
|
|
# -- JSON extraction ----------------------------------------------------
|
|
|
|
def _extract_items(self, response_json: Any) -> List[Mapping[str, Any]]:
|
|
"""Extract the items array from a JSON response."""
|
|
if self.items_path and _jsonpath is not None:
|
|
try:
|
|
matches = _jsonpath(response_json, self.items_path)
|
|
except Exception as exc:
|
|
raise ConnectorValidationError(
|
|
f"Failed to apply items JSONPath '{self.items_path}': {exc}"
|
|
) from exc
|
|
if not matches:
|
|
return []
|
|
if len(matches) == 1 and isinstance(matches[0], list):
|
|
items = matches[0]
|
|
else:
|
|
items = matches
|
|
elif isinstance(response_json, list):
|
|
items = response_json
|
|
elif isinstance(response_json, dict):
|
|
items = []
|
|
for key in ("items", "results", "data", "records"):
|
|
if key in response_json and isinstance(response_json[key], list):
|
|
items = response_json[key]
|
|
break
|
|
else:
|
|
for value in response_json.values():
|
|
if isinstance(value, list):
|
|
items = value
|
|
break
|
|
else:
|
|
items = []
|
|
|
|
return [it for it in items if isinstance(it, Mapping)]
|
|
|
|
def _extract_next_cursor(self, response_json: Any) -> Optional[str]:
|
|
"""Extract cursor value for cursor-based pagination."""
|
|
cursor_path = self.pagination_config.get("next_cursor_path")
|
|
if not cursor_path:
|
|
field = self.pagination_config.get("next_cursor_field")
|
|
if field and isinstance(response_json, Mapping):
|
|
value = response_json.get(field)
|
|
return str(value) if value is not None else None
|
|
return None
|
|
|
|
if _jsonpath is None:
|
|
return None
|
|
|
|
try:
|
|
matches = _jsonpath(response_json, cursor_path)
|
|
except Exception:
|
|
return None
|
|
|
|
if not matches:
|
|
return None
|
|
return str(matches[0]) if matches[0] is not None else None
|
|
|
|
# -- Item → Document mapping --------------------------------------------
|
|
|
|
def _item_to_document(self, item: Mapping[str, Any]) -> Document:
|
|
"""Map a single API item to a ``Document``."""
|
|
raw_id = self._get_typed_field_value(self.id_field, item) if self.id_field else None
|
|
if raw_id is None:
|
|
raw_id = hash128(f"rest_api_item:{repr(item)}")
|
|
doc_id = hash128(f"rest_api:{raw_id}")
|
|
|
|
if self.content_template:
|
|
content_text = self._render_content_template(item)
|
|
else:
|
|
parts = []
|
|
for field in self.content_fields:
|
|
val = self._get_typed_field_value(field, item)
|
|
if val is not None:
|
|
text = self._strip_html(self._coerce_to_text(val))
|
|
if text:
|
|
parts.append(text)
|
|
content_text = "\n".join(parts)
|
|
blob = content_text.encode("utf-8")
|
|
|
|
metadata: Dict[str, Any] = {}
|
|
for field in self.metadata_fields:
|
|
value = self._get_typed_field_value(field, item)
|
|
if value is not None:
|
|
metadata[field] = self._serialize_metadata_value(value)
|
|
|
|
doc_updated_at = self._extract_timestamp(item) or datetime.now(timezone.utc)
|
|
|
|
sem = str(self._extract_field(item, self.content_fields[0]) if self.content_fields else raw_id)
|
|
sem = self._strip_html(sem).replace("\n", " ").replace("\r", " ").strip()[:100] or str(doc_id)
|
|
|
|
return Document(
|
|
id=doc_id,
|
|
source=DocumentSource.REST_API,
|
|
semantic_identifier=sem,
|
|
extension=".txt",
|
|
blob=blob,
|
|
doc_updated_at=doc_updated_at,
|
|
size_bytes=len(blob),
|
|
metadata=metadata or None,
|
|
)
|
|
|
|
# -- Field extraction ---------------------------------------------------
|
|
|
|
def _extract_field(self, item: Mapping[str, Any], path: str) -> Any:
|
|
"""Extract a value using dot-notation with optional array indexing.
|
|
|
|
Examples: ``country.name``, ``tags[0].label``, ``tags[*].label``
|
|
"""
|
|
values = self._extract_field_values(item, path)
|
|
if not values:
|
|
return None
|
|
return values[0] if len(values) == 1 else values
|
|
|
|
def _extract_field_values(self, item: Mapping[str, Any], path: str) -> List[Any]:
|
|
"""Return all raw values for a dot-notation field path with wildcards."""
|
|
if not path:
|
|
return []
|
|
|
|
current_values: List[Any] = [item]
|
|
for segment in path.split("."):
|
|
if not segment:
|
|
return []
|
|
|
|
match = _FIELD_SEGMENT_RE.match(segment)
|
|
key = segment
|
|
index: Optional[str] = None
|
|
if match:
|
|
key = match.group("key")
|
|
index = match.group("index")
|
|
|
|
next_values: List[Any] = []
|
|
for value in current_values:
|
|
if not isinstance(value, Mapping):
|
|
continue
|
|
child = value.get(key)
|
|
if child is None:
|
|
continue
|
|
if index is None:
|
|
next_values.append(child)
|
|
elif not isinstance(child, list):
|
|
continue
|
|
elif index == "*":
|
|
next_values.extend(child)
|
|
else:
|
|
try:
|
|
idx = int(index)
|
|
except ValueError:
|
|
continue
|
|
if 0 <= idx < len(child):
|
|
next_values.append(child[idx])
|
|
|
|
current_values = next_values
|
|
if not current_values:
|
|
break
|
|
|
|
return current_values
|
|
|
|
def _get_typed_field_value(self, path: str, item: Mapping[str, Any]) -> Any:
|
|
"""Extract a field value, applying type hints, defaults, and array joining."""
|
|
values = self._extract_field_values(item, path)
|
|
if not values:
|
|
return self.field_default_values.get(path)
|
|
|
|
hint = self.field_type_hints.get(path)
|
|
|
|
def _convert(v: Any) -> Any:
|
|
if hint == "string":
|
|
return "" if v is None else str(v)
|
|
if hint == "number":
|
|
if v is None:
|
|
return None
|
|
try:
|
|
num = float(v)
|
|
return int(num) if num.is_integer() else num
|
|
except Exception:
|
|
return None
|
|
if hint == "date":
|
|
if isinstance(v, datetime):
|
|
return v.isoformat()
|
|
dt = self._parse_datetime(v)
|
|
if dt is not None:
|
|
return dt.isoformat()
|
|
return str(v) if v is not None else None
|
|
return v
|
|
|
|
converted = [_convert(v) for v in values]
|
|
non_null = [v for v in converted if v is not None]
|
|
if not non_null:
|
|
return None
|
|
if len(non_null) == 1:
|
|
return non_null[0]
|
|
return ", ".join(self._coerce_to_text(v) for v in non_null)
|
|
|
|
# -- Timestamp parsing --------------------------------------------------
|
|
|
|
def _extract_timestamp(self, item: Mapping[str, Any]) -> Optional[datetime]:
|
|
"""Extract and normalise a timestamp from ``poll_timestamp_field``."""
|
|
if not self.poll_timestamp_field:
|
|
return None
|
|
|
|
value = self._extract_field(item, self.poll_timestamp_field)
|
|
if isinstance(value, list) and value:
|
|
value = value[0]
|
|
return self._parse_datetime(value)
|
|
|
|
@staticmethod
|
|
def _parse_datetime(value: Any) -> Optional[datetime]:
|
|
"""Parse a raw value into a UTC datetime, or return None."""
|
|
if value is None:
|
|
return None
|
|
|
|
if isinstance(value, datetime):
|
|
return (value if value.tzinfo else value.replace(tzinfo=timezone.utc)).astimezone(timezone.utc)
|
|
|
|
if isinstance(value, (int, float)):
|
|
try:
|
|
return datetime.fromtimestamp(float(value), tz=timezone.utc)
|
|
except Exception:
|
|
return None
|
|
|
|
if isinstance(value, str):
|
|
ts = value.strip()
|
|
for fmt in ("%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d"):
|
|
try:
|
|
return datetime.strptime(ts, fmt).replace(tzinfo=timezone.utc)
|
|
except Exception:
|
|
continue
|
|
try:
|
|
dt = datetime.fromisoformat(ts.replace("Z", "+00:00").replace(" ", "T"))
|
|
return (dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)).astimezone(timezone.utc)
|
|
except Exception:
|
|
return None
|
|
|
|
return None
|
|
|
|
# -- Content template rendering -----------------------------------------
|
|
|
|
class _SafeDict(dict):
|
|
"""Dict subclass that returns empty string for missing keys in format_map."""
|
|
def __missing__(self, key: str) -> str:
|
|
return ""
|
|
|
|
def _render_content_template(self, item: Mapping[str, Any]) -> str:
|
|
"""Render content using a user-provided template with ``{field}`` placeholders."""
|
|
template = self.content_template or ""
|
|
values: Dict[str, str] = {}
|
|
for field_path in set(self.content_fields + self.metadata_fields):
|
|
val = self._get_typed_field_value(field_path, item)
|
|
if val is None:
|
|
continue
|
|
name = re.sub(r"\[\d+\]|\[\*\]", "", field_path).replace(".", "_")
|
|
values[name] = self._coerce_to_text(val)
|
|
|
|
try:
|
|
rendered = template.format_map(self._SafeDict(values))
|
|
except Exception as exc:
|
|
logging.warning("Failed to render content template: %s", exc)
|
|
parts = [self._coerce_to_text(self._get_typed_field_value(f, item)) for f in self.content_fields]
|
|
rendered = "\n".join(p for p in parts if p)
|
|
|
|
return self._strip_html(rendered)
|
|
|
|
# -- Static helpers -----------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _strip_html(text: str) -> str:
|
|
"""Remove basic HTML tags and normalise whitespace."""
|
|
if "<" not in text or ">" not in text:
|
|
return text
|
|
cleaned = re.sub(r"<[^>]+>", " ", text)
|
|
return re.sub(r"\s+", " ", cleaned).strip()
|
|
|
|
@staticmethod
|
|
def _coerce_to_text(value: Any) -> str:
|
|
"""Convert any value to a plain-text string."""
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, str):
|
|
return value
|
|
if isinstance(value, (int, float, bool)):
|
|
return str(value)
|
|
try:
|
|
return json.dumps(value, ensure_ascii=False)
|
|
except Exception:
|
|
return str(value)
|
|
|
|
@staticmethod
|
|
def _serialize_metadata_value(value: Any) -> Any:
|
|
"""Serialise a metadata value for storage."""
|
|
if isinstance(value, datetime):
|
|
if value.tzinfo is None:
|
|
value = value.replace(tzinfo=timezone.utc)
|
|
return value.isoformat()
|
|
if isinstance(value, (int, float, bool, str)):
|
|
return value
|
|
try:
|
|
return json.dumps(value, ensure_ascii=False)
|
|
except Exception:
|
|
return str(value)
|
|
|
|
@staticmethod
|
|
def _doc_in_time_window(doc: Document, start: datetime, end: datetime) -> bool:
|
|
if not doc.doc_updated_at:
|
|
return False
|
|
dt = doc.doc_updated_at
|
|
dt = (dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)).astimezone(timezone.utc)
|
|
return start <= dt < end
|