fix: When hit-testing, an empty document dict is returned due to DocumentSegment type modification.

This commit is contained in:
FFXN
2026-05-13 14:06:27 +08:00
parent 0a3bb67778
commit fc0a4a6b56
4 changed files with 131 additions and 14 deletions

View File

@ -39,11 +39,8 @@ class HitTestingPayload(BaseModel):
class DatasetsHitTestingBase:
@staticmethod
def _normalize_hit_testing_query(query: Any) -> str:
"""Return the user-visible query string from legacy and current response shapes."""
if isinstance(query, str):
return query
def _extract_hit_testing_query(query: Any) -> str:
"""Return the query string from the service response shape."""
if isinstance(query, dict):
content = query.get("content")
if isinstance(content, str):
@ -52,15 +49,15 @@ class DatasetsHitTestingBase:
raise ValueError("Invalid hit testing query response")
@staticmethod
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
"""Coerce nullable collection fields into lists before response validation."""
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
"""Ensure collection fields match the API schema before response validation."""
if not isinstance(records, list):
return []
raise ValueError("Invalid hit testing records response")
normalized_records: list[dict[str, Any]] = []
for record in records:
if not isinstance(record, dict):
continue
raise ValueError("Invalid hit testing record response")
normalized_record = dict(record)
segment = normalized_record.get("segment")
@ -118,8 +115,8 @@ class DatasetsHitTestingBase:
limit=10,
)
return {
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
marshal(response.get("records", []), hit_testing_record_fields)
),
}

View File

@ -3,6 +3,8 @@ import logging
import time
from typing import Any, TypedDict, cast
from sqlalchemy import select
from core.app.app_config.entities import ModelConfig
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.index_processor.constant.query_type import QueryType
@ -13,6 +15,7 @@ from extensions.ext_database import db
from graphon.model_runtime.entities import LLMMode
from models import Account
from models.dataset import Dataset, DatasetQuery
from models.dataset import Document as DatasetDocument
from models.enums import CreatorUserRole, DatasetQuerySource
logger = logging.getLogger(__name__)
@ -41,6 +44,59 @@ class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False):
class HitTestingService:
@staticmethod
def _dump_dataset_document(document: DatasetDocument) -> dict[str, Any]:
return {
"id": document.id,
"data_source_type": document.data_source_type,
"name": document.name,
"doc_type": document.doc_type,
"doc_metadata": document.doc_metadata,
}
@classmethod
def _dump_retrieval_records(cls, records: list[Any]) -> list[dict[str, Any]]:
dumped_records = [record.model_dump() for record in records]
document_ids = {
segment.get("document_id")
for record in dumped_records
if isinstance(record, dict)
for segment in [record.get("segment")]
if isinstance(segment, dict) and segment.get("document_id")
}
if not document_ids:
return dumped_records
documents = {
document.id: cls._dump_dataset_document(document)
for document in db.session.scalars(
select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
).all()
}
records_with_documents: list[dict[str, Any]] = []
missing_document_ids: set[str] = set()
for record in dumped_records:
segment = record.get("segment")
if not isinstance(segment, dict):
records_with_documents.append(record)
continue
document_id = segment.get("document_id")
if document_id in documents:
segment["document"] = documents[document_id]
records_with_documents.append(record)
elif document_id:
missing_document_ids.add(document_id)
if missing_document_ids:
logger.warning(
"Skipping hit-testing records with missing documents, document_ids=%s",
sorted(missing_document_ids),
)
return records_with_documents
@classmethod
def retrieve(
cls,
@ -174,7 +230,7 @@ class HitTestingService:
"query": {
"content": query,
},
"records": [record.model_dump() for record in records],
"records": cls._dump_retrieval_records(records),
}
@classmethod

View File

@ -120,7 +120,7 @@ class TestParseArgs:
class TestPerformHitTesting:
def test_success(self, dataset):
response = {
"query": "hello",
"query": {"content": "hello"},
"records": [],
}
@ -134,7 +134,7 @@ class TestPerformHitTesting:
assert result["query"] == "hello"
assert result["records"] == []
def test_success_normalizes_legacy_query_and_nullable_list_fields(self, dataset):
def test_success_prepares_nullable_list_fields(self, dataset):
response = {
"query": {"content": "hello"},
"records": [

View File

@ -574,6 +574,70 @@ class TestHitTestingServiceCompactRetrieveResponse:
assert result["records"][0]["score"] == 0.95
mock_format.assert_called_once_with(documents)
def test_compact_retrieve_response_includes_segment_document(self):
query = "test query"
documents = [HitTestingTestDataFactory.create_document_mock(content="Doc 1")]
mock_record = Mock()
mock_record.model_dump.return_value = {
"segment": {
"id": "segment-1",
"document_id": "document-1",
},
"score": 0.95,
}
dataset_document = Mock()
dataset_document.id = "document-1"
dataset_document.data_source_type = "upload_file"
dataset_document.name = "guide.md"
dataset_document.doc_type = None
dataset_document.doc_metadata = {"source": "manual"}
scalars_result = Mock()
scalars_result.all.return_value = [dataset_document]
with (
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result),
):
mock_format.return_value = [mock_record]
result = HitTestingService.compact_retrieve_response(query, documents)
assert result["records"][0]["segment"]["document"] == {
"id": "document-1",
"data_source_type": "upload_file",
"name": "guide.md",
"doc_type": None,
"doc_metadata": {"source": "manual"},
}
def test_compact_retrieve_response_skips_records_with_missing_document(self):
query = "test query"
documents = [HitTestingTestDataFactory.create_document_mock(content="Doc 1")]
mock_record = Mock()
mock_record.model_dump.return_value = {
"segment": {
"id": "segment-1",
"document_id": "missing-document",
},
"score": 0.95,
}
scalars_result = Mock()
scalars_result.all.return_value = []
with (
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result),
):
mock_format.return_value = [mock_record]
result = HitTestingService.compact_retrieve_response(query, documents)
assert result["records"] == []
def test_compact_retrieve_response_empty_documents(self):
"""
Test response formatting with empty document list.