Files
dify/api/controllers/console/datasets/hit_testing_base.py
FFXN 1e4b2995d4 feat: dev snippet fronted (#36748)
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: EvanYao826 <155432245+EvanYao826@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: 盐粒 Yanli <yanli@dify.ai>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Tianle <40735546+Tianlel@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
Co-authored-by: zyssyz123 <916125788@qq.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: chariri <w@chariri.moe>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Nian <11332799+Lillian68@users.noreply.github.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Carmen Fernández Ruiz <279459669+zeus1959@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: L1nSn0w <l1nsn0w@qq.com>
Co-authored-by: Evan <2869018789@qq.com>
Co-authored-by: Escape0707 <tothesong@gmail.com>
Co-authored-by: Jingyi <jingyi.qi@dify.ai>
Co-authored-by: Amr Sherif <140330826+amr-sheriff@users.noreply.github.com>
Co-authored-by: ZHOU ZHICHEN <118870511+zhuiguangzhe2003@users.noreply.github.com>
Co-authored-by: unknown <EI05187@apwx.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
2026-05-28 11:01:49 +08:00

138 lines
5.4 KiB
Python

import logging
from typing import Any, cast
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import current_user
from models.account import Account
from models.dataset import Dataset
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: RetrievalModel | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None
class DatasetsHitTestingBase:
@staticmethod
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
"""Ensure collection fields match the API schema before response validation."""
if not isinstance(records, list):
raise ValueError("Invalid hit testing records response")
normalized_records: list[dict[str, Any]] = []
for record in records:
if not isinstance(record, dict):
raise ValueError("Invalid hit testing record response")
normalized_record = dict(record)
segment = normalized_record.get("segment")
if isinstance(segment, dict):
normalized_segment = dict(segment)
normalized_segment.setdefault("sign_content", None)
if normalized_segment.get("keywords") is None:
normalized_segment["keywords"] = []
normalized_record["segment"] = normalized_segment
if normalized_record.get("child_chunks") is None:
normalized_record["child_chunks"] = []
if normalized_record.get("files") is None:
normalized_record["files"] = []
normalized_record.setdefault("tsne_position", None)
normalized_record.setdefault("summary", None)
normalized_records.append(normalized_record)
return normalized_records
@staticmethod
def get_and_validate_dataset(dataset_id: str) -> Dataset:
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return dataset
@staticmethod
def hit_testing_args_check(args: dict[str, Any]) -> None:
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args(payload: dict[str, Any] | None) -> dict[str, Any]:
"""Validate and return hit-testing arguments from an incoming payload."""
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
return hit_testing_payload.model_dump(exclude_none=True)
@staticmethod
def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]:
assert isinstance(current_user, Account)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=cast(str, args.get("query")),
account=current_user,
retrieval_model=args.get("retrieval_model"),
external_retrieval_model=cast(dict[str, Any], args.get("external_retrieval_model")),
attachment_ids=args.get("attachment_ids"),
limit=10,
)
query = response.get("query")
if not isinstance(query, dict) or not isinstance(query.get("content"), str):
raise ValueError("Invalid hit testing query response")
return {
"query": {"content": query["content"]},
"records": DatasetsHitTestingBase._prepare_hit_testing_records(response.get("records", [])),
}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logger.exception("Hit testing failed.")
raise InternalServerError(str(e))