From fc0a4a6b568d03c7e8b72ee80da42dc806541aad Mon Sep 17 00:00:00 2001 From: FFXN Date: Wed, 13 May 2026 14:06:27 +0800 Subject: [PATCH] fix: When hit-testing, an empty document dict is returned due to DocumentSegment type modification. --- .../console/datasets/hit_testing_base.py | 19 +++--- api/services/hit_testing_service.py | 58 ++++++++++++++++- .../console/datasets/test_hit_testing_base.py | 4 +- api/tests/unit_tests/services/hit_service.py | 64 +++++++++++++++++++ 4 files changed, 131 insertions(+), 14 deletions(-) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 71ab1513ed..bb725a5f6c 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -39,11 +39,8 @@ class HitTestingPayload(BaseModel): class DatasetsHitTestingBase: @staticmethod - def _normalize_hit_testing_query(query: Any) -> str: - """Return the user-visible query string from legacy and current response shapes.""" - if isinstance(query, str): - return query - + def _extract_hit_testing_query(query: Any) -> str: + """Return the query string from the service response shape.""" if isinstance(query, dict): content = query.get("content") if isinstance(content, str): @@ -52,15 +49,15 @@ class DatasetsHitTestingBase: raise ValueError("Invalid hit testing query response") @staticmethod - def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]: - """Coerce nullable collection fields into lists before response validation.""" + def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]: + """Ensure collection fields match the API schema before response validation.""" if not isinstance(records, list): - return [] + raise ValueError("Invalid hit testing records response") normalized_records: list[dict[str, Any]] = [] for record in records: if not isinstance(record, dict): - continue + raise ValueError("Invalid hit testing record response") normalized_record = dict(record) segment = normalized_record.get("segment") @@ -118,8 +115,8 @@ class DatasetsHitTestingBase: limit=10, ) return { - "query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")), - "records": DatasetsHitTestingBase._normalize_hit_testing_records( + "query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")), + "records": DatasetsHitTestingBase._prepare_hit_testing_records( marshal(response.get("records", []), hit_testing_record_fields) ), } diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 2e5987dd28..42c531ae48 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -3,6 +3,8 @@ import logging import time from typing import Any, TypedDict, cast +from sqlalchemy import select + from core.app.app_config.entities import ModelConfig from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.index_processor.constant.query_type import QueryType @@ -13,6 +15,7 @@ from extensions.ext_database import db from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery +from models.dataset import Document as DatasetDocument from models.enums import CreatorUserRole, DatasetQuerySource logger = logging.getLogger(__name__) @@ -41,6 +44,59 @@ class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False): class HitTestingService: + @staticmethod + def _dump_dataset_document(document: DatasetDocument) -> dict[str, Any]: + return { + "id": document.id, + "data_source_type": document.data_source_type, + "name": document.name, + "doc_type": document.doc_type, + "doc_metadata": document.doc_metadata, + } + + @classmethod + def _dump_retrieval_records(cls, records: list[Any]) -> list[dict[str, Any]]: + dumped_records = [record.model_dump() for record in records] + document_ids = { + segment.get("document_id") + for record in dumped_records + if isinstance(record, dict) + for segment in [record.get("segment")] + if isinstance(segment, dict) and segment.get("document_id") + } + if not document_ids: + return dumped_records + + documents = { + document.id: cls._dump_dataset_document(document) + for document in db.session.scalars( + select(DatasetDocument).where(DatasetDocument.id.in_(document_ids)) + ).all() + } + + records_with_documents: list[dict[str, Any]] = [] + missing_document_ids: set[str] = set() + for record in dumped_records: + segment = record.get("segment") + if not isinstance(segment, dict): + records_with_documents.append(record) + continue + + document_id = segment.get("document_id") + if document_id in documents: + segment["document"] = documents[document_id] + records_with_documents.append(record) + elif document_id: + missing_document_ids.add(document_id) + + if missing_document_ids: + logger.warning( + "Skipping hit-testing records with missing documents, document_ids=%s", + sorted(missing_document_ids), + ) + + return records_with_documents + @classmethod def retrieve( cls, @@ -174,7 +230,7 @@ class HitTestingService: "query": { "content": query, }, - "records": [record.model_dump() for record in records], + "records": cls._dump_retrieval_records(records), } @classmethod diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index d29b34beb2..d2b1e5b735 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -120,7 +120,7 @@ class TestParseArgs: class TestPerformHitTesting: def test_success(self, dataset): response = { - "query": "hello", + "query": {"content": "hello"}, "records": [], } @@ -134,7 +134,7 @@ class TestPerformHitTesting: assert result["query"] == "hello" assert result["records"] == [] - def test_success_normalizes_legacy_query_and_nullable_list_fields(self, dataset): + def test_success_prepares_nullable_list_fields(self, dataset): response = { "query": {"content": "hello"}, "records": [ diff --git a/api/tests/unit_tests/services/hit_service.py b/api/tests/unit_tests/services/hit_service.py index ddbc7dc041..48722622bb 100644 --- a/api/tests/unit_tests/services/hit_service.py +++ b/api/tests/unit_tests/services/hit_service.py @@ -574,6 +574,70 @@ class TestHitTestingServiceCompactRetrieveResponse: assert result["records"][0]["score"] == 0.95 mock_format.assert_called_once_with(documents) + def test_compact_retrieve_response_includes_segment_document(self): + query = "test query" + documents = [HitTestingTestDataFactory.create_document_mock(content="Doc 1")] + mock_record = Mock() + mock_record.model_dump.return_value = { + "segment": { + "id": "segment-1", + "document_id": "document-1", + }, + "score": 0.95, + } + dataset_document = Mock() + dataset_document.id = "document-1" + dataset_document.data_source_type = "upload_file" + dataset_document.name = "guide.md" + dataset_document.doc_type = None + dataset_document.doc_metadata = {"source": "manual"} + scalars_result = Mock() + scalars_result.all.return_value = [dataset_document] + + with ( + patch( + "services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True + ) as mock_format, + patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result), + ): + mock_format.return_value = [mock_record] + + result = HitTestingService.compact_retrieve_response(query, documents) + + assert result["records"][0]["segment"]["document"] == { + "id": "document-1", + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": {"source": "manual"}, + } + + def test_compact_retrieve_response_skips_records_with_missing_document(self): + query = "test query" + documents = [HitTestingTestDataFactory.create_document_mock(content="Doc 1")] + mock_record = Mock() + mock_record.model_dump.return_value = { + "segment": { + "id": "segment-1", + "document_id": "missing-document", + }, + "score": 0.95, + } + scalars_result = Mock() + scalars_result.all.return_value = [] + + with ( + patch( + "services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True + ) as mock_format, + patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result), + ): + mock_format.return_value = [mock_record] + + result = HitTestingService.compact_retrieve_response(query, documents) + + assert result["records"] == [] + def test_compact_retrieve_response_empty_documents(self): """ Test response formatting with empty document list.