refactor: select in console datasets segments and API key controllers (#34027)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-03-25 04:46:22 +01:00
committed by GitHub
parent b4e541e11a
commit 4c32acf857
4 changed files with 86 additions and 93 deletions

View File

@ -3,7 +3,7 @@ from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy import func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -738,20 +738,23 @@ class DatasetIndexingStatusApi(Resource):
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
total_segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
.count()
or 0
)
# Create a dictionary with document attributes and additional fields
document_dict = {
@ -802,9 +805,12 @@ class DatasetApiKeyApi(Resource):
_, current_tenant_id = current_account_with_tenant()
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count()
db.session.scalar(
select(func.count(ApiToken.id)).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id
)
)
or 0
)
if current_key_count >= self.max_keys:
@ -839,14 +845,14 @@ class DatasetApiDeleteApi(Resource):
def delete(self, api_key_id):
_, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id)
key = (
db.session.query(ApiToken)
key = db.session.scalar(
select(ApiToken)
.where(
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
.limit(1)
)
if key is None:
@ -857,7 +863,7 @@ class DatasetApiDeleteApi(Resource):
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.delete(key)
db.session.commit()
return {"result": "success"}, 204

View File

@ -401,10 +401,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -447,10 +447,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -494,7 +494,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
payload = BatchImportPayload.model_validate(console_ns.payload or {})
upload_file_id = payload.upload_file_id
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
if not upload_file:
raise NotFound("UploadFile not found.")
@ -559,10 +559,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -616,10 +616,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -666,10 +666,10 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
@ -714,24 +714,24 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
.first()
.limit(1)
)
if not child_chunk:
raise NotFound("Child chunk not found.")
@ -771,24 +771,24 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
.first()
.limit(1)
)
if not child_chunk:
raise NotFound("Child chunk not found.")