mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
refactor(api): replace dict/Mapping with TypedDict in core/rag retrieval_service.py (#33615)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class RerankingModelDict(TypedDict):
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class VectorSettingDict(TypedDict):
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSettingDict(TypedDict):
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightsDict(TypedDict):
|
||||
vector_setting: VectorSettingDict
|
||||
keyword_setting: KeywordSettingDict
|
||||
|
||||
|
||||
class DataPostProcessor:
|
||||
"""Interface for data post-processing document."""
|
||||
|
||||
@ -17,8 +39,8 @@ class DataPostProcessor:
|
||||
self,
|
||||
tenant_id: str,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
reorder_enabled: bool = False,
|
||||
):
|
||||
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
|
||||
@ -45,8 +67,8 @@ class DataPostProcessor:
|
||||
self,
|
||||
reranking_mode: str,
|
||||
tenant_id: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
) -> BaseRerankRunner | None:
|
||||
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
|
||||
runner = RerankRunnerFactory.create_rerank_runner(
|
||||
@ -79,12 +101,14 @@ class DataPostProcessor:
|
||||
return ReorderRunner()
|
||||
return None
|
||||
|
||||
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None:
|
||||
def _get_rerank_model_instance(
|
||||
self, tenant_id: str, reranking_model: RerankingModelDict | None
|
||||
) -> ModelInstance | None:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
reranking_provider_name = reranking_model.get("reranking_provider_name")
|
||||
reranking_model_name = reranking_model.get("reranking_model_name")
|
||||
reranking_provider_name = reranking_model["reranking_provider_name"]
|
||||
reranking_model_name = reranking_model["reranking_model_name"]
|
||||
if not reranking_provider_name or not reranking_model_name:
|
||||
return None
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
|
||||
@ -1,19 +1,20 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from typing import Any, NotRequired
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
|
||||
class SegmentAttachmentResult(TypedDict):
|
||||
attachment_info: AttachmentInfoDict
|
||||
segment_id: str
|
||||
|
||||
|
||||
class SegmentAttachmentInfoResult(TypedDict):
|
||||
attachment_id: str
|
||||
attachment_info: AttachmentInfoDict
|
||||
segment_id: str
|
||||
|
||||
|
||||
class ChildChunkDetail(TypedDict):
|
||||
id: str
|
||||
content: str
|
||||
position: int
|
||||
score: float
|
||||
|
||||
|
||||
class SegmentChildMapDetail(TypedDict):
|
||||
max_score: float
|
||||
child_chunks: list[ChildChunkDetail]
|
||||
|
||||
|
||||
class SegmentRecord(TypedDict):
|
||||
segment: DocumentSegment
|
||||
score: NotRequired[float]
|
||||
child_chunks: NotRequired[list[ChildChunkDetail]]
|
||||
files: NotRequired[list[AttachmentInfoDict]]
|
||||
|
||||
|
||||
class DefaultRetrievalModelDict(TypedDict):
|
||||
search_method: RetrievalMethod | str
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModelDict
|
||||
top_k: int
|
||||
score_threshold_enabled: bool
|
||||
|
||||
|
||||
default_retrieval_model: DefaultRetrievalModelDict = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@ -56,9 +96,9 @@ class RetrievalService:
|
||||
query: str,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_ids: list | None = None,
|
||||
):
|
||||
@ -235,7 +275,7 @@ class RetrievalService:
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
@ -277,8 +317,8 @@ class RetrievalService:
|
||||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and reranking_model["reranking_model_name"]
|
||||
and reranking_model["reranking_provider_name"]
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
@ -288,8 +328,8 @@ class RetrievalService:
|
||||
model_manager = ModelManager()
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=reranking_model.get("reranking_provider_name") or "",
|
||||
model=reranking_model.get("reranking_model_name") or "",
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
model=reranking_model["reranking_model_name"],
|
||||
model_type=ModelType.RERANK,
|
||||
)
|
||||
if is_support_vision:
|
||||
@ -329,7 +369,7 @@ class RetrievalService:
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
@ -349,8 +389,8 @@ class RetrievalService:
|
||||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and reranking_model["reranking_model_name"]
|
||||
and reranking_model["reranking_provider_name"]
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
@ -459,7 +499,7 @@ class RetrievalService:
|
||||
segment_ids: list[str] = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
|
||||
@ -544,12 +584,12 @@ class RetrievalService:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
include_segment_ids = set()
|
||||
segment_child_map: dict[str, dict[str, Any]] = {}
|
||||
records: list[dict[str, Any]] = []
|
||||
segment_child_map: dict[str, SegmentChildMapDetail] = {}
|
||||
records: list[SegmentRecord] = []
|
||||
|
||||
for segment in segments:
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
|
||||
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
@ -560,14 +600,14 @@ class RetrievalService:
|
||||
max_score = summary_score_map.get(segment.id, 0.0)
|
||||
|
||||
if child_chunks or attachment_infos:
|
||||
child_chunk_details = []
|
||||
child_chunk_details: list[ChildChunkDetail] = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
|
||||
if child_document:
|
||||
child_score = child_document.metadata.get("score", 0.0)
|
||||
else:
|
||||
child_score = 0.0
|
||||
child_chunk_detail = {
|
||||
child_chunk_detail: ChildChunkDetail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
@ -580,7 +620,7 @@ class RetrievalService:
|
||||
if file_document:
|
||||
max_score = max(max_score, file_document.metadata.get("score", 0.0))
|
||||
|
||||
map_detail = {
|
||||
map_detail: SegmentChildMapDetail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
}
|
||||
@ -593,7 +633,7 @@ class RetrievalService:
|
||||
"max_score": summary_score,
|
||||
"child_chunks": [],
|
||||
}
|
||||
record: dict[str, Any] = {
|
||||
record: SegmentRecord = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
@ -617,19 +657,19 @@ class RetrievalService:
|
||||
if file_doc:
|
||||
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||
|
||||
record = {
|
||||
another_record: SegmentRecord = {
|
||||
"segment": segment,
|
||||
"score": max_score,
|
||||
}
|
||||
records.append(record)
|
||||
records.append(another_record)
|
||||
|
||||
# Add child chunks information to records
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
if record["segment"].id in attachment_map:
|
||||
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||
record["files"] = attachment_map[record["segment"].id]
|
||||
|
||||
result: list[RetrievalSegments] = []
|
||||
for record in records:
|
||||
@ -693,9 +733,9 @@ class RetrievalService:
|
||||
query: str | None = None,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
reranking_model: dict | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: dict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_id: str | None = None,
|
||||
):
|
||||
@ -807,7 +847,7 @@ class RetrievalService:
|
||||
@classmethod
|
||||
def get_segment_attachment_info(
|
||||
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
||||
) -> dict[str, Any] | None:
|
||||
) -> SegmentAttachmentResult | None:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
|
||||
if upload_file:
|
||||
attachment_binding = (
|
||||
@ -816,7 +856,7 @@ class RetrievalService:
|
||||
.first()
|
||||
)
|
||||
if attachment_binding:
|
||||
attachment_info = {
|
||||
attachment_info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
@ -828,8 +868,10 @@ class RetrievalService:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
|
||||
attachment_infos = []
|
||||
def get_segment_attachment_infos(
|
||||
cls, attachment_ids: list[str], session: Session
|
||||
) -> list[SegmentAttachmentInfoResult]:
|
||||
attachment_infos: list[SegmentAttachmentInfoResult] = []
|
||||
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
if upload_files:
|
||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||
@ -843,7 +885,7 @@ class RetrievalService:
|
||||
if attachment_bindings:
|
||||
for upload_file in upload_files:
|
||||
attachment_binding = attachment_binding_map.get(upload_file.id)
|
||||
attachment_info = {
|
||||
info: AttachmentInfoDict = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
@ -855,7 +897,7 @@ class RetrievalService:
|
||||
attachment_infos.append(
|
||||
{
|
||||
"attachment_id": attachment_binding.attachment_id,
|
||||
"attachment_info": attachment_info,
|
||||
"attachment_info": info,
|
||||
"segment_id": attachment_binding.segment_id,
|
||||
}
|
||||
)
|
||||
|
||||
@ -1,8 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
class AttachmentInfoDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
size: int
|
||||
|
||||
|
||||
class RetrievalChildChunk(BaseModel):
|
||||
"""Retrieval segments."""
|
||||
|
||||
@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel):
|
||||
segment: DocumentSegment
|
||||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
files: list[AttachmentInfoDict] | None = None
|
||||
summary: str | None = None # Summary content if retrieved via summary index
|
||||
|
||||
@ -15,6 +15,7 @@ import httpx
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
@ -98,7 +99,7 @@ class BaseIndexProcessor(ABC):
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
|
||||
@ -11,6 +11,7 @@ from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
|
||||
@ -15,6 +15,7 @@ from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
reranking_model: RerankingModelDict,
|
||||
):
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(
|
||||
|
||||
@ -31,7 +31,7 @@ from core.ops.utils import measure_time
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
@ -727,8 +727,8 @@ class DatasetRetrieval:
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None = None,
|
||||
weights: dict[str, Any] | None = None,
|
||||
reranking_model: RerankingModelDict | None = None,
|
||||
weights: WeightsDict | None = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: str | None = None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||
@ -1181,8 +1181,8 @@ class DatasetRetrieval:
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"],
|
||||
reranking_model_name=retrieve_config.reranking_model["reranking_model_name"],
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
@ -1685,8 +1685,8 @@ class DatasetRetrieval:
|
||||
tenant_id: str,
|
||||
reranking_enable: bool,
|
||||
reranking_mode: str,
|
||||
reranking_model: dict | None,
|
||||
weights: dict[str, Any] | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
weights: WeightsDict | None,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
query: str | None,
|
||||
|
||||
Reference in New Issue
Block a user