mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
Merge branch 'main' into feat/support-knowledge-metadata
# Conflicts: # api/core/rag/datasource/retrieval_service.py # api/core/workflow/nodes/code/code_node.py # api/services/dataset_service.py
This commit is contained in:
@ -1,9 +1,12 @@
|
||||
import concurrent.futures
|
||||
import json
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import load_only
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
@ -27,6 +30,7 @@ default_retrieval_model = {
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
# Cache precompiled regular expressions to avoid repeated compilation
|
||||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
@ -42,77 +46,62 @@ class RetrievalService:
|
||||
):
|
||||
if not query:
|
||||
return []
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
|
||||
all_documents: list[Document] = []
|
||||
threads: list[threading.Thread] = []
|
||||
exceptions: list[str] = []
|
||||
# retrieval_model source with keyword
|
||||
if retrieval_method == "keyword_search":
|
||||
keyword_thread = threading.Thread(
|
||||
target=RetrievalService.keyword_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
"exceptions": exceptions,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
},
|
||||
)
|
||||
threads.append(keyword_thread)
|
||||
keyword_thread.start()
|
||||
# retrieval_model source with semantic
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
embedding_thread = threading.Thread(
|
||||
target=RetrievalService.embedding_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"reranking_model": reranking_model,
|
||||
"all_documents": all_documents,
|
||||
"retrieval_method": retrieval_method,
|
||||
"exceptions": exceptions,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
},
|
||||
)
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval source with full text
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
full_text_index_thread = threading.Thread(
|
||||
target=RetrievalService.full_text_index_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"retrieval_method": retrieval_method,
|
||||
"score_threshold": score_threshold,
|
||||
"top_k": top_k,
|
||||
"reranking_model": reranking_model,
|
||||
"all_documents": all_documents,
|
||||
"exceptions": exceptions,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
},
|
||||
)
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# Optimize multithreading with thread pools
|
||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||
futures = []
|
||||
if retrieval_method == "keyword_search":
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.keyword_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
all_documents=all_documents,
|
||||
exceptions=exceptions,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.embedding_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
all_documents=all_documents,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
)
|
||||
)
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.full_text_index_search,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
all_documents=all_documents,
|
||||
retrieval_method=retrieval_method,
|
||||
exceptions=exceptions,
|
||||
)
|
||||
)
|
||||
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
|
||||
if exceptions:
|
||||
exception_message = ";\n".join(exceptions)
|
||||
raise ValueError(exception_message)
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(
|
||||
@ -137,6 +126,10 @@ class RetrievalService:
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
|
||||
return db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
cls,
|
||||
@ -150,7 +143,7 @@ class RetrievalService:
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
@ -179,12 +172,11 @@ class RetrievalService:
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
vector = Vector(dataset=dataset)
|
||||
|
||||
documents = vector.search_by_vector(
|
||||
query,
|
||||
search_type="similarity_score_threshold",
|
||||
@ -202,7 +194,7 @@ class RetrievalService:
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
@ -232,13 +224,11 @@ class RetrievalService:
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = cls._get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
vector_processor = Vector(
|
||||
dataset=dataset,
|
||||
)
|
||||
vector_processor = Vector(dataset=dataset)
|
||||
|
||||
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
|
||||
if documents:
|
||||
@ -249,7 +239,7 @@ class RetrievalService:
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
@ -268,64 +258,104 @@ class RetrievalService:
|
||||
def escape_query_for_search(query: str) -> str:
|
||||
return json.dumps(query).strip('"')
|
||||
|
||||
@staticmethod
|
||||
def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]:
|
||||
records = []
|
||||
include_segment_ids = []
|
||||
segment_child_map = {}
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
|
||||
if dataset_document:
|
||||
@classmethod
|
||||
def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]:
|
||||
"""Format retrieval documents with optimized batch processing"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Collect document IDs
|
||||
document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata}
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
# Batch query dataset documents
|
||||
dataset_documents = {
|
||||
doc.id: doc
|
||||
for doc in db.session.query(DatasetDocument)
|
||||
.filter(DatasetDocument.id.in_(document_ids))
|
||||
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
|
||||
.all()
|
||||
}
|
||||
|
||||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
|
||||
# Process documents
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
continue
|
||||
|
||||
dataset_document = dataset_documents[document_id]
|
||||
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
result = (
|
||||
db.session.query(ChildChunk, DocumentSegment)
|
||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
ChildChunk.index_node_id == child_index_node_id,
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == child_chunk.segment_id,
|
||||
)
|
||||
.options(
|
||||
load_only(
|
||||
DocumentSegment.id,
|
||||
DocumentSegment.content,
|
||||
DocumentSegment.answer,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if result:
|
||||
child_chunk, segment = result
|
||||
if not segment:
|
||||
continue
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.append(segment.id)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
index_node_id = document.metadata["doc_id"]
|
||||
# Handle normal documents
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
@ -340,16 +370,21 @@ class RetrievalService:
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
include_segment_ids.append(segment.id)
|
||||
|
||||
include_segment_ids.add(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score", None),
|
||||
"score": document.metadata.get("score"), # type: ignore
|
||||
}
|
||||
|
||||
records.append(record)
|
||||
|
||||
# Add child chunks information to records
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
|
||||
return [RetrievalSegments(**record) for record in records]
|
||||
return [RetrievalSegments(**record) for record in records]
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
raise e
|
||||
|
||||
@ -119,8 +119,9 @@ class ChromaVector(BaseVector):
|
||||
for index in range(len(ids)):
|
||||
distance = distances[index]
|
||||
metadata = dict(metadatas[index])
|
||||
if distance >= score_threshold:
|
||||
metadata["score"] = distance
|
||||
score = 1 - distance
|
||||
if score > score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
metadata=metadata,
|
||||
|
||||
@ -179,7 +179,7 @@ class LindormVectorStore(BaseVector):
|
||||
if self._using_ugc:
|
||||
params["routing"] = self._routing
|
||||
response = self._client.search(index=self._collection_name, body=query, params=params)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception(f"Error executing vector search, query: {query}")
|
||||
raise
|
||||
|
||||
@ -355,7 +355,8 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
|
||||
}
|
||||
|
||||
if excludes_from_source:
|
||||
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
|
||||
# e.g. {"excludes": ["vector_field"]}
|
||||
mapping["mappings"]["_source"] = {"excludes": excludes_from_source}
|
||||
|
||||
if using_ugc and method_name == "ivfpq":
|
||||
mapping["settings"]["index"]["knn_routing"] = True
|
||||
|
||||
@ -79,7 +79,13 @@ class DatasetDocumentStore:
|
||||
model=self._dataset.embedding_model,
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
if embedding_model:
|
||||
page_content_list = [doc.page_content for doc in docs]
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(page_content_list)
|
||||
else:
|
||||
tokens_list = [0] * len(docs)
|
||||
|
||||
for doc, tokens in zip(docs, tokens_list):
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("doc must be a Document")
|
||||
|
||||
@ -94,12 +100,6 @@ class DatasetDocumentStore:
|
||||
f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite."
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
if embedding_model:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content])
|
||||
else:
|
||||
tokens = 0
|
||||
|
||||
if not segment_document:
|
||||
max_position += 1
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_embeddings.append(normalized_embedding)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("Failed transform embedding")
|
||||
cache_embeddings = []
|
||||
try:
|
||||
|
||||
@ -50,7 +50,7 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
self.web_path = self.file_path
|
||||
# TODO: use a better way to handle the file
|
||||
self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115
|
||||
self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115
|
||||
self.temp_file.write(r.content)
|
||||
self.file_path = self.temp_file.name
|
||||
elif not os.path.isfile(self.file_path):
|
||||
@ -234,7 +234,7 @@ class WordExtractor(BaseExtractor):
|
||||
continue
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to parse HYPERLINK xml")
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
|
||||
@ -47,6 +47,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
embedding_model_instance=kwargs.get("embedding_model_instance"),
|
||||
)
|
||||
for document in documents:
|
||||
if kwargs.get("preview") and len(all_documents) >= 10:
|
||||
return all_documents
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, process_rule)
|
||||
document.page_content = document_text
|
||||
|
||||
@ -26,9 +26,7 @@ from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
@ -451,7 +449,7 @@ class DatasetRetrieval:
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: Optional[TraceQueueManager] = (
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
@ -595,6 +593,8 @@ class DatasetRetrieval:
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
@ -606,20 +606,24 @@ class DatasetRetrieval:
|
||||
|
||||
tools.append(tool)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if retrieve_config.reranking_model is not None:
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
)
|
||||
from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
|
||||
tools.append(tool)
|
||||
if retrieve_config.reranking_model is None:
|
||||
raise ValueError("Reranking model is required for multiple retrieval")
|
||||
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
@ -30,14 +30,14 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
|
||||
**kwargs: Any,
|
||||
):
|
||||
def _token_encoder(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
def _token_encoder(texts: list[str]) -> list[int]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if embedding_model_instance:
|
||||
return embedding_model_instance.get_text_embedding_num_tokens(texts=[text])
|
||||
return embedding_model_instance.get_text_embedding_num_tokens(texts=texts)
|
||||
else:
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
@ -65,8 +65,9 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
chunks = [text]
|
||||
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
if self._length_function(chunk) > self._chunk_size:
|
||||
chunks_lengths = self._length_function(chunks)
|
||||
for chunk, chunk_length in zip(chunks, chunks_lengths):
|
||||
if chunk_length > self._chunk_size:
|
||||
final_chunks.extend(self.recursive_split_text(chunk))
|
||||
else:
|
||||
final_chunks.append(chunk)
|
||||
@ -93,8 +94,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
for s in splits:
|
||||
s_len = self._length_function(s)
|
||||
s_lens = self._length_function(splits)
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
if s_len < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
_good_splits_lengths.append(s_len)
|
||||
|
||||
@ -45,7 +45,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x],
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
) -> None:
|
||||
@ -106,7 +106,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
|
||||
# We now want to combine these smaller pieces into medium size
|
||||
# chunks to send to the LLM.
|
||||
separator_len = self._length_function(separator)
|
||||
separator_len = self._length_function([separator])[0]
|
||||
|
||||
docs = []
|
||||
current_doc: list[str] = []
|
||||
@ -129,7 +129,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
while total > self._chunk_overlap or (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
|
||||
):
|
||||
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
|
||||
total -= self._length_function([current_doc[0]])[0] + (
|
||||
separator_len if len(current_doc) > 1 else 0
|
||||
)
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||
@ -155,7 +157,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. Please install it with `pip install transformers`."
|
||||
)
|
||||
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
||||
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_tiktoken_encoder(
|
||||
@ -199,7 +201,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=_tiktoken_encoder, **kwargs)
|
||||
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
|
||||
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
@ -224,8 +226,8 @@ class CharacterTextSplitter(TextSplitter):
|
||||
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
|
||||
_separator = "" if self._keep_separator else self._separator
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
for split in splits:
|
||||
_good_splits_lengths.append(self._length_function(split))
|
||||
if splits:
|
||||
_good_splits_lengths.extend(self._length_function(splits))
|
||||
return self._merge_splits(splits, _separator, _good_splits_lengths)
|
||||
|
||||
|
||||
@ -478,9 +480,8 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = "" if self._keep_separator else separator
|
||||
|
||||
for s in splits:
|
||||
s_len = self._length_function(s)
|
||||
s_lens = self._length_function(splits)
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
if s_len < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
_good_splits_lengths.append(s_len)
|
||||
|
||||
Reference in New Issue
Block a user