mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
refactor: replace request.args.get with Pydantic BaseModel validation (#31104)
Co-authored-by: GlobalStar117 <GlobalStar117@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@ -87,6 +87,14 @@ class TagUnbindingPayload(BaseModel):
|
||||
target_id: str
|
||||
|
||||
|
||||
class DatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
include_all: bool = Field(default=False, description="Include all datasets")
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
DatasetCreatePayload,
|
||||
@ -96,6 +104,7 @@ register_schema_models(
|
||||
TagDeletePayload,
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
DatasetListQuery,
|
||||
)
|
||||
|
||||
|
||||
@ -113,15 +122,11 @@ class DatasetListApi(DatasetApiResource):
|
||||
)
|
||||
def get(self, tenant_id):
|
||||
"""Resource for getting datasets."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
query = DatasetListQuery.model_validate(request.args.to_dict(flat=False))
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, tenant_id, current_user, search, tag_ids, include_all
|
||||
query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all
|
||||
)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
@ -147,7 +152,13 @@ class DatasetListApi(DatasetApiResource):
|
||||
item["embedding_available"] = False
|
||||
else:
|
||||
item["embedding_available"] = True
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
|
||||
|
||||
@ -69,7 +69,14 @@ class DocumentTextUpdate(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
|
||||
class DocumentListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
status: str | None = Field(default=None, description="Document status filter")
|
||||
|
||||
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate, DocumentListQuery]:
|
||||
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
|
||||
|
||||
|
||||
@ -460,34 +467,33 @@ class DocumentListApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
query_params = DocumentListQuery.model_validate(request.args.to_dict())
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
if query_params.status:
|
||||
query = DocumentService.apply_display_status_filter(query, query_params.status)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
if query_params.keyword:
|
||||
search = f"%{query_params.keyword}%"
|
||||
query = query.where(Document.name.like(search))
|
||||
|
||||
query = query.order_by(desc(Document.created_at), desc(Document.position))
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
paginated_documents = db.paginate(
|
||||
select=query, page=query_params.page, per_page=query_params.limit, max_per_page=100, error_out=False
|
||||
)
|
||||
documents = paginated_documents.items
|
||||
|
||||
response = {
|
||||
"data": marshal(documents, document_fields),
|
||||
"has_more": len(documents) == limit,
|
||||
"limit": limit,
|
||||
"has_more": len(documents) == query_params.limit,
|
||||
"limit": query_params.limit,
|
||||
"total": paginated_documents.total,
|
||||
"page": page,
|
||||
"page": query_params.page,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user