mirror of
https://github.com/langgenius/dify.git
synced 2026-05-28 21:03:22 +08:00
refactor(api): migrate console/service_api.dataset.hit_testing to BaseModel (#36533)
This commit is contained in:
@ -1,62 +1,96 @@
|
||||
from flask_restx import fields
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from pydantic import field_validator
|
||||
|
||||
document_fields = {
|
||||
"id": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"name": fields.String,
|
||||
"doc_type": fields.String,
|
||||
"doc_metadata": fields.Raw,
|
||||
}
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
|
||||
segment_fields = {
|
||||
"id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"document_id": fields.String,
|
||||
"content": fields.String,
|
||||
"sign_content": fields.String,
|
||||
"answer": fields.String,
|
||||
"word_count": fields.Integer,
|
||||
"tokens": fields.Integer,
|
||||
"keywords": fields.List(fields.String),
|
||||
"index_node_id": fields.String,
|
||||
"index_node_hash": fields.String,
|
||||
"hit_count": fields.Integer,
|
||||
"enabled": fields.Boolean,
|
||||
"disabled_at": TimestampField,
|
||||
"disabled_by": fields.String,
|
||||
"status": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"indexing_at": TimestampField,
|
||||
"completed_at": TimestampField,
|
||||
"error": fields.String,
|
||||
"stopped_at": TimestampField,
|
||||
"document": fields.Nested(document_fields),
|
||||
}
|
||||
|
||||
child_chunk_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"position": fields.Integer,
|
||||
"score": fields.Float,
|
||||
}
|
||||
class HitTestingQuery(ResponseModel):
|
||||
content: str
|
||||
|
||||
files_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"size": fields.Integer,
|
||||
"extension": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"source_url": fields.String,
|
||||
}
|
||||
|
||||
hit_testing_record_fields = {
|
||||
"segment": fields.Nested(segment_fields),
|
||||
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
||||
"score": fields.Float,
|
||||
"tsne_position": fields.Raw,
|
||||
"files": fields.List(fields.Nested(files_fields)),
|
||||
"summary": fields.String, # Summary content if retrieved via summary index
|
||||
}
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str
|
||||
data_source_type: str
|
||||
name: str
|
||||
doc_type: str | None
|
||||
doc_metadata: Any | None
|
||||
|
||||
@field_validator("data_source_type", "doc_type", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
|
||||
class HitTestingSegment(ResponseModel):
|
||||
id: str
|
||||
position: int
|
||||
document_id: str
|
||||
content: str
|
||||
sign_content: str | None
|
||||
answer: str | None
|
||||
word_count: int
|
||||
tokens: int
|
||||
keywords: list[str]
|
||||
index_node_id: str | None
|
||||
index_node_hash: str | None
|
||||
hit_count: int
|
||||
enabled: bool
|
||||
disabled_at: int | None
|
||||
disabled_by: str | None
|
||||
status: str
|
||||
created_by: str
|
||||
created_at: int
|
||||
indexing_at: int | None
|
||||
completed_at: int | None
|
||||
error: str | None
|
||||
stopped_at: int | None
|
||||
document: HitTestingDocument
|
||||
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
@field_validator("status", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
id: str
|
||||
content: str
|
||||
position: int
|
||||
score: float
|
||||
|
||||
|
||||
class HitTestingFile(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
|
||||
|
||||
class HitTestingRecord(ResponseModel):
|
||||
segment: HitTestingSegment
|
||||
child_chunks: list[HitTestingChildChunk]
|
||||
score: float | None
|
||||
tsne_position: Any | None
|
||||
files: list[HitTestingFile]
|
||||
summary: str | None
|
||||
|
||||
|
||||
class HitTestingResponse(ResponseModel):
|
||||
query: HitTestingQuery
|
||||
records: list[HitTestingRecord]
|
||||
|
||||
|
||||
def _normalize_enum(value: Any) -> Any:
|
||||
if isinstance(value, str) or value is None:
|
||||
return value
|
||||
return getattr(value, "value", value)
|
||||
|
||||
Reference in New Issue
Block a user