mirror of
https://github.com/langgenius/dify.git
synced 2026-05-27 20:36:18 +08:00
refactor(api): migrate console/service_api.dataset.hit_testing to BaseModel (#36533)
This commit is contained in:
@ -1,15 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.hit_testing_fields import HitTestingResponse
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -20,86 +17,8 @@ from ..wraps import (
|
||||
setup_required,
|
||||
)
|
||||
|
||||
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str | None = None
|
||||
data_source_type: str | None = None
|
||||
name: str | None = None
|
||||
doc_type: str | None = None
|
||||
doc_metadata: Any | None = None
|
||||
|
||||
|
||||
class HitTestingSegment(ResponseModel):
|
||||
id: str | None = None
|
||||
position: int | None = None
|
||||
document_id: str | None = None
|
||||
content: str | None = None
|
||||
sign_content: str | None = None
|
||||
answer: str | None = None
|
||||
word_count: int | None = None
|
||||
tokens: int | None = None
|
||||
keywords: list[str] = Field(default_factory=list)
|
||||
index_node_id: str | None = None
|
||||
index_node_hash: str | None = None
|
||||
hit_count: int | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
status: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
indexing_at: int | None = None
|
||||
completed_at: int | None = None
|
||||
error: str | None = None
|
||||
stopped_at: int | None = None
|
||||
document: HitTestingDocument | None = None
|
||||
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
id: str | None = None
|
||||
content: str | None = None
|
||||
position: int | None = None
|
||||
score: float | None = None
|
||||
|
||||
|
||||
class HitTestingFile(ResponseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
source_url: str | None = None
|
||||
|
||||
|
||||
class HitTestingRecord(ResponseModel):
|
||||
segment: HitTestingSegment | None = None
|
||||
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
|
||||
score: float | None = None
|
||||
tsne_position: Any | None = None
|
||||
files: list[HitTestingFile] = Field(default_factory=list)
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class HitTestingResponse(ResponseModel):
|
||||
query: str
|
||||
records: list[HitTestingRecord] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
HitTestingPayload,
|
||||
HitTestingDocument,
|
||||
HitTestingSegment,
|
||||
HitTestingChildChunk,
|
||||
HitTestingFile,
|
||||
HitTestingRecord,
|
||||
HitTestingResponse,
|
||||
)
|
||||
register_schema_models(console_ns, HitTestingPayload)
|
||||
register_response_schema_models(console_ns, HitTestingResponse)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
@ -119,12 +38,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id: UUID) -> dict[str, object]:
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
payload = HitTestingPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args = self.parse_args(console_ns.payload)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")
|
||||
return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args))
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -19,10 +18,10 @@ from core.errors.error import (
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -38,16 +37,6 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _extract_hit_testing_query(query: Any) -> str:
|
||||
"""Return the query string from the service response shape."""
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Ensure collection fields match the API schema before response validation."""
|
||||
@ -63,6 +52,7 @@ class DatasetsHitTestingBase:
|
||||
segment = normalized_record.get("segment")
|
||||
if isinstance(segment, dict):
|
||||
normalized_segment = dict(segment)
|
||||
normalized_segment.setdefault("sign_content", None)
|
||||
if normalized_segment.get("keywords") is None:
|
||||
normalized_segment["keywords"] = []
|
||||
normalized_record["segment"] = normalized_segment
|
||||
@ -73,12 +63,15 @@ class DatasetsHitTestingBase:
|
||||
if normalized_record.get("files") is None:
|
||||
normalized_record["files"] = []
|
||||
|
||||
normalized_record.setdefault("tsne_position", None)
|
||||
normalized_record.setdefault("summary", None)
|
||||
|
||||
normalized_records.append(normalized_record)
|
||||
|
||||
return normalized_records
|
||||
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
def get_and_validate_dataset(dataset_id: str) -> Dataset:
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
@ -92,33 +85,35 @@ class DatasetsHitTestingBase:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args: dict[str, Any]):
|
||||
def hit_testing_args_check(args: dict[str, Any]) -> None:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
def parse_args(payload: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Validate and return hit-testing arguments from an incoming payload."""
|
||||
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
|
||||
return hit_testing_payload.model_dump(exclude_none=True)
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]:
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args.get("query"),
|
||||
query=cast(str, args.get("query")),
|
||||
account=current_user,
|
||||
retrieval_model=args.get("retrieval_model"),
|
||||
external_retrieval_model=args.get("external_retrieval_model"),
|
||||
external_retrieval_model=cast(dict[str, Any], args.get("external_retrieval_model")),
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
limit=10,
|
||||
)
|
||||
query = response.get("query")
|
||||
if not isinstance(query, dict) or not isinstance(query.get("content"), str):
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
"query": {"content": query["content"]},
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(response.get("records", [])),
|
||||
}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
from uuid import UUID
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.hit_testing_fields import HitTestingResponse
|
||||
from libs.helper import dump_response
|
||||
|
||||
register_schema_model(service_api_ns, HitTestingPayload)
|
||||
register_schema_models(service_api_ns, HitTestingPayload)
|
||||
register_response_schema_models(service_api_ns, HitTestingResponse)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||
@ -13,16 +16,16 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
@service_api_ns.doc("dataset_hit_testing")
|
||||
@service_api_ns.doc(description="Perform hit testing on a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Hit testing results",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Hit testing results",
|
||||
model=service_api_ns.models[HitTestingResponse.__name__],
|
||||
)
|
||||
@service_api_ns.response(401, "Unauthorized - invalid API token")
|
||||
@service_api_ns.response(404, "Dataset not found")
|
||||
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id: UUID):
|
||||
def post(self, tenant_id: str, dataset_id: UUID) -> dict[str, object]:
|
||||
"""Perform hit testing on a dataset.
|
||||
|
||||
Tests retrieval performance for the specified dataset.
|
||||
@ -33,4 +36,4 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
args = self.parse_args(service_api_ns.payload)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args))
|
||||
|
||||
@ -1,62 +1,96 @@
|
||||
from flask_restx import fields
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from pydantic import field_validator
|
||||
|
||||
document_fields = {
|
||||
"id": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"name": fields.String,
|
||||
"doc_type": fields.String,
|
||||
"doc_metadata": fields.Raw,
|
||||
}
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
|
||||
segment_fields = {
|
||||
"id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"document_id": fields.String,
|
||||
"content": fields.String,
|
||||
"sign_content": fields.String,
|
||||
"answer": fields.String,
|
||||
"word_count": fields.Integer,
|
||||
"tokens": fields.Integer,
|
||||
"keywords": fields.List(fields.String),
|
||||
"index_node_id": fields.String,
|
||||
"index_node_hash": fields.String,
|
||||
"hit_count": fields.Integer,
|
||||
"enabled": fields.Boolean,
|
||||
"disabled_at": TimestampField,
|
||||
"disabled_by": fields.String,
|
||||
"status": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"indexing_at": TimestampField,
|
||||
"completed_at": TimestampField,
|
||||
"error": fields.String,
|
||||
"stopped_at": TimestampField,
|
||||
"document": fields.Nested(document_fields),
|
||||
}
|
||||
|
||||
child_chunk_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"position": fields.Integer,
|
||||
"score": fields.Float,
|
||||
}
|
||||
class HitTestingQuery(ResponseModel):
|
||||
content: str
|
||||
|
||||
files_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"size": fields.Integer,
|
||||
"extension": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"source_url": fields.String,
|
||||
}
|
||||
|
||||
hit_testing_record_fields = {
|
||||
"segment": fields.Nested(segment_fields),
|
||||
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
||||
"score": fields.Float,
|
||||
"tsne_position": fields.Raw,
|
||||
"files": fields.List(fields.Nested(files_fields)),
|
||||
"summary": fields.String, # Summary content if retrieved via summary index
|
||||
}
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str
|
||||
data_source_type: str
|
||||
name: str
|
||||
doc_type: str | None
|
||||
doc_metadata: Any | None
|
||||
|
||||
@field_validator("data_source_type", "doc_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
|
||||
class HitTestingSegment(ResponseModel):
|
||||
id: str
|
||||
position: int
|
||||
document_id: str
|
||||
content: str
|
||||
sign_content: str | None
|
||||
answer: str | None
|
||||
word_count: int
|
||||
tokens: int
|
||||
keywords: list[str]
|
||||
index_node_id: str | None
|
||||
index_node_hash: str | None
|
||||
hit_count: int
|
||||
enabled: bool
|
||||
disabled_at: int | None
|
||||
disabled_by: str | None
|
||||
status: str
|
||||
created_by: str
|
||||
created_at: int
|
||||
indexing_at: int | None
|
||||
completed_at: int | None
|
||||
error: str | None
|
||||
stopped_at: int | None
|
||||
document: HitTestingDocument
|
||||
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
id: str
|
||||
content: str
|
||||
position: int
|
||||
score: float
|
||||
|
||||
|
||||
class HitTestingFile(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
|
||||
|
||||
class HitTestingRecord(ResponseModel):
|
||||
segment: HitTestingSegment
|
||||
child_chunks: list[HitTestingChildChunk]
|
||||
score: float | None
|
||||
tsne_position: Any | None
|
||||
files: list[HitTestingFile]
|
||||
summary: str | None
|
||||
|
||||
|
||||
class HitTestingResponse(ResponseModel):
|
||||
query: HitTestingQuery
|
||||
records: list[HitTestingRecord]
|
||||
|
||||
|
||||
def _normalize_enum(value: Any) -> Any:
|
||||
if isinstance(value, str) or value is None:
|
||||
return value
|
||||
return getattr(value, "value", value)
|
||||
|
||||
@ -13051,31 +13051,31 @@ Request payload for bulk downloading documents as a zip archive.
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| content | string | | No |
|
||||
| id | string | | No |
|
||||
| position | integer | | No |
|
||||
| score | number | | No |
|
||||
| content | string | | Yes |
|
||||
| id | string | | Yes |
|
||||
| position | integer | | Yes |
|
||||
| score | number | | Yes |
|
||||
|
||||
#### HitTestingDocument
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| data_source_type | string | | No |
|
||||
| doc_metadata | | | No |
|
||||
| doc_type | string | | No |
|
||||
| id | string | | No |
|
||||
| name | string | | No |
|
||||
| data_source_type | string | | Yes |
|
||||
| doc_metadata | | | Yes |
|
||||
| doc_type | string | | Yes |
|
||||
| id | string | | Yes |
|
||||
| name | string | | Yes |
|
||||
|
||||
#### HitTestingFile
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| extension | string | | No |
|
||||
| id | string | | No |
|
||||
| mime_type | string | | No |
|
||||
| name | string | | No |
|
||||
| size | integer | | No |
|
||||
| source_url | string | | No |
|
||||
| extension | string | | Yes |
|
||||
| id | string | | Yes |
|
||||
| mime_type | string | | Yes |
|
||||
| name | string | | Yes |
|
||||
| size | integer | | Yes |
|
||||
| source_url | string | | Yes |
|
||||
|
||||
#### HitTestingPayload
|
||||
|
||||
@ -13086,51 +13086,57 @@ Request payload for bulk downloading documents as a zip archive.
|
||||
| query | string | | Yes |
|
||||
| retrieval_model | [RetrievalModel](#retrievalmodel) | | No |
|
||||
|
||||
#### HitTestingQuery
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| content | string | | Yes |
|
||||
|
||||
#### HitTestingRecord
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| child_chunks | [ [HitTestingChildChunk](#hittestingchildchunk) ] | | No |
|
||||
| files | [ [HitTestingFile](#hittestingfile) ] | | No |
|
||||
| score | number | | No |
|
||||
| segment | [HitTestingSegment](#hittestingsegment) | | No |
|
||||
| summary | string | | No |
|
||||
| tsne_position | | | No |
|
||||
| child_chunks | [ [HitTestingChildChunk](#hittestingchildchunk) ] | | Yes |
|
||||
| files | [ [HitTestingFile](#hittestingfile) ] | | Yes |
|
||||
| score | number | | Yes |
|
||||
| segment | [HitTestingSegment](#hittestingsegment) | | Yes |
|
||||
| summary | string | | Yes |
|
||||
| tsne_position | | | Yes |
|
||||
|
||||
#### HitTestingResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| query | string | | Yes |
|
||||
| records | [ [HitTestingRecord](#hittestingrecord) ] | | No |
|
||||
| query | [HitTestingQuery](#hittestingquery) | | Yes |
|
||||
| records | [ [HitTestingRecord](#hittestingrecord) ] | | Yes |
|
||||
|
||||
#### HitTestingSegment
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| answer | string | | No |
|
||||
| completed_at | integer | | No |
|
||||
| content | string | | No |
|
||||
| created_at | integer | | No |
|
||||
| created_by | string | | No |
|
||||
| disabled_at | integer | | No |
|
||||
| disabled_by | string | | No |
|
||||
| document | [HitTestingDocument](#hittestingdocument) | | No |
|
||||
| document_id | string | | No |
|
||||
| enabled | boolean | | No |
|
||||
| error | string | | No |
|
||||
| hit_count | integer | | No |
|
||||
| id | string | | No |
|
||||
| index_node_hash | string | | No |
|
||||
| index_node_id | string | | No |
|
||||
| indexing_at | integer | | No |
|
||||
| keywords | [ string ] | | No |
|
||||
| position | integer | | No |
|
||||
| sign_content | string | | No |
|
||||
| status | string | | No |
|
||||
| stopped_at | integer | | No |
|
||||
| tokens | integer | | No |
|
||||
| word_count | integer | | No |
|
||||
| answer | string | | Yes |
|
||||
| completed_at | integer | | Yes |
|
||||
| content | string | | Yes |
|
||||
| created_at | integer | | Yes |
|
||||
| created_by | string | | Yes |
|
||||
| disabled_at | integer | | Yes |
|
||||
| disabled_by | string | | Yes |
|
||||
| document | [HitTestingDocument](#hittestingdocument) | | Yes |
|
||||
| document_id | string | | Yes |
|
||||
| enabled | boolean | | Yes |
|
||||
| error | string | | Yes |
|
||||
| hit_count | integer | | Yes |
|
||||
| id | string | | Yes |
|
||||
| index_node_hash | string | | Yes |
|
||||
| index_node_id | string | | Yes |
|
||||
| indexing_at | integer | | Yes |
|
||||
| keywords | [ string ] | | Yes |
|
||||
| position | integer | | Yes |
|
||||
| sign_content | string | | Yes |
|
||||
| status | string | | Yes |
|
||||
| stopped_at | integer | | Yes |
|
||||
| tokens | integer | | Yes |
|
||||
| word_count | integer | | Yes |
|
||||
|
||||
#### HumanInputContent
|
||||
|
||||
|
||||
@ -1363,11 +1363,11 @@ Tests retrieval performance for the specified dataset.
|
||||
|
||||
##### Responses
|
||||
|
||||
| Code | Description |
|
||||
| ---- | ----------- |
|
||||
| 200 | Hit testing results |
|
||||
| 401 | Unauthorized - invalid API token |
|
||||
| 404 | Dataset not found |
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Hit testing results | [HitTestingResponse](#hittestingresponse) |
|
||||
| 401 | Unauthorized - invalid API token | |
|
||||
| 404 | Dataset not found | |
|
||||
|
||||
### /datasets/{dataset_id}/metadata
|
||||
|
||||
@ -1614,11 +1614,11 @@ Tests retrieval performance for the specified dataset.
|
||||
|
||||
##### Responses
|
||||
|
||||
| Code | Description |
|
||||
| ---- | ----------- |
|
||||
| 200 | Hit testing results |
|
||||
| 401 | Unauthorized - invalid API token |
|
||||
| 404 | Dataset not found |
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Hit testing results | [HitTestingResponse](#hittestingresponse) |
|
||||
| 401 | Unauthorized - invalid API token | |
|
||||
| 404 | Dataset not found | |
|
||||
|
||||
### /datasets/{dataset_id}/tags
|
||||
|
||||
@ -2691,6 +2691,36 @@ Note: The SQLAlchemy model defines an `is_anonymous` property for Flask-Login se
|
||||
| tenant_id | string | | No |
|
||||
| user_id | string | | No |
|
||||
|
||||
#### HitTestingChildChunk
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| content | string | | Yes |
|
||||
| id | string | | Yes |
|
||||
| position | integer | | Yes |
|
||||
| score | number | | Yes |
|
||||
|
||||
#### HitTestingDocument
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| data_source_type | string | | Yes |
|
||||
| doc_metadata | | | Yes |
|
||||
| doc_type | string | | Yes |
|
||||
| id | string | | Yes |
|
||||
| name | string | | Yes |
|
||||
|
||||
#### HitTestingFile
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| extension | string | | Yes |
|
||||
| id | string | | Yes |
|
||||
| mime_type | string | | Yes |
|
||||
| name | string | | Yes |
|
||||
| size | integer | | Yes |
|
||||
| source_url | string | | Yes |
|
||||
|
||||
#### HitTestingPayload
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
@ -2700,6 +2730,58 @@ Note: The SQLAlchemy model defines an `is_anonymous` property for Flask-Login se
|
||||
| query | string | | Yes |
|
||||
| retrieval_model | [RetrievalModel](#retrievalmodel) | | No |
|
||||
|
||||
#### HitTestingQuery
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| content | string | | Yes |
|
||||
|
||||
#### HitTestingRecord
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| child_chunks | [ [HitTestingChildChunk](#hittestingchildchunk) ] | | Yes |
|
||||
| files | [ [HitTestingFile](#hittestingfile) ] | | Yes |
|
||||
| score | number | | Yes |
|
||||
| segment | [HitTestingSegment](#hittestingsegment) | | Yes |
|
||||
| summary | string | | Yes |
|
||||
| tsne_position | | | Yes |
|
||||
|
||||
#### HitTestingResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| query | [HitTestingQuery](#hittestingquery) | | Yes |
|
||||
| records | [ [HitTestingRecord](#hittestingrecord) ] | | Yes |
|
||||
|
||||
#### HitTestingSegment
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| answer | string | | Yes |
|
||||
| completed_at | integer | | Yes |
|
||||
| content | string | | Yes |
|
||||
| created_at | integer | | Yes |
|
||||
| created_by | string | | Yes |
|
||||
| disabled_at | integer | | Yes |
|
||||
| disabled_by | string | | Yes |
|
||||
| document | [HitTestingDocument](#hittestingdocument) | | Yes |
|
||||
| document_id | string | | Yes |
|
||||
| enabled | boolean | | Yes |
|
||||
| error | string | | Yes |
|
||||
| hit_count | integer | | Yes |
|
||||
| id | string | | Yes |
|
||||
| index_node_hash | string | | Yes |
|
||||
| index_node_id | string | | Yes |
|
||||
| indexing_at | integer | | Yes |
|
||||
| keywords | [ string ] | | Yes |
|
||||
| position | integer | | Yes |
|
||||
| sign_content | string | | Yes |
|
||||
| status | string | | Yes |
|
||||
| stopped_at | integer | | Yes |
|
||||
| tokens | integer | | Yes |
|
||||
| word_count | integer | | Yes |
|
||||
|
||||
#### HumanInputFormSubmitPayload
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -8,7 +8,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.hit_testing import HitTestingApi
|
||||
from controllers.console.datasets.hit_testing_base import HitTestingPayload
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@ -32,7 +32,48 @@ def dataset_id():
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
return MagicMock(id="dataset-1")
|
||||
return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1")
|
||||
|
||||
|
||||
def hit_testing_record() -> dict[str, object]:
|
||||
return {
|
||||
"segment": {
|
||||
"id": "segment-1",
|
||||
"position": 1,
|
||||
"document_id": "document-1",
|
||||
"content": "Chunk text",
|
||||
"sign_content": "Chunk text",
|
||||
"answer": None,
|
||||
"word_count": 2,
|
||||
"tokens": 3,
|
||||
"keywords": [],
|
||||
"index_node_id": None,
|
||||
"index_node_hash": None,
|
||||
"hit_count": 0,
|
||||
"enabled": True,
|
||||
"disabled_at": None,
|
||||
"disabled_by": None,
|
||||
"status": "completed",
|
||||
"created_by": "account-1",
|
||||
"created_at": 1_700_000_000,
|
||||
"indexing_at": None,
|
||||
"completed_at": None,
|
||||
"error": None,
|
||||
"stopped_at": None,
|
||||
"document": {
|
||||
"id": "document-1",
|
||||
"data_source_type": "upload_file",
|
||||
"name": "guide.md",
|
||||
"doc_type": None,
|
||||
"doc_metadata": None,
|
||||
},
|
||||
},
|
||||
"child_chunks": [],
|
||||
"score": None,
|
||||
"tsne_position": None,
|
||||
"files": [],
|
||||
"summary": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -63,7 +104,6 @@ class TestHitTestingApi:
|
||||
|
||||
payload = {
|
||||
"query": "what is vector search",
|
||||
"top_k": 3,
|
||||
}
|
||||
|
||||
with (
|
||||
@ -74,11 +114,6 @@ class TestHitTestingApi:
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingPayload,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: payload),
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
@ -91,7 +126,7 @@ class TestHitTestingApi:
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"perform_hit_testing",
|
||||
return_value={"query": "what is vector search", "records": []},
|
||||
return_value={"query": {"content": "what is vector search"}, "records": []},
|
||||
),
|
||||
):
|
||||
result = method(api, dataset_id)
|
||||
@ -107,16 +142,7 @@ class TestHitTestingApi:
|
||||
payload = {
|
||||
"query": "what is vector search",
|
||||
}
|
||||
records = [
|
||||
{
|
||||
"segment": None,
|
||||
"child_chunks": [],
|
||||
"score": None,
|
||||
"tsne_position": None,
|
||||
"files": [],
|
||||
"summary": None,
|
||||
}
|
||||
]
|
||||
records = [hit_testing_record()]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -126,11 +152,6 @@ class TestHitTestingApi:
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingPayload,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: payload),
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
@ -143,13 +164,16 @@ class TestHitTestingApi:
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"perform_hit_testing",
|
||||
return_value={"query": payload["query"], "records": records},
|
||||
return_value={"query": {"content": payload["query"]}, "records": records},
|
||||
),
|
||||
):
|
||||
result = method(api, dataset_id)
|
||||
|
||||
assert result["query"] == payload["query"]
|
||||
assert result["records"] == records
|
||||
assert result["query"] == {"content": payload["query"]}
|
||||
assert result["records"][0]["segment"]["keywords"] == []
|
||||
assert result["records"][0]["child_chunks"] == []
|
||||
assert result["records"][0]["files"] == []
|
||||
assert result["records"][0]["score"] is None
|
||||
|
||||
def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id):
|
||||
api = HitTestingApi()
|
||||
@ -192,11 +216,6 @@ class TestHitTestingApi:
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingPayload,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: payload),
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
|
||||
@ -22,6 +22,7 @@ from core.errors.error import (
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
@ -43,7 +44,45 @@ def patch_current_user(mocker, account):
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
return MagicMock(id="dataset-1")
|
||||
return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1")
|
||||
|
||||
|
||||
def hit_testing_record() -> dict[str, object]:
|
||||
return {
|
||||
"segment": {
|
||||
"id": "segment-1",
|
||||
"position": 1,
|
||||
"document_id": "document-1",
|
||||
"content": "Chunk text",
|
||||
"answer": None,
|
||||
"word_count": 2,
|
||||
"tokens": 3,
|
||||
"keywords": None,
|
||||
"index_node_id": None,
|
||||
"index_node_hash": None,
|
||||
"hit_count": 0,
|
||||
"enabled": True,
|
||||
"disabled_at": None,
|
||||
"disabled_by": None,
|
||||
"status": "completed",
|
||||
"created_by": "account-1",
|
||||
"created_at": 1_700_000_000,
|
||||
"indexing_at": None,
|
||||
"completed_at": None,
|
||||
"error": None,
|
||||
"stopped_at": None,
|
||||
"document": {
|
||||
"id": "document-1",
|
||||
"data_source_type": "upload_file",
|
||||
"name": "guide.md",
|
||||
"doc_type": None,
|
||||
"doc_metadata": None,
|
||||
},
|
||||
},
|
||||
"child_chunks": None,
|
||||
"files": None,
|
||||
"score": 0.8,
|
||||
}
|
||||
|
||||
|
||||
class TestGetAndValidateDataset:
|
||||
@ -116,6 +155,13 @@ class TestParseArgs:
|
||||
with pytest.raises(ValueError):
|
||||
DatasetsHitTestingBase.parse_args(payload)
|
||||
|
||||
def test_parse_args_ignores_unknown_fields_for_compatibility(self):
|
||||
payload = {"query": "hello", "top_k": 3}
|
||||
|
||||
result = DatasetsHitTestingBase.parse_args(payload)
|
||||
|
||||
assert result == {"query": "hello"}
|
||||
|
||||
|
||||
class TestPerformHitTesting:
|
||||
def test_success(self, dataset):
|
||||
@ -131,48 +177,42 @@ class TestPerformHitTesting:
|
||||
):
|
||||
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
assert result["query"] == "hello"
|
||||
assert result["query"] == {"content": "hello"}
|
||||
assert result["records"] == []
|
||||
|
||||
def test_success_prepares_nullable_list_fields(self, dataset):
|
||||
response = {
|
||||
"query": {"content": "hello"},
|
||||
"records": [
|
||||
{
|
||||
"segment": {"id": "segment-1", "keywords": None},
|
||||
"child_chunks": None,
|
||||
"files": None,
|
||||
"score": 0.8,
|
||||
}
|
||||
],
|
||||
"records": [hit_testing_record()],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
return_value=response,
|
||||
):
|
||||
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
assert result["query"] == {"content": "hello"}
|
||||
record = result["records"][0]
|
||||
assert record["segment"]["keywords"] == []
|
||||
assert record["segment"]["sign_content"] is None
|
||||
assert record["child_chunks"] == []
|
||||
assert record["files"] == []
|
||||
assert record["score"] == 0.8
|
||||
assert record["tsne_position"] is None
|
||||
assert record["summary"] is None
|
||||
|
||||
def test_invalid_query_response_raises_value_error(self, dataset):
|
||||
with (
|
||||
patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
return_value=response,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.hit_testing_base.marshal",
|
||||
return_value=response["records"],
|
||||
return_value={"query": "hello", "records": []},
|
||||
),
|
||||
pytest.raises(ValueError, match="Invalid hit testing query response"),
|
||||
):
|
||||
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
assert result["query"] == "hello"
|
||||
assert result["records"] == [
|
||||
{
|
||||
"segment": {"id": "segment-1", "keywords": []},
|
||||
"child_chunks": [],
|
||||
"files": [],
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
|
||||
def test_invalid_query_response_raises_value_error(self):
|
||||
with pytest.raises(ValueError, match="Invalid hit testing query response"):
|
||||
DatasetsHitTestingBase._extract_hit_testing_query("hello")
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_invalid_records_response_raises_value_error(self):
|
||||
with pytest.raises(ValueError, match="Invalid hit testing records response"):
|
||||
|
||||
@ -24,12 +24,53 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
import services
|
||||
from controllers.service_api.dataset.hit_testing import HitTestingApi, HitTestingPayload
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HitTestingPayload Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def hit_testing_record() -> dict[str, object]:
|
||||
return {
|
||||
"segment": {
|
||||
"id": "segment-1",
|
||||
"position": 1,
|
||||
"document_id": "document-1",
|
||||
"content": "Chunk text",
|
||||
"sign_content": "Chunk text",
|
||||
"answer": None,
|
||||
"word_count": 2,
|
||||
"tokens": 3,
|
||||
"keywords": None,
|
||||
"index_node_id": None,
|
||||
"index_node_hash": None,
|
||||
"hit_count": 0,
|
||||
"enabled": True,
|
||||
"disabled_at": None,
|
||||
"disabled_by": None,
|
||||
"status": "completed",
|
||||
"created_by": "account-1",
|
||||
"created_at": 1_700_000_000,
|
||||
"indexing_at": None,
|
||||
"completed_at": None,
|
||||
"error": None,
|
||||
"stopped_at": None,
|
||||
"document": {
|
||||
"id": "document-1",
|
||||
"data_source_type": "upload_file",
|
||||
"name": "guide.md",
|
||||
"doc_type": None,
|
||||
"doc_metadata": None,
|
||||
},
|
||||
},
|
||||
"child_chunks": None,
|
||||
"files": None,
|
||||
"score": 0.9,
|
||||
}
|
||||
|
||||
|
||||
class TestHitTestingPayload:
|
||||
"""Test suite for HitTestingPayload Pydantic model."""
|
||||
|
||||
@ -48,7 +89,7 @@ class TestHitTestingPayload:
|
||||
}
|
||||
payload = HitTestingPayload(
|
||||
query="test query",
|
||||
retrieval_model=retrieval_model_data,
|
||||
retrieval_model=RetrievalModel.model_validate(retrieval_model_data),
|
||||
external_retrieval_model={"provider": "openai"},
|
||||
attachment_ids=["att_1", "att_2"],
|
||||
)
|
||||
@ -68,6 +109,12 @@ class TestHitTestingPayload:
|
||||
payload = HitTestingPayload(query="x" * 250)
|
||||
assert len(payload.query) == 250
|
||||
|
||||
def test_payload_ignores_unknown_fields_for_compatibility(self):
|
||||
"""Top-level fields outside the documented schema remain ignored as before."""
|
||||
payload = HitTestingPayload.model_validate({"query": "test query", "top_k": 3})
|
||||
|
||||
assert payload.model_dump(exclude_none=True) == {"query": "test query"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HitTestingApi Tests
|
||||
@ -80,8 +127,11 @@ class TestHitTestingPayload:
|
||||
class TestHitTestingApiPost:
|
||||
"""Tests for HitTestingApi.post() via __wrapped__ to skip billing decorator."""
|
||||
|
||||
@staticmethod
|
||||
def _dataset(dataset_id: str, tenant_id: str) -> Dataset:
|
||||
return Dataset(id=dataset_id, tenant_id=tenant_id, name="Dataset", created_by="account-1")
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.marshal")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
@ -90,7 +140,6 @@ class TestHitTestingApiPost:
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_marshal,
|
||||
mock_ns,
|
||||
app: Flask,
|
||||
):
|
||||
@ -98,15 +147,13 @@ class TestHitTestingApiPost:
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
mock_dataset = self._dataset(dataset_id, tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": {"content": "test query"}, "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
mock_marshal.return_value = []
|
||||
|
||||
mock_ns.payload = {"query": "test query"}
|
||||
|
||||
@ -115,11 +162,10 @@ class TestHitTestingApiPost:
|
||||
# Skip billing decorator via __wrapped__
|
||||
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
assert response["query"] == "test query"
|
||||
assert response["query"] == {"content": "test query"}
|
||||
mock_hit_svc.retrieve.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.marshal")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
@ -128,7 +174,6 @@ class TestHitTestingApiPost:
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_marshal,
|
||||
mock_ns,
|
||||
app: Flask,
|
||||
):
|
||||
@ -136,8 +181,7 @@ class TestHitTestingApiPost:
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
mock_dataset = self._dataset(dataset_id, tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
@ -152,7 +196,6 @@ class TestHitTestingApiPost:
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": {"content": "complex query"}, "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
mock_marshal.return_value = []
|
||||
|
||||
mock_ns.payload = {
|
||||
"query": "complex query",
|
||||
@ -164,7 +207,7 @@ class TestHitTestingApiPost:
|
||||
api = HitTestingApi()
|
||||
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
assert response["query"] == "complex query"
|
||||
assert response["query"] == {"content": "complex query"}
|
||||
call_kwargs = mock_hit_svc.retrieve.call_args
|
||||
# retrieval_model is serialized via model_dump, verify key fields
|
||||
passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model")
|
||||
@ -173,7 +216,6 @@ class TestHitTestingApiPost:
|
||||
assert passed_retrieval_model["top_k"] == 10
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.marshal")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
@ -182,7 +224,6 @@ class TestHitTestingApiPost:
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_marshal,
|
||||
mock_ns,
|
||||
app: Flask,
|
||||
):
|
||||
@ -190,14 +231,12 @@ class TestHitTestingApiPost:
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
mock_dataset = self._dataset(dataset_id, tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
mock_hit_svc.retrieve.return_value = {"query": {"content": "filtered query"}, "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
mock_marshal.return_value = []
|
||||
|
||||
metadata_filtering_conditions = {
|
||||
"logical_operator": "and",
|
||||
@ -229,7 +268,6 @@ class TestHitTestingApiPost:
|
||||
assert passed_retrieval_model["metadata_filtering_conditions"] == metadata_filtering_conditions
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.marshal")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
@ -238,30 +276,23 @@ class TestHitTestingApiPost:
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_marshal,
|
||||
mock_ns,
|
||||
app: Flask,
|
||||
):
|
||||
"""Test service API prepares nullable list fields from marshalled records."""
|
||||
"""Test service API prepares nullable list fields from retrieval records."""
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
mock_dataset = self._dataset(dataset_id, tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
|
||||
mock_hit_svc.retrieve.return_value = {"query": {"content": "legacy query"}, "records": ["placeholder"]}
|
||||
mock_hit_svc.retrieve.return_value = {
|
||||
"query": {"content": "legacy query"},
|
||||
"records": [hit_testing_record()],
|
||||
}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
mock_marshal.return_value = [
|
||||
{
|
||||
"segment": {"id": "segment-1", "keywords": None},
|
||||
"child_chunks": None,
|
||||
"files": None,
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
|
||||
mock_ns.payload = {"query": "legacy query"}
|
||||
|
||||
@ -269,15 +300,15 @@ class TestHitTestingApiPost:
|
||||
api = HitTestingApi()
|
||||
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
assert response["query"] == "legacy query"
|
||||
assert response["records"] == [
|
||||
{
|
||||
"segment": {"id": "segment-1", "keywords": []},
|
||||
"child_chunks": [],
|
||||
"files": [],
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
assert response["query"] == {"content": "legacy query"}
|
||||
record = response["records"][0]
|
||||
assert record["segment"]["id"] == "segment-1"
|
||||
assert record["segment"]["keywords"] == []
|
||||
assert record["child_chunks"] == []
|
||||
assert record["files"] == []
|
||||
assert record["score"] == 0.9
|
||||
assert record["tsne_position"] is None
|
||||
assert record["summary"] is None
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@ -315,8 +346,7 @@ class TestHitTestingApiPost:
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
mock_dataset = self._dataset(dataset_id, tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError(
|
||||
|
||||
@ -396,8 +396,8 @@ export type HitTestingPayload = {
|
||||
}
|
||||
|
||||
export type HitTestingResponse = {
|
||||
query: string
|
||||
records?: Array<HitTestingRecord>
|
||||
query: HitTestingQuery
|
||||
records: Array<HitTestingRecord>
|
||||
}
|
||||
|
||||
export type DocumentStatusListResponse = {
|
||||
@ -666,13 +666,17 @@ export type DocumentStatusResponse = {
|
||||
total_segments?: number | null
|
||||
}
|
||||
|
||||
export type HitTestingQuery = {
|
||||
content: string
|
||||
}
|
||||
|
||||
export type HitTestingRecord = {
|
||||
child_chunks?: Array<HitTestingChildChunk>
|
||||
files?: Array<HitTestingFile>
|
||||
score?: number | null
|
||||
segment?: HitTestingSegment
|
||||
summary?: string | null
|
||||
tsne_position?: unknown
|
||||
child_chunks: Array<HitTestingChildChunk>
|
||||
files: Array<HitTestingFile>
|
||||
score: number | null
|
||||
segment: HitTestingSegment
|
||||
summary: string | null
|
||||
tsne_position: unknown
|
||||
}
|
||||
|
||||
export type DatasetMetadataListItemResponse = {
|
||||
@ -768,45 +772,45 @@ export type MetadataDetail = {
|
||||
}
|
||||
|
||||
export type HitTestingChildChunk = {
|
||||
content?: string | null
|
||||
id?: string | null
|
||||
position?: number | null
|
||||
score?: number | null
|
||||
content: string
|
||||
id: string
|
||||
position: number
|
||||
score: number
|
||||
}
|
||||
|
||||
export type HitTestingFile = {
|
||||
extension?: string | null
|
||||
id?: string | null
|
||||
mime_type?: string | null
|
||||
name?: string | null
|
||||
size?: number | null
|
||||
source_url?: string | null
|
||||
extension: string
|
||||
id: string
|
||||
mime_type: string
|
||||
name: string
|
||||
size: number
|
||||
source_url: string
|
||||
}
|
||||
|
||||
export type HitTestingSegment = {
|
||||
answer?: string | null
|
||||
completed_at?: number | null
|
||||
content?: string | null
|
||||
created_at?: number | null
|
||||
created_by?: string | null
|
||||
disabled_at?: number | null
|
||||
disabled_by?: string | null
|
||||
document?: HitTestingDocument
|
||||
document_id?: string | null
|
||||
enabled?: boolean | null
|
||||
error?: string | null
|
||||
hit_count?: number | null
|
||||
id?: string | null
|
||||
index_node_hash?: string | null
|
||||
index_node_id?: string | null
|
||||
indexing_at?: number | null
|
||||
keywords?: Array<string>
|
||||
position?: number | null
|
||||
sign_content?: string | null
|
||||
status?: string | null
|
||||
stopped_at?: number | null
|
||||
tokens?: number | null
|
||||
word_count?: number | null
|
||||
answer: string | null
|
||||
completed_at: number | null
|
||||
content: string
|
||||
created_at: number
|
||||
created_by: string
|
||||
disabled_at: number | null
|
||||
disabled_by: string | null
|
||||
document: HitTestingDocument
|
||||
document_id: string
|
||||
enabled: boolean
|
||||
error: string | null
|
||||
hit_count: number
|
||||
id: string
|
||||
index_node_hash: string | null
|
||||
index_node_id: string | null
|
||||
indexing_at: number | null
|
||||
keywords: Array<string>
|
||||
position: number
|
||||
sign_content: string | null
|
||||
status: string
|
||||
stopped_at: number | null
|
||||
tokens: number
|
||||
word_count: number
|
||||
}
|
||||
|
||||
export type DatasetQueryContentResponse = {
|
||||
@ -898,11 +902,11 @@ export type WeightVectorSetting = {
|
||||
}
|
||||
|
||||
export type HitTestingDocument = {
|
||||
data_source_type?: string | null
|
||||
doc_metadata?: unknown
|
||||
doc_type?: string | null
|
||||
id?: string | null
|
||||
name?: string | null
|
||||
data_source_type: string
|
||||
doc_metadata: unknown
|
||||
doc_type: string | null
|
||||
id: string
|
||||
name: string
|
||||
}
|
||||
|
||||
export type DatasetQueryFileInfoResponse = {
|
||||
|
||||
@ -530,6 +530,13 @@ export const zDocumentStatusListResponse = z.object({
|
||||
data: z.array(zDocumentStatusResponse),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingQuery
|
||||
*/
|
||||
export const zHitTestingQuery = z.object({
|
||||
content: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* DatasetMetadataListItemResponse
|
||||
*/
|
||||
@ -632,22 +639,22 @@ export const zMetadataOperationData = z.object({
|
||||
* HitTestingChildChunk
|
||||
*/
|
||||
export const zHitTestingChildChunk = z.object({
|
||||
content: z.string().nullish(),
|
||||
id: z.string().nullish(),
|
||||
position: z.int().nullish(),
|
||||
score: z.number().nullish(),
|
||||
content: z.string(),
|
||||
id: z.string(),
|
||||
position: z.int(),
|
||||
score: z.number(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingFile
|
||||
*/
|
||||
export const zHitTestingFile = z.object({
|
||||
extension: z.string().nullish(),
|
||||
id: z.string().nullish(),
|
||||
mime_type: z.string().nullish(),
|
||||
name: z.string().nullish(),
|
||||
size: z.int().nullish(),
|
||||
source_url: z.string().nullish(),
|
||||
extension: z.string(),
|
||||
id: z.string(),
|
||||
mime_type: z.string(),
|
||||
name: z.string(),
|
||||
size: z.int(),
|
||||
source_url: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
@ -1036,60 +1043,60 @@ export const zHitTestingPayload = z.object({
|
||||
* HitTestingDocument
|
||||
*/
|
||||
export const zHitTestingDocument = z.object({
|
||||
data_source_type: z.string().nullish(),
|
||||
doc_metadata: z.unknown().optional(),
|
||||
doc_type: z.string().nullish(),
|
||||
id: z.string().nullish(),
|
||||
name: z.string().nullish(),
|
||||
data_source_type: z.string(),
|
||||
doc_metadata: z.unknown(),
|
||||
doc_type: z.string().nullable(),
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingSegment
|
||||
*/
|
||||
export const zHitTestingSegment = z.object({
|
||||
answer: z.string().nullish(),
|
||||
completed_at: z.int().nullish(),
|
||||
content: z.string().nullish(),
|
||||
created_at: z.int().nullish(),
|
||||
created_by: z.string().nullish(),
|
||||
disabled_at: z.int().nullish(),
|
||||
disabled_by: z.string().nullish(),
|
||||
document: zHitTestingDocument.optional(),
|
||||
document_id: z.string().nullish(),
|
||||
enabled: z.boolean().nullish(),
|
||||
error: z.string().nullish(),
|
||||
hit_count: z.int().nullish(),
|
||||
id: z.string().nullish(),
|
||||
index_node_hash: z.string().nullish(),
|
||||
index_node_id: z.string().nullish(),
|
||||
indexing_at: z.int().nullish(),
|
||||
keywords: z.array(z.string()).optional(),
|
||||
position: z.int().nullish(),
|
||||
sign_content: z.string().nullish(),
|
||||
status: z.string().nullish(),
|
||||
stopped_at: z.int().nullish(),
|
||||
tokens: z.int().nullish(),
|
||||
word_count: z.int().nullish(),
|
||||
answer: z.string().nullable(),
|
||||
completed_at: z.int().nullable(),
|
||||
content: z.string(),
|
||||
created_at: z.int(),
|
||||
created_by: z.string(),
|
||||
disabled_at: z.int().nullable(),
|
||||
disabled_by: z.string().nullable(),
|
||||
document: zHitTestingDocument,
|
||||
document_id: z.string(),
|
||||
enabled: z.boolean(),
|
||||
error: z.string().nullable(),
|
||||
hit_count: z.int(),
|
||||
id: z.string(),
|
||||
index_node_hash: z.string().nullable(),
|
||||
index_node_id: z.string().nullable(),
|
||||
indexing_at: z.int().nullable(),
|
||||
keywords: z.array(z.string()),
|
||||
position: z.int(),
|
||||
sign_content: z.string().nullable(),
|
||||
status: z.string(),
|
||||
stopped_at: z.int().nullable(),
|
||||
tokens: z.int(),
|
||||
word_count: z.int(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingRecord
|
||||
*/
|
||||
export const zHitTestingRecord = z.object({
|
||||
child_chunks: z.array(zHitTestingChildChunk).optional(),
|
||||
files: z.array(zHitTestingFile).optional(),
|
||||
score: z.number().nullish(),
|
||||
segment: zHitTestingSegment.optional(),
|
||||
summary: z.string().nullish(),
|
||||
tsne_position: z.unknown().optional(),
|
||||
child_chunks: z.array(zHitTestingChildChunk),
|
||||
files: z.array(zHitTestingFile),
|
||||
score: z.number().nullable(),
|
||||
segment: zHitTestingSegment,
|
||||
summary: z.string().nullable(),
|
||||
tsne_position: z.unknown(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingResponse
|
||||
*/
|
||||
export const zHitTestingResponse = z.object({
|
||||
query: z.string(),
|
||||
records: z.array(zHitTestingRecord).optional(),
|
||||
query: zHitTestingQuery,
|
||||
records: z.array(zHitTestingRecord),
|
||||
})
|
||||
|
||||
/**
|
||||
|
||||
@ -469,6 +469,30 @@ export type FileResponse = {
|
||||
user_id?: string | null
|
||||
}
|
||||
|
||||
export type HitTestingChildChunk = {
|
||||
content: string
|
||||
id: string
|
||||
position: number
|
||||
score: number
|
||||
}
|
||||
|
||||
export type HitTestingDocument = {
|
||||
data_source_type: string
|
||||
doc_metadata: unknown
|
||||
doc_type: string | null
|
||||
id: string
|
||||
name: string
|
||||
}
|
||||
|
||||
export type HitTestingFile = {
|
||||
extension: string
|
||||
id: string
|
||||
mime_type: string
|
||||
name: string
|
||||
size: number
|
||||
source_url: string
|
||||
}
|
||||
|
||||
export type HitTestingPayload = {
|
||||
attachment_ids?: Array<string> | null
|
||||
external_retrieval_model?: {
|
||||
@ -478,6 +502,50 @@ export type HitTestingPayload = {
|
||||
retrieval_model?: RetrievalModel
|
||||
}
|
||||
|
||||
export type HitTestingQuery = {
|
||||
content: string
|
||||
}
|
||||
|
||||
export type HitTestingRecord = {
|
||||
child_chunks: Array<HitTestingChildChunk>
|
||||
files: Array<HitTestingFile>
|
||||
score: number | null
|
||||
segment: HitTestingSegment
|
||||
summary: string | null
|
||||
tsne_position: unknown
|
||||
}
|
||||
|
||||
export type HitTestingResponse = {
|
||||
query: HitTestingQuery
|
||||
records: Array<HitTestingRecord>
|
||||
}
|
||||
|
||||
export type HitTestingSegment = {
|
||||
answer: string | null
|
||||
completed_at: number | null
|
||||
content: string
|
||||
created_at: number
|
||||
created_by: string
|
||||
disabled_at: number | null
|
||||
disabled_by: string | null
|
||||
document: HitTestingDocument
|
||||
document_id: string
|
||||
enabled: boolean
|
||||
error: string | null
|
||||
hit_count: number
|
||||
id: string
|
||||
index_node_hash: string | null
|
||||
index_node_id: string | null
|
||||
indexing_at: number | null
|
||||
keywords: Array<string>
|
||||
position: number
|
||||
sign_content: string | null
|
||||
status: string
|
||||
stopped_at: number | null
|
||||
tokens: number
|
||||
word_count: number
|
||||
}
|
||||
|
||||
export type HumanInputFormSubmitPayload = {
|
||||
action: string
|
||||
inputs: {
|
||||
@ -2510,9 +2578,7 @@ export type PostDatasetsByDatasetIdHitTestingError
|
||||
= PostDatasetsByDatasetIdHitTestingErrors[keyof PostDatasetsByDatasetIdHitTestingErrors]
|
||||
|
||||
export type PostDatasetsByDatasetIdHitTestingResponses = {
|
||||
200: {
|
||||
[key: string]: unknown
|
||||
}
|
||||
200: HitTestingResponse
|
||||
}
|
||||
|
||||
export type PostDatasetsByDatasetIdHitTestingResponse
|
||||
@ -2794,9 +2860,7 @@ export type PostDatasetsByDatasetIdRetrieveError
|
||||
= PostDatasetsByDatasetIdRetrieveErrors[keyof PostDatasetsByDatasetIdRetrieveErrors]
|
||||
|
||||
export type PostDatasetsByDatasetIdRetrieveResponses = {
|
||||
200: {
|
||||
[key: string]: unknown
|
||||
}
|
||||
200: HitTestingResponse
|
||||
}
|
||||
|
||||
export type PostDatasetsByDatasetIdRetrieveResponse
|
||||
|
||||
@ -553,6 +553,95 @@ export const zFileResponse = z.object({
|
||||
user_id: z.string().nullish(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingChildChunk
|
||||
*/
|
||||
export const zHitTestingChildChunk = z.object({
|
||||
content: z.string(),
|
||||
id: z.string(),
|
||||
position: z.int(),
|
||||
score: z.number(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingDocument
|
||||
*/
|
||||
export const zHitTestingDocument = z.object({
|
||||
data_source_type: z.string(),
|
||||
doc_metadata: z.unknown(),
|
||||
doc_type: z.string().nullable(),
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingFile
|
||||
*/
|
||||
export const zHitTestingFile = z.object({
|
||||
extension: z.string(),
|
||||
id: z.string(),
|
||||
mime_type: z.string(),
|
||||
name: z.string(),
|
||||
size: z.int(),
|
||||
source_url: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingQuery
|
||||
*/
|
||||
export const zHitTestingQuery = z.object({
|
||||
content: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingSegment
|
||||
*/
|
||||
export const zHitTestingSegment = z.object({
|
||||
answer: z.string().nullable(),
|
||||
completed_at: z.int().nullable(),
|
||||
content: z.string(),
|
||||
created_at: z.int(),
|
||||
created_by: z.string(),
|
||||
disabled_at: z.int().nullable(),
|
||||
disabled_by: z.string().nullable(),
|
||||
document: zHitTestingDocument,
|
||||
document_id: z.string(),
|
||||
enabled: z.boolean(),
|
||||
error: z.string().nullable(),
|
||||
hit_count: z.int(),
|
||||
id: z.string(),
|
||||
index_node_hash: z.string().nullable(),
|
||||
index_node_id: z.string().nullable(),
|
||||
indexing_at: z.int().nullable(),
|
||||
keywords: z.array(z.string()),
|
||||
position: z.int(),
|
||||
sign_content: z.string().nullable(),
|
||||
status: z.string(),
|
||||
stopped_at: z.int().nullable(),
|
||||
tokens: z.int(),
|
||||
word_count: z.int(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingRecord
|
||||
*/
|
||||
export const zHitTestingRecord = z.object({
|
||||
child_chunks: z.array(zHitTestingChildChunk),
|
||||
files: z.array(zHitTestingFile),
|
||||
score: z.number().nullable(),
|
||||
segment: zHitTestingSegment,
|
||||
summary: z.string().nullable(),
|
||||
tsne_position: z.unknown(),
|
||||
})
|
||||
|
||||
/**
|
||||
* HitTestingResponse
|
||||
*/
|
||||
export const zHitTestingResponse = z.object({
|
||||
query: zHitTestingQuery,
|
||||
records: z.array(zHitTestingRecord),
|
||||
})
|
||||
|
||||
/**
|
||||
* IndexInfoResponse
|
||||
*/
|
||||
@ -1720,7 +1809,7 @@ export const zPostDatasetsByDatasetIdHitTestingPath = z.object({
|
||||
/**
|
||||
* Hit testing results
|
||||
*/
|
||||
export const zPostDatasetsByDatasetIdHitTestingResponse = z.record(z.string(), z.unknown())
|
||||
export const zPostDatasetsByDatasetIdHitTestingResponse = zHitTestingResponse
|
||||
|
||||
export const zGetDatasetsByDatasetIdMetadataPath = z.object({
|
||||
dataset_id: z.string(),
|
||||
@ -1834,7 +1923,7 @@ export const zPostDatasetsByDatasetIdRetrievePath = z.object({
|
||||
/**
|
||||
* Hit testing results
|
||||
*/
|
||||
export const zPostDatasetsByDatasetIdRetrieveResponse = z.record(z.string(), z.unknown())
|
||||
export const zPostDatasetsByDatasetIdRetrieveResponse = zHitTestingResponse
|
||||
|
||||
export const zGetDatasetsByDatasetIdTagsPath = z.object({
|
||||
dataset_id: z.string(),
|
||||
|
||||
Reference in New Issue
Block a user