mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
refactor(api): tighten core rag typing batch 1 (#35210)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -139,8 +139,10 @@ class Jieba(BaseKeyword):
|
|||||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
||||||
}
|
}
|
||||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||||
keyword_data_source_type = dataset_keyword_table.data_source_type
|
keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
|
||||||
if keyword_data_source_type == "database":
|
if keyword_data_source_type == "database":
|
||||||
|
if dataset_keyword_table is None:
|
||||||
|
return
|
||||||
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
|
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
from collections.abc import Callable
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@ -80,12 +81,14 @@ class JiebaKeywordTableHandler:
|
|||||||
|
|
||||||
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
|
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
|
||||||
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
|
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
|
||||||
top_k = kwargs.pop("topK", top_k)
|
top_k = cast(int | None, kwargs.pop("topK", top_k))
|
||||||
|
if top_k is None:
|
||||||
|
top_k = 20
|
||||||
cut = getattr(jieba, "cut", None)
|
cut = getattr(jieba, "cut", None)
|
||||||
if self._lcut:
|
if self._lcut:
|
||||||
tokens = self._lcut(sentence)
|
tokens = self._lcut(sentence)
|
||||||
elif callable(cut):
|
elif callable(cut):
|
||||||
tokens = list(cut(sentence))
|
tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
|
||||||
else:
|
else:
|
||||||
tokens = re.findall(r"\w+", sentence)
|
tokens = re.findall(r"\w+", sentence)
|
||||||
|
|
||||||
@ -108,7 +111,7 @@ class JiebaKeywordTableHandler:
|
|||||||
sentence=text,
|
sentence=text,
|
||||||
topK=max_keywords_per_chunk,
|
topK=max_keywords_per_chunk,
|
||||||
)
|
)
|
||||||
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
|
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
||||||
keywords = cast(list[str], keywords)
|
keywords = cast(list[str], keywords)
|
||||||
|
|
||||||
return set(self._expand_tokens_with_subtokens(set(keywords)))
|
return set(self._expand_tokens_with_subtokens(set(keywords)))
|
||||||
|
|||||||
@ -158,7 +158,7 @@ class RetrievalService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if futures:
|
if futures:
|
||||||
for future in concurrent.futures.as_completed(futures, timeout=3600):
|
for _ in concurrent.futures.as_completed(futures, timeout=3600):
|
||||||
if exceptions:
|
if exceptions:
|
||||||
for f in futures:
|
for f in futures:
|
||||||
f.cancel()
|
f.cancel()
|
||||||
|
|||||||
@ -94,6 +94,7 @@ class ExtractProcessor:
|
|||||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
|
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
if extract_setting.datasource_type == DatasourceType.FILE:
|
if extract_setting.datasource_type == DatasourceType.FILE:
|
||||||
|
upload_file = extract_setting.upload_file
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
upload_file = extract_setting.upload_file
|
upload_file = extract_setting.upload_file
|
||||||
if not file_path:
|
if not file_path:
|
||||||
@ -104,6 +105,7 @@ class ExtractProcessor:
|
|||||||
storage.download(upload_file.key, file_path)
|
storage.download(upload_file.key, file_path)
|
||||||
input_file = Path(file_path)
|
input_file = Path(file_path)
|
||||||
file_extension = input_file.suffix.lower()
|
file_extension = input_file.suffix.lower()
|
||||||
|
assert upload_file is not None, "upload_file is required"
|
||||||
etl_type = dify_config.ETL_TYPE
|
etl_type = dify_config.ETL_TYPE
|
||||||
extractor: BaseExtractor | None = None
|
extractor: BaseExtractor | None = None
|
||||||
if etl_type == "Unstructured":
|
if etl_type == "Unstructured":
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class FunctionCallMultiDatasetRouter:
|
|||||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||||
UserPromptMessage(content=query),
|
UserPromptMessage(content=query),
|
||||||
]
|
]
|
||||||
result: LLMResult = model_instance.invoke_llm(
|
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
tools=dataset_tools,
|
tools=dataset_tools,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
import re
|
import re
|
||||||
from collections.abc import Collection
|
from collections.abc import Set as AbstractSet
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
|||||||
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
|
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
|
||||||
cls: type[T],
|
cls: type[T],
|
||||||
embedding_model_instance: ModelInstance | None,
|
embedding_model_instance: ModelInstance | None,
|
||||||
allowed_special: Literal["all"] | set[str] = set(),
|
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
|
||||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> T:
|
) -> T:
|
||||||
def _token_encoder(texts: list[str]) -> list[int]:
|
def _token_encoder(texts: list[str]) -> list[int]:
|
||||||
@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
|||||||
|
|
||||||
return [len(text) for text in texts]
|
return [len(text) for text in texts]
|
||||||
|
|
||||||
|
_ = _token_encoder # kept for future token-length wiring
|
||||||
return cls(length_function=_character_encoder, **kwargs)
|
return cls(length_function=_character_encoder, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,8 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Collection, Iterable, Sequence, Set
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
|
from collections.abc import Set as AbstractSet
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter):
|
|||||||
self,
|
self,
|
||||||
encoding_name: str = "gpt2",
|
encoding_name: str = "gpt2",
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
allowed_special: Literal["all"] | Set[str] = set(),
|
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
|
||||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Create a new TextSplitter."""
|
"""Create a new TextSplitter."""
|
||||||
@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter):
|
|||||||
else:
|
else:
|
||||||
enc = tiktoken.get_encoding(encoding_name)
|
enc = tiktoken.get_encoding(encoding_name)
|
||||||
self._tokenizer = enc
|
self._tokenizer = enc
|
||||||
self._allowed_special = allowed_special
|
self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
|
||||||
self._disallowed_special = disallowed_special
|
self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
|
||||||
|
|
||||||
def split_text(self, text: str) -> list[str]:
|
def split_text(self, text: str) -> list[str]:
|
||||||
def _encode(_text: str) -> list[int]:
|
def _encode(_text: str) -> list[int]:
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict, cast
|
||||||
|
|
||||||
from core.app.app_config.entities import ModelConfig
|
from core.app.app_config.entities import ModelConfig
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||||
from core.rag.index_processor.constant.query_type import QueryType
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
@ -36,6 +36,10 @@ default_retrieval_model = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False):
|
||||||
|
metadata_filtering_conditions: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class HitTestingService:
|
class HitTestingService:
|
||||||
@classmethod
|
@classmethod
|
||||||
def retrieve(
|
def retrieve(
|
||||||
@ -51,17 +55,18 @@ class HitTestingService:
|
|||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
# get retrieval model , if the model is not setting , using default
|
# get retrieval model , if the model is not setting , using default
|
||||||
if not retrieval_model:
|
resolved_retrieval_model = cast(
|
||||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
HitTestingRetrievalModelDict,
|
||||||
assert isinstance(retrieval_model, dict)
|
retrieval_model or dataset.retrieval_model or default_retrieval_model,
|
||||||
|
)
|
||||||
document_ids_filter = None
|
document_ids_filter = None
|
||||||
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {})
|
||||||
if metadata_filtering_conditions and query:
|
if metadata_filtering_conditions_raw and query:
|
||||||
dataset_retrieval = DatasetRetrieval()
|
dataset_retrieval = DatasetRetrieval()
|
||||||
|
|
||||||
from core.rag.entities import MetadataFilteringCondition
|
from core.rag.entities import MetadataFilteringCondition
|
||||||
|
|
||||||
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw)
|
||||||
|
|
||||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||||
dataset_ids=[dataset.id],
|
dataset_ids=[dataset.id],
|
||||||
@ -78,19 +83,21 @@ class HitTestingService:
|
|||||||
if metadata_condition and not document_ids_filter:
|
if metadata_condition and not document_ids_filter:
|
||||||
return cls.compact_retrieve_response(query, [])
|
return cls.compact_retrieve_response(query, [])
|
||||||
all_documents = RetrievalService.retrieve(
|
all_documents = RetrievalService.retrieve(
|
||||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
retrieval_method=RetrievalMethod(
|
||||||
|
resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)
|
||||||
|
),
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
attachment_ids=attachment_ids,
|
attachment_ids=attachment_ids,
|
||||||
top_k=retrieval_model.get("top_k", 4),
|
top_k=resolved_retrieval_model.get("top_k", 4),
|
||||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
score_threshold=resolved_retrieval_model.get("score_threshold", 0.0)
|
||||||
if retrieval_model["score_threshold_enabled"]
|
if resolved_retrieval_model["score_threshold_enabled"]
|
||||||
else 0.0,
|
else 0.0,
|
||||||
reranking_model=retrieval_model.get("reranking_model", None)
|
reranking_model=resolved_retrieval_model.get("reranking_model", None)
|
||||||
if retrieval_model["reranking_enable"]
|
if resolved_retrieval_model["reranking_enable"]
|
||||||
else None,
|
else None,
|
||||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model",
|
||||||
weights=retrieval_model.get("weights", None),
|
weights=resolved_retrieval_model.get("weights", None),
|
||||||
document_ids_filter=document_ids_filter,
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user