Migrate to DeclarativeBaseModel

This commit is contained in:
Yeuoly
2024-10-21 20:38:27 +08:00
parent 53e1b45d40
commit 11270a7ef2
9 changed files with 138 additions and 78 deletions

View File

@ -5,7 +5,8 @@ from datetime import datetime, timezone
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc
from sqlalchemy import asc, desc, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -104,7 +105,8 @@ class GetProcessRuleApi(Resource):
rules = DocumentService.DEFAULT_RULES["rules"]
if document_id:
# get the latest process rule
document = Document.query.get_or_404(document_id)
with Session(db.engine) as session:
document = session.execute(select(Document).get_or_404(document_id)).scalar_one_or_none()
dataset = DatasetService.get_dataset(document.dataset_id)
@ -167,7 +169,10 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
with Session(db.engine) as session:
query = session.execute(
select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
).all()
if search:
search = f"%{search}%"
@ -204,18 +209,25 @@ class DatasetDocumentListApi(Resource):
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
if fetch:
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
with Session(db.engine) as session:
for document in documents:
completed_segments = (
session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
else:
data = marshal(documents, document_fields)
response = {