mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
test: improve code-cov for controller tests (#33225)
This commit is contained in:
320
api/tests/unit_tests/controllers/console/app/test_message.py
Normal file
320
api/tests/unit_tests/controllers/console/app/test_message.py
Normal file
@ -0,0 +1,320 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.message import (
|
||||
ChatMessageListApi,
|
||||
ChatMessagesQuery,
|
||||
FeedbackExportQuery,
|
||||
MessageAnnotationCountApi,
|
||||
MessageApi,
|
||||
MessageFeedbackApi,
|
||||
MessageFeedbackExportApi,
|
||||
MessageFeedbackPayload,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from models import App, AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
from models.account import Account, AccountStatus
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user_123"
|
||||
account.timezone = "UTC"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.is_admin_or_owner = True
|
||||
account.current_tenant.current_role = "owner"
|
||||
account.has_edit_permission = True
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app_123"
|
||||
app_model.mode = AppMode.CHAT
|
||||
app_model.tenant_id = "tenant_123"
|
||||
return app_model
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_csrf():
|
||||
with patch("libs.login.check_csrf_token") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
import contextlib
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def setup_test_context(
|
||||
test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None
|
||||
):
|
||||
with (
|
||||
patch("extensions.ext_database.db") as mock_db,
|
||||
patch("controllers.console.app.wraps.db", mock_db),
|
||||
patch("controllers.console.wraps.db", mock_db),
|
||||
patch("controllers.console.app.message.db", mock_db),
|
||||
patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")),
|
||||
):
|
||||
# Set up a generic query mock that usually returns mock_app_model when getting app
|
||||
app_query_mock = MagicMock()
|
||||
app_query_mock.filter.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.where.return_value.first.return_value = mock_app_model
|
||||
app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model
|
||||
|
||||
data_query_mock = MagicMock()
|
||||
|
||||
def query_side_effect(*args, **kwargs):
|
||||
if args and hasattr(args[0], "__name__") and args[0].__name__ == "App":
|
||||
return app_query_mock
|
||||
return data_query_mock
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
mock_db.data_query = data_query_mock
|
||||
|
||||
# Let the caller override the stat db query logic
|
||||
proxy_mock = LocalProxy(lambda: mock_account)
|
||||
|
||||
query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()])
|
||||
full_path = f"{route_path}?{query_string}" if qs else route_path
|
||||
|
||||
with (
|
||||
patch("libs.login.current_user", proxy_mock),
|
||||
patch("flask_login.current_user", proxy_mock),
|
||||
patch("controllers.console.app.message.attach_message_extra_contents", return_value=None),
|
||||
):
|
||||
with test_app.test_request_context(full_path, method=method, json=payload):
|
||||
request.view_args = {"app_id": "app_123"}
|
||||
|
||||
if "suggested-questions" in route_path:
|
||||
# simplistic extraction for message_id
|
||||
parts = route_path.split("chat-messages/")
|
||||
if len(parts) > 1:
|
||||
request.view_args["message_id"] = parts[1].split("/")[0]
|
||||
elif "messages/" in route_path and "chat-messages" not in route_path:
|
||||
parts = route_path.split("messages/")
|
||||
if len(parts) > 1:
|
||||
request.view_args["message_id"] = parts[1].split("/")[0]
|
||||
|
||||
api_instance = endpoint_class()
|
||||
|
||||
# Check if it has a dispatch_request or method
|
||||
if hasattr(api_instance, method.lower()):
|
||||
yield api_instance, mock_db, request.view_args
|
||||
|
||||
|
||||
class TestMessageValidators:
|
||||
def test_chat_messages_query_validators(self):
|
||||
# Test empty_to_none
|
||||
assert ChatMessagesQuery.empty_to_none("") is None
|
||||
assert ChatMessagesQuery.empty_to_none("val") == "val"
|
||||
|
||||
# Test validate_uuid
|
||||
assert ChatMessagesQuery.validate_uuid(None) is None
|
||||
assert (
|
||||
ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_message_feedback_validators(self):
|
||||
assert (
|
||||
MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000")
|
||||
== "123e4567-e89b-12d3-a456-426614174000"
|
||||
)
|
||||
|
||||
def test_feedback_export_validators(self):
|
||||
assert FeedbackExportQuery.parse_bool(None) is None
|
||||
assert FeedbackExportQuery.parse_bool(True) is True
|
||||
assert FeedbackExportQuery.parse_bool("1") is True
|
||||
assert FeedbackExportQuery.parse_bool("0") is False
|
||||
assert FeedbackExportQuery.parse_bool("off") is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackExportQuery.parse_bool("invalid")
|
||||
|
||||
|
||||
class TestMessageEndpoints:
|
||||
def test_chat_message_list_not_found(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
ChatMessageListApi,
|
||||
"/apps/app_123/chat-messages",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
api.get(**v_args)
|
||||
|
||||
def test_chat_message_list_success(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
ChatMessageListApi,
|
||||
"/apps/app_123/chat-messages",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_conv = MagicMock()
|
||||
mock_conv.id = "123e4567-e89b-12d3-a456-426614174000"
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = "msg_123"
|
||||
mock_msg.feedbacks = []
|
||||
mock_msg.annotation = None
|
||||
mock_msg.annotation_hit_history = None
|
||||
mock_msg.agent_thoughts = []
|
||||
mock_msg.message_files = []
|
||||
mock_msg.extra_contents = []
|
||||
mock_msg.message = {}
|
||||
mock_msg.message_metadata_dict = {}
|
||||
|
||||
# mock returns
|
||||
q_mock = mock_db.data_query
|
||||
q_mock.where.return_value.first.side_effect = [mock_conv]
|
||||
q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg]
|
||||
mock_db.session.scalar.return_value = False
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["limit"] == 1
|
||||
assert resp["has_more"] is False
|
||||
assert len(resp["data"]) == 1
|
||||
|
||||
def test_message_feedback_not_found(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageFeedbackApi,
|
||||
"/apps/app_123/feedbacks",
|
||||
"POST",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"},
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
api.post(**v_args)
|
||||
|
||||
def test_message_feedback_success(self, app, mock_account, mock_app_model):
|
||||
payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"}
|
||||
with setup_test_context(
|
||||
app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload
|
||||
) as (api, mock_db, v_args):
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.admin_feedback = None
|
||||
mock_db.data_query.where.return_value.first.return_value = mock_msg
|
||||
|
||||
resp = api.post(**v_args)
|
||||
assert resp == {"result": "success"}
|
||||
|
||||
def test_message_annotation_count(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
mock_db.data_query.where.return_value.count.return_value = 5
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"count": 5}
|
||||
|
||||
@patch("controllers.console.app.message.MessageService")
|
||||
def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model):
|
||||
mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"]
|
||||
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageSuggestedQuestionApi,
|
||||
"/apps/app_123/chat-messages/msg_123/suggested-questions",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
) as (api, mock_db, v_args):
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"data": ["q1", "q2"]}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected_exc"),
|
||||
[
|
||||
(MessageNotExistsError, NotFound),
|
||||
(ConversationNotExistsError, NotFound),
|
||||
(ProviderTokenNotInitError, ProviderNotInitializeError),
|
||||
(QuotaExceededError, ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError),
|
||||
(SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError),
|
||||
(Exception, InternalServerError),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.app.message.MessageService")
|
||||
def test_message_suggested_questions_errors(
|
||||
self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model
|
||||
):
|
||||
mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc()
|
||||
|
||||
with setup_test_context(
|
||||
app,
|
||||
MessageSuggestedQuestionApi,
|
||||
"/apps/app_123/chat-messages/msg_123/suggested-questions",
|
||||
"GET",
|
||||
mock_account,
|
||||
mock_app_model,
|
||||
) as (api, mock_db, v_args):
|
||||
with pytest.raises(expected_exc):
|
||||
api.get(**v_args)
|
||||
|
||||
@patch("services.feedback_service.FeedbackService.export_feedbacks")
|
||||
def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model):
|
||||
mock_export.return_value = {"exported": True}
|
||||
|
||||
with setup_test_context(
|
||||
app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
resp = api.get(**v_args)
|
||||
assert resp == {"exported": True}
|
||||
|
||||
def test_message_api_get_success(self, app, mock_account, mock_app_model):
|
||||
with setup_test_context(
|
||||
app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model
|
||||
) as (api, mock_db, v_args):
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = "msg_123"
|
||||
mock_msg.feedbacks = []
|
||||
mock_msg.annotation = None
|
||||
mock_msg.annotation_hit_history = None
|
||||
mock_msg.agent_thoughts = []
|
||||
mock_msg.message_files = []
|
||||
mock_msg.extra_contents = []
|
||||
mock_msg.message = {}
|
||||
mock_msg.message_metadata_dict = {}
|
||||
|
||||
mock_db.data_query.where.return_value.first.return_value = mock_msg
|
||||
|
||||
resp = api.get(**v_args)
|
||||
assert resp["id"] == "msg_123"
|
||||
Reference in New Issue
Block a user