mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
refactor(api): replace dict/Mapping with TypedDict in core/rag retrieval_service.py (#33615)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,19 +1,20 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from typing import Any, NotRequired
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
|
||||
class SegmentAttachmentResult(TypedDict):
|
||||
attachment_info: AttachmentInfoDict
|
||||
segment_id: str
|
||||
|
||||
|
||||
class SegmentAttachmentInfoResult(TypedDict):
|
||||
attachment_id: str
|
||||
attachment_info: AttachmentInfoDict
|
||||
segment_id: str
|
||||
|
||||
|
||||
class ChildChunkDetail(TypedDict):
|
||||
id: str
|
||||
content: str
|
||||
position: int
|
||||
score: float
|
||||
|
||||
|
||||
class SegmentChildMapDetail(TypedDict):
|
||||
max_score: float
|
||||
child_chunks: list[ChildChunkDetail]
|
||||
|
||||
|
||||
class SegmentRecord(TypedDict):
|
||||
segment: DocumentSegment
|
||||
score: NotRequired[float]
|
||||
child_chunks: NotRequired[list[ChildChunkDetail]]
|
||||
files: NotRequired[list[AttachmentInfoDict]]
|
||||
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod | str
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@ -56,9 +96,9 @@ class RetrievalService:
|
||||
query: str,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_ids: list | None = None,
|
||||
):
|
||||
@ -235,7 +275,7 @@ class RetrievalService:
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
@ -277,8 +317,8 @@ class RetrievalService:
|
||||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and reranking_model["reranking_model_name"]
|
||||
and reranking_model["reranking_provider_name"]
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
@ -288,8 +328,8 @@ class RetrievalService:
|
||||
model_manager = ModelManager()
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model.get("reranking_provider_name") or "",
|
||||
model=reranking_model.get("reranking_model_name") or "",
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
model=reranking_model["reranking_model_name"],
|
||||
model_type=ModelType.RERANK,
|
||||
)
|
||||
if is_support_vision:
|
||||
@ -329,7 +369,7 @@ class RetrievalService:
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
@ -349,8 +389,8 @@ class RetrievalService:
|
||||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and reranking_model["reranking_model_name"]
|
||||
and reranking_model["reranking_provider_name"]
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
@ -459,7 +499,7 @@ class RetrievalService:
|
||||
segment_ids: list[str] = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
|
||||
@ -544,12 +584,12 @@ class RetrievalService:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
include_segment_ids = set()
|
||||
segment_child_map: dict[str, dict[str, Any]] = {}
|
||||
records: list[dict[str, Any]] = []
|
||||
segment_child_map: dict[str, SegmentChildMapDetail] = {}
|
||||
records: list[SegmentRecord] = []
|
||||
|
||||
for segment in segments:
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
|
||||
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
@ -560,14 +600,14 @@ class RetrievalService:
|
||||
max_score = summary_score_map.get(segment.id, 0.0)
|
||||
|
||||
if child_chunks or attachment_infos:
|
||||
child_chunk_details = []
|
||||
child_chunk_details: list[ChildChunkDetail] = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
|
||||
if child_document:
|
||||
child_score = child_document.metadata.get("score", 0.0)
|
||||
else:
|
||||
child_score = 0.0
|
||||
child_chunk_detail = {
|
||||
child_chunk_detail: ChildChunkDetail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
@ -580,7 +620,7 @@ class RetrievalService:
|
||||
if file_document:
|
||||
max_score = max(max_score, file_document.metadata.get("score", 0.0))
|
||||
|
||||
map_detail = {
|
||||
map_detail: SegmentChildMapDetail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
}
|
||||
@ -593,7 +633,7 @@ class RetrievalService:
|
||||
"max_score": summary_score,
|
||||
"child_chunks": [],
|
||||
}
|
||||
record: dict[str, Any] = {
|
||||
record: SegmentRecord = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
@ -617,19 +657,19 @@ class RetrievalService:
|
||||
if file_doc:
|
||||
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||
|
||||
record = {
|
||||
another_record: SegmentRecord = {
|
||||
"segment": segment,
|
||||
"score": max_score,
|
||||
}
|
||||
records.append(record)
|
||||
records.append(another_record)
|
||||
|
||||
# Add child chunks information to records
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
if record["segment"].id in attachment_map:
|
||||
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||
record["files"] = attachment_map[record["segment"].id]
|
||||
|
||||
result: list[RetrievalSegments] = []
|
||||
for record in records:
|
||||
@ -693,9 +733,9 @@ class RetrievalService:
|
||||
query: str | None = None,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_id: str | None = None,
|
||||
):
|
||||
@ -807,7 +847,7 @@ class RetrievalService:
|
||||
@classmethod
|
||||
def get_segment_attachment_info(
|
||||
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
||||
) -> dict[str, Any] | None:
|
||||
) -> SegmentAttachmentResult | None:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
|
||||
if upload_file:
|
||||
attachment_binding = (
|
||||
@ -816,7 +856,7 @@ class RetrievalService:
|
||||
.first()
|
||||
)
|
||||
if attachment_binding:
|
||||
attachment_info = {
|
||||
attachment_info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
@ -828,8 +868,10 @@ class RetrievalService:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
|
||||
attachment_infos = []
|
||||
def get_segment_attachment_infos(
|
||||
cls, attachment_ids: list[str], session: Session
|
||||
) -> list[SegmentAttachmentInfoResult]:
|
||||
attachment_infos: list[SegmentAttachmentInfoResult] = []
|
||||
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
if upload_files:
|
||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||
@ -843,7 +885,7 @@ class RetrievalService:
|
||||
if attachment_bindings:
|
||||
for upload_file in upload_files:
|
||||
attachment_binding = attachment_binding_map.get(upload_file.id)
|
||||
attachment_info = {
|
||||
info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
@ -855,7 +897,7 @@ class RetrievalService:
|
||||
attachment_infos.append(
|
||||
{
|
||||
"attachment_id": attachment_binding.attachment_id,
|
||||
"attachment_info": attachment_info,
|
||||
"attachment_info": info,
|
||||
"segment_id": attachment_binding.segment_id,
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user