mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
feat: skip rerank if only one dataset is retrieved (#30075)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@ -515,6 +515,7 @@ class DatasetRetrieval:
|
|||||||
0
|
0
|
||||||
].embedding_model_provider
|
].embedding_model_provider
|
||||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||||
|
dataset_count = len(available_datasets)
|
||||||
with measure_time() as timer:
|
with measure_time() as timer:
|
||||||
cancel_event = threading.Event()
|
cancel_event = threading.Event()
|
||||||
thread_exceptions: list[Exception] = []
|
thread_exceptions: list[Exception] = []
|
||||||
@ -537,6 +538,7 @@ class DatasetRetrieval:
|
|||||||
"score_threshold": score_threshold,
|
"score_threshold": score_threshold,
|
||||||
"query": query,
|
"query": query,
|
||||||
"attachment_id": None,
|
"attachment_id": None,
|
||||||
|
"dataset_count": dataset_count,
|
||||||
"cancel_event": cancel_event,
|
"cancel_event": cancel_event,
|
||||||
"thread_exceptions": thread_exceptions,
|
"thread_exceptions": thread_exceptions,
|
||||||
},
|
},
|
||||||
@ -562,6 +564,7 @@ class DatasetRetrieval:
|
|||||||
"score_threshold": score_threshold,
|
"score_threshold": score_threshold,
|
||||||
"query": None,
|
"query": None,
|
||||||
"attachment_id": attachment_id,
|
"attachment_id": attachment_id,
|
||||||
|
"dataset_count": dataset_count,
|
||||||
"cancel_event": cancel_event,
|
"cancel_event": cancel_event,
|
||||||
"thread_exceptions": thread_exceptions,
|
"thread_exceptions": thread_exceptions,
|
||||||
},
|
},
|
||||||
@ -1422,6 +1425,7 @@ class DatasetRetrieval:
|
|||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
query: str | None,
|
query: str | None,
|
||||||
attachment_id: str | None,
|
attachment_id: str | None,
|
||||||
|
dataset_count: int,
|
||||||
cancel_event: threading.Event | None = None,
|
cancel_event: threading.Event | None = None,
|
||||||
thread_exceptions: list[Exception] | None = None,
|
thread_exceptions: list[Exception] | None = None,
|
||||||
):
|
):
|
||||||
@ -1470,7 +1474,8 @@ class DatasetRetrieval:
|
|||||||
if cancel_event and cancel_event.is_set():
|
if cancel_event and cancel_event.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
if reranking_enable:
|
# Skip second reranking when there is only one dataset
|
||||||
|
if reranking_enable and dataset_count > 1:
|
||||||
# do rerank for searched documents
|
# do rerank for searched documents
|
||||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||||
if query:
|
if query:
|
||||||
|
|||||||
@ -73,6 +73,7 @@ import pytest
|
|||||||
|
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
@ -1518,6 +1519,282 @@ class TestRetrievalService:
|
|||||||
call_kwargs = mock_retrieve.call_args.kwargs
|
call_kwargs = mock_retrieve.call_args.kwargs
|
||||||
assert call_kwargs["reranking_model"] == reranking_model
|
assert call_kwargs["reranking_model"] == reranking_model
|
||||||
|
|
||||||
|
# ==================== Multiple Retrieve Thread Tests ====================
|
||||||
|
|
||||||
|
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||||
|
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
|
||||||
|
def test_multiple_retrieve_thread_skips_second_reranking_with_single_dataset(
|
||||||
|
self, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that _multiple_retrieve_thread skips second reranking when dataset_count is 1.
|
||||||
|
|
||||||
|
When there is only one dataset, the second reranking is unnecessary
|
||||||
|
because the documents are already ranked from the first retrieval.
|
||||||
|
This optimization avoids the overhead of reranking when it won't
|
||||||
|
provide any benefit.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- DataPostProcessor is NOT called when dataset_count == 1
|
||||||
|
- Documents are still added to all_documents
|
||||||
|
- Standard scoring logic is applied instead
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
dataset_retrieval = DatasetRetrieval()
|
||||||
|
tenant_id = str(uuid4())
|
||||||
|
|
||||||
|
# Create test documents
|
||||||
|
doc1 = Document(
|
||||||
|
page_content="Test content 1",
|
||||||
|
metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||||
|
provider="dify",
|
||||||
|
)
|
||||||
|
doc2 = Document(
|
||||||
|
page_content="Test content 2",
|
||||||
|
metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||||
|
provider="dify",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock _retriever to return documents
|
||||||
|
def side_effect_retriever(
|
||||||
|
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||||
|
):
|
||||||
|
all_documents.extend([doc1, doc2])
|
||||||
|
|
||||||
|
mock_retriever.side_effect = side_effect_retriever
|
||||||
|
|
||||||
|
# Set up dataset with high_quality indexing
|
||||||
|
mock_dataset.indexing_technique = "high_quality"
|
||||||
|
|
||||||
|
all_documents = []
|
||||||
|
|
||||||
|
# Act - Call with dataset_count = 1
|
||||||
|
dataset_retrieval._multiple_retrieve_thread(
|
||||||
|
flask_app=mock_flask_app,
|
||||||
|
available_datasets=[mock_dataset],
|
||||||
|
metadata_condition=None,
|
||||||
|
metadata_filter_document_ids=None,
|
||||||
|
all_documents=all_documents,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
reranking_enable=True,
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||||
|
weights=None,
|
||||||
|
top_k=5,
|
||||||
|
score_threshold=0.5,
|
||||||
|
query="test query",
|
||||||
|
attachment_id=None,
|
||||||
|
dataset_count=1, # Single dataset - should skip second reranking
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# DataPostProcessor should NOT be called (second reranking skipped)
|
||||||
|
mock_data_processor_class.assert_not_called()
|
||||||
|
|
||||||
|
# Documents should still be added to all_documents
|
||||||
|
assert len(all_documents) == 2
|
||||||
|
assert all_documents[0].page_content == "Test content 1"
|
||||||
|
assert all_documents[1].page_content == "Test content 2"
|
||||||
|
|
||||||
|
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||||
|
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
|
||||||
|
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score")
|
||||||
|
def test_multiple_retrieve_thread_performs_second_reranking_with_multiple_datasets(
|
||||||
|
self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that _multiple_retrieve_thread performs second reranking when dataset_count > 1.
|
||||||
|
|
||||||
|
When there are multiple datasets, the second reranking is necessary
|
||||||
|
to merge and re-rank results from different datasets. This ensures
|
||||||
|
the most relevant documents across all datasets are returned.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- DataPostProcessor IS called when dataset_count > 1
|
||||||
|
- Reranking is applied with correct parameters
|
||||||
|
- Documents are processed correctly
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
dataset_retrieval = DatasetRetrieval()
|
||||||
|
tenant_id = str(uuid4())
|
||||||
|
|
||||||
|
# Create test documents
|
||||||
|
doc1 = Document(
|
||||||
|
page_content="Test content 1",
|
||||||
|
metadata={"doc_id": "doc1", "score": 0.7, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||||
|
provider="dify",
|
||||||
|
)
|
||||||
|
doc2 = Document(
|
||||||
|
page_content="Test content 2",
|
||||||
|
metadata={"doc_id": "doc2", "score": 0.6, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||||
|
provider="dify",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock _retriever to return documents
|
||||||
|
def side_effect_retriever(
|
||||||
|
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||||
|
):
|
||||||
|
all_documents.extend([doc1, doc2])
|
||||||
|
|
||||||
|
mock_retriever.side_effect = side_effect_retriever
|
||||||
|
|
||||||
|
# Set up dataset with high_quality indexing
|
||||||
|
mock_dataset.indexing_technique = "high_quality"
|
||||||
|
|
||||||
|
# Mock DataPostProcessor instance and its invoke method
|
||||||
|
mock_processor_instance = Mock()
|
||||||
|
# Simulate reranking - return documents in reversed order with updated scores
|
||||||
|
reranked_docs = [
|
||||||
|
Document(
|
||||||
|
page_content="Test content 2",
|
||||||
|
metadata={"doc_id": "doc2", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||||
|
provider="dify",
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content="Test content 1",
|
||||||
|
metadata={"doc_id": "doc1", "score": 0.85, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||||
|
provider="dify",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mock_processor_instance.invoke.return_value = reranked_docs
|
||||||
|
mock_data_processor_class.return_value = mock_processor_instance
|
||||||
|
|
||||||
|
all_documents = []
|
||||||
|
|
||||||
|
# Create second dataset
|
||||||
|
mock_dataset2 = Mock(spec=Dataset)
|
||||||
|
mock_dataset2.id = str(uuid4())
|
||||||
|
mock_dataset2.indexing_technique = "high_quality"
|
||||||
|
mock_dataset2.provider = "dify"
|
||||||
|
|
||||||
|
# Act - Call with dataset_count = 2
|
||||||
|
dataset_retrieval._multiple_retrieve_thread(
|
||||||
|
flask_app=mock_flask_app,
|
||||||
|
available_datasets=[mock_dataset, mock_dataset2],
|
||||||
|
metadata_condition=None,
|
||||||
|
metadata_filter_document_ids=None,
|
||||||
|
all_documents=all_documents,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
reranking_enable=True,
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||||
|
weights=None,
|
||||||
|
top_k=5,
|
||||||
|
score_threshold=0.5,
|
||||||
|
query="test query",
|
||||||
|
attachment_id=None,
|
||||||
|
dataset_count=2, # Multiple datasets - should perform second reranking
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# DataPostProcessor SHOULD be called (second reranking performed)
|
||||||
|
mock_data_processor_class.assert_called_once_with(
|
||||||
|
tenant_id,
|
||||||
|
"reranking_model",
|
||||||
|
{"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify invoke was called with correct parameters
|
||||||
|
mock_processor_instance.invoke.assert_called_once()
|
||||||
|
|
||||||
|
# Documents should be added to all_documents after reranking
|
||||||
|
assert len(all_documents) == 2
|
||||||
|
# The reranked order should be reflected
|
||||||
|
assert all_documents[0].page_content == "Test content 2"
|
||||||
|
assert all_documents[1].page_content == "Test content 1"
|
||||||
|
|
||||||
|
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||||
|
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
|
||||||
|
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score")
|
||||||
|
def test_multiple_retrieve_thread_single_dataset_uses_standard_scoring(
|
||||||
|
self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that _multiple_retrieve_thread uses standard scoring when dataset_count is 1
|
||||||
|
and reranking is enabled.
|
||||||
|
|
||||||
|
When there's only one dataset, instead of using DataPostProcessor,
|
||||||
|
the method should fall through to the standard scoring logic
|
||||||
|
(calculate_vector_score for high_quality datasets).
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- DataPostProcessor is NOT called
|
||||||
|
- calculate_vector_score IS called for high_quality indexing
|
||||||
|
- Documents are scored correctly
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
dataset_retrieval = DatasetRetrieval()
|
||||||
|
tenant_id = str(uuid4())
|
||||||
|
|
||||||
|
# Create test documents
|
||||||
|
doc1 = Document(
|
||||||
|
page_content="Test content 1",
|
||||||
|
metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||||
|
provider="dify",
|
||||||
|
)
|
||||||
|
doc2 = Document(
|
||||||
|
page_content="Test content 2",
|
||||||
|
metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||||
|
provider="dify",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock _retriever to return documents
|
||||||
|
def side_effect_retriever(
|
||||||
|
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||||
|
):
|
||||||
|
all_documents.extend([doc1, doc2])
|
||||||
|
|
||||||
|
mock_retriever.side_effect = side_effect_retriever
|
||||||
|
|
||||||
|
# Set up dataset with high_quality indexing
|
||||||
|
mock_dataset.indexing_technique = "high_quality"
|
||||||
|
|
||||||
|
# Mock calculate_vector_score to return scored documents
|
||||||
|
scored_docs = [
|
||||||
|
Document(
|
||||||
|
page_content="Test content 1",
|
||||||
|
metadata={"doc_id": "doc1", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||||
|
provider="dify",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mock_calculate_vector_score.return_value = scored_docs
|
||||||
|
|
||||||
|
all_documents = []
|
||||||
|
|
||||||
|
# Act - Call with dataset_count = 1
|
||||||
|
dataset_retrieval._multiple_retrieve_thread(
|
||||||
|
flask_app=mock_flask_app,
|
||||||
|
available_datasets=[mock_dataset],
|
||||||
|
metadata_condition=None,
|
||||||
|
metadata_filter_document_ids=None,
|
||||||
|
all_documents=all_documents,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
reranking_enable=True, # Reranking enabled but should be skipped for single dataset
|
||||||
|
reranking_mode="reranking_model",
|
||||||
|
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||||
|
weights=None,
|
||||||
|
top_k=5,
|
||||||
|
score_threshold=0.5,
|
||||||
|
query="test query",
|
||||||
|
attachment_id=None,
|
||||||
|
dataset_count=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# DataPostProcessor should NOT be called
|
||||||
|
mock_data_processor_class.assert_not_called()
|
||||||
|
|
||||||
|
# calculate_vector_score SHOULD be called for high_quality datasets
|
||||||
|
mock_calculate_vector_score.assert_called_once()
|
||||||
|
call_args = mock_calculate_vector_score.call_args
|
||||||
|
assert call_args[0][1] == 5 # top_k
|
||||||
|
|
||||||
|
# Documents should be added after standard scoring
|
||||||
|
assert len(all_documents) == 1
|
||||||
|
assert all_documents[0].page_content == "Test content 1"
|
||||||
|
|
||||||
|
|
||||||
class TestRetrievalMethods:
|
class TestRetrievalMethods:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user