diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 7eebd9ec95..275c1fc110 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,6 +5,7 @@ import re import threading import time import uuid +from collections.abc import Mapping from typing import Any from flask import Flask, current_app @@ -37,7 +38,7 @@ from extensions.ext_storage import storage from libs import helper from libs.datetime_utils import naive_utc_now from models import Account -from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.feature_service import FeatureService @@ -265,7 +266,7 @@ class IndexingRunner: self, tenant_id: str, extract_settings: list[ExtractSetting], - tmp_processing_rule: dict, + tmp_processing_rule: Mapping[str, Any], doc_form: str | None = None, doc_language: str = "English", dataset_id: str | None = None, @@ -376,7 +377,7 @@ class IndexingRunner: return IndexingEstimate(total_segments=total_segments, preview=preview_texts) def _extract( - self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict + self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: Mapping[str, Any] ) -> list[Document]: data_source_info = dataset_document.data_source_info_dict text_docs = [] @@ -543,6 +544,7 @@ class IndexingRunner: """ Clean the document text according to the processing rules. """ + rules: AutomaticRulesConfig | dict[str, Any] if processing_rule.mode == "automatic": rules = DatasetProcessRule.AUTOMATIC_RULES else: @@ -756,7 +758,7 @@ class IndexingRunner: dataset: Dataset, text_docs: list[Document], doc_language: str, - process_rule: dict, + process_rule: Mapping[str, Any], current_user: Account | None = None, ) -> list[Document]: # get embedding model instance diff --git a/api/models/dataset.py b/api/models/dataset.py index b3fa11a58c..8438fda25f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -10,7 +10,7 @@ import re import time from datetime import datetime from json import JSONDecodeError -from typing import Any, cast +from typing import Any, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa @@ -37,6 +37,61 @@ from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adj logger = logging.getLogger(__name__) +class PreProcessingRuleItem(TypedDict): + id: str + enabled: bool + + +class SegmentationConfig(TypedDict): + delimiter: str + max_tokens: int + chunk_overlap: int + + +class AutomaticRulesConfig(TypedDict): + pre_processing_rules: list[PreProcessingRuleItem] + segmentation: SegmentationConfig + + +class ProcessRuleDict(TypedDict): + id: str + dataset_id: str + mode: str + rules: dict[str, Any] | None + + +class DocMetadataDetailItem(TypedDict): + id: str + name: str + type: str + value: Any + + +class AttachmentItem(TypedDict): + id: str + name: str + size: int + extension: str + mime_type: str + source_url: str + + +class DatasetBindingItem(TypedDict): + id: str + name: str + + +class ExternalKnowledgeApiDict(TypedDict): + id: str + tenant_id: str + name: str + description: str + settings: dict[str, Any] | None + dataset_bindings: list[DatasetBindingItem] + created_by: str + created_at: str + + class DatasetPermissionEnum(enum.StrEnum): ONLY_ME = "only_me" ALL_TEAM = "all_team_members" @@ -334,7 +389,7 @@ class DatasetProcessRule(Base): # bug MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] - AUTOMATIC_RULES: dict[str, Any] = { + AUTOMATIC_RULES: AutomaticRulesConfig = { "pre_processing_rules": [ {"id": "remove_extra_spaces", "enabled": True}, {"id": "remove_urls_emails", "enabled": False}, @@ -342,7 +397,7 @@ class DatasetProcessRule(Base): # bug "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, } - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> ProcessRuleDict: return { "id": self.id, "dataset_id": self.dataset_id, @@ -531,7 +586,7 @@ class Document(Base): return self.updated_at @property - def doc_metadata_details(self) -> list[dict[str, Any]] | None: + def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None: if self.doc_metadata: document_metadatas = ( db.session.query(DatasetMetadata) @@ -541,9 +596,9 @@ class Document(Base): ) .all() ) - metadata_list: list[dict[str, Any]] = [] + metadata_list: list[DocMetadataDetailItem] = [] for metadata in document_metadatas: - metadata_dict: dict[str, Any] = { + metadata_dict: DocMetadataDetailItem = { "id": metadata.id, "name": metadata.name, "type": metadata.type, @@ -557,13 +612,13 @@ class Document(Base): return None @property - def process_rule_dict(self) -> dict[str, Any] | None: + def process_rule_dict(self) -> ProcessRuleDict | None: if self.dataset_process_rule_id and self.dataset_process_rule: return self.dataset_process_rule.to_dict() return None - def get_built_in_fields(self) -> list[dict[str, Any]]: - built_in_fields: list[dict[str, Any]] = [] + def get_built_in_fields(self) -> list[DocMetadataDetailItem]: + built_in_fields: list[DocMetadataDetailItem] = [] built_in_fields.append( { "id": "built-in", @@ -877,7 +932,7 @@ class DocumentSegment(Base): return text @property - def attachments(self) -> list[dict[str, Any]]: + def attachments(self) -> list[AttachmentItem]: # Use JOIN to fetch attachments in a single query instead of two separate queries attachments_with_bindings = db.session.execute( select(SegmentAttachmentBinding, UploadFile) @@ -891,7 +946,7 @@ class DocumentSegment(Base): ).all() if not attachments_with_bindings: return [] - attachment_list = [] + attachment_list: list[AttachmentItem] = [] for _, attachment in attachments_with_bindings: upload_file_id = attachment.id nonce = os.urandom(16).hex() @@ -1261,7 +1316,7 @@ class ExternalKnowledgeApis(TypeBase): DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False ) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> ExternalKnowledgeApiDict: return { "id": self.id, "tenant_id": self.tenant_id, @@ -1281,13 +1336,13 @@ class ExternalKnowledgeApis(TypeBase): return None @property - def dataset_bindings(self) -> list[dict[str, Any]]: + def dataset_bindings(self) -> list[DatasetBindingItem]: external_knowledge_bindings = db.session.scalars( select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) ).all() dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() - dataset_bindings: list[dict[str, Any]] = [] + dataset_bindings: list[DatasetBindingItem] = [] for dataset in datasets: dataset_bindings.append({"id": dataset.id, "name": dataset.name}) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 73bb46b797..b66fdd7a20 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -156,7 +156,8 @@ class VectorService: ) # use full doc mode to generate segment's child chunk processing_rule_dict = processing_rule.to_dict() - processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC + if processing_rule_dict["rules"] is not None: + processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC documents = index_processor.transform( [document], embedding_model_instance=embedding_model_instance,