refactor(api): tighten shared adapter typing contracts

This commit is contained in:
Yanli 盐粒
2026-03-17 19:47:16 +08:00
parent a717519822
commit 7572db15ff
24 changed files with 420 additions and 195 deletions

View File

@ -139,7 +139,7 @@ class HologresVector(BaseVector):
)
return bool(result)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
"""Get document IDs by metadata field key and value."""
result = self._client.execute(
psql.SQL("SELECT id FROM {} WHERE meta->>{} = {}").format(
@ -149,7 +149,7 @@ class HologresVector(BaseVector):
)
if result:
return [row[0] for row in result]
return None
return []
def delete_by_ids(self, ids: list[str]):
"""Delete documents by their doc_id list."""

View File

@ -130,8 +130,11 @@ class LindormVectorStore(BaseVector):
Field.METADATA_KEY: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
action_values[ROUTING_FIELD] = self._routing
routing = self._routing
if routing is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
action_header["index"]["routing"] = routing
action_values[ROUTING_FIELD] = routing
actions.append(action_header)
actions.append(action_values)

View File

@ -7,7 +7,9 @@ from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
_collection_name: str
def __init__(self, collection_name: str) -> None:
self._collection_name = collection_name
@abstractmethod
@ -30,7 +32,7 @@ class BaseVector(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
def get_ids_by_metadata_field(self, key: str, value: str):
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
raise NotImplementedError
@abstractmethod
@ -63,5 +65,5 @@ class BaseVector(ABC):
return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
@property
def collection_name(self):
def collection_name(self) -> str:
return self._collection_name

View File

@ -2,7 +2,8 @@ import base64
import logging
import time
from abc import ABC, abstractmethod
from typing import Any
from collections.abc import Sequence
from typing import Any, TypedDict
from sqlalchemy import select
@ -13,7 +14,7 @@ from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -24,19 +25,40 @@ from models.model import UploadFile
logger = logging.getLogger(__name__)
class VectorStoreIndexConfig(TypedDict):
class_prefix: str
class VectorIndexStructDict(TypedDict):
type: VectorType
vector_store: VectorStoreIndexConfig
class MultimodalEmbeddingPayload(TypedDict):
content: str
content_type: str
file_id: str
VectorDocumentInput = Document | ChildDocument | AttachmentDocument
class AbstractVectorFactory(ABC):
@abstractmethod
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
def init_vector(self, dataset: Dataset, attributes: list[str], embeddings: Embeddings) -> BaseVector:
raise NotImplementedError
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
index_struct_dict: VectorIndexStructDict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name},
}
return index_struct_dict
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
def __init__(self, dataset: Dataset, attributes: list[str] | None = None) -> None:
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
self._dataset = dataset
@ -198,7 +220,7 @@ class Vector:
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
def create(self, texts: list | None = None, **kwargs):
def create(self, texts: Sequence[Document | ChildDocument] | None = None, **kwargs: Any) -> None:
if texts:
start = time.time()
logger.info("start embedding %s texts %s", len(texts), start)
@ -212,10 +234,12 @@ class Vector:
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
self._vector_processor.create(
texts=self._normalize_documents(batch), embeddings=batch_embeddings, **kwargs
)
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
def create_multimodal(self, file_documents: list | None = None, **kwargs):
def create_multimodal(self, file_documents: list[AttachmentDocument] | None = None, **kwargs: Any) -> None:
if file_documents:
start = time.time()
logger.info("start embedding %s files %s", len(file_documents), start)
@ -227,14 +251,16 @@ class Vector:
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
# Batch query all upload files to avoid N+1 queries
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
attachment_ids = [doc.metadata["doc_id"] for doc in batch if doc.metadata is not None]
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
upload_files = db.session.scalars(stmt).all()
upload_file_map = {str(f.id): f for f in upload_files}
file_base64_list = []
real_batch = []
file_base64_list: list[dict[str, str]] = []
real_batch: list[AttachmentDocument] = []
for document in batch:
if document.metadata is None:
continue
attachment_id = document.metadata["doc_id"]
doc_type = document.metadata["doc_type"]
upload_file = upload_file_map.get(attachment_id)
@ -249,14 +275,20 @@ class Vector:
}
)
real_batch.append(document)
if not real_batch:
continue
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
self._vector_processor.create(
texts=self._normalize_documents(real_batch),
embeddings=batch_embeddings,
**kwargs,
)
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
def add_texts(self, documents: list[Document], **kwargs):
def add_texts(self, documents: list[Document], **kwargs: Any) -> None:
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
@ -266,10 +298,10 @@ class Vector:
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]):
def delete_by_ids(self, ids: list[str]) -> None:
self._vector_processor.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str):
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
@ -295,7 +327,7 @@ class Vector:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self):
def delete(self) -> None:
self._vector_processor.delete()
# delete collection redis cache
if self._vector_processor.collection_name:
@ -325,7 +357,26 @@ class Vector:
return texts
def __getattr__(self, name):
@staticmethod
def _normalize_documents(documents: Sequence[VectorDocumentInput]) -> list[Document]:
normalized_documents: list[Document] = []
for document in documents:
if isinstance(document, Document):
normalized_documents.append(document)
continue
normalized_documents.append(
Document(
page_content=document.page_content,
vector=document.vector,
metadata=document.metadata,
provider=document.provider if isinstance(document, AttachmentDocument) else "dify",
)
)
return normalized_documents
def __getattr__(self, name: str) -> Any:
if self._vector_processor is not None:
method = getattr(self._vector_processor, name)
if callable(method):

View File

@ -1,7 +1,11 @@
from typing import Literal
from pydantic import BaseModel, ConfigDict
from core.rag.extractor.entity.datasource_type import DatasourceType
from models.dataset import Document
from models.model import UploadFile
from services.auth.auth_type import AuthType
class NotionInfo(BaseModel):
@ -12,7 +16,7 @@ class NotionInfo(BaseModel):
credential_id: str | None = None
notion_workspace_id: str | None = ""
notion_obj_id: str
notion_page_type: str
notion_page_type: Literal["database", "page"]
document: Document | None = None
tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -25,10 +29,10 @@ class WebsiteInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
provider: str
provider: AuthType
job_id: str
url: str
mode: str
mode: Literal["crawl", "crawl_return_urls", "scrape"]
tenant_id: str
only_main_content: bool = False
@ -38,7 +42,7 @@ class ExtractSetting(BaseModel):
Model class for provider response.
"""
datasource_type: str
datasource_type: DatasourceType
upload_file: UploadFile | None = None
notion_info: NotionInfo | None = None
website_info: WebsiteInfo | None = None

View File

@ -1,7 +1,8 @@
import os
import re
import tempfile
from pathlib import Path
from typing import Union
from typing import TypeAlias
from urllib.parse import unquote
from configs import dify_config
@ -31,19 +32,27 @@ from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
from extensions.ext_storage import storage
from models.model import UploadFile
from services.auth.auth_type import AuthType
SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"]
USER_AGENT = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124"
" Safari/537.36"
)
ExtractProcessorOutput: TypeAlias = list[Document] | str
class ExtractProcessor:
@staticmethod
def _build_temp_file_path(temp_dir: str, suffix: str) -> str:
file_descriptor, file_path = tempfile.mkstemp(dir=temp_dir, suffix=suffix)
os.close(file_descriptor)
return file_path
@classmethod
def load_from_upload_file(
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
) -> Union[list[Document], str]:
) -> ExtractProcessorOutput:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=upload_file, document_model="text_model"
)
@ -54,7 +63,7 @@ class ExtractProcessor:
return cls.extract(extract_setting, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
def load_from_url(cls, url: str, return_text: bool = False) -> ExtractProcessorOutput:
response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT})
with tempfile.TemporaryDirectory() as temp_dir:
@ -65,17 +74,16 @@ class ExtractProcessor:
suffix = "." + response.headers.get("Content-Type").split("/")[-1]
else:
content_disposition = response.headers.get("Content-Disposition")
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
match = re.search(r"\.(\w+)$", filename)
if match:
suffix = "." + match.group(1)
else:
suffix = ""
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
# Generate a temporary filename under the created temp_dir and ensure the directory exists
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
if content_disposition:
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
match = re.search(r"\.(\w+)$", filename)
if match:
suffix = "." + match.group(1)
else:
suffix = ""
file_path = cls._build_temp_file_path(temp_dir, suffix)
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model")
if return_text:
@ -94,13 +102,12 @@ class ExtractProcessor:
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE:
upload_file = extract_setting.upload_file
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
assert extract_setting.upload_file is not None, "upload_file is required"
upload_file: UploadFile = extract_setting.upload_file
assert upload_file is not None, "upload_file is required"
suffix = Path(upload_file.key).suffix
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
file_path = cls._build_temp_file_path(temp_dir, suffix)
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
@ -113,6 +120,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
assert upload_file is not None, "upload_file is required for PDF extraction"
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
@ -123,6 +131,7 @@ class ExtractProcessor:
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
assert upload_file is not None, "upload_file is required for DOCX extraction"
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".doc":
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
@ -149,12 +158,14 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
assert upload_file is not None, "upload_file is required for PDF extraction"
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
assert upload_file is not None, "upload_file is required for DOCX extraction"
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)
@ -177,7 +188,7 @@ class ExtractProcessor:
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE:
assert extract_setting.website_info is not None, "website_info is required"
if extract_setting.website_info.provider == "firecrawl":
if extract_setting.website_info.provider == AuthType.FIRECRAWL:
extractor = FirecrawlWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
@ -186,7 +197,7 @@ class ExtractProcessor:
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
elif extract_setting.website_info.provider == "watercrawl":
elif extract_setting.website_info.provider == AuthType.WATERCRAWL:
extractor = WaterCrawlWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
@ -195,7 +206,7 @@ class ExtractProcessor:
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
elif extract_setting.website_info.provider == "jinareader":
elif extract_setting.website_info.provider == AuthType.JINA:
extractor = JinaReaderWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,

View File

@ -2,10 +2,12 @@
from abc import ABC, abstractmethod
from core.rag.models.document import Document
class BaseExtractor(ABC):
"""Interface for extract files."""
@abstractmethod
def extract(self):
def extract(self) -> list[Document]:
raise NotImplementedError

View File

@ -30,7 +30,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
For large files, reading only a sample is sufficient and prevents timeout.
"""
def read_and_detect(filename: str):
def read_and_detect(filename: str) -> list[FileEncoding]:
rst = charset_normalizer.from_path(filename)
best = rst.best()
if best is None:

View File

@ -3,7 +3,7 @@ import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
from typing import Any, Literal, TypedDict
from flask import current_app
from sqlalchemy import delete, func, select
@ -19,6 +19,16 @@ from .processor.paragraph_index_processor import ParagraphIndexProcessor
logger = logging.getLogger(__name__)
class IndexAndCleanResult(TypedDict):
dataset_id: str
dataset_name: str
batch: str
document_id: str
document_name: str
created_at: float
display_status: Literal["completed"]
class IndexProcessor:
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
@ -50,9 +60,9 @@ class IndexProcessor:
document_id: str,
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
):
batch: str,
summary_index_setting: dict[str, object] | None = None,
) -> IndexAndCleanResult:
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if not document:
@ -131,7 +141,12 @@ class IndexProcessor:
}
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
self,
chunks: Any,
dataset_id: str,
document_id: str,
chunk_structure: str,
summary_index_setting: dict[str, object] | None,
) -> Preview:
doc_language = None
with session_factory.create_session() as session:

View File

@ -138,18 +138,20 @@ class BaseIndexProcessor(ABC):
embedding_model_instance=embedding_model_instance,
)
return character_splitter # type: ignore
return character_splitter
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
"""
Get the content files from the document.
"""
multi_model_documents: list[AttachmentDocument] = []
if document.metadata is None:
return multi_model_documents
text = document.page_content
images = self._extract_markdown_images(text)
if not images:
return multi_model_documents
upload_file_id_list = []
upload_file_id_list: list[str] = []
for image in images:
# Collect all upload_file_ids including duplicates to preserve occurrence count

View File

@ -10,7 +10,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess
class IndexProcessorFactory:
"""IndexProcessorInit."""
def __init__(self, index_type: str | None):
def __init__(self, index_type: str | None) -> None:
self._index_type = index_type
def init_index_processor(self) -> BaseIndexProcessor:
@ -19,11 +19,12 @@ class IndexProcessorFactory:
if not self._index_type:
raise ValueError("Index type must be specified.")
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
elif self._index_type == IndexStructureType.QA_INDEX:
return QAIndexProcessor()
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")
match self._index_type:
case IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
case IndexStructureType.QA_INDEX:
return QAIndexProcessor()
case IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
case _:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@ -30,7 +30,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
**kwargs: Any,
):
) -> TS:
def _token_encoder(texts: list[str]) -> list[int]:
if not texts:
return []

View File

@ -8,7 +8,7 @@ class BaseStorage(ABC):
"""Interface for file storage."""
@abstractmethod
def save(self, filename: str, data: bytes):
def save(self, filename: str, data: bytes) -> None:
raise NotImplementedError
@abstractmethod
@ -16,7 +16,7 @@ class BaseStorage(ABC):
raise NotImplementedError
@abstractmethod
def load_stream(self, filename: str) -> Generator:
def load_stream(self, filename: str) -> Generator[bytes, None, None]:
raise NotImplementedError
@abstractmethod
@ -28,10 +28,10 @@ class BaseStorage(ABC):
raise NotImplementedError
@abstractmethod
def delete(self, filename: str):
def delete(self, filename: str) -> None:
raise NotImplementedError
def scan(self, path, files=True, directories=False) -> list[str]:
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
"""
Scan files and directories in the given path.
This method is implemented only in some storage backends.

View File

@ -43,58 +43,6 @@ core/ops/tencent_trace/utils.py
core/plugin/backwards_invocation/base.py
core/plugin/backwards_invocation/model.py
core/prompt/utils/extract_thread_messages.py
core/rag/datasource/keyword/jieba/jieba.py
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
core/rag/datasource/vdb/baidu/baidu_vector.py
core/rag/datasource/vdb/chroma/chroma_vector.py
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
core/rag/datasource/vdb/couchbase/couchbase_vector.py
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
core/rag/datasource/vdb/lindorm/lindorm_vector.py
core/rag/datasource/vdb/matrixone/matrixone_vector.py
core/rag/datasource/vdb/milvus/milvus_vector.py
core/rag/datasource/vdb/myscale/myscale_vector.py
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
core/rag/datasource/vdb/opensearch/opensearch_vector.py
core/rag/datasource/vdb/oracle/oraclevector.py
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
core/rag/datasource/vdb/relyt/relyt_vector.py
core/rag/datasource/vdb/tablestore/tablestore_vector.py
core/rag/datasource/vdb/tencent/tencent_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
core/rag/datasource/vdb/upstash/upstash_vector.py
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
core/rag/datasource/vdb/weaviate/weaviate_vector.py
core/rag/extractor/csv_extractor.py
core/rag/extractor/excel_extractor.py
core/rag/extractor/firecrawl/firecrawl_app.py
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
core/rag/extractor/html_extractor.py
core/rag/extractor/jina_reader_extractor.py
core/rag/extractor/markdown_extractor.py
core/rag/extractor/notion_extractor.py
core/rag/extractor/pdf_extractor.py
core/rag/extractor/text_extractor.py
core/rag/extractor/unstructured/unstructured_doc_extractor.py
core/rag/extractor/unstructured/unstructured_eml_extractor.py
core/rag/extractor/unstructured/unstructured_epub_extractor.py
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
core/rag/extractor/unstructured/unstructured_msg_extractor.py
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
core/rag/extractor/unstructured/unstructured_xml_extractor.py
core/rag/extractor/watercrawl/client.py
core/rag/extractor/watercrawl/extractor.py
core/rag/extractor/watercrawl/provider.py
core/rag/extractor/word_extractor.py
core/rag/index_processor/processor/paragraph_index_processor.py
core/rag/index_processor/processor/parent_child_index_processor.py
core/rag/index_processor/processor/qa_index_processor.py
core/rag/retrieval/router/multi_dataset_function_call_router.py
core/rag/summary_index/summary_index.py
core/repositories/sqlalchemy_workflow_execution_repository.py
@ -140,27 +88,10 @@ dify_graph/nodes/variable_assigner/v2/node.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py
extensions/otel/runtime.py
extensions/storage/aliyun_oss_storage.py
extensions/storage/aws_s3_storage.py
extensions/storage/azure_blob_storage.py
extensions/storage/baidu_obs_storage.py
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
extensions/storage/clickzetta_volume/file_lifecycle.py
extensions/storage/google_cloud_storage.py
extensions/storage/huawei_obs_storage.py
extensions/storage/opendal_storage.py
extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
libs/gmpy2_pkcs10aep_cipher.py
schedule/queue_monitor_task.py
services/account_service.py
services/audio_service.py
services/auth/firecrawl/firecrawl.py
services/auth/jina.py
services/auth/jina/jina.py
services/auth/watercrawl/watercrawl.py
services/conversation_service.py
services/dataset_service.py
services/document_indexing_proxy/document_indexing_task_proxy.py
@ -183,3 +114,84 @@ tasks/regenerate_summary_index_task.py
tasks/trigger_processing_tasks.py
tasks/workflow_cfs_scheduler/cfs_scheduler.py
tasks/workflow_execution_tasks.py
# not need to be fixed by now: storage adapters
extensions/storage/aliyun_oss_storage.py
extensions/storage/aws_s3_storage.py
extensions/storage/azure_blob_storage.py
extensions/storage/baidu_obs_storage.py
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
extensions/storage/clickzetta_volume/file_lifecycle.py
extensions/storage/google_cloud_storage.py
extensions/storage/huawei_obs_storage.py
extensions/storage/opendal_storage.py
extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
# not need to be fixed by now: auth adapters
services/auth/firecrawl/firecrawl.py
services/auth/jina.py
services/auth/jina/jina.py
services/auth/watercrawl/watercrawl.py
# not need to be fixed by now: keyword adapters
core/rag/datasource/keyword/jieba/jieba.py
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
# not need to be fixed by now: vector db adapters
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
core/rag/datasource/vdb/baidu/baidu_vector.py
core/rag/datasource/vdb/chroma/chroma_vector.py
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
core/rag/datasource/vdb/couchbase/couchbase_vector.py
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
core/rag/datasource/vdb/lindorm/lindorm_vector.py
core/rag/datasource/vdb/matrixone/matrixone_vector.py
core/rag/datasource/vdb/milvus/milvus_vector.py
core/rag/datasource/vdb/myscale/myscale_vector.py
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
core/rag/datasource/vdb/opensearch/opensearch_vector.py
core/rag/datasource/vdb/oracle/oraclevector.py
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
core/rag/datasource/vdb/relyt/relyt_vector.py
core/rag/datasource/vdb/tablestore/tablestore_vector.py
core/rag/datasource/vdb/tencent/tencent_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
core/rag/datasource/vdb/upstash/upstash_vector.py
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
core/rag/datasource/vdb/weaviate/weaviate_vector.py
# not need to be fixed by now: extractors
core/rag/extractor/csv_extractor.py
core/rag/extractor/excel_extractor.py
core/rag/extractor/firecrawl/firecrawl_app.py
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
core/rag/extractor/html_extractor.py
core/rag/extractor/jina_reader_extractor.py
core/rag/extractor/markdown_extractor.py
core/rag/extractor/notion_extractor.py
core/rag/extractor/pdf_extractor.py
core/rag/extractor/text_extractor.py
core/rag/extractor/unstructured/unstructured_doc_extractor.py
core/rag/extractor/unstructured/unstructured_eml_extractor.py
core/rag/extractor/unstructured/unstructured_epub_extractor.py
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
core/rag/extractor/unstructured/unstructured_msg_extractor.py
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
core/rag/extractor/unstructured/unstructured_xml_extractor.py
core/rag/extractor/watercrawl/client.py
core/rag/extractor/watercrawl/extractor.py
core/rag/extractor/watercrawl/provider.py
core/rag/extractor/word_extractor.py
# not need to be fixed by now: index processors
core/rag/index_processor/processor/paragraph_index_processor.py
core/rag/index_processor/processor/parent_child_index_processor.py
core/rag/index_processor/processor/qa_index_processor.py

View File

@ -1,10 +1,27 @@
from abc import ABC, abstractmethod
from typing import Annotated, NotRequired, TypedDict
from pydantic import StringConstraints
NonEmptyString = Annotated[str, StringConstraints(min_length=1)]
class ApiKeyAuthConfig(TypedDict):
api_key: NotRequired[str]
base_url: NotRequired[str]
class ApiKeyAuthCredentials(TypedDict):
auth_type: NonEmptyString
config: ApiKeyAuthConfig
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
credentials: ApiKeyAuthCredentials
def __init__(self, credentials: ApiKeyAuthCredentials) -> None:
self.credentials = credentials
@abstractmethod
def validate_credentials(self):
def validate_credentials(self) -> bool:
raise NotImplementedError

View File

@ -1,17 +1,19 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.auth_type import AuthType
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
from services.auth.auth_type import AuthProvider, AuthType
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
auth: ApiKeyAuthBase
def __init__(self, provider: AuthProvider, credentials: ApiKeyAuthCredentials) -> None:
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)
def validate_credentials(self):
def validate_credentials(self) -> bool:
return self.auth.validate_credentials()
@staticmethod
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
def get_apikey_auth_factory(provider: AuthProvider) -> type[ApiKeyAuthBase]:
match provider:
case AuthType.FIRECRAWL:
from services.auth.firecrawl.firecrawl import FirecrawlAuth

View File

@ -1,40 +1,75 @@
import json
from collections.abc import Mapping
from typing import Annotated, TypedDict, TypeVar
from pydantic import StringConstraints, TypeAdapter, ValidationError
from sqlalchemy import select
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_base import ApiKeyAuthConfig, ApiKeyAuthCredentials
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.auth_type import AuthProvider
NonEmptyString = Annotated[str, StringConstraints(min_length=1)]
ValidatedPayload = TypeVar("ValidatedPayload")
class ApiKeyAuthCreateArgs(TypedDict):
category: NonEmptyString
provider: NonEmptyString
credentials: ApiKeyAuthCredentials
AUTH_CREDENTIALS_ADAPTER = TypeAdapter(ApiKeyAuthCredentials)
AUTH_CREATE_ARGS_ADAPTER = TypeAdapter(ApiKeyAuthCreateArgs)
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str):
def get_provider_auth_list(tenant_id: str) -> list[DataSourceApiKeyAuthBinding]:
data_source_api_key_bindings = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
).all()
return data_source_api_key_bindings
return list(data_source_api_key_bindings)
@staticmethod
def create_provider_auth(tenant_id: str, args: dict):
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
@classmethod
def create_provider_auth(cls, tenant_id: str, args: Mapping[str, object] | ApiKeyAuthCreateArgs) -> None:
validated_args = cls.validate_api_key_auth_args(args)
auth_result = ApiKeyAuthFactory(
validated_args["provider"], validated_args["credentials"]
).validate_credentials()
if auth_result:
stored_config: ApiKeyAuthConfig = {}
if "api_key" in validated_args["credentials"]["config"]:
stored_config["api_key"] = validated_args["credentials"]["config"]["api_key"]
if "base_url" in validated_args["credentials"]["config"]:
stored_config["base_url"] = validated_args["credentials"]["config"]["base_url"]
stored_credentials: ApiKeyAuthCredentials = {
"auth_type": validated_args["credentials"]["auth_type"],
"config": stored_config,
}
# Encrypt the api key
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
args["credentials"]["config"]["api_key"] = api_key
api_key_value = stored_credentials["config"].get("api_key")
if api_key_value is None:
raise ValueError("credentials config api_key is required")
api_key = encrypter.encrypt_token(tenant_id, api_key_value)
stored_credentials["config"]["api_key"] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
tenant_id=tenant_id,
category=validated_args["category"],
provider=validated_args["provider"],
)
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
data_source_api_key_binding.credentials = json.dumps(stored_credentials, ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
def get_auth_credentials(tenant_id: str, category: str, provider: AuthProvider) -> ApiKeyAuthCredentials | None:
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(
@ -49,11 +84,11 @@ class ApiKeyAuthService:
return None
if not data_source_api_key_bindings.credentials:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
return credentials
raw_credentials = json.loads(data_source_api_key_bindings.credentials)
return ApiKeyAuthService._validate_credentials_payload(raw_credentials)
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
def delete_provider_auth(tenant_id: str, binding_id: str) -> None:
data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
@ -64,14 +99,16 @@ class ApiKeyAuthService:
db.session.commit()
@classmethod
def validate_api_key_auth_args(cls, args):
if "category" not in args or not args["category"]:
raise ValueError("category is required")
if "provider" not in args or not args["provider"]:
raise ValueError("provider is required")
if "credentials" not in args or not args["credentials"]:
raise ValueError("credentials is required")
if not isinstance(args["credentials"], dict):
raise ValueError("credentials must be a dictionary")
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
raise ValueError("auth_type is required")
def validate_api_key_auth_args(cls, args: Mapping[str, object] | None) -> ApiKeyAuthCreateArgs:
return cls._validate_payload(AUTH_CREATE_ARGS_ADAPTER, args)
@staticmethod
def _validate_credentials_payload(raw_credentials: object) -> ApiKeyAuthCredentials:
return ApiKeyAuthService._validate_payload(AUTH_CREDENTIALS_ADAPTER, raw_credentials)
@staticmethod
def _validate_payload(adapter: TypeAdapter[ValidatedPayload], payload: object) -> ValidatedPayload:
try:
return adapter.validate_python(payload)
except ValidationError as exc:
raise ValueError(exc.errors()[0]["msg"]) from exc

View File

@ -5,3 +5,6 @@ class AuthType(StrEnum):
FIRECRAWL = "firecrawl"
WATERCRAWL = "watercrawl"
JINA = "jinareader"
AuthProvider = AuthType | str

View File

@ -7,7 +7,7 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")

View File

@ -7,7 +7,7 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")

View File

@ -7,7 +7,7 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")

View File

@ -8,7 +8,7 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials)
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":
raise ValueError("Invalid auth type, WaterCrawl auth type must be x-api-key")

View File

@ -81,7 +81,10 @@ class TestApiKeyAuthService:
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
# Verify factory class calls
mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
mock_factory.assert_called_once_with(
self.provider,
{"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}},
)
mock_auth_instance.validate_credentials.assert_called_once()
# Verify encryption calls
@ -129,9 +132,8 @@ class TestApiKeyAuthService:
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
# Verify original key is replaced with encrypted key
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
assert args_copy["credentials"]["config"]["api_key"] != original_key
# Verify the service does not mutate the caller's payload while still encrypting persisted credentials
assert args_copy["credentials"]["config"]["api_key"] == original_key
# Verify encryption function is called correctly
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
@ -230,7 +232,7 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
del args["category"]
with pytest.raises(ValueError, match="category is required"):
with pytest.raises(ValueError, match="Field required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_category(self):
@ -238,7 +240,7 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
args["category"] = ""
with pytest.raises(ValueError, match="category is required"):
with pytest.raises(ValueError, match="at least 1 character"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_missing_provider(self):
@ -246,7 +248,7 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
del args["provider"]
with pytest.raises(ValueError, match="provider is required"):
with pytest.raises(ValueError, match="Field required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_provider(self):
@ -254,7 +256,7 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
args["provider"] = ""
with pytest.raises(ValueError, match="provider is required"):
with pytest.raises(ValueError, match="at least 1 character"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_missing_credentials(self):
@ -262,7 +264,7 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
del args["credentials"]
with pytest.raises(ValueError, match="credentials is required"):
with pytest.raises(ValueError, match="Field required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_credentials(self):
@ -270,7 +272,7 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
args["credentials"] = None
with pytest.raises(ValueError, match="credentials is required"):
with pytest.raises(ValueError, match="valid dictionary"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_invalid_credentials_type(self):
@ -278,7 +280,7 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
args["credentials"] = "not_a_dict"
with pytest.raises(ValueError, match="credentials must be a dictionary"):
with pytest.raises(ValueError, match="valid dictionary"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_missing_auth_type(self):
@ -286,7 +288,7 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
del args["credentials"]["auth_type"]
with pytest.raises(ValueError, match="auth_type is required"):
with pytest.raises(ValueError, match="Field required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_auth_type(self):
@ -294,7 +296,7 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
args["credentials"]["auth_type"] = ""
with pytest.raises(ValueError, match="auth_type is required"):
with pytest.raises(ValueError, match="at least 1 character"):
ApiKeyAuthService.validate_api_key_auth_args(args)
@pytest.mark.parametrize(
@ -374,7 +376,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_none_input(self):
"""Test API key auth args validation - None input"""
with pytest.raises(TypeError):
with pytest.raises(ValueError, match="valid dictionary"):
ApiKeyAuthService.validate_api_key_auth_args(None)
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
@ -382,6 +384,5 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
args["credentials"]["auth_type"] = ["api_key"]
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
# So this should not raise exception, this test should pass
ApiKeyAuthService.validate_api_key_auth_args(args)
with pytest.raises(ValueError, match="valid string"):
ApiKeyAuthService.validate_api_key_auth_args(args)

View File

@ -121,7 +121,7 @@ import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment
from services.vector_service import VectorService
@ -1300,6 +1300,68 @@ class TestVector:
assert mock_vector_processor.create.call_count == 3
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
def test_vector_create_normalizes_child_documents(self, mock_get_embeddings, mock_init_vector):
dataset = VectorServiceTestDataFactory.create_dataset_mock()
documents = [
ChildDocument(page_content="Child content", metadata={"doc_id": "child-1", "dataset_id": "dataset-123"})
]
mock_embeddings = Mock()
mock_embeddings.embed_documents = Mock(return_value=[[0.1] * 1536])
mock_get_embeddings.return_value = mock_embeddings
mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock()
mock_init_vector.return_value = mock_vector_processor
vector = Vector(dataset=dataset)
vector.create(texts=documents)
create_call = mock_vector_processor.create.call_args.kwargs
normalized_document = create_call["texts"][0]
assert isinstance(normalized_document, Document)
assert normalized_document.page_content == "Child content"
assert normalized_document.metadata["doc_id"] == "child-1"
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
@patch("core.rag.datasource.vdb.vector_factory.storage")
@patch("core.rag.datasource.vdb.vector_factory.db.session")
def test_vector_create_multimodal_normalizes_attachment_documents(
self, mock_session, mock_storage, mock_get_embeddings, mock_init_vector
):
dataset = VectorServiceTestDataFactory.create_dataset_mock()
file_document = AttachmentDocument(
page_content="Attachment content",
provider="custom-provider",
metadata={"doc_id": "file-1", "doc_type": "image/png"},
)
upload_file = Mock(id="file-1", key="upload-key")
mock_scalars = Mock()
mock_scalars.all.return_value = [upload_file]
mock_session.scalars.return_value = mock_scalars
mock_storage.load_once.return_value = b"binary-content"
mock_embeddings = Mock()
mock_embeddings.embed_multimodal_documents = Mock(return_value=[[0.2] * 1536])
mock_get_embeddings.return_value = mock_embeddings
mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock()
mock_init_vector.return_value = mock_vector_processor
vector = Vector(dataset=dataset)
vector.create_multimodal(file_documents=[file_document])
create_call = mock_vector_processor.create.call_args.kwargs
normalized_document = create_call["texts"][0]
assert isinstance(normalized_document, Document)
assert normalized_document.provider == "custom-provider"
assert normalized_document.metadata["doc_id"] == "file-1"
# ========================================================================
# Tests for Vector.add_texts
# ========================================================================