refactor(api): replace dict/Mapping with TypedDict in core/rag retrieval_service.py (#33615)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
BitToby
2026-03-18 04:49:09 +02:00
committed by GitHub
parent d7f70f3c0f
commit 485da15a4d
18 changed files with 165 additions and 71 deletions

View File

@ -1,19 +1,20 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Any, NotRequired
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from typing_extensions import TypedDict
from configs import dify_config
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
class SegmentAttachmentResult(TypedDict):
attachment_info: AttachmentInfoDict
segment_id: str
class SegmentAttachmentInfoResult(TypedDict):
attachment_id: str
attachment_info: AttachmentInfoDict
segment_id: str
class ChildChunkDetail(TypedDict):
id: str
content: str
position: int
score: float
class SegmentChildMapDetail(TypedDict):
max_score: float
child_chunks: list[ChildChunkDetail]
class SegmentRecord(TypedDict):
segment: DocumentSegment
score: NotRequired[float]
child_chunks: NotRequired[list[ChildChunkDetail]]
files: NotRequired[list[AttachmentInfoDict]]
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod | str
reranking_enable: bool
reranking_model: RerankingModelDict
top_k: int
score_threshold_enabled: bool
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -56,9 +96,9 @@ class RetrievalService:
query: str,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_model: RerankingModelDict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
):
@ -235,7 +275,7 @@ class RetrievalService:
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
retrieval_method: RetrievalMethod,
exceptions: list,
@ -277,8 +317,8 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and reranking_model["reranking_model_name"]
and reranking_model["reranking_provider_name"]
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
):
data_post_processor = DataPostProcessor(
@ -288,8 +328,8 @@ class RetrievalService:
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model.get("reranking_provider_name") or "",
model=reranking_model.get("reranking_model_name") or "",
provider=reranking_model["reranking_provider_name"],
model=reranking_model["reranking_model_name"],
model_type=ModelType.RERANK,
)
if is_support_vision:
@ -329,7 +369,7 @@ class RetrievalService:
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
retrieval_method: str,
exceptions: list,
@ -349,8 +389,8 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and reranking_model["reranking_model_name"]
and reranking_model["reranking_provider_name"]
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
):
data_post_processor = DataPostProcessor(
@ -459,7 +499,7 @@ class RetrievalService:
segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map: dict[str, list[dict[str, Any]]] = {}
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
@ -544,12 +584,12 @@ class RetrievalService:
segment_summary_map[summary.chunk_id] = summary.summary_content
include_segment_ids = set()
segment_child_map: dict[str, dict[str, Any]] = {}
records: list[dict[str, Any]] = []
segment_child_map: dict[str, SegmentChildMapDetail] = {}
records: list[SegmentRecord] = []
for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
@ -560,14 +600,14 @@ class RetrievalService:
max_score = summary_score_map.get(segment.id, 0.0)
if child_chunks or attachment_infos:
child_chunk_details = []
child_chunk_details: list[ChildChunkDetail] = []
for child_chunk in child_chunks:
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
if child_document:
child_score = child_document.metadata.get("score", 0.0)
else:
child_score = 0.0
child_chunk_detail = {
child_chunk_detail: ChildChunkDetail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
@ -580,7 +620,7 @@ class RetrievalService:
if file_document:
max_score = max(max_score, file_document.metadata.get("score", 0.0))
map_detail = {
map_detail: SegmentChildMapDetail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
}
@ -593,7 +633,7 @@ class RetrievalService:
"max_score": summary_score,
"child_chunks": [],
}
record: dict[str, Any] = {
record: SegmentRecord = {
"segment": segment,
}
records.append(record)
@ -617,19 +657,19 @@ class RetrievalService:
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
another_record: SegmentRecord = {
"segment": segment,
"score": max_score,
}
records.append(record)
records.append(another_record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
record["score"] = segment_child_map[record["segment"].id]["max_score"]
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
record["files"] = attachment_map[record["segment"].id]
result: list[RetrievalSegments] = []
for record in records:
@ -693,9 +733,9 @@ class RetrievalService:
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_model: RerankingModelDict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
):
@ -807,7 +847,7 @@ class RetrievalService:
@classmethod
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> dict[str, Any] | None:
) -> SegmentAttachmentResult | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = (
@ -816,7 +856,7 @@ class RetrievalService:
.first()
)
if attachment_binding:
attachment_info = {
attachment_info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -828,8 +868,10 @@ class RetrievalService:
return None
@classmethod
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
attachment_infos = []
def get_segment_attachment_infos(
cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
@ -843,7 +885,7 @@ class RetrievalService:
if attachment_bindings:
for upload_file in upload_files:
attachment_binding = attachment_binding_map.get(upload_file.id)
attachment_info = {
info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -855,7 +897,7 @@ class RetrievalService:
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
"attachment_info": attachment_info,
"attachment_info": info,
"segment_id": attachment_binding.segment_id,
}
)