mirror of
https://github.com/langgenius/dify.git
synced 2026-03-11 10:17:50 +08:00
162 lines
6.9 KiB
Python
162 lines
6.9 KiB
Python
"""Unit tests for controllers.web.completion endpoints."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from flask import Flask
|
|
|
|
from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
|
from controllers.web.error import (
|
|
CompletionRequestError,
|
|
NotChatAppError,
|
|
NotCompletionAppError,
|
|
ProviderModelCurrentlyNotSupportError,
|
|
ProviderNotInitializeError,
|
|
ProviderQuotaExceededError,
|
|
)
|
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
|
from dify_graph.model_runtime.errors.invoke import InvokeError
|
|
|
|
|
|
def _completion_app() -> SimpleNamespace:
|
|
return SimpleNamespace(id="app-1", mode="completion")
|
|
|
|
|
|
def _chat_app() -> SimpleNamespace:
|
|
return SimpleNamespace(id="app-1", mode="chat")
|
|
|
|
|
|
def _end_user() -> SimpleNamespace:
|
|
return SimpleNamespace(id="eu-1")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CompletionApi
|
|
# ---------------------------------------------------------------------------
|
|
class TestCompletionApi:
|
|
def test_wrong_mode_raises(self, app: Flask) -> None:
|
|
with app.test_request_context("/completion-messages", method="POST"):
|
|
with pytest.raises(NotCompletionAppError):
|
|
CompletionApi().post(_chat_app(), _end_user())
|
|
|
|
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "hi"})
|
|
@patch("controllers.web.completion.AppGenerateService.generate")
|
|
@patch("controllers.web.completion.web_ns")
|
|
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
|
mock_ns.payload = {"inputs": {}, "query": "test"}
|
|
mock_gen.return_value = "response-obj"
|
|
|
|
with app.test_request_context("/completion-messages", method="POST"):
|
|
result = CompletionApi().post(_completion_app(), _end_user())
|
|
|
|
assert result == {"answer": "hi"}
|
|
|
|
@patch(
|
|
"controllers.web.completion.AppGenerateService.generate",
|
|
side_effect=ProviderTokenNotInitError(description="not init"),
|
|
)
|
|
@patch("controllers.web.completion.web_ns")
|
|
def test_provider_not_init_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
|
mock_ns.payload = {"inputs": {}}
|
|
|
|
with app.test_request_context("/completion-messages", method="POST"):
|
|
with pytest.raises(ProviderNotInitializeError):
|
|
CompletionApi().post(_completion_app(), _end_user())
|
|
|
|
@patch(
|
|
"controllers.web.completion.AppGenerateService.generate",
|
|
side_effect=QuotaExceededError(),
|
|
)
|
|
@patch("controllers.web.completion.web_ns")
|
|
def test_quota_exceeded_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
|
mock_ns.payload = {"inputs": {}}
|
|
|
|
with app.test_request_context("/completion-messages", method="POST"):
|
|
with pytest.raises(ProviderQuotaExceededError):
|
|
CompletionApi().post(_completion_app(), _end_user())
|
|
|
|
@patch(
|
|
"controllers.web.completion.AppGenerateService.generate",
|
|
side_effect=ModelCurrentlyNotSupportError(),
|
|
)
|
|
@patch("controllers.web.completion.web_ns")
|
|
def test_model_not_support_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
|
mock_ns.payload = {"inputs": {}}
|
|
|
|
with app.test_request_context("/completion-messages", method="POST"):
|
|
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
|
CompletionApi().post(_completion_app(), _end_user())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CompletionStopApi
|
|
# ---------------------------------------------------------------------------
|
|
class TestCompletionStopApi:
|
|
def test_wrong_mode_raises(self, app: Flask) -> None:
|
|
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
|
|
with pytest.raises(NotCompletionAppError):
|
|
CompletionStopApi().post(_chat_app(), _end_user(), "task-1")
|
|
|
|
@patch("controllers.web.completion.AppTaskService.stop_task")
|
|
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
|
|
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
|
|
result, status = CompletionStopApi().post(_completion_app(), _end_user(), "task-1")
|
|
|
|
assert status == 200
|
|
assert result == {"result": "success"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ChatApi
|
|
# ---------------------------------------------------------------------------
|
|
class TestChatApi:
|
|
def test_wrong_mode_raises(self, app: Flask) -> None:
|
|
with app.test_request_context("/chat-messages", method="POST"):
|
|
with pytest.raises(NotChatAppError):
|
|
ChatApi().post(_completion_app(), _end_user())
|
|
|
|
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "reply"})
|
|
@patch("controllers.web.completion.AppGenerateService.generate")
|
|
@patch("controllers.web.completion.web_ns")
|
|
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
|
mock_ns.payload = {"inputs": {}, "query": "hi"}
|
|
mock_gen.return_value = "response"
|
|
|
|
with app.test_request_context("/chat-messages", method="POST"):
|
|
result = ChatApi().post(_chat_app(), _end_user())
|
|
|
|
assert result == {"answer": "reply"}
|
|
|
|
@patch(
|
|
"controllers.web.completion.AppGenerateService.generate",
|
|
side_effect=InvokeError(description="rate limit"),
|
|
)
|
|
@patch("controllers.web.completion.web_ns")
|
|
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
|
mock_ns.payload = {"inputs": {}, "query": "x"}
|
|
|
|
with app.test_request_context("/chat-messages", method="POST"):
|
|
with pytest.raises(CompletionRequestError):
|
|
ChatApi().post(_chat_app(), _end_user())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ChatStopApi
|
|
# ---------------------------------------------------------------------------
|
|
class TestChatStopApi:
|
|
def test_wrong_mode_raises(self, app: Flask) -> None:
|
|
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
|
|
with pytest.raises(NotChatAppError):
|
|
ChatStopApi().post(_completion_app(), _end_user(), "task-1")
|
|
|
|
@patch("controllers.web.completion.AppTaskService.stop_task")
|
|
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
|
|
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
|
|
result, status = ChatStopApi().post(_chat_app(), _end_user(), "task-1")
|
|
|
|
assert status == 200
|
|
assert result == {"result": "success"}
|