Files
ragflow/common/data_source/rest_api_connector.py
Ahmad Intisar e994051eb9 Feature/generic api connector (#13545)
# feat: Add Generic REST API Connector

## What problem does this PR solve?

RAGFlow supports many specific data source connectors (MySQL, Slack,
Google Drive, etc.), but there was no way to connect an arbitrary REST
API as a data source. Users with custom or third-party APIs had to write
a new connector class for each one.

This PR adds a **generic, configuration-driven REST API connector** that
lets users connect any REST API as a data source entirely through the UI
— no code changes needed per API.

---

## Features

### Core Connector (`common/data_source/rest_api_connector.py`)

- Implements `LoadConnector` and `PollConnector` interfaces for full and
incremental sync
- **Configurable authentication:** None, API Key (custom header), Bearer
Token, Basic Auth
- **Pluggable pagination:** Page-based, Offset-based, Cursor-based, or
None
- Smart page-size inference from user's query parameters to avoid
duplicate/conflicting params
- Configurable request delay between pages to prevent API rate limiting
- Auto-detection of the items array in JSON responses (`items`,
`results`, `data`, `records`, or first list found)
- **Advanced field mapping** with dot-notation (`country.name`), array
wildcards (`newsType[*].name`), type hints, and default values
- Optional content template rendering (`"Title: {title}\nBody: {body}"`)
- HTML stripping for content fields
- Stable document IDs via `hash128` from a configurable ID field or
auto-generated from item content
- Pydantic configuration schema with automatic coercion of UI string
inputs to dicts/lists

### Backend Registration (`rag/svr/sync_data_source.py`,
`common/constants.py`, `common/data_source/config.py`)

- `REST_API` sync class wired into RAGFlow's `func_factory`
- Full sync (`load_from_state`) and incremental polling (`poll_source`)
support
- Credentials and config passed from task to connector following
existing patterns (MySQL, SeaFile, etc.)

### Test Connection Endpoint (`api/apps/connector_app.py`)

- `POST /v1/connector/<id>/test` validates config schema,
authentication, and API connectivity without triggering a sync
- Clear error messages for auth failures vs. config issues

### Frontend UI (`web/src/pages/user-setting/data-source/constant/`)

- **Postman-style configuration:** Base URL, Query Parameters (key=value
per line), Auth, Content Fields, Metadata Fields, Pagination Type
- Auth-type-aware form: fields for API key header/value, Bearer token,
or Basic username/password appear only when relevant
- **Advanced Settings** toggle for: Custom Headers, Max Pages, Request
Delay, Poll Timestamp Field, Request Body (POST)
- Connector icon (SVG) and i18n strings (English)
- **"Test Connection"** button to validate before syncing

---

## Controls & Safety

- Configurable max pages safety cap (default: 1000, adjustable in UI)
- Configurable request delay between pages (default: 0.5s, adjustable in
UI)
- Auth errors (401/403) fail immediately without retries; transient
errors retry with exponential backoff
- Diagnostic logging: auth setup confirmation, request details on
failure, content field extraction status

---

## Type of change

- [x] New Feature (non-breaking change which adds functionality)


##Visual Screenshots of Features
<img width="482" height="510" alt="Screenshot 2026-03-11 at 5 19 52 PM"
src="https://github.com/user-attachments/assets/dcb7ab4a-1622-44f3-bb02-d6f0527314c4"
/>
(Connector can be configured within the external data sources tab)

Configuration Parameters:
<img width="661" height="682" alt="Screenshot 2026-03-11 at 5 20 46 PM"
src="https://github.com/user-attachments/assets/5e154e71-4ab5-4872-bfb2-04f02b73c18a"
/>
<img width="661" height="682" alt="Screenshot 2026-03-11 at 5 20 54 PM"
src="https://github.com/user-attachments/assets/00cb14b7-0bcf-4b94-9d71-34e93369ecb2"
/>

Connection can be tested before attaching to dataset:
<img width="981" height="681" alt="Screenshot 2026-03-11 at 5 21 40 PM"
src="https://github.com/user-attachments/assets/aaa6eeeb-89a7-4349-bc34-2423bf8be9ee"
/>

Ingestion tested with API connector (works perfectly fine):
<img width="1062" height="705" alt="Screenshot 2026-03-11 at 5 22 30 PM"
src="https://github.com/user-attachments/assets/afcd0d58-cadd-4152-badc-d2f14d96fbec"
/>

Search & Retrieval works as well with metadata flow:
<img width="1062" height="705" alt="Screenshot 2026-03-11 at 5 23 05 PM"
src="https://github.com/user-attachments/assets/d41ee935-dcf7-4456-b317-22a76ca032c0"
/>

---------

Co-authored-by: Ahmad Intisar <ahmadintisar@Ahmads-MacBook-M4-Pro.local>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-13 20:35:01 +08:00

1013 lines
39 KiB
Python

"""Generic, configuration-driven REST API data source connector.
Connect any REST API as a RAGFlow data source without code changes.
All behaviour — URL, auth, pagination, field mapping — is controlled
via the ``RestAPIConnectorConfig`` schema exposed by the UI.
"""
from __future__ import annotations
import json
import logging
import re
import time
from datetime import datetime, timezone
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional
from urllib.parse import parse_qs, urlparse, urlunparse
import ipaddress
import socket
import requests
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, ValidationError, field_validator
logger = logging.getLogger(__name__)
from api.utils.common import hash128
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
from common.data_source.exceptions import (
ConnectorMissingCredentialError,
ConnectorValidationError,
)
from common.data_source.interfaces import (
LoadConnector,
PollConnector,
SecondsSinceUnixEpoch,
)
from common.data_source.models import Document
from common.data_source.utils import rl_requests, retry_builder
try:
from jsonpath import jsonpath as _jsonpath # type: ignore[import]
except Exception: # pragma: no cover
_jsonpath = None
_FIELD_SEGMENT_RE = re.compile(r'^(?P<key>[^\[\]]+)(\[(?P<index>\d+|\*)\])?$')
_DEFAULT_MAX_PAGES = 1000
class AuthType:
NONE = "none"
API_KEY_HEADER = "api_key_header"
BEARER = "bearer"
BASIC = "basic"
class PaginationType:
NONE = "none"
PAGE = "page"
OFFSET = "offset"
CURSOR = "cursor"
def _text_to_dict(v: Any) -> Dict[str, str]:
"""Parse a dict, JSON string, or ``key=value`` text (one per line) into a dict.
This is module-level because Pydantic ``@field_validator`` classmethods
on ``RestAPIConnectorConfig`` need to call it before any instance exists.
"""
if v is None or v == "":
return {}
if isinstance(v, dict):
return {str(k): str(vv) for k, vv in v.items()}
if isinstance(v, str):
try:
parsed = json.loads(v)
if isinstance(parsed, dict):
return {str(k): str(vv) for k, vv in parsed.items()}
except Exception:
pass
result: Dict[str, str] = {}
for line in v.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
k, _, val = line.partition("=")
result[k.strip()] = val.strip()
return result
return {}
class RestAPIConnectorConfig(BaseModel):
"""Validated schema for the REST API connector configuration."""
model_config = ConfigDict(extra="ignore")
url: HttpUrl
method: str = "GET"
headers: Dict[str, str] = Field(default_factory=dict)
query_params: Dict[str, str] = Field(default_factory=dict)
auth_type: str = AuthType.NONE
auth_config: Dict[str, Any] = Field(default_factory=dict)
items_path: Optional[str] = None
id_field: Optional[str] = None
content_fields: List[str] = Field(default_factory=list)
metadata_fields: List[str] = Field(default_factory=list)
pagination_type: str = PaginationType.NONE
pagination_config: Dict[str, Any] = Field(default_factory=dict)
poll_timestamp_field: Optional[str] = None
request_body: Optional[Dict[str, Any]] = None
field_type_hints: Dict[str, str] = Field(default_factory=dict)
field_default_values: Dict[str, Any] = Field(default_factory=dict)
content_template: Optional[str] = None
batch_size: int = INDEX_BATCH_SIZE
max_pages: int = _DEFAULT_MAX_PAGES
request_delay: float = 0.5
@field_validator("headers", mode="before")
@classmethod
def _coerce_headers(cls, v: Any) -> Dict[str, str]:
return _text_to_dict(v)
@field_validator("query_params", mode="before")
@classmethod
def _coerce_query_params(cls, v: Any) -> Dict[str, str]:
return _text_to_dict(v)
@field_validator("content_fields", "metadata_fields", mode="before")
@classmethod
def _coerce_field_list(cls, v: Any) -> List[str]:
if v is None or v == "":
return []
if isinstance(v, str):
return [p.strip() for p in v.split(",") if p.strip()]
if isinstance(v, list):
return [str(p).strip() for p in v if str(p).strip()]
return []
def normalized_method(self) -> str:
m = (self.method or "GET").upper()
if m not in {"GET", "POST"}:
raise ConnectorValidationError(f"Unsupported HTTP method '{m}'.")
return m
def normalized_auth_type(self) -> str:
if self.auth_type not in {AuthType.NONE, AuthType.API_KEY_HEADER, AuthType.BEARER, AuthType.BASIC}:
raise ConnectorValidationError(f"Unsupported auth_type '{self.auth_type}'.")
return self.auth_type
def normalized_pagination_type(self) -> str:
if self.pagination_type not in {PaginationType.NONE, PaginationType.PAGE, PaginationType.OFFSET, PaginationType.CURSOR}:
raise ConnectorValidationError(f"Unsupported pagination_type '{self.pagination_type}'.")
return self.pagination_type
def ensure_required_fields(self) -> None:
if not self.content_fields:
raise ConnectorValidationError("At least one content field must be configured (content_fields).")
class RestAPIConnector(LoadConnector, PollConnector):
"""Configuration-driven REST API connector.
Implements ``LoadConnector`` and ``PollConnector`` to fetch documents
from any REST API using user-provided configuration (URL, auth,
pagination, field mapping).
"""
@staticmethod
def _validate_url_for_ssrf(url: str) -> None:
"""Validate that the URL does not point to localhost or private/internal networks.
Raises:
ConnectorValidationError: If the URL is considered unsafe.
"""
parsed = urlparse(str(url))
if parsed.scheme not in ("http", "https"):
msg = f"Unsupported URL scheme for REST API connector: {parsed.scheme!r}. Only http/https are allowed."
logger.warning(msg)
raise ConnectorValidationError(msg)
hostname = parsed.hostname
if not hostname:
msg = "REST API connector URL must include a hostname."
logger.warning(msg)
raise ConnectorValidationError(msg)
# Quick checks for obvious localhost-style hostnames.
lower_host = hostname.lower()
if lower_host in ("localhost",):
msg = f"REST API connector URL hostname {hostname!r} is not allowed (localhost is blocked)."
logger.warning(msg)
raise ConnectorValidationError(msg)
try:
addrinfo_list = socket.getaddrinfo(hostname, None)
except OSError as exc:
# If resolution fails, log and let higher-level validation (if any) decide.
# We do not treat this as an SSRF condition by itself.
logger.info("DNS resolution failed for REST API connector URL %r: %s", url, exc)
return
for family, _, _, _, sockaddr in addrinfo_list:
ip_str = sockaddr[0]
try:
ip_obj = ipaddress.ip_address(ip_str)
except ValueError:
# Not an IP address we understand; skip.
logger.debug("Skipping non-IP address resolved from %r: %r", hostname, ip_str)
continue
if (
ip_obj.is_loopback
or ip_obj.is_private
or ip_obj.is_link_local
or ip_obj.is_reserved
or ip_obj.is_multicast
):
msg = (
f"REST API connector URL {url!r} resolves to disallowed address {ip_str} "
"(localhost, private, link-local, reserved, or multicast addresses are blocked)."
)
logger.warning(msg)
raise ConnectorValidationError(msg)
logger.debug("REST API connector URL %r passed SSRF safety validation.", url)
def __init__(
self,
url: str,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
query_params: Optional[Dict[str, str]] = None,
auth_type: str = AuthType.NONE,
auth_config: Optional[Dict[str, Any]] = None,
items_path: Optional[str] = None,
id_field: Optional[str] = None,
content_fields: Optional[List[str]] = None,
metadata_fields: Optional[List[str]] = None,
pagination_type: str = PaginationType.NONE,
pagination_config: Optional[Dict[str, Any]] = None,
poll_timestamp_field: Optional[str] = None,
batch_size: int = INDEX_BATCH_SIZE,
max_pages: int = _DEFAULT_MAX_PAGES,
request_delay: float = 0.5,
request_body: Optional[Dict[str, Any]] = None,
field_type_hints: Optional[Dict[str, str]] = None,
field_default_values: Optional[Dict[str, Any]] = None,
content_template: Optional[str] = None,
) -> None:
# Validate URL against SSRF-style targets (localhost, private/internal ranges, etc.)
self._validate_url_for_ssrf(url)
parsed = urlparse(str(url))
self._base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, "", "", ""))
self._url_params: Dict[str, str] = {}
if parsed.query:
for k, v_list in parse_qs(parsed.query, keep_blank_values=True).items():
self._url_params[k] = v_list[-1]
self._explicit_query_params: Dict[str, str] = (
_text_to_dict(query_params) if isinstance(query_params, str) else (query_params or {})
)
self.url = self._base_url
self.method = (method or "GET").upper()
self._base_headers: Dict[str, str] = (
_text_to_dict(headers) if isinstance(headers, str) else (headers or {})
)
self.auth_type = auth_type or AuthType.NONE
self.auth_config: Dict[str, Any] = auth_config or {}
self.items_path = items_path
self.id_field = id_field
self.content_fields: List[str] = content_fields or []
self.metadata_fields: List[str] = metadata_fields or []
self.pagination_type = pagination_type or PaginationType.NONE
self.pagination_config: Dict[str, Any] = pagination_config or {}
self._static_request_body: Dict[str, Any] = (
request_body if request_body is not None
else self.pagination_config.get("request_body") or {}
)
self.poll_timestamp_field = poll_timestamp_field
self.batch_size = batch_size
self.max_pages = max_pages
self.request_delay = max(request_delay, 0.0)
self.field_type_hints: Dict[str, str] = field_type_hints or {}
self.field_default_values: Dict[str, Any] = field_default_values or {}
self.content_template = content_template
self._credentials: Dict[str, Any] = {}
self._auth_headers: Dict[str, str] = {}
self._basic_auth: Optional[requests.auth.HTTPBasicAuth] = None
# -- Credentials --------------------------------------------------------
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
"""Apply authentication credentials (no network call).
Use ``validate_config()`` to perform a live connectivity check.
"""
self._credentials = credentials or {}
self._build_auth()
return None
def _build_auth(self) -> None:
"""Derive auth headers / basic-auth object from credentials."""
self._auth_headers = {}
self._basic_auth = None
if self.auth_type == AuthType.NONE:
logging.info("REST API auth_type=none, no authentication configured.")
return
if self.auth_type == AuthType.API_KEY_HEADER:
header_name = self.auth_config.get("header_name")
api_key = (
self._credentials.get("api_key")
or self.auth_config.get("api_key_value")
or self.auth_config.get("api_key")
)
if not header_name or not api_key:
logging.warning(
"REST API auth setup failed: header_name=%s, api_key present=%s, "
"credentials keys=%s, auth_config keys=%s",
header_name, bool(api_key),
list(self._credentials.keys()), list(self.auth_config.keys()),
)
raise ConnectorMissingCredentialError(
"REST API (api_key_header) requires 'header_name' in auth_config and 'api_key' in credentials"
)
self._auth_headers[header_name] = str(api_key)
logging.info("REST API auth configured: header '%s' set.", header_name)
return
if self.auth_type == AuthType.BEARER:
token = self._credentials.get("token") or self.auth_config.get("token")
if not token:
raise ConnectorMissingCredentialError("REST API (bearer) requires 'token' in credentials")
self._auth_headers["Authorization"] = f"Bearer {token}"
logging.info("REST API auth configured: Bearer token set.")
return
if self.auth_type == AuthType.BASIC:
username = self._credentials.get("username") or self.auth_config.get("username")
password = self._credentials.get("password") or self.auth_config.get("password")
if not username or password is None:
raise ConnectorMissingCredentialError("REST API (basic) requires 'username' and 'password'")
self._basic_auth = requests.auth.HTTPBasicAuth(str(username), str(password))
logging.info("REST API auth configured: Basic auth for user '%s'.", username)
return
raise ConnectorValidationError(f"Unsupported auth_type: {self.auth_type}")
# -- Config validation (test connection) --------------------------------
@classmethod
def parse_storage_config(cls, raw: Dict[str, Any]) -> RestAPIConnectorConfig:
"""Parse connector config as stored on the connector row (no network I/O).
``credentials`` live under ``raw`` but are excluded from the schema and
must be applied via ``load_credentials`` separately.
"""
body = {k: v for k, v in raw.items() if k != "credentials"}
try:
cfg = RestAPIConnectorConfig(**body)
except ValidationError as exc:
raise ConnectorValidationError(f"Invalid REST API config: {exc}") from exc
cfg.normalized_method()
cfg.normalized_auth_type()
cfg.normalized_pagination_type()
cfg.ensure_required_fields()
return cfg
@classmethod
def from_parsed_config(
cls,
cfg: RestAPIConnectorConfig,
*,
max_pages: Optional[int] = None,
) -> RestAPIConnector:
"""Build a connector from validated config (``__init__`` runs SSRF validation)."""
return cls(
url=str(cfg.url),
method=cfg.normalized_method(),
headers=cfg.headers,
query_params=cfg.query_params,
auth_type=cfg.normalized_auth_type(),
auth_config=cfg.auth_config,
items_path=cfg.items_path,
id_field=cfg.id_field,
content_fields=cfg.content_fields,
metadata_fields=cfg.metadata_fields,
pagination_type=cfg.normalized_pagination_type(),
pagination_config=cfg.pagination_config,
poll_timestamp_field=cfg.poll_timestamp_field,
batch_size=cfg.batch_size,
max_pages=max_pages if max_pages is not None else cfg.max_pages,
request_delay=cfg.request_delay,
request_body=cfg.request_body,
field_type_hints=cfg.field_type_hints,
field_default_values=cfg.field_default_values,
content_template=cfg.content_template,
)
@classmethod
def validate_config(
cls,
config: Dict[str, Any],
credentials: Optional[Dict[str, Any]] = None,
) -> RestAPIConnectorConfig:
"""Validate config schema and optionally perform a live API call.
Args:
config: Raw config dict from the UI / database.
credentials: Optional credentials dict; when provided a live
connectivity check is performed.
Returns:
The validated ``RestAPIConnectorConfig`` instance.
Raises:
ConnectorValidationError: On schema or connectivity failure.
"""
cfg = cls.parse_storage_config(config)
validation_cap = min(cfg.max_pages, 10)
connector = cls.from_parsed_config(cfg, max_pages=validation_cap)
if credentials is None and cfg.auth_type != AuthType.NONE:
return cfg
if credentials is not None:
connector.load_credentials(credentials)
else:
connector._credentials = {}
connector._build_auth()
try:
logging.info("Validating REST API connector by fetching first page")
_ = next(connector._page_iter_for_validation())
except StopIteration:
pass
return cfg
# -- LoadConnector / PollConnector interface -----------------------------
def load_from_state(self) -> Generator[List[Document], None, None]:
"""Full fetch with pagination."""
return self._yield_documents(time_window=None)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[List[Document], None, None]:
"""Incremental fetch; filters by ``poll_timestamp_field`` if configured."""
if not self.poll_timestamp_field:
logging.warning(
"poll_source called without poll_timestamp_field; "
"falling back to full fetch with in-memory filtering."
)
return self._yield_documents(
time_window=(
datetime.fromtimestamp(start, tz=timezone.utc),
datetime.fromtimestamp(end, tz=timezone.utc),
)
)
# -- Document generation ------------------------------------------------
def _yield_documents(
self,
time_window: tuple[datetime, datetime] | None,
) -> Generator[List[Document], None, None]:
batch: List[Document] = []
for item in self._iter_items():
try:
doc = self._item_to_document(item)
except Exception as exc:
logging.warning("Failed to convert REST API item to Document: %s", exc)
continue
if time_window is not None and not self._doc_in_time_window(doc, *time_window):
continue
batch.append(doc)
if len(batch) >= self.batch_size:
yield batch
batch = []
if batch:
yield batch
# -- Pagination & page fetching -----------------------------------------
def _iter_items(self) -> Iterable[Mapping[str, Any]]:
"""Iterate over raw items across all pages."""
page_count = 0
page = int(self.pagination_config.get("start_page", 1))
per_page = self._resolve_page_size()
offset = int(self.pagination_config.get("start_offset", 0))
limit = int(self.pagination_config.get("limit", per_page))
if limit <= 0:
limit = per_page
cursor: Optional[str] = self.pagination_config.get("initial_cursor")
while True:
if page_count >= self.max_pages:
logging.warning("REST API connector reached max_pages=%d, stopping.", self.max_pages)
break
params: Dict[str, Any] = {}
if self.pagination_type == PaginationType.PAGE:
self._apply_page_pagination(params, page, per_page)
elif self.pagination_type == PaginationType.OFFSET:
self._apply_offset_pagination(params, offset, limit)
elif self.pagination_type == PaginationType.CURSOR and cursor is not None:
self._apply_cursor_pagination(params, cursor)
if page_count > 0 and self.request_delay > 0:
time.sleep(self.request_delay)
try:
response_json = self._fetch_page(params)
except (ConnectorValidationError, ConnectorMissingCredentialError):
raise
except Exception as exc:
raise ConnectorValidationError(f"REST API page fetch failed: {exc}") from exc
items = self._extract_items(response_json)
if not items:
break
for item in items:
if isinstance(item, Mapping):
yield item
page_count += 1
if self.pagination_type == PaginationType.NONE:
break
elif self.pagination_type == PaginationType.PAGE:
if len(items) < per_page:
break
page += 1
elif self.pagination_type == PaginationType.OFFSET:
if len(items) < limit:
break
offset += limit
elif self.pagination_type == PaginationType.CURSOR:
next_cursor = self._extract_next_cursor(response_json)
if not next_cursor:
break
cursor = next_cursor
def _page_iter_for_validation(self) -> Iterable[Mapping[str, Any]]:
"""Single-page iterator used for connectivity checks."""
params: Dict[str, Any] = {}
if self.pagination_type == PaginationType.PAGE:
page = int(self.pagination_config.get("start_page", 1))
per_page = self._resolve_page_size()
self._apply_page_pagination(params, page, per_page)
elif self.pagination_type == PaginationType.OFFSET:
per_page = self._resolve_page_size()
offset = int(self.pagination_config.get("start_offset", 0))
limit = int(self.pagination_config.get("limit", per_page))
if limit <= 0:
limit = per_page
self._apply_offset_pagination(params, offset, limit)
elif self.pagination_type == PaginationType.CURSOR:
cursor = self.pagination_config.get("initial_cursor")
if cursor is not None:
self._apply_cursor_pagination(params, cursor)
response_json = self._fetch_page(params=params)
for item in self._extract_items(response_json):
yield item
@retry_builder(
tries=5, delay=1, max_delay=30, backoff=2,
exceptions=(requests.ConnectionError, requests.Timeout, requests.HTTPError),
)
def _fetch_page(self, params: Dict[str, Any]) -> Any:
"""Fetch a single page with retry and exponential backoff."""
headers = {**self._base_headers, **self._auth_headers}
merged: Dict[str, Any] = {**self._url_params}
merged.update(self._explicit_query_params)
merged.update(params)
url, query_params = self._build_url_with_templates(merged)
sensitive = {"authorization", "apikey", "api-key", "x-api-key"}
logging.debug(
"REST API request: %s %s | params=%s | headers=%s",
self.method, url,
{k: ("***" if k.lower() in sensitive else v) for k, v in query_params.items()},
{k: ("***" if k.lower() in sensitive else v) for k, v in headers.items()},
)
if self.method == "GET":
resp = rl_requests.get(url, headers=headers, params=query_params, auth=self._basic_auth, timeout=60)
elif self.method == "POST":
resp = rl_requests.post(
url, headers=headers, params=query_params,
json=self._static_request_body or {}, auth=self._basic_auth, timeout=60,
)
else:
raise ConnectorValidationError(f"Unsupported HTTP method: {self.method}")
try:
resp.raise_for_status()
except requests.HTTPError as exc:
status = exc.response.status_code if exc.response is not None else None
if status in (401, 403):
sensitive = {"authorization", "apikey", "api-key", "x-api-key"}
logging.warning(
"REST API %d for %s %s | auth_type=%s | "
"request header keys=%s | auth_header keys=%s",
status, self.method, resp.url,
self.auth_type,
[k for k in headers],
[k for k in self._auth_headers],
)
raise ConnectorMissingCredentialError(
f"REST API authentication failed with status {status}"
) from exc
if status is not None and 400 <= status < 500 and status != 429:
logging.warning(
"REST API client error %d for %s %s; not retrying.",
status,
self.method,
resp.url,
)
raise ConnectorValidationError(
f"REST API request failed with non-retriable client error status {status}"
) from exc
raise
try:
return resp.json()
except ValueError as exc:
raise ConnectorValidationError("REST API response is not valid JSON") from exc
def _build_url_with_templates(self, params: Dict[str, Any]) -> tuple[str, Dict[str, Any]]:
"""Substitute ``{key}`` placeholders in the URL; return remaining query params."""
url = self.url
query_params = dict(params)
used_keys: List[str] = []
for key, value in list(query_params.items()):
placeholder = "{" + key + "}"
if placeholder in url:
url = url.replace(placeholder, str(value))
used_keys.append(key)
for key in used_keys:
query_params.pop(key, None)
return url, query_params
# -- Pagination helpers -------------------------------------------------
def _resolve_page_size(self) -> int:
"""Determine per-page size from config, query params, or batch_size fallback.
Priority: explicit ``page_size`` in pagination_config > value already
present in user query params for the same param name > batch_size.
"""
explicit = self.pagination_config.get("page_size")
if explicit is not None:
val = int(explicit)
if val > 0:
return val
size_param = self.pagination_config.get("page_size_param") or self.pagination_config.get("limit_param")
if size_param:
for source in (self._explicit_query_params, self._url_params):
if size_param in source:
try:
val = int(source[size_param])
if val > 0:
return val
except (ValueError, TypeError):
pass
return self.batch_size
def _apply_page_pagination(self, params: Dict[str, Any], page: int, per_page: int) -> None:
params[self.pagination_config.get("page_param", "page")] = page
size_param = self.pagination_config.get("page_size_param")
if size_param:
params[size_param] = per_page
def _apply_offset_pagination(self, params: Dict[str, Any], offset: int, limit: int) -> None:
params[self.pagination_config.get("offset_param", "offset")] = offset
limit_param = self.pagination_config.get("limit_param")
if limit_param:
params[limit_param] = limit
def _apply_cursor_pagination(self, params: Dict[str, Any], cursor: str) -> None:
params[self.pagination_config.get("cursor_param", "cursor")] = cursor
# -- JSON extraction ----------------------------------------------------
def _extract_items(self, response_json: Any) -> List[Mapping[str, Any]]:
"""Extract the items array from a JSON response."""
if self.items_path and _jsonpath is not None:
try:
matches = _jsonpath(response_json, self.items_path)
except Exception as exc:
raise ConnectorValidationError(
f"Failed to apply items JSONPath '{self.items_path}': {exc}"
) from exc
if not matches:
return []
if len(matches) == 1 and isinstance(matches[0], list):
items = matches[0]
else:
items = matches
elif isinstance(response_json, list):
items = response_json
elif isinstance(response_json, dict):
items = []
for key in ("items", "results", "data", "records"):
if key in response_json and isinstance(response_json[key], list):
items = response_json[key]
break
else:
for value in response_json.values():
if isinstance(value, list):
items = value
break
else:
items = []
return [it for it in items if isinstance(it, Mapping)]
def _extract_next_cursor(self, response_json: Any) -> Optional[str]:
"""Extract cursor value for cursor-based pagination."""
cursor_path = self.pagination_config.get("next_cursor_path")
if not cursor_path:
field = self.pagination_config.get("next_cursor_field")
if field and isinstance(response_json, Mapping):
value = response_json.get(field)
return str(value) if value is not None else None
return None
if _jsonpath is None:
return None
try:
matches = _jsonpath(response_json, cursor_path)
except Exception:
return None
if not matches:
return None
return str(matches[0]) if matches[0] is not None else None
# -- Item → Document mapping --------------------------------------------
def _item_to_document(self, item: Mapping[str, Any]) -> Document:
"""Map a single API item to a ``Document``."""
raw_id = self._get_typed_field_value(self.id_field, item) if self.id_field else None
if raw_id is None:
raw_id = hash128(f"rest_api_item:{repr(item)}")
doc_id = hash128(f"rest_api:{raw_id}")
if self.content_template:
content_text = self._render_content_template(item)
else:
parts = []
for field in self.content_fields:
val = self._get_typed_field_value(field, item)
if val is not None:
text = self._strip_html(self._coerce_to_text(val))
if text:
parts.append(text)
content_text = "\n".join(parts)
blob = content_text.encode("utf-8")
metadata: Dict[str, Any] = {}
for field in self.metadata_fields:
value = self._get_typed_field_value(field, item)
if value is not None:
metadata[field] = self._serialize_metadata_value(value)
doc_updated_at = self._extract_timestamp(item) or datetime.now(timezone.utc)
sem = str(self._extract_field(item, self.content_fields[0]) if self.content_fields else raw_id)
sem = self._strip_html(sem).replace("\n", " ").replace("\r", " ").strip()[:100] or str(doc_id)
return Document(
id=doc_id,
source=DocumentSource.REST_API,
semantic_identifier=sem,
extension=".txt",
blob=blob,
doc_updated_at=doc_updated_at,
size_bytes=len(blob),
metadata=metadata or None,
)
# -- Field extraction ---------------------------------------------------
def _extract_field(self, item: Mapping[str, Any], path: str) -> Any:
"""Extract a value using dot-notation with optional array indexing.
Examples: ``country.name``, ``tags[0].label``, ``tags[*].label``
"""
values = self._extract_field_values(item, path)
if not values:
return None
return values[0] if len(values) == 1 else values
def _extract_field_values(self, item: Mapping[str, Any], path: str) -> List[Any]:
"""Return all raw values for a dot-notation field path with wildcards."""
if not path:
return []
current_values: List[Any] = [item]
for segment in path.split("."):
if not segment:
return []
match = _FIELD_SEGMENT_RE.match(segment)
key = segment
index: Optional[str] = None
if match:
key = match.group("key")
index = match.group("index")
next_values: List[Any] = []
for value in current_values:
if not isinstance(value, Mapping):
continue
child = value.get(key)
if child is None:
continue
if index is None:
next_values.append(child)
elif not isinstance(child, list):
continue
elif index == "*":
next_values.extend(child)
else:
try:
idx = int(index)
except ValueError:
continue
if 0 <= idx < len(child):
next_values.append(child[idx])
current_values = next_values
if not current_values:
break
return current_values
def _get_typed_field_value(self, path: str, item: Mapping[str, Any]) -> Any:
"""Extract a field value, applying type hints, defaults, and array joining."""
values = self._extract_field_values(item, path)
if not values:
return self.field_default_values.get(path)
hint = self.field_type_hints.get(path)
def _convert(v: Any) -> Any:
if hint == "string":
return "" if v is None else str(v)
if hint == "number":
if v is None:
return None
try:
num = float(v)
return int(num) if num.is_integer() else num
except Exception:
return None
if hint == "date":
if isinstance(v, datetime):
return v.isoformat()
dt = self._parse_datetime(v)
if dt is not None:
return dt.isoformat()
return str(v) if v is not None else None
return v
converted = [_convert(v) for v in values]
non_null = [v for v in converted if v is not None]
if not non_null:
return None
if len(non_null) == 1:
return non_null[0]
return ", ".join(self._coerce_to_text(v) for v in non_null)
# -- Timestamp parsing --------------------------------------------------
def _extract_timestamp(self, item: Mapping[str, Any]) -> Optional[datetime]:
"""Extract and normalise a timestamp from ``poll_timestamp_field``."""
if not self.poll_timestamp_field:
return None
value = self._extract_field(item, self.poll_timestamp_field)
if isinstance(value, list) and value:
value = value[0]
return self._parse_datetime(value)
@staticmethod
def _parse_datetime(value: Any) -> Optional[datetime]:
"""Parse a raw value into a UTC datetime, or return None."""
if value is None:
return None
if isinstance(value, datetime):
return (value if value.tzinfo else value.replace(tzinfo=timezone.utc)).astimezone(timezone.utc)
if isinstance(value, (int, float)):
try:
return datetime.fromtimestamp(float(value), tz=timezone.utc)
except Exception:
return None
if isinstance(value, str):
ts = value.strip()
for fmt in ("%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d"):
try:
return datetime.strptime(ts, fmt).replace(tzinfo=timezone.utc)
except Exception:
continue
try:
dt = datetime.fromisoformat(ts.replace("Z", "+00:00").replace(" ", "T"))
return (dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)).astimezone(timezone.utc)
except Exception:
return None
return None
# -- Content template rendering -----------------------------------------
class _SafeDict(dict):
"""Dict subclass that returns empty string for missing keys in format_map."""
def __missing__(self, key: str) -> str:
return ""
def _render_content_template(self, item: Mapping[str, Any]) -> str:
"""Render content using a user-provided template with ``{field}`` placeholders."""
template = self.content_template or ""
values: Dict[str, str] = {}
for field_path in set(self.content_fields + self.metadata_fields):
val = self._get_typed_field_value(field_path, item)
if val is None:
continue
name = re.sub(r"\[\d+\]|\[\*\]", "", field_path).replace(".", "_")
values[name] = self._coerce_to_text(val)
try:
rendered = template.format_map(self._SafeDict(values))
except Exception as exc:
logging.warning("Failed to render content template: %s", exc)
parts = [self._coerce_to_text(self._get_typed_field_value(f, item)) for f in self.content_fields]
rendered = "\n".join(p for p in parts if p)
return self._strip_html(rendered)
# -- Static helpers -----------------------------------------------------
@staticmethod
def _strip_html(text: str) -> str:
"""Remove basic HTML tags and normalise whitespace."""
if "<" not in text or ">" not in text:
return text
cleaned = re.sub(r"<[^>]+>", " ", text)
return re.sub(r"\s+", " ", cleaned).strip()
@staticmethod
def _coerce_to_text(value: Any) -> str:
"""Convert any value to a plain-text string."""
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, (int, float, bool)):
return str(value)
try:
return json.dumps(value, ensure_ascii=False)
except Exception:
return str(value)
@staticmethod
def _serialize_metadata_value(value: Any) -> Any:
"""Serialise a metadata value for storage."""
if isinstance(value, datetime):
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.isoformat()
if isinstance(value, (int, float, bool, str)):
return value
try:
return json.dumps(value, ensure_ascii=False)
except Exception:
return str(value)
@staticmethod
def _doc_in_time_window(doc: Document, start: datetime, end: datetime) -> bool:
if not doc.doc_updated_at:
return False
dt = doc.doc_updated_at
dt = (dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)).astimezone(timezone.utc)
return start <= dt < end