refactor(api): replace dict/Mapping with TypedDict in dataset models (#33550)

This commit is contained in:
statxc
2026-03-17 03:33:29 +02:00
committed by GitHub
parent fa82a0f708
commit f886f11094
3 changed files with 77 additions and 19 deletions

View File

@ -5,6 +5,7 @@ import re
import threading
import time
import uuid
from collections.abc import Mapping
from typing import Any
from flask import Flask, current_app
@ -37,7 +38,7 @@ from extensions.ext_storage import storage
from libs import helper
from libs.datetime_utils import naive_utc_now
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.feature_service import FeatureService
@ -265,7 +266,7 @@ class IndexingRunner:
self,
tenant_id: str,
extract_settings: list[ExtractSetting],
tmp_processing_rule: dict,
tmp_processing_rule: Mapping[str, Any],
doc_form: str | None = None,
doc_language: str = "English",
dataset_id: str | None = None,
@ -376,7 +377,7 @@ class IndexingRunner:
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: Mapping[str, Any]
) -> list[Document]:
data_source_info = dataset_document.data_source_info_dict
text_docs = []
@ -543,6 +544,7 @@ class IndexingRunner:
"""
Clean the document text according to the processing rules.
"""
rules: AutomaticRulesConfig | dict[str, Any]
if processing_rule.mode == "automatic":
rules = DatasetProcessRule.AUTOMATIC_RULES
else:
@ -756,7 +758,7 @@ class IndexingRunner:
dataset: Dataset,
text_docs: list[Document],
doc_language: str,
process_rule: dict,
process_rule: Mapping[str, Any],
current_user: Account | None = None,
) -> list[Document]:
# get embedding model instance

View File

@ -10,7 +10,7 @@ import re
import time
from datetime import datetime
from json import JSONDecodeError
from typing import Any, cast
from typing import Any, TypedDict, cast
from uuid import uuid4
import sqlalchemy as sa
@ -37,6 +37,61 @@ from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adj
logger = logging.getLogger(__name__)
class PreProcessingRuleItem(TypedDict):
id: str
enabled: bool
class SegmentationConfig(TypedDict):
delimiter: str
max_tokens: int
chunk_overlap: int
class AutomaticRulesConfig(TypedDict):
pre_processing_rules: list[PreProcessingRuleItem]
segmentation: SegmentationConfig
class ProcessRuleDict(TypedDict):
id: str
dataset_id: str
mode: str
rules: dict[str, Any] | None
class DocMetadataDetailItem(TypedDict):
id: str
name: str
type: str
value: Any
class AttachmentItem(TypedDict):
id: str
name: str
size: int
extension: str
mime_type: str
source_url: str
class DatasetBindingItem(TypedDict):
id: str
name: str
class ExternalKnowledgeApiDict(TypedDict):
id: str
tenant_id: str
name: str
description: str
settings: dict[str, Any] | None
dataset_bindings: list[DatasetBindingItem]
created_by: str
created_at: str
class DatasetPermissionEnum(enum.StrEnum):
ONLY_ME = "only_me"
ALL_TEAM = "all_team_members"
@ -334,7 +389,7 @@ class DatasetProcessRule(Base): # bug
MODES = ["automatic", "custom", "hierarchical"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
AUTOMATIC_RULES: dict[str, Any] = {
AUTOMATIC_RULES: AutomaticRulesConfig = {
"pre_processing_rules": [
{"id": "remove_extra_spaces", "enabled": True},
{"id": "remove_urls_emails", "enabled": False},
@ -342,7 +397,7 @@ class DatasetProcessRule(Base): # bug
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
}
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> ProcessRuleDict:
return {
"id": self.id,
"dataset_id": self.dataset_id,
@ -531,7 +586,7 @@ class Document(Base):
return self.updated_at
@property
def doc_metadata_details(self) -> list[dict[str, Any]] | None:
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
@ -541,9 +596,9 @@ class Document(Base):
)
.all()
)
metadata_list: list[dict[str, Any]] = []
metadata_list: list[DocMetadataDetailItem] = []
for metadata in document_metadatas:
metadata_dict: dict[str, Any] = {
metadata_dict: DocMetadataDetailItem = {
"id": metadata.id,
"name": metadata.name,
"type": metadata.type,
@ -557,13 +612,13 @@ class Document(Base):
return None
@property
def process_rule_dict(self) -> dict[str, Any] | None:
def process_rule_dict(self) -> ProcessRuleDict | None:
if self.dataset_process_rule_id and self.dataset_process_rule:
return self.dataset_process_rule.to_dict()
return None
def get_built_in_fields(self) -> list[dict[str, Any]]:
built_in_fields: list[dict[str, Any]] = []
def get_built_in_fields(self) -> list[DocMetadataDetailItem]:
built_in_fields: list[DocMetadataDetailItem] = []
built_in_fields.append(
{
"id": "built-in",
@ -877,7 +932,7 @@ class DocumentSegment(Base):
return text
@property
def attachments(self) -> list[dict[str, Any]]:
def attachments(self) -> list[AttachmentItem]:
# Use JOIN to fetch attachments in a single query instead of two separate queries
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
@ -891,7 +946,7 @@ class DocumentSegment(Base):
).all()
if not attachments_with_bindings:
return []
attachment_list = []
attachment_list: list[AttachmentItem] = []
for _, attachment in attachments_with_bindings:
upload_file_id = attachment.id
nonce = os.urandom(16).hex()
@ -1261,7 +1316,7 @@ class ExternalKnowledgeApis(TypeBase):
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> ExternalKnowledgeApiDict:
return {
"id": self.id,
"tenant_id": self.tenant_id,
@ -1281,13 +1336,13 @@ class ExternalKnowledgeApis(TypeBase):
return None
@property
def dataset_bindings(self) -> list[dict[str, Any]]:
def dataset_bindings(self) -> list[DatasetBindingItem]:
external_knowledge_bindings = db.session.scalars(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
).all()
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
dataset_bindings: list[dict[str, Any]] = []
dataset_bindings: list[DatasetBindingItem] = []
for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name})

View File

@ -156,7 +156,8 @@ class VectorService:
)
# use full doc mode to generate segment's child chunk
processing_rule_dict = processing_rule.to_dict()
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC
if processing_rule_dict["rules"] is not None:
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC
documents = index_processor.transform(
[document],
embedding_model_instance=embedding_model_instance,