mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-23 01:18:22 +08:00
### What problem does this PR solve? S3-family connector syncs currently re-download every in-window object just so we can compute `xxhash128(blob)` and compare against `Document.content_hash`. Anything that bumps `LastModified` without changing bytes (`aws s3 cp` touches, bucket re-encryption, etc.) pays full bandwidth and re-parses files that didn't actually change. #14628 covers the broader incremental-ingestion redesign; this PR is the first slice. The fix is a pre-listing short-circuit. `BlobStorageConnector` (S3 / R2 / GCS / OCI / S3-compat) now implements a new `FingerprintConnector` interface: `list_keys()` paginates `list_objects_v2` and yields `KeyRecord(key, fingerprint)` where `fingerprint = xxhash128(ETag)`. The orchestrator joins those against the connector's existing `{doc_id: content_hash}` map and only calls `get_value(key)` when the fingerprint differs. Unchanged keys are skipped entirely — no `GetObject`, no re-parse. No DDL. xxhash128(ETag) is 32 hex chars and reuses the existing `Document.content_hash` column per @yingfeng's suggestion; the connector decides at listing time whether to populate it. Local uploads and connectors that don't opt in fall through to the existing post-download `xxhash128(blob)` path with no behavior change. This is PR-1 of a 4-PR series — full design lives on #14628. Subsequent PRs extend tier 1 to local FS / WebDAV / Dropbox / Seafile / RDBMS (PR-2), wire up tier 2 cursor connectors with `SyncLogs.next_checkpoint` (PR-3), and unify deletion via `KeyRecord(deleted=True)` reconciliation (PR-4). Holding those back keeps this PR additive and reviewable on its own. #### Files touched - `common/data_source/models.py` — new `KeyRecord`; optional `fingerprint` on `Document` - `common/data_source/interfaces.py` — `IncrementalCapability` enum, `FingerprintConnector` ABC - `common/data_source/blob_connector.py` — `BlobStorageConnector` implements `FingerprintConnector`; per-object download factored into `_build_document_from_obj()` so `_yield_blob_objects`, `list_keys`, `get_value` all share it - `rag/svr/sync_data_source.py` — `_BlobLikeBase._fingerprint_filtered_generator` does the bypass loop; `_run_task_logic` plumbs `doc.fingerprint` into the upload dict - `api/db/services/document_service.py` — `list_id_content_hash_map_by_kb_and_source_type()` helper - `api/db/services/connector_service.py` + `file_service.py` — fingerprint flows through `duplicate_and_parse → upload_document` and lands in `content_hash` - `test/unit_test/common/test_blob_connector_fingerprint.py` — 14 tests covering ETag normalization (single-part, multipart, quoted, empty), `list_keys()` not calling `GetObject`, `get_value()` materializing with fingerprint, deterministic/stable fingerprints, and the bypass loop asserting `GetObject` is *not* called on a match #### Worth flagging for review Old `_BlobLikeBase._generate` called `poll_source(start, now)` with a `LastModified` window when `poll_range_start` was set. New code uses `_fingerprint_filtered_generator` (full bucket listing + fingerprint compare) outside of explicit `reindex=1`. Strictly better for unchanged-bucket cases since it skips `GetObject`, but it does mean every sync now does a full `list_objects_v2` paginate. Should still be cheap for most buckets — flagging in case anyone has a very large bucket where the time-window filter was meaningful. On migration: existing rows have `content_hash = xxhash128(blob)` from the old code. The first sync after this lands sees ETag-derived fingerprints that don't match, re-fetches every object once, and writes the new fingerprint. From the second sync onward the bypass works as expected. "Slow day one, fast every day after." A `fingerprint_backfill: trust` opt-out is sketched in the design doc but not in this PR. #### Test plan - [x] `uv run ruff check` — clean on all 8 touched files - [x] `uv run pytest test/unit_test/common/test_blob_connector_fingerprint.py -v` — 14 passed - [x] Broader unit-test suite — no regressions in anything I touched - [ ] Manual smoke against a real S3 bucket — configure a connector, run sync twice, expect the second sync to log `bypassed=N, fetched=0` and no `GetObject` calls in CloudTrail / bucket access logs - [ ] Manual smoke with `reindex=1` — confirm the full re-download path still works ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
430 lines
16 KiB
Python
430 lines
16 KiB
Python
"""Blob storage connector"""
|
|
import logging
|
|
import os
|
|
from collections.abc import Iterator
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Optional
|
|
|
|
import xxhash
|
|
|
|
from common.data_source.utils import (
|
|
create_s3_client,
|
|
detect_bucket_region,
|
|
download_object,
|
|
extract_size_bytes,
|
|
get_file_ext,
|
|
)
|
|
from common.data_source.config import BlobType, DocumentSource, BLOB_STORAGE_SIZE_THRESHOLD, INDEX_BATCH_SIZE
|
|
from common.data_source.exceptions import (
|
|
ConnectorMissingCredentialError,
|
|
ConnectorValidationError,
|
|
CredentialExpiredError,
|
|
InsufficientPermissionsError
|
|
)
|
|
from common.data_source.interfaces import (
|
|
FingerprintConnector,
|
|
LoadConnector,
|
|
PollConnector,
|
|
)
|
|
from common.data_source.models import (
|
|
Document,
|
|
KeyRecord,
|
|
SecondsSinceUnixEpoch,
|
|
GenerateDocumentsOutput,
|
|
GenerateSlimDocumentOutput,
|
|
SlimDocument,
|
|
)
|
|
|
|
|
|
def _normalize_etag(raw_etag: Optional[str]) -> Optional[str]:
|
|
"""Return a 32-char hex fingerprint derived from an S3 ETag.
|
|
|
|
S3 ETags are MD5 (32 hex chars) for single-part uploads and "<md5>-<n>"
|
|
(34+ chars) for multipart. We always hash so the column format is uniform
|
|
regardless of upload type or provider quirks; equality of the hashed value
|
|
is sufficient for change detection.
|
|
"""
|
|
if not raw_etag:
|
|
return None
|
|
return xxhash.xxh128(raw_etag.strip('"').encode()).hexdigest()
|
|
|
|
|
|
class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector):
|
|
"""Blob storage connector"""
|
|
|
|
def __init__(
|
|
self,
|
|
bucket_type: str,
|
|
bucket_name: str,
|
|
prefix: str = "",
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
european_residency: bool = False,
|
|
) -> None:
|
|
self.bucket_type: BlobType = BlobType(bucket_type)
|
|
self.bucket_name = bucket_name.strip()
|
|
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
|
|
self.batch_size = batch_size
|
|
self.s3_client: Optional[Any] = None
|
|
self._allow_images: bool | None = None
|
|
self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD
|
|
self.bucket_region: Optional[str] = None
|
|
self.european_residency: bool = european_residency
|
|
# Populated by list_keys() so a subsequent get_value(key) can find the
|
|
# raw S3 object metadata (LastModified, ETag, Key, Size) without a second
|
|
# head_object call. Lifetime is one list_keys() pass.
|
|
self._listing_cache: dict[str, dict[str, Any]] = {}
|
|
self._filename_counts: dict[str, int] = {}
|
|
|
|
def set_allow_images(self, allow_images: bool) -> None:
|
|
"""Set whether to process images"""
|
|
logging.info(f"Setting allow_images to {allow_images}.")
|
|
self._allow_images = allow_images
|
|
|
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
|
"""Load credentials"""
|
|
logging.debug(
|
|
f"Loading credentials for {self.bucket_name} of type {self.bucket_type}"
|
|
)
|
|
|
|
# Validate credentials
|
|
if self.bucket_type == BlobType.R2:
|
|
if not all(
|
|
credentials.get(key)
|
|
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
|
|
):
|
|
raise ConnectorMissingCredentialError("Cloudflare R2")
|
|
|
|
elif self.bucket_type == BlobType.S3:
|
|
authentication_method = credentials.get("authentication_method", "access_key")
|
|
|
|
if authentication_method == "access_key":
|
|
if not all(
|
|
credentials.get(key)
|
|
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
|
):
|
|
raise ConnectorMissingCredentialError("Amazon S3")
|
|
|
|
elif authentication_method == "iam_role":
|
|
if not credentials.get("aws_role_arn"):
|
|
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
|
|
|
|
elif authentication_method == "assume_role":
|
|
pass
|
|
|
|
else:
|
|
raise ConnectorMissingCredentialError("Unsupported S3 authentication method")
|
|
|
|
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
|
if not all(
|
|
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
|
|
):
|
|
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
|
|
|
elif self.bucket_type == BlobType.OCI_STORAGE:
|
|
if not all(
|
|
credentials.get(key)
|
|
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
|
|
):
|
|
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
|
|
|
|
elif self.bucket_type == BlobType.S3_COMPATIBLE:
|
|
if not all(
|
|
credentials.get(key)
|
|
for key in ["endpoint_url", "aws_access_key_id", "aws_secret_access_key", "addressing_style"]
|
|
):
|
|
raise ConnectorMissingCredentialError("S3 Compatible Storage")
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
|
|
|
# Create S3 client
|
|
self.s3_client = create_s3_client(
|
|
self.bucket_type, credentials, self.european_residency
|
|
)
|
|
|
|
# Detect bucket region (only important for S3)
|
|
if self.bucket_type == BlobType.S3:
|
|
self.bucket_region = detect_bucket_region(self.s3_client, self.bucket_name)
|
|
|
|
return None
|
|
|
|
def _build_document_from_obj(
|
|
self,
|
|
obj: dict[str, Any],
|
|
filename_counts: dict[str, int],
|
|
) -> Optional[Document]:
|
|
"""Materialize a Document for one S3 object, downloading its body."""
|
|
key = obj["Key"]
|
|
file_name = os.path.basename(key)
|
|
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
|
|
|
size_bytes = extract_size_bytes(obj)
|
|
if (
|
|
self.size_threshold is not None
|
|
and isinstance(size_bytes, int)
|
|
and size_bytes > self.size_threshold
|
|
):
|
|
logging.warning(
|
|
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
|
)
|
|
return None
|
|
|
|
blob = download_object(
|
|
self.s3_client, self.bucket_name, key, self.size_threshold
|
|
)
|
|
if blob is None:
|
|
return None
|
|
|
|
return Document(
|
|
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
|
blob=blob,
|
|
source=DocumentSource(self.bucket_type.value),
|
|
semantic_identifier=self._get_semantic_id(key, file_name, filename_counts),
|
|
extension=get_file_ext(file_name),
|
|
doc_updated_at=last_modified,
|
|
size_bytes=size_bytes if size_bytes else 0,
|
|
fingerprint=_normalize_etag(obj.get("ETag")),
|
|
)
|
|
|
|
def _yield_blob_objects(
|
|
self,
|
|
start: datetime,
|
|
end: datetime,
|
|
) -> GenerateDocumentsOutput:
|
|
"""Generate bucket objects"""
|
|
all_objects, filename_counts = self._collect_blob_objects(start, end)
|
|
|
|
batch: list[Document] = []
|
|
for obj in all_objects:
|
|
try:
|
|
doc = self._build_document_from_obj(obj, filename_counts)
|
|
if doc is None:
|
|
continue
|
|
batch.append(doc)
|
|
if len(batch) == self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
except Exception:
|
|
logging.exception(f"Error decoding object {obj.get('Key')}")
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
def list_keys(self) -> Iterator[KeyRecord]:
|
|
"""Enumerate the full bucket keyspace with per-object fingerprints.
|
|
|
|
Cheap path: relies on list_objects_v2 which returns ETag in the listing,
|
|
so no GetObject call is needed. Caches each object's metadata so a
|
|
subsequent get_value(key) call can rebuild the Document without a second
|
|
round-trip to S3.
|
|
"""
|
|
if self.s3_client is None:
|
|
raise ConnectorMissingCredentialError("Blob storage")
|
|
|
|
all_objects, filename_counts = self._collect_blob_objects(
|
|
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
|
end=datetime.now(timezone.utc),
|
|
)
|
|
self._filename_counts = filename_counts
|
|
self._listing_cache = {}
|
|
|
|
for obj in all_objects:
|
|
doc_id = f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}"
|
|
self._listing_cache[doc_id] = obj
|
|
yield KeyRecord(
|
|
key=doc_id,
|
|
fingerprint=_normalize_etag(obj.get("ETag")),
|
|
)
|
|
|
|
def get_value(self, key: str) -> Document:
|
|
"""Materialize the Document for a key previously yielded by list_keys().
|
|
|
|
Must be called within the same list_keys() pass that produced the key,
|
|
since the metadata cache lives on the connector instance and is reset
|
|
each list_keys() call.
|
|
"""
|
|
obj = self._listing_cache.get(key)
|
|
if obj is None:
|
|
raise KeyError(
|
|
f"get_value({key!r}) called before list_keys() yielded the key, "
|
|
"or after a subsequent list_keys() reset the cache"
|
|
)
|
|
doc = self._build_document_from_obj(obj, self._filename_counts)
|
|
if doc is None:
|
|
raise RuntimeError(f"Failed to materialize Document for key {key!r}")
|
|
return doc
|
|
|
|
def _collect_blob_objects(
|
|
self,
|
|
start: datetime,
|
|
end: datetime,
|
|
) -> tuple[list[dict[str, Any]], dict[str, int]]:
|
|
"""Collect object metadata for files in the requested window."""
|
|
if self.s3_client is None:
|
|
raise ConnectorMissingCredentialError("Blob storage")
|
|
|
|
paginator = self.s3_client.get_paginator("list_objects_v2")
|
|
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
|
|
|
# Collect all objects first to count filename occurrences
|
|
all_objects: list[dict[str, Any]] = []
|
|
for page in pages:
|
|
if "Contents" not in page:
|
|
continue
|
|
for obj in page["Contents"]:
|
|
if obj["Key"].endswith("/"):
|
|
continue
|
|
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
|
if start < last_modified <= end:
|
|
all_objects.append(obj)
|
|
|
|
filename_counts: dict[str, int] = {}
|
|
for obj in all_objects:
|
|
file_name = os.path.basename(obj["Key"])
|
|
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
|
|
|
|
return all_objects, filename_counts
|
|
|
|
def _get_semantic_id(
|
|
self,
|
|
key: str,
|
|
file_name: str,
|
|
filename_counts: dict[str, int],
|
|
) -> str:
|
|
"""Use full relative path only when filenames collide."""
|
|
if filename_counts.get(file_name, 0) > 1:
|
|
relative_path = key
|
|
if self.prefix and key.startswith(self.prefix):
|
|
relative_path = key[len(self.prefix):]
|
|
return relative_path.replace("/", " / ") if relative_path else file_name
|
|
return file_name
|
|
|
|
def retrieve_all_slim_docs_perm_sync(
|
|
self,
|
|
callback: Any = None,
|
|
) -> GenerateSlimDocumentOutput:
|
|
"""Return a full current snapshot of blob object IDs without downloading content."""
|
|
del callback
|
|
|
|
all_objects, _ = self._collect_blob_objects(
|
|
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
|
end=datetime.now(timezone.utc),
|
|
)
|
|
|
|
batch: list[SlimDocument] = []
|
|
for obj in all_objects:
|
|
batch.append(
|
|
SlimDocument(id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}")
|
|
)
|
|
if len(batch) == self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
|
"""Load documents from state"""
|
|
logging.debug("Loading blob objects")
|
|
return self._yield_blob_objects(
|
|
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
|
end=datetime.now(timezone.utc),
|
|
)
|
|
|
|
def poll_source(
|
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
) -> GenerateDocumentsOutput:
|
|
"""Poll source to get documents"""
|
|
if self.s3_client is None:
|
|
raise ConnectorMissingCredentialError("Blob storage")
|
|
|
|
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
|
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
|
|
|
for batch in self._yield_blob_objects(start_datetime, end_datetime):
|
|
yield batch
|
|
|
|
def validate_connector_settings(self) -> None:
|
|
"""Validate connector settings"""
|
|
if self.s3_client is None:
|
|
raise ConnectorMissingCredentialError(
|
|
"Blob storage credentials not loaded."
|
|
)
|
|
|
|
if not self.bucket_name:
|
|
raise ConnectorValidationError(
|
|
"No bucket name was provided in connector settings."
|
|
)
|
|
|
|
try:
|
|
# Lightweight validation step
|
|
self.s3_client.list_objects_v2(
|
|
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
|
|
)
|
|
|
|
except Exception as e:
|
|
error_code = getattr(e, 'response', {}).get('Error', {}).get('Code', '')
|
|
status_code = getattr(e, 'response', {}).get('ResponseMetadata', {}).get('HTTPStatusCode')
|
|
|
|
# Common S3 error scenarios
|
|
if error_code in [
|
|
"AccessDenied",
|
|
"InvalidAccessKeyId",
|
|
"SignatureDoesNotMatch",
|
|
]:
|
|
if status_code == 403 or error_code == "AccessDenied":
|
|
raise InsufficientPermissionsError(
|
|
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
|
|
"Please check your bucket policy and/or IAM policy."
|
|
)
|
|
if status_code == 401 or error_code == "SignatureDoesNotMatch":
|
|
raise CredentialExpiredError(
|
|
"Provided blob storage credentials appear invalid or expired."
|
|
)
|
|
|
|
raise CredentialExpiredError(
|
|
f"Credential issue encountered ({error_code})."
|
|
)
|
|
|
|
if error_code == "NoSuchBucket" or status_code == 404:
|
|
raise ConnectorValidationError(
|
|
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
|
|
)
|
|
|
|
raise ConnectorValidationError(
|
|
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage
|
|
credentials_dict = {
|
|
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
|
|
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
|
}
|
|
|
|
# Initialize connector
|
|
connector = BlobStorageConnector(
|
|
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
|
|
bucket_name=os.environ.get("BUCKET_NAME") or "yyboombucket",
|
|
prefix="",
|
|
)
|
|
|
|
try:
|
|
connector.load_credentials(credentials_dict)
|
|
document_batch_generator = connector.load_from_state()
|
|
for document_batch in document_batch_generator:
|
|
print("First batch of documents:")
|
|
for doc in document_batch:
|
|
print(f"Document ID: {doc.id}")
|
|
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
|
print(f"Source: {doc.source}")
|
|
print(f"Updated At: {doc.doc_updated_at}")
|
|
print("---")
|
|
break
|
|
|
|
except ConnectorMissingCredentialError as e:
|
|
print(f"Error: {e}")
|
|
except Exception as e:
|
|
print(f"An unexpected error occurred: {e}")
|