refactor(api): tighten core rag typing batch 1 (#35210)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
tmimmanuel
2026-04-21 06:32:43 +02:00
committed by GitHub
parent 77d6c108e7
commit 77f8f2babb
8 changed files with 46 additions and 30 deletions

View File

@ -139,8 +139,10 @@ class Jieba(BaseKeyword):
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
} }
dataset_keyword_table = self.dataset.dataset_keyword_table dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
if keyword_data_source_type == "database": if keyword_data_source_type == "database":
if dataset_keyword_table is None:
return
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict) dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit() db.session.commit()
else: else:

View File

@ -1,4 +1,5 @@
import re import re
from collections.abc import Callable
from operator import itemgetter from operator import itemgetter
from typing import cast from typing import cast
@ -80,12 +81,14 @@ class JiebaKeywordTableHandler:
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs): def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable. # Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
top_k = kwargs.pop("topK", top_k) top_k = cast(int | None, kwargs.pop("topK", top_k))
if top_k is None:
top_k = 20
cut = getattr(jieba, "cut", None) cut = getattr(jieba, "cut", None)
if self._lcut: if self._lcut:
tokens = self._lcut(sentence) tokens = self._lcut(sentence)
elif callable(cut): elif callable(cut):
tokens = list(cut(sentence)) tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
else: else:
tokens = re.findall(r"\w+", sentence) tokens = re.findall(r"\w+", sentence)
@ -108,7 +111,7 @@ class JiebaKeywordTableHandler:
sentence=text, sentence=text,
topK=max_keywords_per_chunk, topK=max_keywords_per_chunk,
) )
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords) keywords = cast(list[str], keywords)
return set(self._expand_tokens_with_subtokens(set(keywords))) return set(self._expand_tokens_with_subtokens(set(keywords)))

View File

@ -158,7 +158,7 @@ class RetrievalService:
) )
if futures: if futures:
for future in concurrent.futures.as_completed(futures, timeout=3600): for _ in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions: if exceptions:
for f in futures: for f in futures:
f.cancel() f.cancel()

View File

@ -94,6 +94,7 @@ class ExtractProcessor:
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]: ) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE: if extract_setting.datasource_type == DatasourceType.FILE:
upload_file = extract_setting.upload_file
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
upload_file = extract_setting.upload_file upload_file = extract_setting.upload_file
if not file_path: if not file_path:
@ -104,6 +105,7 @@ class ExtractProcessor:
storage.download(upload_file.key, file_path) storage.download(upload_file.key, file_path)
input_file = Path(file_path) input_file = Path(file_path)
file_extension = input_file.suffix.lower() file_extension = input_file.suffix.lower()
assert upload_file is not None, "upload_file is required"
etl_type = dify_config.ETL_TYPE etl_type = dify_config.ETL_TYPE
extractor: BaseExtractor | None = None extractor: BaseExtractor | None = None
if etl_type == "Unstructured": if etl_type == "Unstructured":

View File

@ -28,7 +28,7 @@ class FunctionCallMultiDatasetRouter:
SystemPromptMessage(content="You are a helpful AI assistant."), SystemPromptMessage(content="You are a helpful AI assistant."),
UserPromptMessage(content=query), UserPromptMessage(content=query),
] ]
result: LLMResult = model_instance.invoke_llm( result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=dataset_tools, tools=dataset_tools,
stream=False, stream=False,

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import codecs import codecs
import re import re
from collections.abc import Collection from collections.abc import Set as AbstractSet
from typing import Any, Literal from typing import Any, Literal
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter]( def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
cls: type[T], cls: type[T],
embedding_model_instance: ModelInstance | None, embedding_model_instance: ModelInstance | None,
allowed_special: Literal["all"] | set[str] = set(), allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
disallowed_special: Literal["all"] | Collection[str] = "all", disallowed_special: Literal["all"] | AbstractSet[str] = "all",
**kwargs: Any, **kwargs: Any,
) -> T: ) -> T:
def _token_encoder(texts: list[str]) -> list[int]: def _token_encoder(texts: list[str]) -> list[int]:
@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return [len(text) for text in texts] return [len(text) for text in texts]
_ = _token_encoder # kept for future token-length wiring
return cls(length_function=_character_encoder, **kwargs) return cls(length_function=_character_encoder, **kwargs)

View File

@ -4,7 +4,8 @@ import copy
import logging import logging
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable, Sequence, Set from collections.abc import Callable, Iterable, Sequence
from collections.abc import Set as AbstractSet
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal from typing import Any, Literal
@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter):
self, self,
encoding_name: str = "gpt2", encoding_name: str = "gpt2",
model_name: str | None = None, model_name: str | None = None,
allowed_special: Literal["all"] | Set[str] = set(), allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
disallowed_special: Literal["all"] | Collection[str] = "all", disallowed_special: Literal["all"] | AbstractSet[str] = "all",
**kwargs: Any, **kwargs: Any,
): ):
"""Create a new TextSplitter.""" """Create a new TextSplitter."""
@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter):
else: else:
enc = tiktoken.get_encoding(encoding_name) enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc self._tokenizer = enc
self._allowed_special = allowed_special self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
self._disallowed_special = disallowed_special self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
def split_text(self, text: str) -> list[str]: def split_text(self, text: str) -> list[str]:
def _encode(_text: str) -> list[int]: def _encode(_text: str) -> list[int]:

View File

@ -1,10 +1,10 @@
import json import json
import logging import logging
import time import time
from typing import Any, TypedDict from typing import Any, TypedDict, cast
from core.app.app_config.entities import ModelConfig from core.app.app_config.entities import ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.index_processor.constant.query_type import QueryType from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -36,6 +36,10 @@ default_retrieval_model = {
} }
class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False):
metadata_filtering_conditions: dict[str, Any]
class HitTestingService: class HitTestingService:
@classmethod @classmethod
def retrieve( def retrieve(
@ -51,17 +55,18 @@ class HitTestingService:
start = time.perf_counter() start = time.perf_counter()
# get retrieval model , if the model is not setting , using default # get retrieval model , if the model is not setting , using default
if not retrieval_model: resolved_retrieval_model = cast(
retrieval_model = dataset.retrieval_model or default_retrieval_model HitTestingRetrievalModelDict,
assert isinstance(retrieval_model, dict) retrieval_model or dataset.retrieval_model or default_retrieval_model,
)
document_ids_filter = None document_ids_filter = None
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions and query: if metadata_filtering_conditions_raw and query:
dataset_retrieval = DatasetRetrieval() dataset_retrieval = DatasetRetrieval()
from core.rag.entities import MetadataFilteringCondition from core.rag.entities import MetadataFilteringCondition
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions) metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw)
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
dataset_ids=[dataset.id], dataset_ids=[dataset.id],
@ -78,19 +83,21 @@ class HitTestingService:
if metadata_condition and not document_ids_filter: if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, []) return cls.compact_retrieve_response(query, [])
all_documents = RetrievalService.retrieve( all_documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), retrieval_method=RetrievalMethod(
resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)
),
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
attachment_ids=attachment_ids, attachment_ids=attachment_ids,
top_k=retrieval_model.get("top_k", 4), top_k=resolved_retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0) score_threshold=resolved_retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"] if resolved_retrieval_model["score_threshold_enabled"]
else 0.0, else 0.0,
reranking_model=retrieval_model.get("reranking_model", None) reranking_model=resolved_retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"] if resolved_retrieval_model["reranking_enable"]
else None, else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None), weights=resolved_retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter, document_ids_filter=document_ids_filter,
) )