From b034449a0cedaf9880f36bbbbe449ac8a09e83df Mon Sep 17 00:00:00 2001 From: chariri Date: Wed, 27 May 2026 15:51:42 +0900 Subject: [PATCH] refactor(api): migrate console/service_api.dataset.hit_testing to BaseModel (#36533) --- .../console/datasets/hit_testing.py | 98 +----------- .../console/datasets/hit_testing_base.py | 41 +++-- .../service_api/dataset/hit_testing.py | 23 +-- api/fields/hit_testing_fields.py | 146 +++++++++++------- api/openapi/markdown/console-swagger.md | 98 ++++++------ api/openapi/markdown/service-swagger.md | 102 ++++++++++-- .../console/datasets/test_hit_testing.py | 85 ++++++---- .../console/datasets/test_hit_testing_base.py | 100 ++++++++---- .../service_api/dataset/test_hit_testing.py | 116 ++++++++------ .../api/console/datasets/types.gen.ts | 96 ++++++------ .../generated/api/console/datasets/zod.gen.ts | 99 ++++++------ .../generated/api/service/types.gen.ts | 76 ++++++++- .../generated/api/service/zod.gen.ts | 93 ++++++++++- 13 files changed, 732 insertions(+), 441 deletions(-) diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 110a2e16f5..37640138eb 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,15 +1,12 @@ from __future__ import annotations -from datetime import datetime -from typing import Any from uuid import UUID from flask_restx import Resource -from pydantic import Field, field_validator -from controllers.common.schema import register_schema_models -from fields.base import ResponseModel -from libs.helper import to_timestamp +from controllers.common.schema import register_response_schema_models, register_schema_models +from fields.hit_testing_fields import HitTestingResponse +from libs.helper import dump_response from libs.login import login_required from .. import console_ns @@ -20,86 +17,8 @@ from ..wraps import ( setup_required, ) - -class HitTestingDocument(ResponseModel): - id: str | None = None - data_source_type: str | None = None - name: str | None = None - doc_type: str | None = None - doc_metadata: Any | None = None - - -class HitTestingSegment(ResponseModel): - id: str | None = None - position: int | None = None - document_id: str | None = None - content: str | None = None - sign_content: str | None = None - answer: str | None = None - word_count: int | None = None - tokens: int | None = None - keywords: list[str] = Field(default_factory=list) - index_node_id: str | None = None - index_node_hash: str | None = None - hit_count: int | None = None - enabled: bool | None = None - disabled_at: int | None = None - disabled_by: str | None = None - status: str | None = None - created_by: str | None = None - created_at: int | None = None - indexing_at: int | None = None - completed_at: int | None = None - error: str | None = None - stopped_at: int | None = None - document: HitTestingDocument | None = None - - @field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before") - @classmethod - def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return to_timestamp(value) - - -class HitTestingChildChunk(ResponseModel): - id: str | None = None - content: str | None = None - position: int | None = None - score: float | None = None - - -class HitTestingFile(ResponseModel): - id: str | None = None - name: str | None = None - size: int | None = None - extension: str | None = None - mime_type: str | None = None - source_url: str | None = None - - -class HitTestingRecord(ResponseModel): - segment: HitTestingSegment | None = None - child_chunks: list[HitTestingChildChunk] = Field(default_factory=list) - score: float | None = None - tsne_position: Any | None = None - files: list[HitTestingFile] = Field(default_factory=list) - summary: str | None = None - - -class HitTestingResponse(ResponseModel): - query: str - records: list[HitTestingRecord] = Field(default_factory=list) - - -register_schema_models( - console_ns, - HitTestingPayload, - HitTestingDocument, - HitTestingSegment, - HitTestingChildChunk, - HitTestingFile, - HitTestingRecord, - HitTestingResponse, -) +register_schema_models(console_ns, HitTestingPayload) +register_response_schema_models(console_ns, HitTestingResponse) @console_ns.route("/datasets//hit-testing") @@ -119,12 +38,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def post(self, dataset_id: UUID): + def post(self, dataset_id: UUID) -> dict[str, object]: dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) - payload = HitTestingPayload.model_validate(console_ns.payload or {}) - args = payload.model_dump(exclude_none=True) + args = self.parse_args(console_ns.payload) self.hit_testing_args_check(args) - return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json") + return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args)) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index bb725a5f6c..4be91e0e54 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,6 @@ import logging -from typing import Any +from typing import Any, cast -from flask_restx import marshal from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -19,10 +18,10 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from fields.hit_testing_fields import hit_testing_record_fields from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account +from models.dataset import Dataset from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.hit_testing_service import HitTestingService @@ -38,16 +37,6 @@ class HitTestingPayload(BaseModel): class DatasetsHitTestingBase: - @staticmethod - def _extract_hit_testing_query(query: Any) -> str: - """Return the query string from the service response shape.""" - if isinstance(query, dict): - content = query.get("content") - if isinstance(content, str): - return content - - raise ValueError("Invalid hit testing query response") - @staticmethod def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]: """Ensure collection fields match the API schema before response validation.""" @@ -63,6 +52,7 @@ class DatasetsHitTestingBase: segment = normalized_record.get("segment") if isinstance(segment, dict): normalized_segment = dict(segment) + normalized_segment.setdefault("sign_content", None) if normalized_segment.get("keywords") is None: normalized_segment["keywords"] = [] normalized_record["segment"] = normalized_segment @@ -73,12 +63,15 @@ class DatasetsHitTestingBase: if normalized_record.get("files") is None: normalized_record["files"] = [] + normalized_record.setdefault("tsne_position", None) + normalized_record.setdefault("summary", None) + normalized_records.append(normalized_record) return normalized_records @staticmethod - def get_and_validate_dataset(dataset_id: str): + def get_and_validate_dataset(dataset_id: str) -> Dataset: assert isinstance(current_user, Account) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: @@ -92,33 +85,35 @@ class DatasetsHitTestingBase: return dataset @staticmethod - def hit_testing_args_check(args: dict[str, Any]): + def hit_testing_args_check(args: dict[str, Any]) -> None: HitTestingService.hit_testing_args_check(args) @staticmethod - def parse_args(payload: dict[str, Any]) -> dict[str, Any]: + def parse_args(payload: dict[str, Any] | None) -> dict[str, Any]: """Validate and return hit-testing arguments from an incoming payload.""" hit_testing_payload = HitTestingPayload.model_validate(payload or {}) return hit_testing_payload.model_dump(exclude_none=True) @staticmethod - def perform_hit_testing(dataset, args): + def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]: assert isinstance(current_user, Account) try: response = HitTestingService.retrieve( dataset=dataset, - query=args.get("query"), + query=cast(str, args.get("query")), account=current_user, retrieval_model=args.get("retrieval_model"), - external_retrieval_model=args.get("external_retrieval_model"), + external_retrieval_model=cast(dict[str, Any], args.get("external_retrieval_model")), attachment_ids=args.get("attachment_ids"), limit=10, ) + query = response.get("query") + if not isinstance(query, dict) or not isinstance(query.get("content"), str): + raise ValueError("Invalid hit testing query response") + return { - "query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")), - "records": DatasetsHitTestingBase._prepare_hit_testing_records( - marshal(response.get("records", []), hit_testing_record_fields) - ), + "query": {"content": query["content"]}, + "records": DatasetsHitTestingBase._prepare_hit_testing_records(response.get("records", [])), } except services.errors.index.IndexNotInitializedError: raise DatasetNotInitializedError() diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index ba914c4dd4..55a1c47c42 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,11 +1,14 @@ from uuid import UUID -from controllers.common.schema import register_schema_model +from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +from fields.hit_testing_fields import HitTestingResponse +from libs.helper import dump_response -register_schema_model(service_api_ns, HitTestingPayload) +register_schema_models(service_api_ns, HitTestingPayload) +register_response_schema_models(service_api_ns, HitTestingResponse) @service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve") @@ -13,16 +16,16 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): @service_api_ns.doc("dataset_hit_testing") @service_api_ns.doc(description="Perform hit testing on a dataset") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) - @service_api_ns.doc( - responses={ - 200: "Hit testing results", - 401: "Unauthorized - invalid API token", - 404: "Dataset not found", - } + @service_api_ns.response( + 200, + "Hit testing results", + model=service_api_ns.models[HitTestingResponse.__name__], ) + @service_api_ns.response(401, "Unauthorized - invalid API token") + @service_api_ns.response(404, "Dataset not found") @service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__]) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id: UUID): + def post(self, tenant_id: str, dataset_id: UUID) -> dict[str, object]: """Perform hit testing on a dataset. Tests retrieval performance for the specified dataset. @@ -33,4 +36,4 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): args = self.parse_args(service_api_ns.payload) self.hit_testing_args_check(args) - return self.perform_hit_testing(dataset, args) + return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args)) diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 0b54992835..dd7865afed 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,62 +1,96 @@ -from flask_restx import fields +from datetime import datetime +from typing import Any -from libs.helper import TimestampField +from pydantic import field_validator -document_fields = { - "id": fields.String, - "data_source_type": fields.String, - "name": fields.String, - "doc_type": fields.String, - "doc_metadata": fields.Raw, -} +from fields.base import ResponseModel +from libs.helper import to_timestamp -segment_fields = { - "id": fields.String, - "position": fields.Integer, - "document_id": fields.String, - "content": fields.String, - "sign_content": fields.String, - "answer": fields.String, - "word_count": fields.Integer, - "tokens": fields.Integer, - "keywords": fields.List(fields.String), - "index_node_id": fields.String, - "index_node_hash": fields.String, - "hit_count": fields.Integer, - "enabled": fields.Boolean, - "disabled_at": TimestampField, - "disabled_by": fields.String, - "status": fields.String, - "created_by": fields.String, - "created_at": TimestampField, - "indexing_at": TimestampField, - "completed_at": TimestampField, - "error": fields.String, - "stopped_at": TimestampField, - "document": fields.Nested(document_fields), -} -child_chunk_fields = { - "id": fields.String, - "content": fields.String, - "position": fields.Integer, - "score": fields.Float, -} +class HitTestingQuery(ResponseModel): + content: str -files_fields = { - "id": fields.String, - "name": fields.String, - "size": fields.Integer, - "extension": fields.String, - "mime_type": fields.String, - "source_url": fields.String, -} -hit_testing_record_fields = { - "segment": fields.Nested(segment_fields), - "child_chunks": fields.List(fields.Nested(child_chunk_fields)), - "score": fields.Float, - "tsne_position": fields.Raw, - "files": fields.List(fields.Nested(files_fields)), - "summary": fields.String, # Summary content if retrieved via summary index -} +class HitTestingDocument(ResponseModel): + id: str + data_source_type: str + name: str + doc_type: str | None + doc_metadata: Any | None + + @field_validator("data_source_type", "doc_type", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + +class HitTestingSegment(ResponseModel): + id: str + position: int + document_id: str + content: str + sign_content: str | None + answer: str | None + word_count: int + tokens: int + keywords: list[str] + index_node_id: str | None + index_node_hash: str | None + hit_count: int + enabled: bool + disabled_at: int | None + disabled_by: str | None + status: str + created_by: str + created_at: int + indexing_at: int | None + completed_at: int | None + error: str | None + stopped_at: int | None + document: HitTestingDocument + + @field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return to_timestamp(value) + + @field_validator("status", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + +class HitTestingChildChunk(ResponseModel): + id: str + content: str + position: int + score: float + + +class HitTestingFile(ResponseModel): + id: str + name: str + size: int + extension: str + mime_type: str + source_url: str + + +class HitTestingRecord(ResponseModel): + segment: HitTestingSegment + child_chunks: list[HitTestingChildChunk] + score: float | None + tsne_position: Any | None + files: list[HitTestingFile] + summary: str | None + + +class HitTestingResponse(ResponseModel): + query: HitTestingQuery + records: list[HitTestingRecord] + + +def _normalize_enum(value: Any) -> Any: + if isinstance(value, str) or value is None: + return value + return getattr(value, "value", value) diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index 194fd631c3..a2302be32c 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -13051,31 +13051,31 @@ Request payload for bulk downloading documents as a zip archive. | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| content | string | | No | -| id | string | | No | -| position | integer | | No | -| score | number | | No | +| content | string | | Yes | +| id | string | | Yes | +| position | integer | | Yes | +| score | number | | Yes | #### HitTestingDocument | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| data_source_type | string | | No | -| doc_metadata | | | No | -| doc_type | string | | No | -| id | string | | No | -| name | string | | No | +| data_source_type | string | | Yes | +| doc_metadata | | | Yes | +| doc_type | string | | Yes | +| id | string | | Yes | +| name | string | | Yes | #### HitTestingFile | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| extension | string | | No | -| id | string | | No | -| mime_type | string | | No | -| name | string | | No | -| size | integer | | No | -| source_url | string | | No | +| extension | string | | Yes | +| id | string | | Yes | +| mime_type | string | | Yes | +| name | string | | Yes | +| size | integer | | Yes | +| source_url | string | | Yes | #### HitTestingPayload @@ -13086,51 +13086,57 @@ Request payload for bulk downloading documents as a zip archive. | query | string | | Yes | | retrieval_model | [RetrievalModel](#retrievalmodel) | | No | +#### HitTestingQuery + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| content | string | | Yes | + #### HitTestingRecord | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| child_chunks | [ [HitTestingChildChunk](#hittestingchildchunk) ] | | No | -| files | [ [HitTestingFile](#hittestingfile) ] | | No | -| score | number | | No | -| segment | [HitTestingSegment](#hittestingsegment) | | No | -| summary | string | | No | -| tsne_position | | | No | +| child_chunks | [ [HitTestingChildChunk](#hittestingchildchunk) ] | | Yes | +| files | [ [HitTestingFile](#hittestingfile) ] | | Yes | +| score | number | | Yes | +| segment | [HitTestingSegment](#hittestingsegment) | | Yes | +| summary | string | | Yes | +| tsne_position | | | Yes | #### HitTestingResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| query | string | | Yes | -| records | [ [HitTestingRecord](#hittestingrecord) ] | | No | +| query | [HitTestingQuery](#hittestingquery) | | Yes | +| records | [ [HitTestingRecord](#hittestingrecord) ] | | Yes | #### HitTestingSegment | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| answer | string | | No | -| completed_at | integer | | No | -| content | string | | No | -| created_at | integer | | No | -| created_by | string | | No | -| disabled_at | integer | | No | -| disabled_by | string | | No | -| document | [HitTestingDocument](#hittestingdocument) | | No | -| document_id | string | | No | -| enabled | boolean | | No | -| error | string | | No | -| hit_count | integer | | No | -| id | string | | No | -| index_node_hash | string | | No | -| index_node_id | string | | No | -| indexing_at | integer | | No | -| keywords | [ string ] | | No | -| position | integer | | No | -| sign_content | string | | No | -| status | string | | No | -| stopped_at | integer | | No | -| tokens | integer | | No | -| word_count | integer | | No | +| answer | string | | Yes | +| completed_at | integer | | Yes | +| content | string | | Yes | +| created_at | integer | | Yes | +| created_by | string | | Yes | +| disabled_at | integer | | Yes | +| disabled_by | string | | Yes | +| document | [HitTestingDocument](#hittestingdocument) | | Yes | +| document_id | string | | Yes | +| enabled | boolean | | Yes | +| error | string | | Yes | +| hit_count | integer | | Yes | +| id | string | | Yes | +| index_node_hash | string | | Yes | +| index_node_id | string | | Yes | +| indexing_at | integer | | Yes | +| keywords | [ string ] | | Yes | +| position | integer | | Yes | +| sign_content | string | | Yes | +| status | string | | Yes | +| stopped_at | integer | | Yes | +| tokens | integer | | Yes | +| word_count | integer | | Yes | #### HumanInputContent diff --git a/api/openapi/markdown/service-swagger.md b/api/openapi/markdown/service-swagger.md index ee801b1b8e..071b1b526c 100644 --- a/api/openapi/markdown/service-swagger.md +++ b/api/openapi/markdown/service-swagger.md @@ -1363,11 +1363,11 @@ Tests retrieval performance for the specified dataset. ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Hit testing results | -| 401 | Unauthorized - invalid API token | -| 404 | Dataset not found | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Hit testing results | [HitTestingResponse](#hittestingresponse) | +| 401 | Unauthorized - invalid API token | | +| 404 | Dataset not found | | ### /datasets/{dataset_id}/metadata @@ -1614,11 +1614,11 @@ Tests retrieval performance for the specified dataset. ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Hit testing results | -| 401 | Unauthorized - invalid API token | -| 404 | Dataset not found | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Hit testing results | [HitTestingResponse](#hittestingresponse) | +| 401 | Unauthorized - invalid API token | | +| 404 | Dataset not found | | ### /datasets/{dataset_id}/tags @@ -2691,6 +2691,36 @@ Note: The SQLAlchemy model defines an `is_anonymous` property for Flask-Login se | tenant_id | string | | No | | user_id | string | | No | +#### HitTestingChildChunk + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| content | string | | Yes | +| id | string | | Yes | +| position | integer | | Yes | +| score | number | | Yes | + +#### HitTestingDocument + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| data_source_type | string | | Yes | +| doc_metadata | | | Yes | +| doc_type | string | | Yes | +| id | string | | Yes | +| name | string | | Yes | + +#### HitTestingFile + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| extension | string | | Yes | +| id | string | | Yes | +| mime_type | string | | Yes | +| name | string | | Yes | +| size | integer | | Yes | +| source_url | string | | Yes | + #### HitTestingPayload | Name | Type | Description | Required | @@ -2700,6 +2730,58 @@ Note: The SQLAlchemy model defines an `is_anonymous` property for Flask-Login se | query | string | | Yes | | retrieval_model | [RetrievalModel](#retrievalmodel) | | No | +#### HitTestingQuery + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| content | string | | Yes | + +#### HitTestingRecord + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| child_chunks | [ [HitTestingChildChunk](#hittestingchildchunk) ] | | Yes | +| files | [ [HitTestingFile](#hittestingfile) ] | | Yes | +| score | number | | Yes | +| segment | [HitTestingSegment](#hittestingsegment) | | Yes | +| summary | string | | Yes | +| tsne_position | | | Yes | + +#### HitTestingResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| query | [HitTestingQuery](#hittestingquery) | | Yes | +| records | [ [HitTestingRecord](#hittestingrecord) ] | | Yes | + +#### HitTestingSegment + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| answer | string | | Yes | +| completed_at | integer | | Yes | +| content | string | | Yes | +| created_at | integer | | Yes | +| created_by | string | | Yes | +| disabled_at | integer | | Yes | +| disabled_by | string | | Yes | +| document | [HitTestingDocument](#hittestingdocument) | | Yes | +| document_id | string | | Yes | +| enabled | boolean | | Yes | +| error | string | | Yes | +| hit_count | integer | | Yes | +| id | string | | Yes | +| index_node_hash | string | | Yes | +| index_node_id | string | | Yes | +| indexing_at | integer | | Yes | +| keywords | [ string ] | | Yes | +| position | integer | | Yes | +| sign_content | string | | Yes | +| status | string | | Yes | +| stopped_at | integer | | Yes | +| tokens | integer | | Yes | +| word_count | integer | | Yes | + #### HumanInputFormSubmitPayload | Name | Type | Description | Required | diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py index 4fa5d21493..faedd4d7e1 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -1,5 +1,5 @@ import uuid -from unittest.mock import MagicMock, PropertyMock, patch +from unittest.mock import PropertyMock, patch import pytest from flask import Flask @@ -8,7 +8,7 @@ from werkzeug.exceptions import NotFound from controllers.console import console_ns from controllers.console.datasets.hit_testing import HitTestingApi -from controllers.console.datasets.hit_testing_base import HitTestingPayload +from models.dataset import Dataset def unwrap(func): @@ -32,7 +32,48 @@ def dataset_id(): @pytest.fixture def dataset(): - return MagicMock(id="dataset-1") + return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1") + + +def hit_testing_record() -> dict[str, object]: + return { + "segment": { + "id": "segment-1", + "position": 1, + "document_id": "document-1", + "content": "Chunk text", + "sign_content": "Chunk text", + "answer": None, + "word_count": 2, + "tokens": 3, + "keywords": [], + "index_node_id": None, + "index_node_hash": None, + "hit_count": 0, + "enabled": True, + "disabled_at": None, + "disabled_by": None, + "status": "completed", + "created_by": "account-1", + "created_at": 1_700_000_000, + "indexing_at": None, + "completed_at": None, + "error": None, + "stopped_at": None, + "document": { + "id": "document-1", + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": None, + }, + }, + "child_chunks": [], + "score": None, + "tsne_position": None, + "files": [], + "summary": None, + } @pytest.fixture(autouse=True) @@ -63,7 +104,6 @@ class TestHitTestingApi: payload = { "query": "what is vector search", - "top_k": 3, } with ( @@ -74,11 +114,6 @@ class TestHitTestingApi: new_callable=PropertyMock, return_value=payload, ), - patch.object( - HitTestingPayload, - "model_validate", - return_value=MagicMock(model_dump=lambda **_: payload), - ), patch.object( HitTestingApi, "get_and_validate_dataset", @@ -91,7 +126,7 @@ class TestHitTestingApi: patch.object( HitTestingApi, "perform_hit_testing", - return_value={"query": "what is vector search", "records": []}, + return_value={"query": {"content": "what is vector search"}, "records": []}, ), ): result = method(api, dataset_id) @@ -107,16 +142,7 @@ class TestHitTestingApi: payload = { "query": "what is vector search", } - records = [ - { - "segment": None, - "child_chunks": [], - "score": None, - "tsne_position": None, - "files": [], - "summary": None, - } - ] + records = [hit_testing_record()] with ( app.test_request_context("/"), @@ -126,11 +152,6 @@ class TestHitTestingApi: new_callable=PropertyMock, return_value=payload, ), - patch.object( - HitTestingPayload, - "model_validate", - return_value=MagicMock(model_dump=lambda **_: payload), - ), patch.object( HitTestingApi, "get_and_validate_dataset", @@ -143,13 +164,16 @@ class TestHitTestingApi: patch.object( HitTestingApi, "perform_hit_testing", - return_value={"query": payload["query"], "records": records}, + return_value={"query": {"content": payload["query"]}, "records": records}, ), ): result = method(api, dataset_id) - assert result["query"] == payload["query"] - assert result["records"] == records + assert result["query"] == {"content": payload["query"]} + assert result["records"][0]["segment"]["keywords"] == [] + assert result["records"][0]["child_chunks"] == [] + assert result["records"][0]["files"] == [] + assert result["records"][0]["score"] is None def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id): api = HitTestingApi() @@ -192,11 +216,6 @@ class TestHitTestingApi: new_callable=PropertyMock, return_value=payload, ), - patch.object( - HitTestingPayload, - "model_validate", - return_value=MagicMock(model_dump=lambda **_: payload), - ), patch.object( HitTestingApi, "get_and_validate_dataset", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index 77e9cfeb5b..072aa559df 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -22,6 +22,7 @@ from core.errors.error import ( ) from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account +from models.dataset import Dataset from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService @@ -43,7 +44,45 @@ def patch_current_user(mocker, account): @pytest.fixture def dataset(): - return MagicMock(id="dataset-1") + return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1") + + +def hit_testing_record() -> dict[str, object]: + return { + "segment": { + "id": "segment-1", + "position": 1, + "document_id": "document-1", + "content": "Chunk text", + "answer": None, + "word_count": 2, + "tokens": 3, + "keywords": None, + "index_node_id": None, + "index_node_hash": None, + "hit_count": 0, + "enabled": True, + "disabled_at": None, + "disabled_by": None, + "status": "completed", + "created_by": "account-1", + "created_at": 1_700_000_000, + "indexing_at": None, + "completed_at": None, + "error": None, + "stopped_at": None, + "document": { + "id": "document-1", + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": None, + }, + }, + "child_chunks": None, + "files": None, + "score": 0.8, + } class TestGetAndValidateDataset: @@ -116,6 +155,13 @@ class TestParseArgs: with pytest.raises(ValueError): DatasetsHitTestingBase.parse_args(payload) + def test_parse_args_ignores_unknown_fields_for_compatibility(self): + payload = {"query": "hello", "top_k": 3} + + result = DatasetsHitTestingBase.parse_args(payload) + + assert result == {"query": "hello"} + class TestPerformHitTesting: def test_success(self, dataset): @@ -131,48 +177,42 @@ class TestPerformHitTesting: ): result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) - assert result["query"] == "hello" + assert result["query"] == {"content": "hello"} assert result["records"] == [] def test_success_prepares_nullable_list_fields(self, dataset): response = { "query": {"content": "hello"}, - "records": [ - { - "segment": {"id": "segment-1", "keywords": None}, - "child_chunks": None, - "files": None, - "score": 0.8, - } - ], + "records": [hit_testing_record()], } + with patch.object( + HitTestingService, + "retrieve", + return_value=response, + ): + result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + assert result["query"] == {"content": "hello"} + record = result["records"][0] + assert record["segment"]["keywords"] == [] + assert record["segment"]["sign_content"] is None + assert record["child_chunks"] == [] + assert record["files"] == [] + assert record["score"] == 0.8 + assert record["tsne_position"] is None + assert record["summary"] is None + + def test_invalid_query_response_raises_value_error(self, dataset): with ( patch.object( HitTestingService, "retrieve", - return_value=response, - ), - patch( - "controllers.console.datasets.hit_testing_base.marshal", - return_value=response["records"], + return_value={"query": "hello", "records": []}, ), + pytest.raises(ValueError, match="Invalid hit testing query response"), ): - result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) - - assert result["query"] == "hello" - assert result["records"] == [ - { - "segment": {"id": "segment-1", "keywords": []}, - "child_chunks": [], - "files": [], - "score": 0.8, - } - ] - - def test_invalid_query_response_raises_value_error(self): - with pytest.raises(ValueError, match="Invalid hit testing query response"): - DatasetsHitTestingBase._extract_hit_testing_query("hello") + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) def test_invalid_records_response_raises_value_error(self): with pytest.raises(ValueError, match="Invalid hit testing records response"): diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py index 38fcb55fc0..4809cc0e8a 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py @@ -24,12 +24,53 @@ from werkzeug.exceptions import Forbidden, NotFound import services from controllers.service_api.dataset.hit_testing import HitTestingApi, HitTestingPayload from models.account import Account +from models.dataset import Dataset +from services.entities.knowledge_entities.knowledge_entities import RetrievalModel # --------------------------------------------------------------------------- # HitTestingPayload Model Tests # --------------------------------------------------------------------------- +def hit_testing_record() -> dict[str, object]: + return { + "segment": { + "id": "segment-1", + "position": 1, + "document_id": "document-1", + "content": "Chunk text", + "sign_content": "Chunk text", + "answer": None, + "word_count": 2, + "tokens": 3, + "keywords": None, + "index_node_id": None, + "index_node_hash": None, + "hit_count": 0, + "enabled": True, + "disabled_at": None, + "disabled_by": None, + "status": "completed", + "created_by": "account-1", + "created_at": 1_700_000_000, + "indexing_at": None, + "completed_at": None, + "error": None, + "stopped_at": None, + "document": { + "id": "document-1", + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": None, + }, + }, + "child_chunks": None, + "files": None, + "score": 0.9, + } + + class TestHitTestingPayload: """Test suite for HitTestingPayload Pydantic model.""" @@ -48,7 +89,7 @@ class TestHitTestingPayload: } payload = HitTestingPayload( query="test query", - retrieval_model=retrieval_model_data, + retrieval_model=RetrievalModel.model_validate(retrieval_model_data), external_retrieval_model={"provider": "openai"}, attachment_ids=["att_1", "att_2"], ) @@ -68,6 +109,12 @@ class TestHitTestingPayload: payload = HitTestingPayload(query="x" * 250) assert len(payload.query) == 250 + def test_payload_ignores_unknown_fields_for_compatibility(self): + """Top-level fields outside the documented schema remain ignored as before.""" + payload = HitTestingPayload.model_validate({"query": "test query", "top_k": 3}) + + assert payload.model_dump(exclude_none=True) == {"query": "test query"} + # --------------------------------------------------------------------------- # HitTestingApi Tests @@ -80,8 +127,11 @@ class TestHitTestingPayload: class TestHitTestingApiPost: """Tests for HitTestingApi.post() via __wrapped__ to skip billing decorator.""" + @staticmethod + def _dataset(dataset_id: str, tenant_id: str) -> Dataset: + return Dataset(id=dataset_id, tenant_id=tenant_id, name="Dataset", created_by="account-1") + @patch("controllers.service_api.dataset.hit_testing.service_api_ns") - @patch("controllers.console.datasets.hit_testing_base.marshal") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) @@ -90,7 +140,6 @@ class TestHitTestingApiPost: mock_current_user, mock_dataset_svc, mock_hit_svc, - mock_marshal, mock_ns, app: Flask, ): @@ -98,15 +147,13 @@ class TestHitTestingApiPost: dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) - mock_dataset = Mock() - mock_dataset.id = dataset_id + mock_dataset = self._dataset(dataset_id, tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None mock_hit_svc.retrieve.return_value = {"query": {"content": "test query"}, "records": []} mock_hit_svc.hit_testing_args_check.return_value = None - mock_marshal.return_value = [] mock_ns.payload = {"query": "test query"} @@ -115,11 +162,10 @@ class TestHitTestingApiPost: # Skip billing decorator via __wrapped__ response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) - assert response["query"] == "test query" + assert response["query"] == {"content": "test query"} mock_hit_svc.retrieve.assert_called_once() @patch("controllers.service_api.dataset.hit_testing.service_api_ns") - @patch("controllers.console.datasets.hit_testing_base.marshal") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) @@ -128,7 +174,6 @@ class TestHitTestingApiPost: mock_current_user, mock_dataset_svc, mock_hit_svc, - mock_marshal, mock_ns, app: Flask, ): @@ -136,8 +181,7 @@ class TestHitTestingApiPost: dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) - mock_dataset = Mock() - mock_dataset.id = dataset_id + mock_dataset = self._dataset(dataset_id, tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None @@ -152,7 +196,6 @@ class TestHitTestingApiPost: mock_hit_svc.retrieve.return_value = {"query": {"content": "complex query"}, "records": []} mock_hit_svc.hit_testing_args_check.return_value = None - mock_marshal.return_value = [] mock_ns.payload = { "query": "complex query", @@ -164,7 +207,7 @@ class TestHitTestingApiPost: api = HitTestingApi() response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) - assert response["query"] == "complex query" + assert response["query"] == {"content": "complex query"} call_kwargs = mock_hit_svc.retrieve.call_args # retrieval_model is serialized via model_dump, verify key fields passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model") @@ -173,7 +216,6 @@ class TestHitTestingApiPost: assert passed_retrieval_model["top_k"] == 10 @patch("controllers.service_api.dataset.hit_testing.service_api_ns") - @patch("controllers.console.datasets.hit_testing_base.marshal") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) @@ -182,7 +224,6 @@ class TestHitTestingApiPost: mock_current_user, mock_dataset_svc, mock_hit_svc, - mock_marshal, mock_ns, app: Flask, ): @@ -190,14 +231,12 @@ class TestHitTestingApiPost: dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) - mock_dataset = Mock() - mock_dataset.id = dataset_id + mock_dataset = self._dataset(dataset_id, tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None mock_hit_svc.retrieve.return_value = {"query": {"content": "filtered query"}, "records": []} mock_hit_svc.hit_testing_args_check.return_value = None - mock_marshal.return_value = [] metadata_filtering_conditions = { "logical_operator": "and", @@ -229,7 +268,6 @@ class TestHitTestingApiPost: assert passed_retrieval_model["metadata_filtering_conditions"] == metadata_filtering_conditions @patch("controllers.service_api.dataset.hit_testing.service_api_ns") - @patch("controllers.console.datasets.hit_testing_base.marshal") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) @@ -238,30 +276,23 @@ class TestHitTestingApiPost: mock_current_user, mock_dataset_svc, mock_hit_svc, - mock_marshal, mock_ns, app: Flask, ): - """Test service API prepares nullable list fields from marshalled records.""" + """Test service API prepares nullable list fields from retrieval records.""" dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) - mock_dataset = Mock() - mock_dataset.id = dataset_id + mock_dataset = self._dataset(dataset_id, tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None - mock_hit_svc.retrieve.return_value = {"query": {"content": "legacy query"}, "records": ["placeholder"]} + mock_hit_svc.retrieve.return_value = { + "query": {"content": "legacy query"}, + "records": [hit_testing_record()], + } mock_hit_svc.hit_testing_args_check.return_value = None - mock_marshal.return_value = [ - { - "segment": {"id": "segment-1", "keywords": None}, - "child_chunks": None, - "files": None, - "score": 0.9, - } - ] mock_ns.payload = {"query": "legacy query"} @@ -269,15 +300,15 @@ class TestHitTestingApiPost: api = HitTestingApi() response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) - assert response["query"] == "legacy query" - assert response["records"] == [ - { - "segment": {"id": "segment-1", "keywords": []}, - "child_chunks": [], - "files": [], - "score": 0.9, - } - ] + assert response["query"] == {"content": "legacy query"} + record = response["records"][0] + assert record["segment"]["id"] == "segment-1" + assert record["segment"]["keywords"] == [] + assert record["child_chunks"] == [] + assert record["files"] == [] + assert record["score"] == 0.9 + assert record["tsne_position"] is None + assert record["summary"] is None @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @@ -315,8 +346,7 @@ class TestHitTestingApiPost: dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) - mock_dataset = Mock() - mock_dataset.id = dataset_id + mock_dataset = self._dataset(dataset_id, tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError( diff --git a/packages/contracts/generated/api/console/datasets/types.gen.ts b/packages/contracts/generated/api/console/datasets/types.gen.ts index 6c96c80d4a..938331f9c9 100644 --- a/packages/contracts/generated/api/console/datasets/types.gen.ts +++ b/packages/contracts/generated/api/console/datasets/types.gen.ts @@ -396,8 +396,8 @@ export type HitTestingPayload = { } export type HitTestingResponse = { - query: string - records?: Array + query: HitTestingQuery + records: Array } export type DocumentStatusListResponse = { @@ -666,13 +666,17 @@ export type DocumentStatusResponse = { total_segments?: number | null } +export type HitTestingQuery = { + content: string +} + export type HitTestingRecord = { - child_chunks?: Array - files?: Array - score?: number | null - segment?: HitTestingSegment - summary?: string | null - tsne_position?: unknown + child_chunks: Array + files: Array + score: number | null + segment: HitTestingSegment + summary: string | null + tsne_position: unknown } export type DatasetMetadataListItemResponse = { @@ -768,45 +772,45 @@ export type MetadataDetail = { } export type HitTestingChildChunk = { - content?: string | null - id?: string | null - position?: number | null - score?: number | null + content: string + id: string + position: number + score: number } export type HitTestingFile = { - extension?: string | null - id?: string | null - mime_type?: string | null - name?: string | null - size?: number | null - source_url?: string | null + extension: string + id: string + mime_type: string + name: string + size: number + source_url: string } export type HitTestingSegment = { - answer?: string | null - completed_at?: number | null - content?: string | null - created_at?: number | null - created_by?: string | null - disabled_at?: number | null - disabled_by?: string | null - document?: HitTestingDocument - document_id?: string | null - enabled?: boolean | null - error?: string | null - hit_count?: number | null - id?: string | null - index_node_hash?: string | null - index_node_id?: string | null - indexing_at?: number | null - keywords?: Array - position?: number | null - sign_content?: string | null - status?: string | null - stopped_at?: number | null - tokens?: number | null - word_count?: number | null + answer: string | null + completed_at: number | null + content: string + created_at: number + created_by: string + disabled_at: number | null + disabled_by: string | null + document: HitTestingDocument + document_id: string + enabled: boolean + error: string | null + hit_count: number + id: string + index_node_hash: string | null + index_node_id: string | null + indexing_at: number | null + keywords: Array + position: number + sign_content: string | null + status: string + stopped_at: number | null + tokens: number + word_count: number } export type DatasetQueryContentResponse = { @@ -898,11 +902,11 @@ export type WeightVectorSetting = { } export type HitTestingDocument = { - data_source_type?: string | null - doc_metadata?: unknown - doc_type?: string | null - id?: string | null - name?: string | null + data_source_type: string + doc_metadata: unknown + doc_type: string | null + id: string + name: string } export type DatasetQueryFileInfoResponse = { diff --git a/packages/contracts/generated/api/console/datasets/zod.gen.ts b/packages/contracts/generated/api/console/datasets/zod.gen.ts index f795d17f2f..082695d39d 100644 --- a/packages/contracts/generated/api/console/datasets/zod.gen.ts +++ b/packages/contracts/generated/api/console/datasets/zod.gen.ts @@ -530,6 +530,13 @@ export const zDocumentStatusListResponse = z.object({ data: z.array(zDocumentStatusResponse), }) +/** + * HitTestingQuery + */ +export const zHitTestingQuery = z.object({ + content: z.string(), +}) + /** * DatasetMetadataListItemResponse */ @@ -632,22 +639,22 @@ export const zMetadataOperationData = z.object({ * HitTestingChildChunk */ export const zHitTestingChildChunk = z.object({ - content: z.string().nullish(), - id: z.string().nullish(), - position: z.int().nullish(), - score: z.number().nullish(), + content: z.string(), + id: z.string(), + position: z.int(), + score: z.number(), }) /** * HitTestingFile */ export const zHitTestingFile = z.object({ - extension: z.string().nullish(), - id: z.string().nullish(), - mime_type: z.string().nullish(), - name: z.string().nullish(), - size: z.int().nullish(), - source_url: z.string().nullish(), + extension: z.string(), + id: z.string(), + mime_type: z.string(), + name: z.string(), + size: z.int(), + source_url: z.string(), }) /** @@ -1036,60 +1043,60 @@ export const zHitTestingPayload = z.object({ * HitTestingDocument */ export const zHitTestingDocument = z.object({ - data_source_type: z.string().nullish(), - doc_metadata: z.unknown().optional(), - doc_type: z.string().nullish(), - id: z.string().nullish(), - name: z.string().nullish(), + data_source_type: z.string(), + doc_metadata: z.unknown(), + doc_type: z.string().nullable(), + id: z.string(), + name: z.string(), }) /** * HitTestingSegment */ export const zHitTestingSegment = z.object({ - answer: z.string().nullish(), - completed_at: z.int().nullish(), - content: z.string().nullish(), - created_at: z.int().nullish(), - created_by: z.string().nullish(), - disabled_at: z.int().nullish(), - disabled_by: z.string().nullish(), - document: zHitTestingDocument.optional(), - document_id: z.string().nullish(), - enabled: z.boolean().nullish(), - error: z.string().nullish(), - hit_count: z.int().nullish(), - id: z.string().nullish(), - index_node_hash: z.string().nullish(), - index_node_id: z.string().nullish(), - indexing_at: z.int().nullish(), - keywords: z.array(z.string()).optional(), - position: z.int().nullish(), - sign_content: z.string().nullish(), - status: z.string().nullish(), - stopped_at: z.int().nullish(), - tokens: z.int().nullish(), - word_count: z.int().nullish(), + answer: z.string().nullable(), + completed_at: z.int().nullable(), + content: z.string(), + created_at: z.int(), + created_by: z.string(), + disabled_at: z.int().nullable(), + disabled_by: z.string().nullable(), + document: zHitTestingDocument, + document_id: z.string(), + enabled: z.boolean(), + error: z.string().nullable(), + hit_count: z.int(), + id: z.string(), + index_node_hash: z.string().nullable(), + index_node_id: z.string().nullable(), + indexing_at: z.int().nullable(), + keywords: z.array(z.string()), + position: z.int(), + sign_content: z.string().nullable(), + status: z.string(), + stopped_at: z.int().nullable(), + tokens: z.int(), + word_count: z.int(), }) /** * HitTestingRecord */ export const zHitTestingRecord = z.object({ - child_chunks: z.array(zHitTestingChildChunk).optional(), - files: z.array(zHitTestingFile).optional(), - score: z.number().nullish(), - segment: zHitTestingSegment.optional(), - summary: z.string().nullish(), - tsne_position: z.unknown().optional(), + child_chunks: z.array(zHitTestingChildChunk), + files: z.array(zHitTestingFile), + score: z.number().nullable(), + segment: zHitTestingSegment, + summary: z.string().nullable(), + tsne_position: z.unknown(), }) /** * HitTestingResponse */ export const zHitTestingResponse = z.object({ - query: z.string(), - records: z.array(zHitTestingRecord).optional(), + query: zHitTestingQuery, + records: z.array(zHitTestingRecord), }) /** diff --git a/packages/contracts/generated/api/service/types.gen.ts b/packages/contracts/generated/api/service/types.gen.ts index 54ce811a95..cd84f94d81 100644 --- a/packages/contracts/generated/api/service/types.gen.ts +++ b/packages/contracts/generated/api/service/types.gen.ts @@ -469,6 +469,30 @@ export type FileResponse = { user_id?: string | null } +export type HitTestingChildChunk = { + content: string + id: string + position: number + score: number +} + +export type HitTestingDocument = { + data_source_type: string + doc_metadata: unknown + doc_type: string | null + id: string + name: string +} + +export type HitTestingFile = { + extension: string + id: string + mime_type: string + name: string + size: number + source_url: string +} + export type HitTestingPayload = { attachment_ids?: Array | null external_retrieval_model?: { @@ -478,6 +502,50 @@ export type HitTestingPayload = { retrieval_model?: RetrievalModel } +export type HitTestingQuery = { + content: string +} + +export type HitTestingRecord = { + child_chunks: Array + files: Array + score: number | null + segment: HitTestingSegment + summary: string | null + tsne_position: unknown +} + +export type HitTestingResponse = { + query: HitTestingQuery + records: Array +} + +export type HitTestingSegment = { + answer: string | null + completed_at: number | null + content: string + created_at: number + created_by: string + disabled_at: number | null + disabled_by: string | null + document: HitTestingDocument + document_id: string + enabled: boolean + error: string | null + hit_count: number + id: string + index_node_hash: string | null + index_node_id: string | null + indexing_at: number | null + keywords: Array + position: number + sign_content: string | null + status: string + stopped_at: number | null + tokens: number + word_count: number +} + export type HumanInputFormSubmitPayload = { action: string inputs: { @@ -2510,9 +2578,7 @@ export type PostDatasetsByDatasetIdHitTestingError = PostDatasetsByDatasetIdHitTestingErrors[keyof PostDatasetsByDatasetIdHitTestingErrors] export type PostDatasetsByDatasetIdHitTestingResponses = { - 200: { - [key: string]: unknown - } + 200: HitTestingResponse } export type PostDatasetsByDatasetIdHitTestingResponse @@ -2794,9 +2860,7 @@ export type PostDatasetsByDatasetIdRetrieveError = PostDatasetsByDatasetIdRetrieveErrors[keyof PostDatasetsByDatasetIdRetrieveErrors] export type PostDatasetsByDatasetIdRetrieveResponses = { - 200: { - [key: string]: unknown - } + 200: HitTestingResponse } export type PostDatasetsByDatasetIdRetrieveResponse diff --git a/packages/contracts/generated/api/service/zod.gen.ts b/packages/contracts/generated/api/service/zod.gen.ts index 22e4b24721..e3008ddfbf 100644 --- a/packages/contracts/generated/api/service/zod.gen.ts +++ b/packages/contracts/generated/api/service/zod.gen.ts @@ -553,6 +553,95 @@ export const zFileResponse = z.object({ user_id: z.string().nullish(), }) +/** + * HitTestingChildChunk + */ +export const zHitTestingChildChunk = z.object({ + content: z.string(), + id: z.string(), + position: z.int(), + score: z.number(), +}) + +/** + * HitTestingDocument + */ +export const zHitTestingDocument = z.object({ + data_source_type: z.string(), + doc_metadata: z.unknown(), + doc_type: z.string().nullable(), + id: z.string(), + name: z.string(), +}) + +/** + * HitTestingFile + */ +export const zHitTestingFile = z.object({ + extension: z.string(), + id: z.string(), + mime_type: z.string(), + name: z.string(), + size: z.int(), + source_url: z.string(), +}) + +/** + * HitTestingQuery + */ +export const zHitTestingQuery = z.object({ + content: z.string(), +}) + +/** + * HitTestingSegment + */ +export const zHitTestingSegment = z.object({ + answer: z.string().nullable(), + completed_at: z.int().nullable(), + content: z.string(), + created_at: z.int(), + created_by: z.string(), + disabled_at: z.int().nullable(), + disabled_by: z.string().nullable(), + document: zHitTestingDocument, + document_id: z.string(), + enabled: z.boolean(), + error: z.string().nullable(), + hit_count: z.int(), + id: z.string(), + index_node_hash: z.string().nullable(), + index_node_id: z.string().nullable(), + indexing_at: z.int().nullable(), + keywords: z.array(z.string()), + position: z.int(), + sign_content: z.string().nullable(), + status: z.string(), + stopped_at: z.int().nullable(), + tokens: z.int(), + word_count: z.int(), +}) + +/** + * HitTestingRecord + */ +export const zHitTestingRecord = z.object({ + child_chunks: z.array(zHitTestingChildChunk), + files: z.array(zHitTestingFile), + score: z.number().nullable(), + segment: zHitTestingSegment, + summary: z.string().nullable(), + tsne_position: z.unknown(), +}) + +/** + * HitTestingResponse + */ +export const zHitTestingResponse = z.object({ + query: zHitTestingQuery, + records: z.array(zHitTestingRecord), +}) + /** * IndexInfoResponse */ @@ -1720,7 +1809,7 @@ export const zPostDatasetsByDatasetIdHitTestingPath = z.object({ /** * Hit testing results */ -export const zPostDatasetsByDatasetIdHitTestingResponse = z.record(z.string(), z.unknown()) +export const zPostDatasetsByDatasetIdHitTestingResponse = zHitTestingResponse export const zGetDatasetsByDatasetIdMetadataPath = z.object({ dataset_id: z.string(), @@ -1834,7 +1923,7 @@ export const zPostDatasetsByDatasetIdRetrievePath = z.object({ /** * Hit testing results */ -export const zPostDatasetsByDatasetIdRetrieveResponse = z.record(z.string(), z.unknown()) +export const zPostDatasetsByDatasetIdRetrieveResponse = zHitTestingResponse export const zGetDatasetsByDatasetIdTagsPath = z.object({ dataset_id: z.string(),