refactor: select in console app message controller (#33893)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo
2026-03-23 08:38:04 +01:00
committed by GitHub
parent a942d4c926
commit 02e13e6d05
2 changed files with 31 additions and 29 deletions

View File

@ -4,7 +4,7 @@ from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from sqlalchemy import exists, func, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@ -244,27 +244,25 @@ class ChatMessageListApi(Resource):
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = (
db.session.query(Conversation)
conversation = db.session.scalar(
select(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first()
.limit(1)
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args.first_id:
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first()
first_message = db.session.scalar(
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
)
if not first_message:
raise NotFound("First message not found")
history_messages = (
db.session.query(Message)
history_messages = db.session.scalars(
select(Message)
.where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
@ -272,16 +270,14 @@ class ChatMessageListApi(Resource):
)
.order_by(Message.created_at.desc())
.limit(args.limit)
.all()
)
).all()
else:
history_messages = (
db.session.query(Message)
history_messages = db.session.scalars(
select(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args.limit)
.all()
)
).all()
# Initialize has_more based on whether we have a full page
if len(history_messages) == args.limit:
@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource):
message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message:
raise NotFound("Message Not Exists.")
@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
count = db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
)
return {"count": count}
@ -479,7 +479,9 @@ class MessageApi(Resource):
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message:
raise NotFound("Message Not Exists.")

View File

@ -170,7 +170,7 @@ class TestMessageEndpoints:
mock_app_model,
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
with pytest.raises(NotFound):
api.get(**v_args)
@ -198,11 +198,11 @@ class TestMessageEndpoints:
mock_msg.message = {}
mock_msg.message_metadata_dict = {}
# mock returns
q_mock = mock_db.data_query
q_mock.where.return_value.first.side_effect = [mock_conv]
q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg]
mock_db.session.scalar.return_value = False
# scalar() is called twice: first for conversation lookup, second for has_more check
mock_db.session.scalar.side_effect = [mock_conv, False]
scalars_result = MagicMock()
scalars_result.all.return_value = [mock_msg]
mock_db.session.scalars.return_value = scalars_result
resp = api.get(**v_args)
assert resp["limit"] == 1
@ -219,7 +219,7 @@ class TestMessageEndpoints:
mock_app_model,
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
with pytest.raises(NotFound):
api.post(**v_args)
@ -231,7 +231,7 @@ class TestMessageEndpoints:
) as (api, mock_db, v_args):
mock_msg = MagicMock()
mock_msg.admin_feedback = None
mock_db.data_query.where.return_value.first.return_value = mock_msg
mock_db.session.scalar.return_value = mock_msg
resp = api.post(**v_args)
assert resp == {"result": "success"}
@ -240,7 +240,7 @@ class TestMessageEndpoints:
with setup_test_context(
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
) as (api, mock_db, v_args):
mock_db.data_query.where.return_value.count.return_value = 5
mock_db.session.scalar.return_value = 5
resp = api.get(**v_args)
assert resp == {"count": 5}
@ -314,7 +314,7 @@ class TestMessageEndpoints:
mock_msg.message = {}
mock_msg.message_metadata_dict = {}
mock_db.data_query.where.return_value.first.return_value = mock_msg
mock_db.session.scalar.return_value = mock_msg
resp = api.get(**v_args)
assert resp["id"] == "msg_123"