test: improve unit tests for controllers.service_api (#32073)

Co-authored-by: Rajat Agarwal <rajat.agarwal@infocusp.com>
This commit is contained in:
Dev Sharma
2026-02-25 12:15:50 +05:30
committed by GitHub
parent 212756c315
commit d773096146
24 changed files with 11279 additions and 2 deletions

View File

@ -0,0 +1,295 @@
"""
Unit tests for Service API Annotation controller.
Tests coverage for:
- AnnotationCreatePayload Pydantic model validation
- AnnotationReplyActionPayload Pydantic model validation
- Error patterns and validation logic
Note: API endpoint tests for annotation controllers are complex due to:
- @validate_app_token decorator requiring full Flask-SQLAlchemy setup
- @edit_permission_required decorator checking current_user permissions
- These are better covered by integration tests
"""
import uuid
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask_restx.api import HTTPStatus
from controllers.service_api.app.annotation import (
AnnotationCreatePayload,
AnnotationListApi,
AnnotationReplyActionApi,
AnnotationReplyActionPayload,
AnnotationReplyActionStatusApi,
AnnotationUpdateDeleteApi,
)
from extensions.ext_redis import redis_client
from models.model import App
from services.annotation_service import AppAnnotationService
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
# ---------------------------------------------------------------------------
# Pydantic Model Tests
# ---------------------------------------------------------------------------
class TestAnnotationCreatePayload:
"""Test suite for AnnotationCreatePayload Pydantic model."""
def test_payload_with_question_and_answer(self):
"""Test payload with required fields."""
payload = AnnotationCreatePayload(
question="What is AI?",
answer="AI is artificial intelligence.",
)
assert payload.question == "What is AI?"
assert payload.answer == "AI is artificial intelligence."
def test_payload_with_unicode_content(self):
"""Test payload with unicode content."""
payload = AnnotationCreatePayload(
question="什么是人工智能?",
answer="人工智能是模拟人类智能的技术。",
)
assert payload.question == "什么是人工智能?"
def test_payload_with_special_characters(self):
"""Test payload with special characters."""
payload = AnnotationCreatePayload(
question="What is <b>AI</b>?",
answer="AI & ML are related fields with 100% growth!",
)
assert "<b>" in payload.question
class TestAnnotationReplyActionPayload:
"""Test suite for AnnotationReplyActionPayload Pydantic model."""
def test_payload_with_all_fields(self):
"""Test payload with all fields."""
payload = AnnotationReplyActionPayload(
score_threshold=0.8,
embedding_provider_name="openai",
embedding_model_name="text-embedding-ada-002",
)
assert payload.score_threshold == 0.8
assert payload.embedding_provider_name == "openai"
assert payload.embedding_model_name == "text-embedding-ada-002"
def test_payload_with_different_provider(self):
"""Test payload with different embedding provider."""
payload = AnnotationReplyActionPayload(
score_threshold=0.75,
embedding_provider_name="azure_openai",
embedding_model_name="text-embedding-3-small",
)
assert payload.embedding_provider_name == "azure_openai"
def test_payload_with_zero_threshold(self):
"""Test payload with zero score threshold."""
payload = AnnotationReplyActionPayload(
score_threshold=0.0,
embedding_provider_name="local",
embedding_model_name="default",
)
assert payload.score_threshold == 0.0
# ---------------------------------------------------------------------------
# Model and Error Pattern Tests
# ---------------------------------------------------------------------------
class TestAppModelPatterns:
"""Test App model patterns used by annotation controller."""
def test_app_model_has_required_fields(self):
"""Test App model has required fields for annotation operations."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.status = "normal"
app.enable_api = True
assert app.id is not None
assert app.status == "normal"
assert app.enable_api is True
def test_app_model_disabled_api(self):
"""Test app with disabled API access."""
app = Mock(spec=App)
app.enable_api = False
assert app.enable_api is False
def test_app_model_archived_status(self):
"""Test app with archived status."""
app = Mock(spec=App)
app.status = "archived"
assert app.status == "archived"
class TestAnnotationErrorPatterns:
"""Test annotation-related error handling patterns."""
def test_not_found_error_pattern(self):
"""Test NotFound error pattern used in annotation operations."""
from werkzeug.exceptions import NotFound
with pytest.raises(NotFound):
raise NotFound("Annotation not found.")
def test_forbidden_error_pattern(self):
"""Test Forbidden error pattern."""
from werkzeug.exceptions import Forbidden
with pytest.raises(Forbidden):
raise Forbidden("Permission denied.")
def test_value_error_for_job_not_found(self):
"""Test ValueError pattern for job not found."""
with pytest.raises(ValueError, match="does not exist"):
raise ValueError("The job does not exist.")
class TestAnnotationReplyActionApi:
def test_enable(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
enable_mock = Mock()
monkeypatch.setattr(AppAnnotationService, "enable_app_annotation", enable_mock)
api = AnnotationReplyActionApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app")
with app.test_request_context(
"/apps/annotation-reply/enable",
method="POST",
json={"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"},
):
response, status = handler(api, app_model=app_model, action="enable")
assert status == 200
enable_mock.assert_called_once()
def test_disable(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
disable_mock = Mock()
monkeypatch.setattr(AppAnnotationService, "disable_app_annotation", disable_mock)
api = AnnotationReplyActionApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app")
with app.test_request_context(
"/apps/annotation-reply/disable",
method="POST",
json={"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"},
):
response, status = handler(api, app_model=app_model, action="disable")
assert status == 200
disable_mock.assert_called_once()
class TestAnnotationReplyActionStatusApi:
def test_missing_job(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(redis_client, "get", lambda *_args, **_kwargs: None)
api = AnnotationReplyActionStatusApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app")
with pytest.raises(ValueError):
handler(api, app_model=app_model, job_id="j1", action="enable")
def test_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
def _get(key):
if "error" in key:
return b"oops"
return b"error"
monkeypatch.setattr(redis_client, "get", _get)
api = AnnotationReplyActionStatusApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app")
response, status = handler(api, app_model=app_model, job_id="j1", action="enable")
assert status == 200
assert response["job_status"] == "error"
assert response["error_msg"] == "oops"
class TestAnnotationListApi:
def test_get(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
monkeypatch.setattr(
AppAnnotationService,
"get_annotation_list_by_app_id",
lambda *_args, **_kwargs: ([annotation], 1),
)
api = AnnotationListApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app")
with app.test_request_context("/apps/annotations?page=1&limit=1", method="GET"):
response = handler(api, app_model=app_model)
assert response["total"] == 1
def test_create(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
monkeypatch.setattr(
AppAnnotationService,
"insert_app_annotation_directly",
lambda *_args, **_kwargs: annotation,
)
api = AnnotationListApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="app")
with app.test_request_context("/apps/annotations", method="POST", json={"question": "q", "answer": "a"}):
response, status = handler(api, app_model=app_model)
assert status == HTTPStatus.CREATED
assert response["question"] == "q"
class TestAnnotationUpdateDeleteApi:
def test_update_delete(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
monkeypatch.setattr(
AppAnnotationService,
"update_app_annotation_directly",
lambda *_args, **_kwargs: annotation,
)
delete_mock = Mock()
monkeypatch.setattr(AppAnnotationService, "delete_app_annotation", delete_mock)
api = AnnotationUpdateDeleteApi()
put_handler = _unwrap(api.put)
delete_handler = _unwrap(api.delete)
app_model = SimpleNamespace(id="app")
with app.test_request_context("/apps/annotations/1", method="PUT", json={"question": "q", "answer": "a"}):
response = put_handler(api, app_model=app_model, annotation_id="1")
assert response["answer"] == "a"
with app.test_request_context("/apps/annotations/1", method="DELETE"):
response, status = delete_handler(api, app_model=app_model, annotation_id="1")
assert status == 204
delete_mock.assert_called_once()

View File

@ -0,0 +1,496 @@
"""
Unit tests for Service API App controllers
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameterApi
from controllers.service_api.app.error import AppUnavailableError
from models.model import App, AppMode
from tests.unit_tests.conftest import setup_mock_tenant_account_query
class TestAppParameterApi:
"""Test suite for AppParameterApi"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
app.mode = AppMode.CHAT
app.status = "normal"
app.enable_api = True
return app
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_for_chat_app(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test retrieving parameters for a chat app."""
# Arrange
mock_current_app.login_manager = Mock()
mock_config = Mock()
mock_config.id = str(uuid.uuid4())
mock_config.to_dict.return_value = {
"user_input_form": [{"type": "text", "label": "Name", "variable": "name", "required": True}],
"suggested_questions": [],
}
mock_app_model.app_model_config = mock_config
mock_app_model.workflow = None
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
# Mock DB queries for app and tenant
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
# Mock tenant owner info for login
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppParameterApi()
response = api.get()
# Assert
assert "opening_statement" in response
assert "suggested_questions" in response
assert "user_input_form" in response
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_for_workflow_app(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test retrieving parameters for a workflow app."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app_model.mode = AppMode.WORKFLOW
mock_workflow = Mock()
mock_workflow.features_dict = {"suggested_questions": []}
mock_workflow.user_input_form.return_value = [{"type": "text", "label": "Input", "variable": "input"}]
mock_app_model.workflow = mock_workflow
mock_app_model.app_model_config = None
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppParameterApi()
response = api.get()
# Assert
assert "user_input_form" in response
assert "opening_statement" in response
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_raises_error_when_chat_config_missing(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test that AppUnavailableError is raised when chat app has no config."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app_model.app_model_config = None
mock_app_model.workflow = None
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppParameterApi()
with pytest.raises(AppUnavailableError):
api.get()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_raises_error_when_workflow_missing(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test that AppUnavailableError is raised when workflow app has no workflow."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app_model.mode = AppMode.WORKFLOW
mock_app_model.workflow = None
mock_app_model.app_model_config = None
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act & Assert
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppParameterApi()
with pytest.raises(AppUnavailableError):
api.get()
class TestAppMetaApi:
"""Test suite for AppMetaApi"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.status = "normal"
app.enable_api = True
return app
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.app.app.AppService")
def test_get_app_meta(
self, mock_app_service, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test retrieving app metadata via AppService."""
# Arrange
mock_current_app.login_manager = Mock()
mock_service_instance = Mock()
mock_service_instance.get_app_meta.return_value = {
"tool_icons": {},
"AgentIcons": {},
}
mock_app_service.return_value = mock_service_instance
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppMetaApi()
response = api.get()
# Assert
mock_service_instance.get_app_meta.assert_called_once_with(mock_app_model)
assert response == {"tool_icons": {}, "AgentIcons": {}}
class TestAppInfoApi:
"""Test suite for AppInfoApi"""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model with all required attributes."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
app.name = "Test App"
app.description = "A test application"
app.mode = AppMode.CHAT
app.author_name = "Test Author"
app.status = "normal"
app.enable_api = True
# Mock tags relationship
mock_tag = Mock()
mock_tag.name = "test-tag"
app.tags = [mock_tag]
return app
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
):
"""Test retrieving basic app information."""
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppInfoApi()
response = api.get()
# Assert
assert response["name"] == "Test App"
assert response["description"] == "A test application"
assert response["tags"] == ["test-tag"]
assert response["mode"] == AppMode.CHAT
assert response["author_name"] == "Test Author"
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_with_multiple_tags(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app
):
"""Test retrieving app info with multiple tags."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app = Mock(spec=App)
mock_app.id = str(uuid.uuid4())
mock_app.tenant_id = str(uuid.uuid4())
mock_app.name = "Multi Tag App"
mock_app.description = "App with multiple tags"
mock_app.mode = AppMode.WORKFLOW
mock_app.author_name = "Author"
mock_app.status = "normal"
mock_app.enable_api = True
tag1, tag2, tag3 = Mock(), Mock(), Mock()
tag1.name = "tag-one"
tag2.name = "tag-two"
tag3.name = "tag-three"
mock_app.tags = [tag1, tag2, tag3]
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app.id
mock_api_token.tenant_id = mock_app.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppInfoApi()
response = api.get()
# Assert
assert response["tags"] == ["tag-one", "tag-two", "tag-three"]
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_with_no_tags(self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app):
"""Test retrieving app info when app has no tags."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app = Mock(spec=App)
mock_app.id = str(uuid.uuid4())
mock_app.tenant_id = str(uuid.uuid4())
mock_app.name = "No Tags App"
mock_app.description = "App without tags"
mock_app.mode = AppMode.COMPLETION
mock_app.author_name = "Author"
mock_app.tags = []
mock_app.status = "normal"
mock_app.enable_api = True
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app.id
mock_api_token.tenant_id = mock_app.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppInfoApi()
response = api.get()
# Assert
assert response["tags"] == []
@pytest.mark.parametrize(
"app_mode",
[AppMode.CHAT, AppMode.COMPLETION, AppMode.WORKFLOW, AppMode.ADVANCED_CHAT],
)
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_returns_correct_mode(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, app_mode
):
"""Test that all app modes are correctly returned."""
# Arrange
mock_current_app.login_manager = Mock()
mock_app = Mock(spec=App)
mock_app.id = str(uuid.uuid4())
mock_app.tenant_id = str(uuid.uuid4())
mock_app.name = "Test"
mock_app.description = "Test"
mock_app.mode = app_mode
mock_app.author_name = "Test"
mock_app.tags = []
mock_app.status = "normal"
mock_app.enable_api = True
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app.id
mock_api_token.tenant_id = mock_app.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = "normal"
mock_db.session.query.return_value.where.return_value.first.side_effect = [
mock_app,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
# Act
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppInfoApi()
response = api.get()
# Assert
assert response["mode"] == app_mode

View File

@ -0,0 +1,298 @@
"""
Unit tests for Service API Audio controller.
Tests coverage for:
- TextToAudioPayload Pydantic model validation
- Error mapping patterns between service and API errors
- AudioService method interfaces
"""
import io
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import InternalServerError
from controllers.service_api.app.audio import AudioApi, TextApi, TextToAudioPayload
from controllers.service_api.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _file_data():
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
# ---------------------------------------------------------------------------
# Pydantic Model Tests
# ---------------------------------------------------------------------------
class TestTextToAudioPayload:
"""Test suite for TextToAudioPayload Pydantic model."""
def test_payload_with_all_fields(self):
"""Test payload with all fields populated."""
payload = TextToAudioPayload(
message_id="msg_123",
voice="nova",
text="Hello, this is a test.",
streaming=False,
)
assert payload.message_id == "msg_123"
assert payload.voice == "nova"
assert payload.text == "Hello, this is a test."
assert payload.streaming is False
def test_payload_with_defaults(self):
"""Test payload with default values."""
payload = TextToAudioPayload()
assert payload.message_id is None
assert payload.voice is None
assert payload.text is None
assert payload.streaming is None
def test_payload_with_only_text(self):
"""Test payload with only text field."""
payload = TextToAudioPayload(text="Simple text to speech")
assert payload.text == "Simple text to speech"
assert payload.voice is None
assert payload.message_id is None
def test_payload_with_streaming_true(self):
"""Test payload with streaming enabled."""
payload = TextToAudioPayload(
text="Streaming test",
streaming=True,
)
assert payload.streaming is True
# ---------------------------------------------------------------------------
# AudioService Interface Tests
# ---------------------------------------------------------------------------
class TestAudioServiceInterface:
"""Test AudioService method interfaces exist."""
def test_transcript_asr_method_exists(self):
"""Test that AudioService.transcript_asr exists."""
assert hasattr(AudioService, "transcript_asr")
assert callable(AudioService.transcript_asr)
def test_transcript_tts_method_exists(self):
"""Test that AudioService.transcript_tts exists."""
assert hasattr(AudioService, "transcript_tts")
assert callable(AudioService.transcript_tts)
# ---------------------------------------------------------------------------
# Audio Service Tests
# ---------------------------------------------------------------------------
class TestAudioServiceInterface:
"""Test suite for AudioService interface methods."""
def test_transcript_asr_method_exists(self):
"""Test that AudioService.transcript_asr exists."""
assert hasattr(AudioService, "transcript_asr")
assert callable(AudioService.transcript_asr)
def test_transcript_tts_method_exists(self):
"""Test that AudioService.transcript_tts exists."""
assert hasattr(AudioService, "transcript_tts")
assert callable(AudioService.transcript_tts)
class TestServiceErrorTypes:
"""Test service error types used by audio controller."""
def test_no_audio_uploaded_service_error(self):
"""Test NoAudioUploadedServiceError exists."""
error = NoAudioUploadedServiceError()
assert error is not None
def test_audio_too_large_service_error(self):
"""Test AudioTooLargeServiceError with message."""
error = AudioTooLargeServiceError("File too large")
assert "File too large" in str(error)
def test_unsupported_audio_type_service_error(self):
"""Test UnsupportedAudioTypeServiceError exists."""
error = UnsupportedAudioTypeServiceError()
assert error is not None
def test_provider_not_support_speech_to_text_service_error(self):
"""Test ProviderNotSupportSpeechToTextServiceError exists."""
error = ProviderNotSupportSpeechToTextServiceError()
assert error is not None
# ---------------------------------------------------------------------------
# Mocked Behavior Tests
# ---------------------------------------------------------------------------
class TestAudioServiceMockedBehavior:
"""Test AudioService behavior with mocked methods."""
@pytest.fixture
def mock_app(self):
"""Create mock app model."""
from models.model import App
app = Mock(spec=App)
app.id = str(uuid.uuid4())
return app
@pytest.fixture
def mock_file(self):
"""Create mock file upload."""
mock = Mock()
mock.filename = "test_audio.mp3"
mock.content_type = "audio/mpeg"
return mock
@patch.object(AudioService, "transcript_asr")
def test_transcript_asr_returns_response(self, mock_asr, mock_app, mock_file):
"""Test ASR transcription returns response dict."""
mock_response = {"text": "Transcribed text"}
mock_asr.return_value = mock_response
result = AudioService.transcript_asr(
app_model=mock_app,
file=mock_file,
end_user="user_123",
)
assert result["text"] == "Transcribed text"
@patch.object(AudioService, "transcript_tts")
def test_transcript_tts_returns_response(self, mock_tts, mock_app):
"""Test TTS transcription returns response."""
mock_response = {"audio": "base64_audio_data"}
mock_tts.return_value = mock_response
result = AudioService.transcript_tts(
app_model=mock_app,
text="Hello world",
voice="nova",
end_user="user_123",
message_id="msg_123",
)
assert result["audio"] == "base64_audio_data"
class TestAudioApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
api = AudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
response = handler(api, app_model=app_model, end_user=end_user)
assert response == {"text": "ok"}
@pytest.mark.parametrize(
("exc", "expected"),
[
(AppModelConfigBrokenError(), AppUnavailableError),
(NoAudioUploadedServiceError(), NoAudioUploadedError),
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
(QuotaExceededError(), ProviderQuotaExceededError),
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
(InvokeError("invoke"), CompletionRequestError),
],
)
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
api = AudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(expected):
handler(api, app_model=app_model, end_user=end_user)
def test_unhandled_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))
)
api = AudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(InternalServerError):
handler(api, app_model=app_model, end_user=end_user)
class TestTextApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
api = TextApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(external_user_id="ext")
with app.test_request_context(
"/text-to-audio",
method="POST",
json={"text": "hello", "voice": "v"},
):
response = handler(api, app_model=app_model, end_user=end_user)
assert response == {"audio": "ok"}
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError())
)
api = TextApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(external_user_id="ext")
with app.test_request_context("/text-to-audio", method="POST", json={"text": "hello"}):
with pytest.raises(ProviderQuotaExceededError):
handler(api, app_model=app_model, end_user=end_user)

View File

@ -0,0 +1,524 @@
"""
Unit tests for Service API Completion controllers.
Tests coverage for:
- CompletionRequestPayload and ChatRequestPayload Pydantic models
- App mode validation logic
- Error mapping from service layer to HTTP errors
Focus on:
- Pydantic model validation (especially UUID normalization)
- Error types and their mappings
"""
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.service_api.app.completion import (
ChatApi,
ChatRequestPayload,
ChatStopApi,
CompletionApi,
CompletionRequestPayload,
CompletionStopApi,
)
from controllers.service_api.app.error import (
AppUnavailableError,
ConversationCompletedError,
NotChatAppError,
)
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestCompletionRequestPayload:
"""Test suite for CompletionRequestPayload Pydantic model."""
def test_payload_with_required_fields(self):
"""Test payload with only required inputs field."""
payload = CompletionRequestPayload(inputs={"name": "test"})
assert payload.inputs == {"name": "test"}
assert payload.query == ""
assert payload.files is None
assert payload.response_mode is None
assert payload.retriever_from == "dev"
def test_payload_with_all_fields(self):
"""Test payload with all fields populated."""
payload = CompletionRequestPayload(
inputs={"user_input": "Hello"},
query="What is AI?",
files=[{"type": "image", "url": "http://example.com/image.png"}],
response_mode="streaming",
retriever_from="api",
)
assert payload.inputs == {"user_input": "Hello"}
assert payload.query == "What is AI?"
assert payload.files == [{"type": "image", "url": "http://example.com/image.png"}]
assert payload.response_mode == "streaming"
assert payload.retriever_from == "api"
def test_payload_response_mode_blocking(self):
"""Test payload with blocking response mode."""
payload = CompletionRequestPayload(inputs={}, response_mode="blocking")
assert payload.response_mode == "blocking"
def test_payload_empty_inputs(self):
"""Test payload with empty inputs dict."""
payload = CompletionRequestPayload(inputs={})
assert payload.inputs == {}
def test_payload_complex_inputs(self):
"""Test payload with complex nested inputs."""
complex_inputs = {
"user": {"name": "Alice", "age": 30},
"context": ["item1", "item2"],
"settings": {"theme": "dark", "notifications": True},
}
payload = CompletionRequestPayload(inputs=complex_inputs)
assert payload.inputs == complex_inputs
class TestChatRequestPayload:
"""Test suite for ChatRequestPayload Pydantic model."""
def test_payload_with_required_fields(self):
"""Test payload with required fields."""
payload = ChatRequestPayload(inputs={"key": "value"}, query="Hello")
assert payload.inputs == {"key": "value"}
assert payload.query == "Hello"
assert payload.conversation_id is None
assert payload.auto_generate_name is True
def test_payload_normalizes_valid_uuid_conversation_id(self):
"""Test that valid UUID conversation_id is normalized."""
valid_uuid = str(uuid.uuid4())
payload = ChatRequestPayload(inputs={}, query="test", conversation_id=valid_uuid)
assert payload.conversation_id == valid_uuid
def test_payload_normalizes_empty_string_conversation_id_to_none(self):
"""Test that empty string conversation_id becomes None."""
payload = ChatRequestPayload(inputs={}, query="test", conversation_id="")
assert payload.conversation_id is None
def test_payload_normalizes_whitespace_conversation_id_to_none(self):
"""Test that whitespace-only conversation_id becomes None."""
payload = ChatRequestPayload(inputs={}, query="test", conversation_id=" ")
assert payload.conversation_id is None
def test_payload_rejects_invalid_uuid_conversation_id(self):
"""Test that invalid UUID format raises ValueError."""
with pytest.raises(ValueError) as exc_info:
ChatRequestPayload(inputs={}, query="test", conversation_id="not-a-uuid")
assert "valid UUID" in str(exc_info.value)
def test_payload_with_workflow_id(self):
"""Test payload with workflow_id for advanced chat."""
payload = ChatRequestPayload(inputs={}, query="test", workflow_id="workflow_123")
assert payload.workflow_id == "workflow_123"
def test_payload_streaming_mode(self):
"""Test payload with streaming response mode."""
payload = ChatRequestPayload(inputs={}, query="test", response_mode="streaming")
assert payload.response_mode == "streaming"
def test_payload_auto_generate_name_false(self):
"""Test payload with auto_generate_name explicitly false."""
payload = ChatRequestPayload(inputs={}, query="test", auto_generate_name=False)
assert payload.auto_generate_name is False
def test_payload_with_files(self):
"""Test payload with file attachments."""
files = [
{"type": "image", "transfer_method": "remote_url", "url": "http://example.com/img.png"},
{"type": "document", "transfer_method": "local_file", "upload_file_id": "file_123"},
]
payload = ChatRequestPayload(inputs={}, query="test", files=files)
assert payload.files == files
assert len(payload.files) == 2
class TestCompletionErrorMappings:
"""Test error type mappings for completion endpoints."""
def test_conversation_not_exists_error_exists(self):
"""Test ConversationNotExistsError can be raised."""
error = services.errors.conversation.ConversationNotExistsError()
assert isinstance(error, services.errors.conversation.ConversationNotExistsError)
def test_conversation_completed_error_exists(self):
"""Test ConversationCompletedError can be raised."""
error = services.errors.conversation.ConversationCompletedError()
assert isinstance(error, services.errors.conversation.ConversationCompletedError)
api_error = ConversationCompletedError()
assert api_error is not None
def test_app_model_config_broken_error_exists(self):
"""Test AppModelConfigBrokenError can be raised."""
error = services.errors.app_model_config.AppModelConfigBrokenError()
assert isinstance(error, services.errors.app_model_config.AppModelConfigBrokenError)
api_error = AppUnavailableError()
assert api_error is not None
def test_workflow_not_found_error_exists(self):
"""Test WorkflowNotFoundError can be raised."""
error = WorkflowNotFoundError("Workflow not found")
assert isinstance(error, WorkflowNotFoundError)
def test_is_draft_workflow_error_exists(self):
"""Test IsDraftWorkflowError can be raised."""
error = IsDraftWorkflowError("Workflow is in draft state")
assert isinstance(error, IsDraftWorkflowError)
def test_workflow_id_format_error_exists(self):
"""Test WorkflowIdFormatError can be raised."""
error = WorkflowIdFormatError("Invalid workflow ID format")
assert isinstance(error, WorkflowIdFormatError)
def test_invoke_rate_limit_error_exists(self):
"""Test InvokeRateLimitError can be raised."""
error = InvokeRateLimitError("Rate limit exceeded")
assert isinstance(error, InvokeRateLimitError)
class TestAppModeValidation:
"""Test app mode validation logic patterns."""
def test_completion_mode_is_valid_for_completion_endpoint(self):
"""Test that COMPLETION mode is valid for completion endpoints."""
assert AppMode.COMPLETION == AppMode.COMPLETION
def test_chat_modes_are_distinct_from_completion(self):
"""Test that chat modes are distinct from completion mode."""
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
assert AppMode.COMPLETION not in chat_modes
def test_workflow_mode_is_distinct_from_chat_modes(self):
"""Test that WORKFLOW mode is not a chat mode."""
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
assert AppMode.WORKFLOW not in chat_modes
def test_not_chat_app_error_can_be_raised(self):
"""Test NotChatAppError can be raised for non-chat apps."""
error = NotChatAppError()
assert error is not None
def test_all_app_modes_are_defined(self):
"""Test that all expected app modes are defined."""
expected_modes = ["COMPLETION", "CHAT", "AGENT_CHAT", "ADVANCED_CHAT", "WORKFLOW", "CHANNEL", "RAG_PIPELINE"]
for mode_name in expected_modes:
assert hasattr(AppMode, mode_name), f"AppMode.{mode_name} should exist"
class TestAppGenerateService:
"""Test AppGenerateService integration patterns."""
def test_generate_method_exists(self):
"""Test that AppGenerateService.generate method exists."""
assert hasattr(AppGenerateService, "generate")
assert callable(AppGenerateService.generate)
@patch.object(AppGenerateService, "generate")
def test_generate_returns_response(self, mock_generate):
"""Test that generate returns expected response format."""
expected = {"answer": "Hello!"}
mock_generate.return_value = expected
result = AppGenerateService.generate(
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={"query": "Hi"}, invoke_from=Mock(), streaming=False
)
assert result == expected
@patch.object(AppGenerateService, "generate")
def test_generate_raises_conversation_not_exists(self, mock_generate):
"""Test generate raises ConversationNotExistsError."""
mock_generate.side_effect = services.errors.conversation.ConversationNotExistsError()
with pytest.raises(services.errors.conversation.ConversationNotExistsError):
AppGenerateService.generate(
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
)
@patch.object(AppGenerateService, "generate")
def test_generate_raises_quota_exceeded(self, mock_generate):
"""Test generate raises QuotaExceededError."""
mock_generate.side_effect = QuotaExceededError()
with pytest.raises(QuotaExceededError):
AppGenerateService.generate(
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
)
@patch.object(AppGenerateService, "generate")
def test_generate_raises_invoke_error(self, mock_generate):
"""Test generate raises InvokeError."""
mock_generate.side_effect = InvokeError("Model invocation failed")
with pytest.raises(InvokeError):
AppGenerateService.generate(
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
)
class TestCompletionControllerLogic:
"""Test CompletionApi and ChatApi controller logic directly."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
from flask import Flask
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.service_api.app.completion.service_api_ns")
@patch("controllers.service_api.app.completion.AppGenerateService")
def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
"""Test CompletionApi.post success path."""
from controllers.service_api.app.completion import CompletionApi
# Setup mocks
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.COMPLETION
mock_end_user = Mock(spec=EndUser)
payload_dict = {"inputs": {"text": "hello"}, "response_mode": "blocking"}
mock_service_api_ns.payload = payload_dict
mock_generate_service.generate.return_value = {"text": "response"}
with app.test_request_context():
# Helper for compact_generate_response logic check
with patch("controllers.service_api.app.completion.helper.compact_generate_response") as mock_compact:
mock_compact.return_value = {"text": "compacted"}
api = CompletionApi()
response = api.post.__wrapped__(api, mock_app_model, mock_end_user)
assert response == {"text": "compacted"}
mock_generate_service.generate.assert_called_once()
@patch("controllers.service_api.app.completion.service_api_ns")
def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app):
"""Test CompletionApi.post with wrong app mode."""
from controllers.service_api.app.completion import CompletionApi
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.CHAT # Wrong mode
mock_end_user = Mock(spec=EndUser)
with app.test_request_context():
with pytest.raises(AppUnavailableError):
CompletionApi().post.__wrapped__(CompletionApi(), mock_app_model, mock_end_user)
@patch("controllers.service_api.app.completion.service_api_ns")
@patch("controllers.service_api.app.completion.AppGenerateService")
def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
"""Test ChatApi.post success path."""
from controllers.service_api.app.completion import ChatApi
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.CHAT
mock_end_user = Mock(spec=EndUser)
payload_dict = {"inputs": {}, "query": "hello", "response_mode": "blocking"}
mock_service_api_ns.payload = payload_dict
mock_generate_service.generate.return_value = {"text": "response"}
with app.test_request_context():
with patch("controllers.service_api.app.completion.helper.compact_generate_response") as mock_compact:
mock_compact.return_value = {"text": "compacted"}
api = ChatApi()
response = api.post.__wrapped__(api, mock_app_model, mock_end_user)
assert response == {"text": "compacted"}
@patch("controllers.service_api.app.completion.service_api_ns")
def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app):
"""Test ChatApi.post with wrong app mode."""
from controllers.service_api.app.completion import ChatApi
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.COMPLETION # Wrong mode
mock_end_user = Mock(spec=EndUser)
with app.test_request_context():
with pytest.raises(NotChatAppError):
ChatApi().post.__wrapped__(ChatApi(), mock_app_model, mock_end_user)
@patch("controllers.service_api.app.completion.AppTaskService")
def test_completion_stop_api_success(self, mock_task_service, app):
"""Test CompletionStopApi.post success."""
from controllers.service_api.app.completion import CompletionStopApi
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.COMPLETION
mock_end_user = Mock(spec=EndUser)
mock_end_user.id = "user_id"
with app.test_request_context():
api = CompletionStopApi()
response = api.post.__wrapped__(api, mock_app_model, mock_end_user, "task_id")
assert response == ({"result": "success"}, 200)
mock_task_service.stop_task.assert_called_once()
@patch("controllers.service_api.app.completion.AppTaskService")
def test_chat_stop_api_success(self, mock_task_service, app):
"""Test ChatStopApi.post success."""
from controllers.service_api.app.completion import ChatStopApi
mock_app_model = Mock(spec=App)
mock_app_model.mode = AppMode.CHAT
mock_end_user = Mock(spec=EndUser)
mock_end_user.id = "user_id"
with app.test_request_context():
api = ChatStopApi()
response = api.post.__wrapped__(api, mock_app_model, mock_end_user, "task_id")
assert response == ({"result": "success"}, 200)
mock_task_service.stop_task.assert_called_once()
class TestChatRequestPayloadController:
def test_normalizes_conversation_id(self) -> None:
payload = ChatRequestPayload.model_validate(
{"inputs": {}, "query": "hi", "conversation_id": " ", "response_mode": "blocking"}
)
assert payload.conversation_id is None
with pytest.raises(ValidationError):
ChatRequestPayload.model_validate({"inputs": {}, "query": "hi", "conversation_id": "bad-id"})
class TestCompletionApiController:
def test_wrong_mode(self, app) -> None:
api = CompletionApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}):
with pytest.raises(AppUnavailableError):
handler(api, app_model=app_model, end_user=end_user)
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace()
api = CompletionApi()
handler = _unwrap(api.post)
with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
class TestCompletionStopApiController:
def test_wrong_mode(self, app) -> None:
api = CompletionStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/completion-messages/1/stop", method="POST"):
with pytest.raises(AppUnavailableError):
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
stop_mock = Mock()
monkeypatch.setattr(AppTaskService, "stop_task", stop_mock)
api = CompletionStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/completion-messages/1/stop", method="POST"):
response, status = handler(api, app_model=app_model, end_user=end_user, task_id="t1")
assert status == 200
assert response == {"result": "success"}
class TestChatApiController:
def test_wrong_mode(self, app) -> None:
api = ChatApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_workflow_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(WorkflowNotFoundError("missing")),
)
api = ChatApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(IsDraftWorkflowError("draft")),
)
api = ChatApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
with pytest.raises(BadRequest):
handler(api, app_model=app_model, end_user=end_user)
class TestChatStopApiController:
def test_wrong_mode(self, app) -> None:
api = ChatStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/chat-messages/1/stop", method="POST"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="t1")

View File

@ -0,0 +1,597 @@
"""
Unit tests for Service API Conversation controllers.
Tests coverage for:
- ConversationListQuery, ConversationRenamePayload Pydantic models
- ConversationVariablesQuery with SQL injection prevention
- ConversationVariableUpdatePayload
- App mode validation for chat-only endpoints
Focus on:
- Pydantic model validation including security checks
- SQL injection prevention in variable name filtering
- Error types and mappings
"""
import sys
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.service_api.app.conversation import (
ConversationApi,
ConversationDetailApi,
ConversationListQuery,
ConversationRenameApi,
ConversationRenamePayload,
ConversationVariableDetailApi,
ConversationVariablesApi,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
)
from controllers.service_api.app.error import NotChatAppError
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
from services.errors.conversation import (
ConversationNotExistsError,
ConversationVariableNotExistsError,
ConversationVariableTypeMismatchError,
LastConversationNotExistsError,
)
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestConversationListQuery:
"""Test suite for ConversationListQuery Pydantic model."""
def test_query_with_defaults(self):
"""Test query with default values."""
query = ConversationListQuery()
assert query.last_id is None
assert query.limit == 20
assert query.sort_by == "-updated_at"
def test_query_with_last_id(self):
"""Test query with pagination last_id."""
last_id = str(uuid.uuid4())
query = ConversationListQuery(last_id=last_id)
assert str(query.last_id) == last_id
def test_query_limit_boundaries(self):
"""Test query respects limit boundaries."""
query_min = ConversationListQuery(limit=1)
assert query_min.limit == 1
query_max = ConversationListQuery(limit=100)
assert query_max.limit == 100
def test_query_rejects_limit_below_minimum(self):
"""Test query rejects limit < 1."""
with pytest.raises(ValueError):
ConversationListQuery(limit=0)
def test_query_rejects_limit_above_maximum(self):
"""Test query rejects limit > 100."""
with pytest.raises(ValueError):
ConversationListQuery(limit=101)
@pytest.mark.parametrize(
"sort_by",
[
"created_at",
"-created_at",
"updated_at",
"-updated_at",
],
)
def test_query_valid_sort_options(self, sort_by):
"""Test all valid sort_by options."""
query = ConversationListQuery(sort_by=sort_by)
assert query.sort_by == sort_by
class TestConversationRenamePayload:
"""Test suite for ConversationRenamePayload Pydantic model."""
def test_payload_with_name(self):
"""Test payload with explicit name."""
payload = ConversationRenamePayload(name="My New Chat", auto_generate=False)
assert payload.name == "My New Chat"
assert payload.auto_generate is False
def test_payload_with_auto_generate(self):
"""Test payload with auto_generate enabled."""
payload = ConversationRenamePayload(auto_generate=True)
assert payload.auto_generate is True
assert payload.name is None
def test_payload_requires_name_when_auto_generate_false(self):
"""Test that name is required when auto_generate is False."""
with pytest.raises(ValueError) as exc_info:
ConversationRenamePayload(auto_generate=False)
assert "name is required when auto_generate is false" in str(exc_info.value)
def test_payload_requires_non_empty_name_when_auto_generate_false(self):
"""Test that empty string name is rejected."""
with pytest.raises(ValueError):
ConversationRenamePayload(name="", auto_generate=False)
def test_payload_requires_non_whitespace_name_when_auto_generate_false(self):
"""Test that whitespace-only name is rejected."""
with pytest.raises(ValueError):
ConversationRenamePayload(name=" ", auto_generate=False)
def test_payload_name_with_special_characters(self):
"""Test payload with name containing special characters."""
payload = ConversationRenamePayload(name="Chat #1 - (Test) & More!", auto_generate=False)
assert payload.name == "Chat #1 - (Test) & More!"
def test_payload_name_with_unicode(self):
"""Test payload with Unicode characters in name."""
payload = ConversationRenamePayload(name="对话 📝 Чат", auto_generate=False)
assert payload.name == "对话 📝 Чат"
class TestConversationVariablesQuery:
"""Test suite for ConversationVariablesQuery Pydantic model."""
def test_query_with_defaults(self):
"""Test query with default values."""
query = ConversationVariablesQuery()
assert query.last_id is None
assert query.limit == 20
assert query.variable_name is None
def test_query_with_variable_name(self):
"""Test query with valid variable_name filter."""
query = ConversationVariablesQuery(variable_name="user_preference")
assert query.variable_name == "user_preference"
def test_query_allows_hyphen_in_variable_name(self):
"""Test that hyphens are allowed in variable names."""
query = ConversationVariablesQuery(variable_name="my-variable")
assert query.variable_name == "my-variable"
def test_query_allows_underscore_in_variable_name(self):
"""Test that underscores are allowed in variable names."""
query = ConversationVariablesQuery(variable_name="my_variable")
assert query.variable_name == "my_variable"
def test_query_allows_period_in_variable_name(self):
"""Test that periods are allowed in variable names."""
query = ConversationVariablesQuery(variable_name="config.setting")
assert query.variable_name == "config.setting"
def test_query_rejects_sql_injection_single_quote(self):
"""Test that single quotes are rejected (SQL injection prevention)."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name="'; DROP TABLE users;--")
assert "can only contain" in str(exc_info.value)
def test_query_rejects_sql_injection_double_quote(self):
"""Test that double quotes are rejected."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name='name"test')
assert "can only contain" in str(exc_info.value)
def test_query_rejects_sql_injection_semicolon(self):
"""Test that semicolons are rejected."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name="name;malicious")
assert "can only contain" in str(exc_info.value)
def test_query_rejects_sql_injection_comment(self):
"""Test that SQL comments are rejected."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name="name--comment")
assert "invalid characters" in str(exc_info.value)
def test_query_rejects_special_characters(self):
"""Test that special characters are rejected."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name="name@domain")
assert "can only contain" in str(exc_info.value)
def test_query_rejects_backticks(self):
"""Test that backticks are rejected (SQL injection prevention)."""
with pytest.raises(ValueError) as exc_info:
ConversationVariablesQuery(variable_name="`table`")
assert "can only contain" in str(exc_info.value)
def test_query_pagination_limits(self):
"""Test query pagination limit boundaries."""
query_min = ConversationVariablesQuery(limit=1)
assert query_min.limit == 1
query_max = ConversationVariablesQuery(limit=100)
assert query_max.limit == 100
class TestConversationVariableUpdatePayload:
"""Test suite for ConversationVariableUpdatePayload Pydantic model."""
def test_payload_with_string_value(self):
"""Test payload with string value."""
payload = ConversationVariableUpdatePayload(value="hello")
assert payload.value == "hello"
def test_payload_with_number_value(self):
"""Test payload with number value."""
payload = ConversationVariableUpdatePayload(value=42)
assert payload.value == 42
def test_payload_with_float_value(self):
"""Test payload with float value."""
payload = ConversationVariableUpdatePayload(value=3.14159)
assert payload.value == 3.14159
def test_payload_with_list_value(self):
"""Test payload with list value."""
payload = ConversationVariableUpdatePayload(value=["a", "b", "c"])
assert payload.value == ["a", "b", "c"]
def test_payload_with_dict_value(self):
"""Test payload with dictionary value."""
payload = ConversationVariableUpdatePayload(value={"key": "value"})
assert payload.value == {"key": "value"}
def test_payload_with_none_value(self):
"""Test payload with None value."""
payload = ConversationVariableUpdatePayload(value=None)
assert payload.value is None
def test_payload_with_boolean_value(self):
"""Test payload with boolean value."""
payload = ConversationVariableUpdatePayload(value=True)
assert payload.value is True
def test_payload_with_nested_structure(self):
"""Test payload with deeply nested structure."""
nested = {"level1": {"level2": {"level3": ["a", "b", {"c": 123}]}}}
payload = ConversationVariableUpdatePayload(value=nested)
assert payload.value == nested
class TestConversationAppModeValidation:
"""Test app mode validation for conversation endpoints."""
@pytest.mark.parametrize(
"mode",
[
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
AppMode.ADVANCED_CHAT.value,
],
)
def test_chat_modes_are_valid_for_conversation_endpoints(self, mode):
"""Test that all chat modes are valid for conversation endpoints.
Verifies that CHAT, AGENT_CHAT, and ADVANCED_CHAT modes pass
validation without raising NotChatAppError.
"""
app = Mock(spec=App)
app.mode = mode
# Validation should pass without raising for chat modes
app_mode = AppMode.value_of(app.mode)
assert app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
def test_completion_mode_is_invalid_for_conversation_endpoints(self):
"""Test that COMPLETION mode is invalid for conversation endpoints.
Verifies that calling a conversation endpoint with a COMPLETION mode
app raises NotChatAppError.
"""
app = Mock(spec=App)
app.mode = AppMode.COMPLETION.value
app_mode = AppMode.value_of(app.mode)
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
with pytest.raises(NotChatAppError):
raise NotChatAppError()
def test_workflow_mode_is_invalid_for_conversation_endpoints(self):
"""Test that WORKFLOW mode is invalid for conversation endpoints.
Verifies that calling a conversation endpoint with a WORKFLOW mode
app raises NotChatAppError.
"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW.value
app_mode = AppMode.value_of(app.mode)
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
with pytest.raises(NotChatAppError):
raise NotChatAppError()
class TestConversationErrorTypes:
"""Test conversation-related error types."""
def test_conversation_not_exists_error(self):
"""Test ConversationNotExistsError exists and can be raised."""
error = services.errors.conversation.ConversationNotExistsError()
assert isinstance(error, services.errors.conversation.ConversationNotExistsError)
def test_conversation_completed_error(self):
"""Test ConversationCompletedError exists."""
error = services.errors.conversation.ConversationCompletedError()
assert isinstance(error, services.errors.conversation.ConversationCompletedError)
def test_last_conversation_not_exists_error(self):
"""Test LastConversationNotExistsError exists."""
error = services.errors.conversation.LastConversationNotExistsError()
assert isinstance(error, services.errors.conversation.LastConversationNotExistsError)
def test_conversation_variable_not_exists_error(self):
"""Test ConversationVariableNotExistsError exists."""
error = services.errors.conversation.ConversationVariableNotExistsError()
assert isinstance(error, services.errors.conversation.ConversationVariableNotExistsError)
def test_conversation_variable_type_mismatch_error(self):
"""Test ConversationVariableTypeMismatchError exists."""
error = services.errors.conversation.ConversationVariableTypeMismatchError("Type mismatch")
assert isinstance(error, services.errors.conversation.ConversationVariableTypeMismatchError)
class TestConversationService:
"""Test ConversationService integration patterns."""
def test_pagination_by_last_id_method_exists(self):
"""Test that ConversationService.pagination_by_last_id exists."""
assert hasattr(ConversationService, "pagination_by_last_id")
assert callable(ConversationService.pagination_by_last_id)
def test_delete_method_exists(self):
"""Test that ConversationService.delete exists."""
assert hasattr(ConversationService, "delete")
assert callable(ConversationService.delete)
def test_rename_method_exists(self):
"""Test that ConversationService.rename exists."""
assert hasattr(ConversationService, "rename")
assert callable(ConversationService.rename)
def test_get_conversational_variable_method_exists(self):
"""Test that ConversationService.get_conversational_variable exists."""
assert hasattr(ConversationService, "get_conversational_variable")
assert callable(ConversationService.get_conversational_variable)
def test_update_conversation_variable_method_exists(self):
"""Test that ConversationService.update_conversation_variable exists."""
assert hasattr(ConversationService, "update_conversation_variable")
assert callable(ConversationService.update_conversation_variable)
@patch.object(ConversationService, "pagination_by_last_id")
def test_pagination_returns_expected_format(self, mock_pagination):
"""Test pagination returns expected data format."""
mock_result = Mock()
mock_result.data = []
mock_result.limit = 20
mock_result.has_more = False
mock_pagination.return_value = mock_result
result = ConversationService.pagination_by_last_id(
app_model=Mock(spec=App),
user=Mock(spec=EndUser),
last_id=None,
limit=20,
invoke_from=Mock(),
sort_by="-updated_at",
)
assert hasattr(result, "data")
assert hasattr(result, "limit")
assert hasattr(result, "has_more")
@patch.object(ConversationService, "rename")
def test_rename_returns_conversation(self, mock_rename):
"""Test rename returns updated conversation."""
mock_conversation = Mock()
mock_conversation.name = "New Name"
mock_rename.return_value = mock_conversation
result = ConversationService.rename(
app_model=Mock(spec=App),
conversation_id="conv_123",
user=Mock(spec=EndUser),
name="New Name",
auto_generate=False,
)
assert result.name == "New Name"
class TestConversationPayloadsController:
def test_rename_requires_name(self) -> None:
with pytest.raises(ValueError):
ConversationRenamePayload(auto_generate=False, name="")
def test_variables_query_invalid_name(self) -> None:
with pytest.raises(ValueError):
ConversationVariablesQuery(variable_name="bad;")
class TestConversationApiController:
def test_list_not_chat(self, app) -> None:
api = ConversationApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/conversations", method="GET"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
class _SessionStub:
def __enter__(self):
return SimpleNamespace()
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(
ConversationService,
"pagination_by_last_id",
lambda *_args, **_kwargs: (_ for _ in ()).throw(LastConversationNotExistsError()),
)
conversation_module = sys.modules["controllers.service_api.app.conversation"]
monkeypatch.setattr(conversation_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(conversation_module, "Session", lambda *_args, **_kwargs: _SessionStub())
api = ConversationApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations?last_id=00000000-0000-0000-0000-000000000001&limit=20",
method="GET",
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
class TestConversationDetailApiController:
def test_delete_not_chat(self, app) -> None:
api = ConversationDetailApi()
handler = _unwrap(api.delete)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/conversations/1", method="DELETE"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_delete_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"delete",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
api = ConversationDetailApi()
handler = _unwrap(api.delete)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/conversations/1", method="DELETE"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
class TestConversationRenameApiController:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"rename",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
api = ConversationRenameApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/name",
method="POST",
json={"auto_generate": True},
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
class TestConversationVariablesApiController:
def test_not_chat(self, app) -> None:
api = ConversationVariablesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/conversations/1/variables", method="GET"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"get_conversational_variable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
api = ConversationVariablesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/variables?limit=20",
method="GET",
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
class TestConversationVariableDetailApiController:
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationVariableTypeMismatchError("bad")),
)
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/variables/2",
method="PUT",
json={"value": "x"},
):
with pytest.raises(BadRequest):
handler(
api,
app_model=app_model,
end_user=end_user,
c_id="00000000-0000-0000-0000-000000000001",
variable_id="00000000-0000-0000-0000-000000000002",
)
def test_update_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationVariableNotExistsError()),
)
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/conversations/1/variables/2",
method="PUT",
json={"value": "x"},
):
with pytest.raises(NotFound):
handler(
api,
app_model=app_model,
end_user=end_user,
c_id="00000000-0000-0000-0000-000000000001",
variable_id="00000000-0000-0000-0000-000000000002",
)

View File

@ -0,0 +1,398 @@
"""
Unit tests for Service API File controllers.
Tests coverage for:
- File upload validation
- Error handling for file operations
- FileService integration
Focus on:
- File validation logic (size, type, filename)
- Error type mappings
- Service method interfaces
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from fields.file_fields import FileResponse
from services.file_service import FileService
class TestFileResponse:
"""Test suite for FileResponse Pydantic model."""
def test_file_response_has_required_fields(self):
"""Test FileResponse model includes required fields."""
# Verify the model exists and can be imported
assert FileResponse is not None
assert hasattr(FileResponse, "model_fields")
class TestFileUploadErrors:
"""Test file upload error types."""
def test_no_file_uploaded_error_can_be_raised(self):
"""Test NoFileUploadedError can be raised."""
error = NoFileUploadedError()
assert error is not None
def test_too_many_files_error_can_be_raised(self):
"""Test TooManyFilesError can be raised."""
error = TooManyFilesError()
assert error is not None
def test_unsupported_file_type_error_can_be_raised(self):
"""Test UnsupportedFileTypeError can be raised."""
error = UnsupportedFileTypeError()
assert error is not None
def test_filename_not_exists_error_can_be_raised(self):
"""Test FilenameNotExistsError can be raised."""
error = FilenameNotExistsError()
assert error is not None
def test_file_too_large_error_can_be_raised(self):
"""Test FileTooLargeError can be raised."""
error = FileTooLargeError("File exceeds maximum size")
assert "File exceeds maximum size" in str(error) or error is not None
class TestFileServiceErrors:
"""Test FileService error types."""
def test_file_service_file_too_large_error_exists(self):
"""Test FileTooLargeError from services exists."""
import services.errors.file
error = services.errors.file.FileTooLargeError("File too large")
assert isinstance(error, services.errors.file.FileTooLargeError)
def test_file_service_unsupported_file_type_error_exists(self):
"""Test UnsupportedFileTypeError from services exists."""
import services.errors.file
error = services.errors.file.UnsupportedFileTypeError()
assert isinstance(error, services.errors.file.UnsupportedFileTypeError)
class TestFileService:
"""Test FileService interface and methods."""
def test_upload_file_method_exists(self):
"""Test FileService.upload_file method exists."""
assert hasattr(FileService, "upload_file")
assert callable(FileService.upload_file)
@patch.object(FileService, "upload_file")
def test_upload_file_returns_upload_file_object(self, mock_upload):
"""Test upload_file returns an upload file object."""
mock_file = Mock()
mock_file.id = str(uuid.uuid4())
mock_file.name = "test.pdf"
mock_file.size = 1024
mock_file.extension = "pdf"
mock_file.mime_type = "application/pdf"
mock_upload.return_value = mock_file
# Call the method directly without instantiation
assert mock_file.name == "test.pdf"
assert mock_file.extension == "pdf"
@patch.object(FileService, "upload_file")
def test_upload_file_raises_file_too_large_error(self, mock_upload):
"""Test upload_file raises FileTooLargeError."""
import services.errors.file
mock_upload.side_effect = services.errors.file.FileTooLargeError("File exceeds 15MB limit")
# Verify error type exists
with pytest.raises(services.errors.file.FileTooLargeError):
mock_upload(Mock(), Mock(), "user_id")
@patch.object(FileService, "upload_file")
def test_upload_file_raises_unsupported_file_type_error(self, mock_upload):
"""Test upload_file raises UnsupportedFileTypeError."""
import services.errors.file
mock_upload.side_effect = services.errors.file.UnsupportedFileTypeError()
# Verify error type exists
with pytest.raises(services.errors.file.UnsupportedFileTypeError):
mock_upload(Mock(), Mock(), "user_id")
class TestFileValidation:
"""Test file validation patterns."""
def test_valid_image_mimetype(self):
"""Test common image MIME types."""
valid_mimetypes = ["image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml"]
for mimetype in valid_mimetypes:
assert mimetype.startswith("image/")
def test_valid_document_mimetype(self):
"""Test common document MIME types."""
valid_mimetypes = [
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
"text/csv",
]
for mimetype in valid_mimetypes:
assert mimetype is not None
assert len(mimetype) > 0
def test_filename_has_extension(self):
"""Test filename validation for extension presence."""
valid_filenames = ["document.pdf", "image.png", "data.csv", "report.docx"]
for filename in valid_filenames:
assert "." in filename
parts = filename.rsplit(".", 1)
assert len(parts) == 2
assert len(parts[1]) > 0 # Extension exists
def test_filename_without_extension_is_invalid(self):
"""Test that filename without extension can be detected."""
filename = "noextension"
assert "." not in filename
class TestFileUploadResponse:
"""Test file upload response structure."""
@patch.object(FileService, "upload_file")
def test_upload_response_structure(self, mock_upload):
"""Test upload response has expected structure."""
mock_file = Mock()
mock_file.id = str(uuid.uuid4())
mock_file.name = "test.pdf"
mock_file.size = 2048
mock_file.extension = "pdf"
mock_file.mime_type = "application/pdf"
mock_file.created_by = str(uuid.uuid4())
mock_file.created_at = Mock()
mock_upload.return_value = mock_file
# Verify expected fields exist on mock
assert hasattr(mock_file, "id")
assert hasattr(mock_file, "name")
assert hasattr(mock_file, "size")
assert hasattr(mock_file, "extension")
assert hasattr(mock_file, "mime_type")
assert hasattr(mock_file, "created_by")
assert hasattr(mock_file, "created_at")
# =============================================================================
# API Endpoint Tests
#
# ``FileApi.post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``
# which preserves ``__wrapped__`` via ``functools.wraps``. We call the
# unwrapped method directly to bypass the decorator.
# =============================================================================
from tests.unit_tests.controllers.service_api.conftest import _unwrap
@pytest.fixture
def mock_app_model():
from models import App
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
return app
@pytest.fixture
def mock_end_user():
from models import EndUser
user = Mock(spec=EndUser)
user.id = str(uuid.uuid4())
return user
class TestFileApiPost:
"""Test suite for FileApi.post() endpoint.
``post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``
which preserves ``__wrapped__``.
"""
@patch("controllers.service_api.app.file.FileService")
@patch("controllers.service_api.app.file.db")
def test_upload_file_success(
self,
mock_db,
mock_file_svc_cls,
app,
mock_app_model,
mock_end_user,
):
"""Test successful file upload."""
from io import BytesIO
from controllers.service_api.app.file import FileApi
mock_upload = Mock()
mock_upload.id = str(uuid.uuid4())
mock_upload.name = "test.pdf"
mock_upload.size = 1024
mock_upload.extension = "pdf"
mock_upload.mime_type = "application/pdf"
mock_upload.created_by = str(mock_end_user.id)
mock_upload.created_by_role = "end_user"
mock_upload.created_at = 1700000000
mock_upload.preview_url = None
mock_upload.source_url = None
mock_upload.original_url = None
mock_upload.user_id = None
mock_upload.tenant_id = None
mock_upload.conversation_id = None
mock_upload.file_key = None
mock_file_svc_cls.return_value.upload_file.return_value = mock_upload
data = {"file": (BytesIO(b"file content"), "test.pdf", "application/pdf")}
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data=data,
):
api = FileApi()
response, status = _unwrap(api.post)(
api,
app_model=mock_app_model,
end_user=mock_end_user,
)
assert status == 201
mock_file_svc_cls.return_value.upload_file.assert_called_once()
def test_upload_no_file(self, app, mock_app_model, mock_end_user):
"""Test NoFileUploadedError when no file in request."""
from controllers.service_api.app.file import FileApi
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data={},
):
api = FileApi()
with pytest.raises(NoFileUploadedError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_too_many_files(self, app, mock_app_model, mock_end_user):
"""Test TooManyFilesError when multiple files uploaded."""
from io import BytesIO
from controllers.service_api.app.file import FileApi
data = {
"file": (BytesIO(b"content1"), "file1.pdf", "application/pdf"),
"extra": (BytesIO(b"content2"), "file2.pdf", "application/pdf"),
}
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data=data,
):
api = FileApi()
with pytest.raises(TooManyFilesError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_no_mimetype(self, app, mock_app_model, mock_end_user):
"""Test UnsupportedFileTypeError when file has no mimetype."""
from io import BytesIO
from controllers.service_api.app.file import FileApi
data = {"file": (BytesIO(b"content"), "test.bin", "")}
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data=data,
):
api = FileApi()
with pytest.raises(UnsupportedFileTypeError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
@patch("controllers.service_api.app.file.FileService")
@patch("controllers.service_api.app.file.db")
def test_upload_file_too_large(
self,
mock_db,
mock_file_svc_cls,
app,
mock_app_model,
mock_end_user,
):
"""Test FileTooLargeError when file exceeds size limit."""
from io import BytesIO
import services.errors.file
from controllers.service_api.app.file import FileApi
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError(
"File exceeds 15MB limit"
)
data = {"file": (BytesIO(b"big content"), "big.pdf", "application/pdf")}
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data=data,
):
api = FileApi()
with pytest.raises(FileTooLargeError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
@patch("controllers.service_api.app.file.FileService")
@patch("controllers.service_api.app.file.db")
def test_upload_unsupported_file_type(
self,
mock_db,
mock_file_svc_cls,
app,
mock_app_model,
mock_end_user,
):
"""Test UnsupportedFileTypeError from FileService."""
from io import BytesIO
import services.errors.file
from controllers.service_api.app.file import FileApi
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
data = {"file": (BytesIO(b"content"), "test.xyz", "application/octet-stream")}
with app.test_request_context(
"/files/upload",
method="POST",
content_type="multipart/form-data",
data=data,
):
api = FileApi()
with pytest.raises(UnsupportedFileTypeError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)

View File

@ -0,0 +1,541 @@
"""
Unit tests for Service API Message controllers.
Tests coverage for:
- MessageListQuery, MessageFeedbackPayload, FeedbackListQuery Pydantic models
- App mode validation for message endpoints
- MessageService integration
- Error handling for message operations
Focus on:
- Pydantic model validation
- UUID normalization
- Error type mappings
- Service method interfaces
"""
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.app.message import (
AppGetFeedbacksApi,
FeedbackListQuery,
MessageFeedbackApi,
MessageFeedbackPayload,
MessageListApi,
MessageListQuery,
MessageSuggestedApi,
)
from models.model import App, AppMode, EndUser
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import (
FirstMessageNotExistsError,
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestMessageListQuery:
"""Test suite for MessageListQuery Pydantic model."""
def test_query_requires_conversation_id(self):
"""Test conversation_id is required."""
conversation_id = str(uuid.uuid4())
query = MessageListQuery(conversation_id=conversation_id)
assert str(query.conversation_id) == conversation_id
def test_query_with_defaults(self):
"""Test query with default values."""
conversation_id = str(uuid.uuid4())
query = MessageListQuery(conversation_id=conversation_id)
assert query.first_id is None
assert query.limit == 20
def test_query_with_first_id(self):
"""Test query with first_id for pagination."""
conversation_id = str(uuid.uuid4())
first_id = str(uuid.uuid4())
query = MessageListQuery(conversation_id=conversation_id, first_id=first_id)
assert str(query.first_id) == first_id
def test_query_with_custom_limit(self):
"""Test query with custom limit."""
conversation_id = str(uuid.uuid4())
query = MessageListQuery(conversation_id=conversation_id, limit=50)
assert query.limit == 50
def test_query_limit_boundaries(self):
"""Test query respects limit boundaries."""
conversation_id = str(uuid.uuid4())
query_min = MessageListQuery(conversation_id=conversation_id, limit=1)
assert query_min.limit == 1
query_max = MessageListQuery(conversation_id=conversation_id, limit=100)
assert query_max.limit == 100
def test_query_rejects_limit_below_minimum(self):
"""Test query rejects limit < 1."""
conversation_id = str(uuid.uuid4())
with pytest.raises(ValueError):
MessageListQuery(conversation_id=conversation_id, limit=0)
def test_query_rejects_limit_above_maximum(self):
"""Test query rejects limit > 100."""
conversation_id = str(uuid.uuid4())
with pytest.raises(ValueError):
MessageListQuery(conversation_id=conversation_id, limit=101)
class TestMessageFeedbackPayload:
"""Test suite for MessageFeedbackPayload Pydantic model."""
def test_payload_with_defaults(self):
"""Test payload with default values."""
payload = MessageFeedbackPayload()
assert payload.rating is None
assert payload.content is None
def test_payload_with_like_rating(self):
"""Test payload with like rating."""
payload = MessageFeedbackPayload(rating="like")
assert payload.rating == "like"
def test_payload_with_dislike_rating(self):
"""Test payload with dislike rating."""
payload = MessageFeedbackPayload(rating="dislike")
assert payload.rating == "dislike"
def test_payload_with_content_only(self):
"""Test payload with content but no rating."""
payload = MessageFeedbackPayload(content="This response was helpful")
assert payload.content == "This response was helpful"
assert payload.rating is None
def test_payload_with_rating_and_content(self):
"""Test payload with both rating and content."""
payload = MessageFeedbackPayload(rating="like", content="Great answer, very detailed!")
assert payload.rating == "like"
assert payload.content == "Great answer, very detailed!"
def test_payload_with_long_content(self):
"""Test payload with long feedback content."""
long_content = "A" * 1000
payload = MessageFeedbackPayload(content=long_content)
assert len(payload.content) == 1000
def test_payload_with_unicode_content(self):
"""Test payload with unicode characters."""
unicode_content = "很好的回答 👍 Отличный ответ"
payload = MessageFeedbackPayload(content=unicode_content)
assert payload.content == unicode_content
class TestFeedbackListQuery:
"""Test suite for FeedbackListQuery Pydantic model."""
def test_query_with_defaults(self):
"""Test query with default values."""
query = FeedbackListQuery()
assert query.page == 1
assert query.limit == 20
def test_query_with_custom_pagination(self):
"""Test query with custom page and limit."""
query = FeedbackListQuery(page=3, limit=50)
assert query.page == 3
assert query.limit == 50
def test_query_page_minimum(self):
"""Test query page minimum validation."""
query = FeedbackListQuery(page=1)
assert query.page == 1
def test_query_rejects_page_below_minimum(self):
"""Test query rejects page < 1."""
with pytest.raises(ValueError):
FeedbackListQuery(page=0)
def test_query_limit_boundaries(self):
"""Test query limit boundaries."""
query_min = FeedbackListQuery(limit=1)
assert query_min.limit == 1
query_max = FeedbackListQuery(limit=101)
assert query_max.limit == 101 # Max is 101
def test_query_rejects_limit_below_minimum(self):
"""Test query rejects limit < 1."""
with pytest.raises(ValueError):
FeedbackListQuery(limit=0)
def test_query_rejects_limit_above_maximum(self):
"""Test query rejects limit > 101."""
with pytest.raises(ValueError):
FeedbackListQuery(limit=102)
class TestMessageAppModeValidation:
"""Test app mode validation for message endpoints."""
def test_chat_modes_are_valid_for_message_endpoints(self):
"""Test that all chat modes are valid."""
valid_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
for mode in valid_modes:
assert mode in valid_modes
def test_completion_mode_is_invalid_for_message_endpoints(self):
"""Test that COMPLETION mode is invalid."""
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
assert AppMode.COMPLETION not in chat_modes
def test_workflow_mode_is_invalid_for_message_endpoints(self):
"""Test that WORKFLOW mode is invalid."""
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
assert AppMode.WORKFLOW not in chat_modes
def test_not_chat_app_error_can_be_raised(self):
"""Test NotChatAppError can be raised."""
error = NotChatAppError()
assert error is not None
class TestMessageErrorTypes:
"""Test message-related error types."""
def test_message_not_exists_error_can_be_raised(self):
"""Test MessageNotExistsError can be raised."""
error = MessageNotExistsError()
assert isinstance(error, MessageNotExistsError)
def test_first_message_not_exists_error_can_be_raised(self):
"""Test FirstMessageNotExistsError can be raised."""
error = FirstMessageNotExistsError()
assert isinstance(error, FirstMessageNotExistsError)
def test_suggested_questions_after_answer_disabled_error_can_be_raised(self):
"""Test SuggestedQuestionsAfterAnswerDisabledError can be raised."""
error = SuggestedQuestionsAfterAnswerDisabledError()
assert isinstance(error, SuggestedQuestionsAfterAnswerDisabledError)
class TestMessageService:
"""Test MessageService interface and methods."""
def test_pagination_by_first_id_method_exists(self):
"""Test MessageService.pagination_by_first_id exists."""
assert hasattr(MessageService, "pagination_by_first_id")
assert callable(MessageService.pagination_by_first_id)
def test_create_feedback_method_exists(self):
"""Test MessageService.create_feedback exists."""
assert hasattr(MessageService, "create_feedback")
assert callable(MessageService.create_feedback)
def test_get_all_messages_feedbacks_method_exists(self):
"""Test MessageService.get_all_messages_feedbacks exists."""
assert hasattr(MessageService, "get_all_messages_feedbacks")
assert callable(MessageService.get_all_messages_feedbacks)
def test_get_suggested_questions_after_answer_method_exists(self):
"""Test MessageService.get_suggested_questions_after_answer exists."""
assert hasattr(MessageService, "get_suggested_questions_after_answer")
assert callable(MessageService.get_suggested_questions_after_answer)
@patch.object(MessageService, "pagination_by_first_id")
def test_pagination_by_first_id_returns_pagination_result(self, mock_pagination):
"""Test pagination_by_first_id returns expected format."""
mock_result = Mock()
mock_result.data = []
mock_result.limit = 20
mock_result.has_more = False
mock_pagination.return_value = mock_result
result = MessageService.pagination_by_first_id(
app_model=Mock(spec=App),
user=Mock(spec=EndUser),
conversation_id=str(uuid.uuid4()),
first_id=None,
limit=20,
)
assert hasattr(result, "data")
assert hasattr(result, "limit")
assert hasattr(result, "has_more")
@patch.object(MessageService, "pagination_by_first_id")
def test_pagination_raises_conversation_not_exists_error(self, mock_pagination):
"""Test pagination raises ConversationNotExistsError."""
import services.errors.conversation
mock_pagination.side_effect = services.errors.conversation.ConversationNotExistsError()
with pytest.raises(services.errors.conversation.ConversationNotExistsError):
MessageService.pagination_by_first_id(
app_model=Mock(spec=App), user=Mock(spec=EndUser), conversation_id="invalid_id", first_id=None, limit=20
)
@patch.object(MessageService, "pagination_by_first_id")
def test_pagination_raises_first_message_not_exists_error(self, mock_pagination):
"""Test pagination raises FirstMessageNotExistsError."""
mock_pagination.side_effect = FirstMessageNotExistsError()
with pytest.raises(FirstMessageNotExistsError):
MessageService.pagination_by_first_id(
app_model=Mock(spec=App),
user=Mock(spec=EndUser),
conversation_id=str(uuid.uuid4()),
first_id="invalid_first_id",
limit=20,
)
@patch.object(MessageService, "create_feedback")
def test_create_feedback_with_rating_and_content(self, mock_create_feedback):
"""Test create_feedback with rating and content."""
mock_create_feedback.return_value = None
MessageService.create_feedback(
app_model=Mock(spec=App),
message_id=str(uuid.uuid4()),
user=Mock(spec=EndUser),
rating="like",
content="Great response!",
)
mock_create_feedback.assert_called_once()
@patch.object(MessageService, "create_feedback")
def test_create_feedback_raises_message_not_exists_error(self, mock_create_feedback):
"""Test create_feedback raises MessageNotExistsError."""
mock_create_feedback.side_effect = MessageNotExistsError()
with pytest.raises(MessageNotExistsError):
MessageService.create_feedback(
app_model=Mock(spec=App),
message_id="invalid_message_id",
user=Mock(spec=EndUser),
rating="like",
content=None,
)
@patch.object(MessageService, "get_all_messages_feedbacks")
def test_get_all_messages_feedbacks_returns_list(self, mock_get_feedbacks):
"""Test get_all_messages_feedbacks returns list of feedbacks."""
mock_feedbacks = [
{"message_id": str(uuid.uuid4()), "rating": "like"},
{"message_id": str(uuid.uuid4()), "rating": "dislike"},
]
mock_get_feedbacks.return_value = mock_feedbacks
result = MessageService.get_all_messages_feedbacks(app_model=Mock(spec=App), page=1, limit=20)
assert len(result) == 2
assert result[0]["rating"] == "like"
@patch.object(MessageService, "get_suggested_questions_after_answer")
def test_get_suggested_questions_returns_questions_list(self, mock_get_questions):
"""Test get_suggested_questions_after_answer returns list of questions."""
mock_questions = ["What about this aspect?", "Can you elaborate on that?", "How does this relate to...?"]
mock_get_questions.return_value = mock_questions
result = MessageService.get_suggested_questions_after_answer(
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id=str(uuid.uuid4()), invoke_from=Mock()
)
assert len(result) == 3
assert isinstance(result[0], str)
@patch.object(MessageService, "get_suggested_questions_after_answer")
def test_get_suggested_questions_raises_disabled_error(self, mock_get_questions):
"""Test get_suggested_questions_after_answer raises SuggestedQuestionsAfterAnswerDisabledError."""
mock_get_questions.side_effect = SuggestedQuestionsAfterAnswerDisabledError()
with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError):
MessageService.get_suggested_questions_after_answer(
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id=str(uuid.uuid4()), invoke_from=Mock()
)
@patch.object(MessageService, "get_suggested_questions_after_answer")
def test_get_suggested_questions_raises_message_not_exists_error(self, mock_get_questions):
"""Test get_suggested_questions_after_answer raises MessageNotExistsError."""
mock_get_questions.side_effect = MessageNotExistsError()
with pytest.raises(MessageNotExistsError):
MessageService.get_suggested_questions_after_answer(
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id="invalid_message_id", invoke_from=Mock()
)
class TestMessageListApi:
def test_not_chat_app(self, app) -> None:
api = MessageListApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages?conversation_id=cid", method="GET"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"pagination_by_first_id",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
api = MessageListApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/messages?conversation_id=00000000-0000-0000-0000-000000000001",
method="GET",
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
def test_first_message_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"pagination_by_first_id",
lambda *_args, **_kwargs: (_ for _ in ()).throw(FirstMessageNotExistsError()),
)
api = MessageListApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context(
"/messages?conversation_id=00000000-0000-0000-0000-000000000001&first_id=00000000-0000-0000-0000-000000000002",
method="GET",
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
class TestMessageFeedbackApi:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"create_feedback",
lambda *_args, **_kwargs: (_ for _ in ()).throw(MessageNotExistsError()),
)
api = MessageFeedbackApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace()
end_user = SimpleNamespace()
with app.test_request_context(
"/messages/m1/feedbacks",
method="POST",
json={"rating": "like", "content": "ok"},
):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
class TestAppGetFeedbacksApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"])
api = AppGetFeedbacksApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace()
with app.test_request_context("/app/feedbacks?page=1&limit=20", method="GET"):
response = handler(api, app_model=app_model)
assert response == {"data": ["f1"]}
class TestMessageSuggestedApi:
def test_not_chat(self, app) -> None:
api = MessageSuggestedApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages/m1/suggested", method="GET"):
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
lambda *_args, **_kwargs: (_ for _ in ()).throw(MessageNotExistsError()),
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages/m1/suggested", method="GET"):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_disabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
lambda *_args, **_kwargs: (_ for _ in ()).throw(SuggestedQuestionsAfterAnswerDisabledError()),
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages/m1/suggested", method="GET"):
with pytest.raises(BadRequest):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_internal_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages/m1/suggested", method="GET"):
with pytest.raises(InternalServerError):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
lambda *_args, **_kwargs: ["q1"],
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/messages/m1/suggested", method="GET"):
response = handler(api, app_model=app_model, end_user=end_user, message_id="m1")
assert response == {"result": "success", "data": ["q1"]}

View File

@ -0,0 +1,653 @@
"""
Unit tests for Service API Workflow controllers.
Tests coverage for:
- WorkflowRunPayload and WorkflowLogQuery Pydantic models
- Workflow execution error handling
- App mode validation for workflow endpoints
- Workflow stop mechanism validation
Focus on:
- Pydantic model validation
- Error type mappings
- Service method interfaces
"""
import sys
import uuid
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.app.workflow import (
AppQueueManager,
DifyAPIRepositoryFactory,
GraphEngineManager,
WorkflowAppLogApi,
WorkflowLogQuery,
WorkflowRunApi,
WorkflowRunByIdApi,
WorkflowRunDetailApi,
WorkflowRunPayload,
WorkflowTaskStopApi,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.workflow.enums import WorkflowExecutionStatus
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
class TestWorkflowRunPayload:
"""Test suite for WorkflowRunPayload Pydantic model."""
def test_payload_with_required_inputs(self):
"""Test payload with required inputs field."""
payload = WorkflowRunPayload(inputs={"key": "value"})
assert payload.inputs == {"key": "value"}
assert payload.files is None
assert payload.response_mode is None
def test_payload_with_all_fields(self):
"""Test payload with all fields populated."""
files = [{"type": "image", "url": "http://example.com/img.png"}]
payload = WorkflowRunPayload(inputs={"param1": "value1", "param2": 123}, files=files, response_mode="streaming")
assert payload.inputs == {"param1": "value1", "param2": 123}
assert payload.files == files
assert payload.response_mode == "streaming"
def test_payload_response_mode_blocking(self):
"""Test payload with blocking response mode."""
payload = WorkflowRunPayload(inputs={}, response_mode="blocking")
assert payload.response_mode == "blocking"
def test_payload_with_complex_inputs(self):
"""Test payload with nested complex inputs."""
complex_inputs = {
"config": {"nested": {"value": 123}},
"items": ["item1", "item2"],
"metadata": {"key": "value"},
}
payload = WorkflowRunPayload(inputs=complex_inputs)
assert payload.inputs == complex_inputs
def test_payload_with_empty_inputs(self):
"""Test payload with empty inputs dict."""
payload = WorkflowRunPayload(inputs={})
assert payload.inputs == {}
def test_payload_with_multiple_files(self):
"""Test payload with multiple file attachments."""
files = [
{"type": "image", "url": "http://example.com/img1.png"},
{"type": "document", "upload_file_id": "file_123"},
{"type": "audio", "url": "http://example.com/audio.mp3"},
]
payload = WorkflowRunPayload(inputs={}, files=files)
assert len(payload.files) == 3
class TestWorkflowLogQuery:
"""Test suite for WorkflowLogQuery Pydantic model."""
def test_query_with_defaults(self):
"""Test query with default values."""
query = WorkflowLogQuery()
assert query.keyword is None
assert query.status is None
assert query.created_at__before is None
assert query.created_at__after is None
assert query.created_by_end_user_session_id is None
assert query.created_by_account is None
assert query.page == 1
assert query.limit == 20
def test_query_with_all_filters(self):
"""Test query with all filter fields populated."""
query = WorkflowLogQuery(
keyword="search term",
status="succeeded",
created_at__before="2024-01-15T10:00:00Z",
created_at__after="2024-01-01T00:00:00Z",
created_by_end_user_session_id="session_123",
created_by_account="user@example.com",
page=2,
limit=50,
)
assert query.keyword == "search term"
assert query.status == "succeeded"
assert query.created_at__before == "2024-01-15T10:00:00Z"
assert query.created_at__after == "2024-01-01T00:00:00Z"
assert query.created_by_end_user_session_id == "session_123"
assert query.created_by_account == "user@example.com"
assert query.page == 2
assert query.limit == 50
@pytest.mark.parametrize("status", ["succeeded", "failed", "stopped"])
def test_query_valid_status_values(self, status):
"""Test all valid status values."""
query = WorkflowLogQuery(status=status)
assert query.status == status
def test_query_pagination_limits(self):
"""Test query pagination boundaries."""
query_min_page = WorkflowLogQuery(page=1)
assert query_min_page.page == 1
query_max_page = WorkflowLogQuery(page=99999)
assert query_max_page.page == 99999
query_min_limit = WorkflowLogQuery(limit=1)
assert query_min_limit.limit == 1
query_max_limit = WorkflowLogQuery(limit=100)
assert query_max_limit.limit == 100
def test_query_rejects_page_below_minimum(self):
"""Test query rejects page < 1."""
with pytest.raises(ValueError):
WorkflowLogQuery(page=0)
def test_query_rejects_page_above_maximum(self):
"""Test query rejects page > 99999."""
with pytest.raises(ValueError):
WorkflowLogQuery(page=100000)
def test_query_rejects_limit_below_minimum(self):
"""Test query rejects limit < 1."""
with pytest.raises(ValueError):
WorkflowLogQuery(limit=0)
def test_query_rejects_limit_above_maximum(self):
"""Test query rejects limit > 100."""
with pytest.raises(ValueError):
WorkflowLogQuery(limit=101)
def test_query_with_keyword_search(self):
"""Test query with keyword filter."""
query = WorkflowLogQuery(keyword="workflow execution")
assert query.keyword == "workflow execution"
def test_query_with_date_filters(self):
"""Test query with before/after date filters."""
query = WorkflowLogQuery(created_at__before="2024-12-31T23:59:59Z", created_at__after="2024-01-01T00:00:00Z")
assert query.created_at__before == "2024-12-31T23:59:59Z"
assert query.created_at__after == "2024-01-01T00:00:00Z"
class TestWorkflowAppService:
"""Test WorkflowAppService interface."""
def test_service_exists(self):
"""Test WorkflowAppService class exists."""
service = WorkflowAppService()
assert service is not None
def test_get_paginate_workflow_app_logs_method_exists(self):
"""Test get_paginate_workflow_app_logs method exists."""
assert hasattr(WorkflowAppService, "get_paginate_workflow_app_logs")
assert callable(WorkflowAppService.get_paginate_workflow_app_logs)
@patch.object(WorkflowAppService, "get_paginate_workflow_app_logs")
def test_get_paginate_workflow_app_logs_returns_pagination(self, mock_get_logs):
"""Test get_paginate_workflow_app_logs returns paginated result."""
mock_pagination = Mock()
mock_pagination.data = []
mock_pagination.page = 1
mock_pagination.limit = 20
mock_pagination.total = 0
mock_get_logs.return_value = mock_pagination
service = WorkflowAppService()
result = service.get_paginate_workflow_app_logs(
session=Mock(),
app_model=Mock(spec=App),
keyword=None,
status=None,
created_at_before=None,
created_at_after=None,
page=1,
limit=20,
created_by_end_user_session_id=None,
created_by_account=None,
)
assert result.page == 1
assert result.limit == 20
class TestWorkflowExecutionStatus:
"""Test WorkflowExecutionStatus enum."""
def test_succeeded_status_exists(self):
"""Test succeeded status value exists."""
status = WorkflowExecutionStatus("succeeded")
assert status.value == "succeeded"
def test_failed_status_exists(self):
"""Test failed status value exists."""
status = WorkflowExecutionStatus("failed")
assert status.value == "failed"
def test_stopped_status_exists(self):
"""Test stopped status value exists."""
status = WorkflowExecutionStatus("stopped")
assert status.value == "stopped"
class TestAppGenerateServiceWorkflow:
"""Test AppGenerateService workflow integration."""
@patch.object(AppGenerateService, "generate")
def test_generate_accepts_workflow_args(self, mock_generate):
"""Test generate accepts workflow-specific args."""
mock_generate.return_value = {"result": "success"}
result = AppGenerateService.generate(
app_model=Mock(spec=App),
user=Mock(),
args={"inputs": {"key": "value"}, "workflow_id": "workflow_123"},
invoke_from=Mock(),
streaming=False,
)
assert result == {"result": "success"}
mock_generate.assert_called_once()
@patch.object(AppGenerateService, "generate")
def test_generate_raises_workflow_not_found_error(self, mock_generate):
"""Test generate raises WorkflowNotFoundError."""
mock_generate.side_effect = WorkflowNotFoundError("Workflow not found")
with pytest.raises(WorkflowNotFoundError):
AppGenerateService.generate(
app_model=Mock(spec=App),
user=Mock(),
args={"workflow_id": "invalid_id"},
invoke_from=Mock(),
streaming=False,
)
@patch.object(AppGenerateService, "generate")
def test_generate_raises_is_draft_workflow_error(self, mock_generate):
"""Test generate raises IsDraftWorkflowError."""
mock_generate.side_effect = IsDraftWorkflowError("Workflow is draft")
with pytest.raises(IsDraftWorkflowError):
AppGenerateService.generate(
app_model=Mock(spec=App),
user=Mock(),
args={"workflow_id": "draft_workflow"},
invoke_from=Mock(),
streaming=False,
)
@patch.object(AppGenerateService, "generate")
def test_generate_supports_streaming_mode(self, mock_generate):
"""Test generate supports streaming response mode."""
mock_stream = Mock()
mock_generate.return_value = mock_stream
result = AppGenerateService.generate(
app_model=Mock(spec=App),
user=Mock(),
args={"inputs": {}, "response_mode": "streaming"},
invoke_from=Mock(),
streaming=True,
)
assert result == mock_stream
class TestWorkflowStopMechanism:
"""Test workflow stop mechanisms."""
def test_app_queue_manager_has_stop_flag_method(self):
"""Test AppQueueManager has set_stop_flag_no_user_check method."""
from core.app.apps.base_app_queue_manager import AppQueueManager
assert hasattr(AppQueueManager, "set_stop_flag_no_user_check")
def test_graph_engine_manager_has_send_stop_command(self):
"""Test GraphEngineManager has send_stop_command method."""
from core.workflow.graph_engine.manager import GraphEngineManager
assert hasattr(GraphEngineManager, "send_stop_command")
class TestWorkflowRunRepository:
"""Test workflow run repository interface."""
def test_repository_factory_can_create_workflow_run_repository(self):
"""Test DifyAPIRepositoryFactory can create workflow run repository."""
from repositories.factory import DifyAPIRepositoryFactory
assert hasattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository")
@patch("repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository")
def test_workflow_run_repository_get_by_id(self, mock_factory):
"""Test workflow run repository get_workflow_run_by_id method."""
mock_repo = Mock()
mock_run = Mock()
mock_run.id = str(uuid.uuid4())
mock_run.status = "succeeded"
mock_repo.get_workflow_run_by_id.return_value = mock_run
mock_factory.return_value = mock_repo
from repositories.factory import DifyAPIRepositoryFactory
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(Mock())
result = repo.get_workflow_run_by_id(tenant_id="tenant_123", app_id="app_456", run_id="run_789")
assert result.status == "succeeded"
class TestWorkflowRunDetailApi:
def test_not_workflow_app(self, app) -> None:
api = WorkflowRunDetailApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
with app.test_request_context("/workflows/run/1", method="GET"):
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, workflow_run_id="run")
def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
run = SimpleNamespace(id="run")
repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run)
workflow_module = sys.modules["controllers.service_api.app.workflow"]
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: repo,
)
api = WorkflowRunDetailApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
assert handler(api, app_model=app_model, workflow_run_id="run") == run
class TestWorkflowRunApi:
def test_not_workflow_app(self, app) -> None:
api = WorkflowRunApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}):
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_rate_limit(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(InvokeRateLimitError("slow")),
)
api = WorkflowRunApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}):
with pytest.raises(InvokeRateLimitHttpError):
handler(api, app_model=app_model, end_user=end_user)
class TestWorkflowRunByIdApi:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(WorkflowNotFoundError("missing")),
)
api = WorkflowRunByIdApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(IsDraftWorkflowError("draft")),
)
api = WorkflowRunByIdApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
with pytest.raises(BadRequest):
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
class TestWorkflowTaskStopApi:
def test_wrong_mode(self, app) -> None:
api = WorkflowTaskStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/tasks/1/stop", method="POST"):
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
stop_mock = Mock()
send_mock = Mock()
monkeypatch.setattr(AppQueueManager, "set_stop_flag_no_user_check", stop_mock)
monkeypatch.setattr(GraphEngineManager, "send_stop_command", send_mock)
api = WorkflowTaskStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/workflows/tasks/1/stop", method="POST"):
response = handler(api, app_model=app_model, end_user=end_user, task_id="t1")
assert response == {"result": "success"}
stop_mock.assert_called_once_with("t1")
send_mock.assert_called_once_with("t1")
class TestWorkflowAppLogApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
class _SessionStub:
def __enter__(self):
return SimpleNamespace()
def __exit__(self, exc_type, exc, tb):
return False
workflow_module = sys.modules["controllers.service_api.app.workflow"]
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(workflow_module, "Session", lambda *_args, **_kwargs: _SessionStub())
monkeypatch.setattr(
WorkflowAppService,
"get_paginate_workflow_app_logs",
lambda *_args, **_kwargs: {"items": [], "total": 0},
)
api = WorkflowAppLogApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/workflows/logs", method="GET"):
response = handler(api, app_model=app_model)
assert response == {"items": [], "total": 0}
# =============================================================================
# API Endpoint Tests
#
# ``WorkflowRunDetailApi``, ``WorkflowTaskStopApi``, and
# ``WorkflowAppLogApi`` use ``@validate_app_token`` which preserves
# ``__wrapped__`` via ``functools.wraps``. We call the unwrapped method
# directly to bypass the decorator.
# =============================================================================
from tests.unit_tests.controllers.service_api.conftest import _unwrap
@pytest.fixture
def mock_workflow_app():
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
app.mode = AppMode.WORKFLOW.value
return app
class TestWorkflowRunDetailApiGet:
"""Test suite for WorkflowRunDetailApi.get() endpoint.
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``)
and ``@service_api_ns.marshal_with``. We call the unwrapped method
directly; ``marshal_with`` is a no-op when calling directly.
"""
@patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory")
@patch("controllers.service_api.app.workflow.db")
def test_get_workflow_run_success(
self,
mock_db,
mock_repo_factory,
app,
mock_workflow_app,
):
"""Test successful workflow run detail retrieval."""
mock_run = Mock()
mock_run.id = "run-1"
mock_run.status = "succeeded"
mock_repo = Mock()
mock_repo.get_workflow_run_by_id.return_value = mock_run
mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo
from controllers.service_api.app.workflow import WorkflowRunDetailApi
with app.test_request_context(
f"/workflows/run/{mock_run.id}",
method="GET",
):
api = WorkflowRunDetailApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
assert result == mock_run
@patch("controllers.service_api.app.workflow.db")
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
"""Test NotWorkflowAppError when app mode is not workflow or advanced_chat."""
from controllers.service_api.app.workflow import WorkflowRunDetailApi
mock_app = Mock(spec=App)
mock_app.mode = AppMode.CHAT.value
with app.test_request_context("/workflows/run/run-1", method="GET"):
api = WorkflowRunDetailApi()
with pytest.raises(NotWorkflowAppError):
_unwrap(api.get)(api, app_model=mock_app, workflow_run_id="run-1")
class TestWorkflowTaskStopApiPost:
"""Test suite for WorkflowTaskStopApi.post() endpoint.
``post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``.
"""
@patch("controllers.service_api.app.workflow.GraphEngineManager")
@patch("controllers.service_api.app.workflow.AppQueueManager")
def test_stop_workflow_task_success(
self,
mock_queue_mgr,
mock_graph_mgr,
app,
mock_workflow_app,
):
"""Test successful workflow task stop."""
from controllers.service_api.app.workflow import WorkflowTaskStopApi
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
api = WorkflowTaskStopApi()
result = _unwrap(api.post)(
api,
app_model=mock_workflow_app,
end_user=Mock(),
task_id="task-1",
)
assert result == {"result": "success"}
mock_queue_mgr.set_stop_flag_no_user_check.assert_called_once_with("task-1")
mock_graph_mgr.send_stop_command.assert_called_once_with("task-1")
def test_stop_workflow_task_wrong_app_mode(self, app):
"""Test NotWorkflowAppError when app mode is not workflow."""
from controllers.service_api.app.workflow import WorkflowTaskStopApi
mock_app = Mock(spec=App)
mock_app.mode = AppMode.COMPLETION.value
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
api = WorkflowTaskStopApi()
with pytest.raises(NotWorkflowAppError):
_unwrap(api.post)(api, app_model=mock_app, end_user=Mock(), task_id="task-1")
class TestWorkflowAppLogApiGet:
"""Test suite for WorkflowAppLogApi.get() endpoint.
``get`` is wrapped by ``@validate_app_token`` and
``@service_api_ns.marshal_with``.
"""
@patch("controllers.service_api.app.workflow.WorkflowAppService")
@patch("controllers.service_api.app.workflow.db")
def test_get_workflow_logs_success(
self,
mock_db,
mock_wf_svc_cls,
app,
mock_workflow_app,
):
"""Test successful workflow log retrieval."""
mock_pagination = Mock()
mock_pagination.data = []
mock_svc_instance = Mock()
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
mock_wf_svc_cls.return_value = mock_svc_instance
# Mock Session context manager
mock_session = Mock()
mock_db.engine = Mock()
mock_session.__enter__ = Mock(return_value=mock_session)
mock_session.__exit__ = Mock(return_value=False)
from controllers.service_api.app.workflow import WorkflowAppLogApi
with app.test_request_context(
"/workflows/logs?page=1&limit=20",
method="GET",
):
with patch("controllers.service_api.app.workflow.Session", return_value=mock_session):
api = WorkflowAppLogApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
assert result == mock_pagination