mirror of
https://github.com/langgenius/dify.git
synced 2026-04-20 10:47:21 +08:00
refactor(api): tighten shared adapter typing contracts
This commit is contained in:
@ -139,7 +139,7 @@ class HologresVector(BaseVector):
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
"""Get document IDs by metadata field key and value."""
|
||||
result = self._client.execute(
|
||||
psql.SQL("SELECT id FROM {} WHERE meta->>{} = {}").format(
|
||||
@ -149,7 +149,7 @@ class HologresVector(BaseVector):
|
||||
)
|
||||
if result:
|
||||
return [row[0] for row in result]
|
||||
return None
|
||||
return []
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
"""Delete documents by their doc_id list."""
|
||||
|
||||
@ -130,8 +130,11 @@ class LindormVectorStore(BaseVector):
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
action_values[ROUTING_FIELD] = self._routing
|
||||
routing = self._routing
|
||||
if routing is None:
|
||||
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
|
||||
action_header["index"]["routing"] = routing
|
||||
action_values[ROUTING_FIELD] = routing
|
||||
|
||||
actions.append(action_header)
|
||||
actions.append(action_values)
|
||||
|
||||
@ -7,7 +7,9 @@ from core.rag.models.document import Document
|
||||
|
||||
|
||||
class BaseVector(ABC):
|
||||
def __init__(self, collection_name: str):
|
||||
_collection_name: str
|
||||
|
||||
def __init__(self, collection_name: str) -> None:
|
||||
self._collection_name = collection_name
|
||||
|
||||
@abstractmethod
|
||||
@ -30,7 +32,7 @@ class BaseVector(ABC):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -63,5 +65,5 @@ class BaseVector(ABC):
|
||||
return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
|
||||
|
||||
@property
|
||||
def collection_name(self):
|
||||
def collection_name(self) -> str:
|
||||
return self._collection_name
|
||||
|
||||
@ -2,7 +2,8 @@ import base64
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -13,7 +14,7 @@ from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -24,19 +25,40 @@ from models.model import UploadFile
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorStoreIndexConfig(TypedDict):
|
||||
class_prefix: str
|
||||
|
||||
|
||||
class VectorIndexStructDict(TypedDict):
|
||||
type: VectorType
|
||||
vector_store: VectorStoreIndexConfig
|
||||
|
||||
|
||||
class MultimodalEmbeddingPayload(TypedDict):
|
||||
content: str
|
||||
content_type: str
|
||||
file_id: str
|
||||
|
||||
|
||||
VectorDocumentInput = Document | ChildDocument | AttachmentDocument
|
||||
|
||||
|
||||
class AbstractVectorFactory(ABC):
|
||||
@abstractmethod
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
def init_vector(self, dataset: Dataset, attributes: list[str], embeddings: Embeddings) -> BaseVector:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
|
||||
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
|
||||
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
|
||||
index_struct_dict: VectorIndexStructDict = {
|
||||
"type": vector_type,
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
return index_struct_dict
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
def __init__(self, dataset: Dataset, attributes: list[str] | None = None) -> None:
|
||||
if attributes is None:
|
||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
|
||||
self._dataset = dataset
|
||||
@ -198,7 +220,7 @@ class Vector:
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
def create(self, texts: list | None = None, **kwargs):
|
||||
def create(self, texts: Sequence[Document | ChildDocument] | None = None, **kwargs: Any) -> None:
|
||||
if texts:
|
||||
start = time.time()
|
||||
logger.info("start embedding %s texts %s", len(texts), start)
|
||||
@ -212,10 +234,12 @@ class Vector:
|
||||
logger.info(
|
||||
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
|
||||
)
|
||||
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
||||
self._vector_processor.create(
|
||||
texts=self._normalize_documents(batch), embeddings=batch_embeddings, **kwargs
|
||||
)
|
||||
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
|
||||
|
||||
def create_multimodal(self, file_documents: list | None = None, **kwargs):
|
||||
def create_multimodal(self, file_documents: list[AttachmentDocument] | None = None, **kwargs: Any) -> None:
|
||||
if file_documents:
|
||||
start = time.time()
|
||||
logger.info("start embedding %s files %s", len(file_documents), start)
|
||||
@ -227,14 +251,16 @@ class Vector:
|
||||
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
|
||||
|
||||
# Batch query all upload files to avoid N+1 queries
|
||||
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
|
||||
attachment_ids = [doc.metadata["doc_id"] for doc in batch if doc.metadata is not None]
|
||||
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
upload_files = db.session.scalars(stmt).all()
|
||||
upload_file_map = {str(f.id): f for f in upload_files}
|
||||
|
||||
file_base64_list = []
|
||||
real_batch = []
|
||||
file_base64_list: list[dict[str, str]] = []
|
||||
real_batch: list[AttachmentDocument] = []
|
||||
for document in batch:
|
||||
if document.metadata is None:
|
||||
continue
|
||||
attachment_id = document.metadata["doc_id"]
|
||||
doc_type = document.metadata["doc_type"]
|
||||
upload_file = upload_file_map.get(attachment_id)
|
||||
@ -249,14 +275,20 @@ class Vector:
|
||||
}
|
||||
)
|
||||
real_batch.append(document)
|
||||
if not real_batch:
|
||||
continue
|
||||
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
|
||||
logger.info(
|
||||
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
|
||||
)
|
||||
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
|
||||
self._vector_processor.create(
|
||||
texts=self._normalize_documents(real_batch),
|
||||
embeddings=batch_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
def add_texts(self, documents: list[Document], **kwargs: Any) -> None:
|
||||
if kwargs.get("duplicate_check", False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
|
||||
@ -266,10 +298,10 @@ class Vector:
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self._vector_processor.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._vector_processor.delete_by_ids(ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._vector_processor.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@ -295,7 +327,7 @@ class Vector:
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._vector_processor.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self):
|
||||
def delete(self) -> None:
|
||||
self._vector_processor.delete()
|
||||
# delete collection redis cache
|
||||
if self._vector_processor.collection_name:
|
||||
@ -325,7 +357,26 @@ class Vector:
|
||||
|
||||
return texts
|
||||
|
||||
def __getattr__(self, name):
|
||||
@staticmethod
|
||||
def _normalize_documents(documents: Sequence[VectorDocumentInput]) -> list[Document]:
|
||||
normalized_documents: list[Document] = []
|
||||
for document in documents:
|
||||
if isinstance(document, Document):
|
||||
normalized_documents.append(document)
|
||||
continue
|
||||
|
||||
normalized_documents.append(
|
||||
Document(
|
||||
page_content=document.page_content,
|
||||
vector=document.vector,
|
||||
metadata=document.metadata,
|
||||
provider=document.provider if isinstance(document, AttachmentDocument) else "dify",
|
||||
)
|
||||
)
|
||||
|
||||
return normalized_documents
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if self._vector_processor is not None:
|
||||
method = getattr(self._vector_processor, name)
|
||||
if callable(method):
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from models.dataset import Document
|
||||
from models.model import UploadFile
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
|
||||
class NotionInfo(BaseModel):
|
||||
@ -12,7 +16,7 @@ class NotionInfo(BaseModel):
|
||||
credential_id: str | None = None
|
||||
notion_workspace_id: str | None = ""
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
notion_page_type: Literal["database", "page"]
|
||||
document: Document | None = None
|
||||
tenant_id: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@ -25,10 +29,10 @@ class WebsiteInfo(BaseModel):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
provider: str
|
||||
provider: AuthType
|
||||
job_id: str
|
||||
url: str
|
||||
mode: str
|
||||
mode: Literal["crawl", "crawl_return_urls", "scrape"]
|
||||
tenant_id: str
|
||||
only_main_content: bool = False
|
||||
|
||||
@ -38,7 +42,7 @@ class ExtractSetting(BaseModel):
|
||||
Model class for provider response.
|
||||
"""
|
||||
|
||||
datasource_type: str
|
||||
datasource_type: DatasourceType
|
||||
upload_file: UploadFile | None = None
|
||||
notion_info: NotionInfo | None = None
|
||||
website_info: WebsiteInfo | None = None
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import TypeAlias
|
||||
from urllib.parse import unquote
|
||||
|
||||
from configs import dify_config
|
||||
@ -31,19 +32,27 @@ from core.rag.extractor.word_extractor import WordExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"]
|
||||
USER_AGENT = (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124"
|
||||
" Safari/537.36"
|
||||
)
|
||||
ExtractProcessorOutput: TypeAlias = list[Document] | str
|
||||
|
||||
|
||||
class ExtractProcessor:
|
||||
@staticmethod
|
||||
def _build_temp_file_path(temp_dir: str, suffix: str) -> str:
|
||||
file_descriptor, file_path = tempfile.mkstemp(dir=temp_dir, suffix=suffix)
|
||||
os.close(file_descriptor)
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
def load_from_upload_file(
|
||||
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
|
||||
) -> Union[list[Document], str]:
|
||||
) -> ExtractProcessorOutput:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=upload_file, document_model="text_model"
|
||||
)
|
||||
@ -54,7 +63,7 @@ class ExtractProcessor:
|
||||
return cls.extract(extract_setting, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> ExtractProcessorOutput:
|
||||
response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
@ -65,17 +74,16 @@ class ExtractProcessor:
|
||||
suffix = "." + response.headers.get("Content-Type").split("/")[-1]
|
||||
else:
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
match = re.search(r"\.(\w+)$", filename)
|
||||
if match:
|
||||
suffix = "." + match.group(1)
|
||||
else:
|
||||
suffix = ""
|
||||
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
|
||||
# Generate a temporary filename under the created temp_dir and ensure the directory exists
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
if content_disposition:
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
match = re.search(r"\.(\w+)$", filename)
|
||||
if match:
|
||||
suffix = "." + match.group(1)
|
||||
else:
|
||||
suffix = ""
|
||||
file_path = cls._build_temp_file_path(temp_dir, suffix)
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model")
|
||||
if return_text:
|
||||
@ -94,13 +102,12 @@ class ExtractProcessor:
|
||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE:
|
||||
upload_file = extract_setting.upload_file
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if not file_path:
|
||||
assert extract_setting.upload_file is not None, "upload_file is required"
|
||||
upload_file: UploadFile = extract_setting.upload_file
|
||||
assert upload_file is not None, "upload_file is required"
|
||||
suffix = Path(upload_file.key).suffix
|
||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
file_path = cls._build_temp_file_path(temp_dir, suffix)
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
@ -113,6 +120,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
assert upload_file is not None, "upload_file is required for PDF extraction"
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = (
|
||||
@ -123,6 +131,7 @@ class ExtractProcessor:
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension == ".docx":
|
||||
assert upload_file is not None, "upload_file is required for DOCX extraction"
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == ".doc":
|
||||
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
@ -149,12 +158,14 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
assert upload_file is not None, "upload_file is required for PDF extraction"
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension == ".docx":
|
||||
assert upload_file is not None, "upload_file is required for DOCX extraction"
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
@ -177,7 +188,7 @@ class ExtractProcessor:
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE:
|
||||
assert extract_setting.website_info is not None, "website_info is required"
|
||||
if extract_setting.website_info.provider == "firecrawl":
|
||||
if extract_setting.website_info.provider == AuthType.FIRECRAWL:
|
||||
extractor = FirecrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
@ -186,7 +197,7 @@ class ExtractProcessor:
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "watercrawl":
|
||||
elif extract_setting.website_info.provider == AuthType.WATERCRAWL:
|
||||
extractor = WaterCrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
@ -195,7 +206,7 @@ class ExtractProcessor:
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "jinareader":
|
||||
elif extract_setting.website_info.provider == AuthType.JINA:
|
||||
extractor = JinaReaderWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
|
||||
@ -2,10 +2,12 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class BaseExtractor(ABC):
|
||||
"""Interface for extract files."""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self):
|
||||
def extract(self) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -30,7 +30,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
|
||||
For large files, reading only a sample is sufficient and prevents timeout.
|
||||
"""
|
||||
|
||||
def read_and_detect(filename: str):
|
||||
def read_and_detect(filename: str) -> list[FileEncoding]:
|
||||
rst = charset_normalizer.from_path(filename)
|
||||
best = rst.best()
|
||||
if best is None:
|
||||
|
||||
@ -3,7 +3,7 @@ import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
@ -19,6 +19,16 @@ from .processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndexAndCleanResult(TypedDict):
|
||||
dataset_id: str
|
||||
dataset_name: str
|
||||
batch: str
|
||||
document_id: str
|
||||
document_name: str
|
||||
created_at: float
|
||||
display_status: Literal["completed"]
|
||||
|
||||
|
||||
class IndexProcessor:
|
||||
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
@ -50,9 +60,9 @@ class IndexProcessor:
|
||||
document_id: str,
|
||||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
):
|
||||
batch: str,
|
||||
summary_index_setting: dict[str, object] | None = None,
|
||||
) -> IndexAndCleanResult:
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if not document:
|
||||
@ -131,7 +141,12 @@ class IndexProcessor:
|
||||
}
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
self,
|
||||
chunks: Any,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
chunk_structure: str,
|
||||
summary_index_setting: dict[str, object] | None,
|
||||
) -> Preview:
|
||||
doc_language = None
|
||||
with session_factory.create_session() as session:
|
||||
|
||||
@ -138,18 +138,20 @@ class BaseIndexProcessor(ABC):
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
)
|
||||
|
||||
return character_splitter # type: ignore
|
||||
return character_splitter
|
||||
|
||||
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
|
||||
"""
|
||||
Get the content files from the document.
|
||||
"""
|
||||
multi_model_documents: list[AttachmentDocument] = []
|
||||
if document.metadata is None:
|
||||
return multi_model_documents
|
||||
text = document.page_content
|
||||
images = self._extract_markdown_images(text)
|
||||
if not images:
|
||||
return multi_model_documents
|
||||
upload_file_id_list = []
|
||||
upload_file_id_list: list[str] = []
|
||||
|
||||
for image in images:
|
||||
# Collect all upload_file_ids including duplicates to preserve occurrence count
|
||||
|
||||
@ -10,7 +10,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess
|
||||
class IndexProcessorFactory:
|
||||
"""IndexProcessorInit."""
|
||||
|
||||
def __init__(self, index_type: str | None):
|
||||
def __init__(self, index_type: str | None) -> None:
|
||||
self._index_type = index_type
|
||||
|
||||
def init_index_processor(self) -> BaseIndexProcessor:
|
||||
@ -19,11 +19,12 @@ class IndexProcessorFactory:
|
||||
if not self._index_type:
|
||||
raise ValueError("Index type must be specified.")
|
||||
|
||||
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
|
||||
return ParagraphIndexProcessor()
|
||||
elif self._index_type == IndexStructureType.QA_INDEX:
|
||||
return QAIndexProcessor()
|
||||
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
return ParentChildIndexProcessor()
|
||||
else:
|
||||
raise ValueError(f"Index type {self._index_type} is not supported.")
|
||||
match self._index_type:
|
||||
case IndexStructureType.PARAGRAPH_INDEX:
|
||||
return ParagraphIndexProcessor()
|
||||
case IndexStructureType.QA_INDEX:
|
||||
return QAIndexProcessor()
|
||||
case IndexStructureType.PARENT_CHILD_INDEX:
|
||||
return ParentChildIndexProcessor()
|
||||
case _:
|
||||
raise ValueError(f"Index type {self._index_type} is not supported.")
|
||||
|
||||
@ -30,7 +30,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> TS:
|
||||
def _token_encoder(texts: list[str]) -> list[int]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
@ -8,7 +8,7 @@ class BaseStorage(ABC):
|
||||
"""Interface for file storage."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, filename: str, data: bytes):
|
||||
def save(self, filename: str, data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -16,7 +16,7 @@ class BaseStorage(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
def load_stream(self, filename: str) -> Generator[bytes, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -28,10 +28,10 @@ class BaseStorage(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, filename: str):
|
||||
def delete(self, filename: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def scan(self, path, files=True, directories=False) -> list[str]:
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
"""
|
||||
Scan files and directories in the given path.
|
||||
This method is implemented only in some storage backends.
|
||||
|
||||
@ -43,58 +43,6 @@ core/ops/tencent_trace/utils.py
|
||||
core/plugin/backwards_invocation/base.py
|
||||
core/plugin/backwards_invocation/model.py
|
||||
core/prompt/utils/extract_thread_messages.py
|
||||
core/rag/datasource/keyword/jieba/jieba.py
|
||||
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
|
||||
core/rag/datasource/vdb/baidu/baidu_vector.py
|
||||
core/rag/datasource/vdb/chroma/chroma_vector.py
|
||||
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
|
||||
core/rag/datasource/vdb/couchbase/couchbase_vector.py
|
||||
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
|
||||
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
|
||||
core/rag/datasource/vdb/lindorm/lindorm_vector.py
|
||||
core/rag/datasource/vdb/matrixone/matrixone_vector.py
|
||||
core/rag/datasource/vdb/milvus/milvus_vector.py
|
||||
core/rag/datasource/vdb/myscale/myscale_vector.py
|
||||
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
|
||||
core/rag/datasource/vdb/opensearch/opensearch_vector.py
|
||||
core/rag/datasource/vdb/oracle/oraclevector.py
|
||||
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
|
||||
core/rag/datasource/vdb/relyt/relyt_vector.py
|
||||
core/rag/datasource/vdb/tablestore/tablestore_vector.py
|
||||
core/rag/datasource/vdb/tencent/tencent_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
|
||||
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
|
||||
core/rag/datasource/vdb/upstash/upstash_vector.py
|
||||
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
|
||||
core/rag/datasource/vdb/weaviate/weaviate_vector.py
|
||||
core/rag/extractor/csv_extractor.py
|
||||
core/rag/extractor/excel_extractor.py
|
||||
core/rag/extractor/firecrawl/firecrawl_app.py
|
||||
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
|
||||
core/rag/extractor/html_extractor.py
|
||||
core/rag/extractor/jina_reader_extractor.py
|
||||
core/rag/extractor/markdown_extractor.py
|
||||
core/rag/extractor/notion_extractor.py
|
||||
core/rag/extractor/pdf_extractor.py
|
||||
core/rag/extractor/text_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_doc_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_eml_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_epub_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_msg_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_xml_extractor.py
|
||||
core/rag/extractor/watercrawl/client.py
|
||||
core/rag/extractor/watercrawl/extractor.py
|
||||
core/rag/extractor/watercrawl/provider.py
|
||||
core/rag/extractor/word_extractor.py
|
||||
core/rag/index_processor/processor/paragraph_index_processor.py
|
||||
core/rag/index_processor/processor/parent_child_index_processor.py
|
||||
core/rag/index_processor/processor/qa_index_processor.py
|
||||
core/rag/retrieval/router/multi_dataset_function_call_router.py
|
||||
core/rag/summary_index/summary_index.py
|
||||
core/repositories/sqlalchemy_workflow_execution_repository.py
|
||||
@ -140,27 +88,10 @@ dify_graph/nodes/variable_assigner/v2/node.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
extensions/storage/aliyun_oss_storage.py
|
||||
extensions/storage/aws_s3_storage.py
|
||||
extensions/storage/azure_blob_storage.py
|
||||
extensions/storage/baidu_obs_storage.py
|
||||
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
|
||||
extensions/storage/clickzetta_volume/file_lifecycle.py
|
||||
extensions/storage/google_cloud_storage.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
extensions/storage/opendal_storage.py
|
||||
extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
services/auth/firecrawl/firecrawl.py
|
||||
services/auth/jina.py
|
||||
services/auth/jina/jina.py
|
||||
services/auth/watercrawl/watercrawl.py
|
||||
services/conversation_service.py
|
||||
services/dataset_service.py
|
||||
services/document_indexing_proxy/document_indexing_task_proxy.py
|
||||
@ -183,3 +114,84 @@ tasks/regenerate_summary_index_task.py
|
||||
tasks/trigger_processing_tasks.py
|
||||
tasks/workflow_cfs_scheduler/cfs_scheduler.py
|
||||
tasks/workflow_execution_tasks.py
|
||||
|
||||
# not need to be fixed by now: storage adapters
|
||||
extensions/storage/aliyun_oss_storage.py
|
||||
extensions/storage/aws_s3_storage.py
|
||||
extensions/storage/azure_blob_storage.py
|
||||
extensions/storage/baidu_obs_storage.py
|
||||
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
|
||||
extensions/storage/clickzetta_volume/file_lifecycle.py
|
||||
extensions/storage/google_cloud_storage.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
extensions/storage/opendal_storage.py
|
||||
extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
|
||||
# not need to be fixed by now: auth adapters
|
||||
services/auth/firecrawl/firecrawl.py
|
||||
services/auth/jina.py
|
||||
services/auth/jina/jina.py
|
||||
services/auth/watercrawl/watercrawl.py
|
||||
|
||||
# not need to be fixed by now: keyword adapters
|
||||
core/rag/datasource/keyword/jieba/jieba.py
|
||||
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
|
||||
|
||||
# not need to be fixed by now: vector db adapters
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
|
||||
core/rag/datasource/vdb/baidu/baidu_vector.py
|
||||
core/rag/datasource/vdb/chroma/chroma_vector.py
|
||||
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
|
||||
core/rag/datasource/vdb/couchbase/couchbase_vector.py
|
||||
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
|
||||
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
|
||||
core/rag/datasource/vdb/lindorm/lindorm_vector.py
|
||||
core/rag/datasource/vdb/matrixone/matrixone_vector.py
|
||||
core/rag/datasource/vdb/milvus/milvus_vector.py
|
||||
core/rag/datasource/vdb/myscale/myscale_vector.py
|
||||
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
|
||||
core/rag/datasource/vdb/opensearch/opensearch_vector.py
|
||||
core/rag/datasource/vdb/oracle/oraclevector.py
|
||||
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
|
||||
core/rag/datasource/vdb/relyt/relyt_vector.py
|
||||
core/rag/datasource/vdb/tablestore/tablestore_vector.py
|
||||
core/rag/datasource/vdb/tencent/tencent_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
|
||||
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
|
||||
core/rag/datasource/vdb/upstash/upstash_vector.py
|
||||
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
|
||||
core/rag/datasource/vdb/weaviate/weaviate_vector.py
|
||||
|
||||
# not need to be fixed by now: extractors
|
||||
core/rag/extractor/csv_extractor.py
|
||||
core/rag/extractor/excel_extractor.py
|
||||
core/rag/extractor/firecrawl/firecrawl_app.py
|
||||
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
|
||||
core/rag/extractor/html_extractor.py
|
||||
core/rag/extractor/jina_reader_extractor.py
|
||||
core/rag/extractor/markdown_extractor.py
|
||||
core/rag/extractor/notion_extractor.py
|
||||
core/rag/extractor/pdf_extractor.py
|
||||
core/rag/extractor/text_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_doc_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_eml_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_epub_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_msg_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_xml_extractor.py
|
||||
core/rag/extractor/watercrawl/client.py
|
||||
core/rag/extractor/watercrawl/extractor.py
|
||||
core/rag/extractor/watercrawl/provider.py
|
||||
core/rag/extractor/word_extractor.py
|
||||
|
||||
# not need to be fixed by now: index processors
|
||||
core/rag/index_processor/processor/paragraph_index_processor.py
|
||||
core/rag/index_processor/processor/parent_child_index_processor.py
|
||||
core/rag/index_processor/processor/qa_index_processor.py
|
||||
|
||||
@ -1,10 +1,27 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Annotated, NotRequired, TypedDict
|
||||
|
||||
from pydantic import StringConstraints
|
||||
|
||||
NonEmptyString = Annotated[str, StringConstraints(min_length=1)]
|
||||
|
||||
|
||||
class ApiKeyAuthConfig(TypedDict):
|
||||
api_key: NotRequired[str]
|
||||
base_url: NotRequired[str]
|
||||
|
||||
|
||||
class ApiKeyAuthCredentials(TypedDict):
|
||||
auth_type: NonEmptyString
|
||||
config: ApiKeyAuthConfig
|
||||
|
||||
|
||||
class ApiKeyAuthBase(ABC):
|
||||
def __init__(self, credentials: dict):
|
||||
credentials: ApiKeyAuthCredentials
|
||||
|
||||
def __init__(self, credentials: ApiKeyAuthCredentials) -> None:
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self):
|
||||
def validate_credentials(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.auth_type import AuthType
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
|
||||
from services.auth.auth_type import AuthProvider, AuthType
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
auth: ApiKeyAuthBase
|
||||
|
||||
def __init__(self, provider: AuthProvider, credentials: ApiKeyAuthCredentials) -> None:
|
||||
auth_factory = self.get_apikey_auth_factory(provider)
|
||||
self.auth = auth_factory(credentials)
|
||||
|
||||
def validate_credentials(self):
|
||||
def validate_credentials(self) -> bool:
|
||||
return self.auth.validate_credentials()
|
||||
|
||||
@staticmethod
|
||||
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
|
||||
def get_apikey_auth_factory(provider: AuthProvider) -> type[ApiKeyAuthBase]:
|
||||
match provider:
|
||||
case AuthType.FIRECRAWL:
|
||||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
@ -1,40 +1,75 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Annotated, TypedDict, TypeVar
|
||||
|
||||
from pydantic import StringConstraints, TypeAdapter, ValidationError
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthConfig, ApiKeyAuthCredentials
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
from services.auth.auth_type import AuthProvider
|
||||
|
||||
NonEmptyString = Annotated[str, StringConstraints(min_length=1)]
|
||||
ValidatedPayload = TypeVar("ValidatedPayload")
|
||||
|
||||
|
||||
class ApiKeyAuthCreateArgs(TypedDict):
|
||||
category: NonEmptyString
|
||||
provider: NonEmptyString
|
||||
credentials: ApiKeyAuthCredentials
|
||||
|
||||
|
||||
AUTH_CREDENTIALS_ADAPTER = TypeAdapter(ApiKeyAuthCredentials)
|
||||
AUTH_CREATE_ARGS_ADAPTER = TypeAdapter(ApiKeyAuthCreateArgs)
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str):
|
||||
def get_provider_auth_list(tenant_id: str) -> list[DataSourceApiKeyAuthBinding]:
|
||||
data_source_api_key_bindings = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
)
|
||||
).all()
|
||||
return data_source_api_key_bindings
|
||||
return list(data_source_api_key_bindings)
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
||||
@classmethod
|
||||
def create_provider_auth(cls, tenant_id: str, args: Mapping[str, object] | ApiKeyAuthCreateArgs) -> None:
|
||||
validated_args = cls.validate_api_key_auth_args(args)
|
||||
auth_result = ApiKeyAuthFactory(
|
||||
validated_args["provider"], validated_args["credentials"]
|
||||
).validate_credentials()
|
||||
if auth_result:
|
||||
stored_config: ApiKeyAuthConfig = {}
|
||||
if "api_key" in validated_args["credentials"]["config"]:
|
||||
stored_config["api_key"] = validated_args["credentials"]["config"]["api_key"]
|
||||
if "base_url" in validated_args["credentials"]["config"]:
|
||||
stored_config["base_url"] = validated_args["credentials"]["config"]["base_url"]
|
||||
stored_credentials: ApiKeyAuthCredentials = {
|
||||
"auth_type": validated_args["credentials"]["auth_type"],
|
||||
"config": stored_config,
|
||||
}
|
||||
# Encrypt the api key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
|
||||
args["credentials"]["config"]["api_key"] = api_key
|
||||
api_key_value = stored_credentials["config"].get("api_key")
|
||||
if api_key_value is None:
|
||||
raise ValueError("credentials config api_key is required")
|
||||
api_key = encrypter.encrypt_token(tenant_id, api_key_value)
|
||||
stored_credentials["config"]["api_key"] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
|
||||
tenant_id=tenant_id,
|
||||
category=validated_args["category"],
|
||||
provider=validated_args["provider"],
|
||||
)
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
data_source_api_key_binding.credentials = json.dumps(stored_credentials, ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: AuthProvider) -> ApiKeyAuthCredentials | None:
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(
|
||||
@ -49,11 +84,11 @@ class ApiKeyAuthService:
|
||||
return None
|
||||
if not data_source_api_key_bindings.credentials:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
raw_credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return ApiKeyAuthService._validate_credentials_payload(raw_credentials)
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str) -> None:
|
||||
data_source_api_key_binding = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
|
||||
@ -64,14 +99,16 @@ class ApiKeyAuthService:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
if "category" not in args or not args["category"]:
|
||||
raise ValueError("category is required")
|
||||
if "provider" not in args or not args["provider"]:
|
||||
raise ValueError("provider is required")
|
||||
if "credentials" not in args or not args["credentials"]:
|
||||
raise ValueError("credentials is required")
|
||||
if not isinstance(args["credentials"], dict):
|
||||
raise ValueError("credentials must be a dictionary")
|
||||
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
|
||||
raise ValueError("auth_type is required")
|
||||
def validate_api_key_auth_args(cls, args: Mapping[str, object] | None) -> ApiKeyAuthCreateArgs:
|
||||
return cls._validate_payload(AUTH_CREATE_ARGS_ADAPTER, args)
|
||||
|
||||
@staticmethod
|
||||
def _validate_credentials_payload(raw_credentials: object) -> ApiKeyAuthCredentials:
|
||||
return ApiKeyAuthService._validate_payload(AUTH_CREDENTIALS_ADAPTER, raw_credentials)
|
||||
|
||||
@staticmethod
|
||||
def _validate_payload(adapter: TypeAdapter[ValidatedPayload], payload: object) -> ValidatedPayload:
|
||||
try:
|
||||
return adapter.validate_python(payload)
|
||||
except ValidationError as exc:
|
||||
raise ValueError(exc.errors()[0]["msg"]) from exc
|
||||
|
||||
@ -5,3 +5,6 @@ class AuthType(StrEnum):
|
||||
FIRECRAWL = "firecrawl"
|
||||
WATERCRAWL = "watercrawl"
|
||||
JINA = "jinareader"
|
||||
|
||||
|
||||
AuthProvider = AuthType | str
|
||||
|
||||
@ -7,7 +7,7 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
|
||||
|
||||
@ -7,7 +7,7 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
|
||||
|
||||
@ -7,7 +7,7 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
|
||||
|
||||
@ -8,7 +8,7 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
class WatercrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "x-api-key":
|
||||
raise ValueError("Invalid auth type, WaterCrawl auth type must be x-api-key")
|
||||
|
||||
@ -81,7 +81,10 @@ class TestApiKeyAuthService:
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
# Verify factory class calls
|
||||
mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
|
||||
mock_factory.assert_called_once_with(
|
||||
self.provider,
|
||||
{"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}},
|
||||
)
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
|
||||
# Verify encryption calls
|
||||
@ -129,9 +132,8 @@ class TestApiKeyAuthService:
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
|
||||
|
||||
# Verify original key is replaced with encrypted key
|
||||
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
|
||||
assert args_copy["credentials"]["config"]["api_key"] != original_key
|
||||
# Verify the service does not mutate the caller's payload while still encrypting persisted credentials
|
||||
assert args_copy["credentials"]["config"]["api_key"] == original_key
|
||||
|
||||
# Verify encryption function is called correctly
|
||||
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
|
||||
@ -230,7 +232,7 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
del args["category"]
|
||||
|
||||
with pytest.raises(ValueError, match="category is required"):
|
||||
with pytest.raises(ValueError, match="Field required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_category(self):
|
||||
@ -238,7 +240,7 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
args["category"] = ""
|
||||
|
||||
with pytest.raises(ValueError, match="category is required"):
|
||||
with pytest.raises(ValueError, match="at least 1 character"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_provider(self):
|
||||
@ -246,7 +248,7 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
del args["provider"]
|
||||
|
||||
with pytest.raises(ValueError, match="provider is required"):
|
||||
with pytest.raises(ValueError, match="Field required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_provider(self):
|
||||
@ -254,7 +256,7 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
args["provider"] = ""
|
||||
|
||||
with pytest.raises(ValueError, match="provider is required"):
|
||||
with pytest.raises(ValueError, match="at least 1 character"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_credentials(self):
|
||||
@ -262,7 +264,7 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
del args["credentials"]
|
||||
|
||||
with pytest.raises(ValueError, match="credentials is required"):
|
||||
with pytest.raises(ValueError, match="Field required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_credentials(self):
|
||||
@ -270,7 +272,7 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"] = None
|
||||
|
||||
with pytest.raises(ValueError, match="credentials is required"):
|
||||
with pytest.raises(ValueError, match="valid dictionary"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_invalid_credentials_type(self):
|
||||
@ -278,7 +280,7 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"] = "not_a_dict"
|
||||
|
||||
with pytest.raises(ValueError, match="credentials must be a dictionary"):
|
||||
with pytest.raises(ValueError, match="valid dictionary"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_auth_type(self):
|
||||
@ -286,7 +288,7 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
del args["credentials"]["auth_type"]
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
with pytest.raises(ValueError, match="Field required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_auth_type(self):
|
||||
@ -294,7 +296,7 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = ""
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
with pytest.raises(ValueError, match="at least 1 character"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -374,7 +376,7 @@ class TestApiKeyAuthService:
|
||||
|
||||
def test_validate_api_key_auth_args_none_input(self):
|
||||
"""Test API key auth args validation - None input"""
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError, match="valid dictionary"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(None)
|
||||
|
||||
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
|
||||
@ -382,6 +384,5 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = ["api_key"]
|
||||
|
||||
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
|
||||
# So this should not raise exception, this test should pass
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
with pytest.raises(ValueError, match="valid string"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
@ -121,7 +121,7 @@ import pytest
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from models.dataset import ChildChunk, Dataset, DatasetDocument, DatasetProcessRule, DocumentSegment
|
||||
from services.vector_service import VectorService
|
||||
|
||||
@ -1300,6 +1300,68 @@ class TestVector:
|
||||
|
||||
assert mock_vector_processor.create.call_count == 3
|
||||
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
|
||||
def test_vector_create_normalizes_child_documents(self, mock_get_embeddings, mock_init_vector):
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock()
|
||||
documents = [
|
||||
ChildDocument(page_content="Child content", metadata={"doc_id": "child-1", "dataset_id": "dataset-123"})
|
||||
]
|
||||
|
||||
mock_embeddings = Mock()
|
||||
mock_embeddings.embed_documents = Mock(return_value=[[0.1] * 1536])
|
||||
mock_get_embeddings.return_value = mock_embeddings
|
||||
|
||||
mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock()
|
||||
mock_init_vector.return_value = mock_vector_processor
|
||||
|
||||
vector = Vector(dataset=dataset)
|
||||
|
||||
vector.create(texts=documents)
|
||||
|
||||
create_call = mock_vector_processor.create.call_args.kwargs
|
||||
normalized_document = create_call["texts"][0]
|
||||
assert isinstance(normalized_document, Document)
|
||||
assert normalized_document.page_content == "Child content"
|
||||
assert normalized_document.metadata["doc_id"] == "child-1"
|
||||
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.storage")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.db.session")
|
||||
def test_vector_create_multimodal_normalizes_attachment_documents(
|
||||
self, mock_session, mock_storage, mock_get_embeddings, mock_init_vector
|
||||
):
|
||||
dataset = VectorServiceTestDataFactory.create_dataset_mock()
|
||||
file_document = AttachmentDocument(
|
||||
page_content="Attachment content",
|
||||
provider="custom-provider",
|
||||
metadata={"doc_id": "file-1", "doc_type": "image/png"},
|
||||
)
|
||||
upload_file = Mock(id="file-1", key="upload-key")
|
||||
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.all.return_value = [upload_file]
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
mock_storage.load_once.return_value = b"binary-content"
|
||||
|
||||
mock_embeddings = Mock()
|
||||
mock_embeddings.embed_multimodal_documents = Mock(return_value=[[0.2] * 1536])
|
||||
mock_get_embeddings.return_value = mock_embeddings
|
||||
|
||||
mock_vector_processor = VectorServiceTestDataFactory.create_vector_processor_mock()
|
||||
mock_init_vector.return_value = mock_vector_processor
|
||||
|
||||
vector = Vector(dataset=dataset)
|
||||
|
||||
vector.create_multimodal(file_documents=[file_document])
|
||||
|
||||
create_call = mock_vector_processor.create.call_args.kwargs
|
||||
normalized_document = create_call["texts"][0]
|
||||
assert isinstance(normalized_document, Document)
|
||||
assert normalized_document.provider == "custom-provider"
|
||||
assert normalized_document.metadata["doc_id"] == "file-1"
|
||||
|
||||
# ========================================================================
|
||||
# Tests for Vector.add_texts
|
||||
# ========================================================================
|
||||
|
||||
Reference in New Issue
Block a user