mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
test: improve unit tests for controllers.service_api (#32073)
Co-authored-by: Rajat Agarwal <rajat.agarwal@infocusp.com>
This commit is contained in:
@ -0,0 +1,295 @@
|
||||
"""
|
||||
Unit tests for Service API Annotation controller.
|
||||
|
||||
Tests coverage for:
|
||||
- AnnotationCreatePayload Pydantic model validation
|
||||
- AnnotationReplyActionPayload Pydantic model validation
|
||||
- Error patterns and validation logic
|
||||
|
||||
Note: API endpoint tests for annotation controllers are complex due to:
|
||||
- @validate_app_token decorator requiring full Flask-SQLAlchemy setup
|
||||
- @edit_permission_required decorator checking current_user permissions
|
||||
- These are better covered by integration tests
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask_restx.api import HTTPStatus
|
||||
|
||||
from controllers.service_api.app.annotation import (
|
||||
AnnotationCreatePayload,
|
||||
AnnotationListApi,
|
||||
AnnotationReplyActionApi,
|
||||
AnnotationReplyActionPayload,
|
||||
AnnotationReplyActionStatusApi,
|
||||
AnnotationUpdateDeleteApi,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAnnotationCreatePayload:
|
||||
"""Test suite for AnnotationCreatePayload Pydantic model."""
|
||||
|
||||
def test_payload_with_question_and_answer(self):
|
||||
"""Test payload with required fields."""
|
||||
payload = AnnotationCreatePayload(
|
||||
question="What is AI?",
|
||||
answer="AI is artificial intelligence.",
|
||||
)
|
||||
assert payload.question == "What is AI?"
|
||||
assert payload.answer == "AI is artificial intelligence."
|
||||
|
||||
def test_payload_with_unicode_content(self):
|
||||
"""Test payload with unicode content."""
|
||||
payload = AnnotationCreatePayload(
|
||||
question="什么是人工智能?",
|
||||
answer="人工智能是模拟人类智能的技术。",
|
||||
)
|
||||
assert payload.question == "什么是人工智能?"
|
||||
|
||||
def test_payload_with_special_characters(self):
|
||||
"""Test payload with special characters."""
|
||||
payload = AnnotationCreatePayload(
|
||||
question="What is <b>AI</b>?",
|
||||
answer="AI & ML are related fields with 100% growth!",
|
||||
)
|
||||
assert "<b>" in payload.question
|
||||
|
||||
|
||||
class TestAnnotationReplyActionPayload:
|
||||
"""Test suite for AnnotationReplyActionPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all fields."""
|
||||
payload = AnnotationReplyActionPayload(
|
||||
score_threshold=0.8,
|
||||
embedding_provider_name="openai",
|
||||
embedding_model_name="text-embedding-ada-002",
|
||||
)
|
||||
assert payload.score_threshold == 0.8
|
||||
assert payload.embedding_provider_name == "openai"
|
||||
assert payload.embedding_model_name == "text-embedding-ada-002"
|
||||
|
||||
def test_payload_with_different_provider(self):
|
||||
"""Test payload with different embedding provider."""
|
||||
payload = AnnotationReplyActionPayload(
|
||||
score_threshold=0.75,
|
||||
embedding_provider_name="azure_openai",
|
||||
embedding_model_name="text-embedding-3-small",
|
||||
)
|
||||
assert payload.embedding_provider_name == "azure_openai"
|
||||
|
||||
def test_payload_with_zero_threshold(self):
|
||||
"""Test payload with zero score threshold."""
|
||||
payload = AnnotationReplyActionPayload(
|
||||
score_threshold=0.0,
|
||||
embedding_provider_name="local",
|
||||
embedding_model_name="default",
|
||||
)
|
||||
assert payload.score_threshold == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model and Error Pattern Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppModelPatterns:
|
||||
"""Test App model patterns used by annotation controller."""
|
||||
|
||||
def test_app_model_has_required_fields(self):
|
||||
"""Test App model has required fields for annotation operations."""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
|
||||
assert app.id is not None
|
||||
assert app.status == "normal"
|
||||
assert app.enable_api is True
|
||||
|
||||
def test_app_model_disabled_api(self):
|
||||
"""Test app with disabled API access."""
|
||||
app = Mock(spec=App)
|
||||
app.enable_api = False
|
||||
|
||||
assert app.enable_api is False
|
||||
|
||||
def test_app_model_archived_status(self):
|
||||
"""Test app with archived status."""
|
||||
app = Mock(spec=App)
|
||||
app.status = "archived"
|
||||
|
||||
assert app.status == "archived"
|
||||
|
||||
|
||||
class TestAnnotationErrorPatterns:
|
||||
"""Test annotation-related error handling patterns."""
|
||||
|
||||
def test_not_found_error_pattern(self):
|
||||
"""Test NotFound error pattern used in annotation operations."""
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
raise NotFound("Annotation not found.")
|
||||
|
||||
def test_forbidden_error_pattern(self):
|
||||
"""Test Forbidden error pattern."""
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
raise Forbidden("Permission denied.")
|
||||
|
||||
def test_value_error_for_job_not_found(self):
|
||||
"""Test ValueError pattern for job not found."""
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
raise ValueError("The job does not exist.")
|
||||
|
||||
|
||||
class TestAnnotationReplyActionApi:
|
||||
def test_enable(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
enable_mock = Mock()
|
||||
monkeypatch.setattr(AppAnnotationService, "enable_app_annotation", enable_mock)
|
||||
|
||||
api = AnnotationReplyActionApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/annotation-reply/enable",
|
||||
method="POST",
|
||||
json={"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"},
|
||||
):
|
||||
response, status = handler(api, app_model=app_model, action="enable")
|
||||
|
||||
assert status == 200
|
||||
enable_mock.assert_called_once()
|
||||
|
||||
def test_disable(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
disable_mock = Mock()
|
||||
monkeypatch.setattr(AppAnnotationService, "disable_app_annotation", disable_mock)
|
||||
|
||||
api = AnnotationReplyActionApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/annotation-reply/disable",
|
||||
method="POST",
|
||||
json={"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"},
|
||||
):
|
||||
response, status = handler(api, app_model=app_model, action="disable")
|
||||
|
||||
assert status == 200
|
||||
disable_mock.assert_called_once()
|
||||
|
||||
|
||||
class TestAnnotationReplyActionStatusApi:
|
||||
def test_missing_job(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(redis_client, "get", lambda *_args, **_kwargs: None)
|
||||
|
||||
api = AnnotationReplyActionStatusApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
handler(api, app_model=app_model, job_id="j1", action="enable")
|
||||
|
||||
def test_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _get(key):
|
||||
if "error" in key:
|
||||
return b"oops"
|
||||
return b"error"
|
||||
|
||||
monkeypatch.setattr(redis_client, "get", _get)
|
||||
|
||||
api = AnnotationReplyActionStatusApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
response, status = handler(api, app_model=app_model, job_id="j1", action="enable")
|
||||
|
||||
assert status == 200
|
||||
assert response["job_status"] == "error"
|
||||
assert response["error_msg"] == "oops"
|
||||
|
||||
|
||||
class TestAnnotationListApi:
|
||||
def test_get(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
|
||||
monkeypatch.setattr(
|
||||
AppAnnotationService,
|
||||
"get_annotation_list_by_app_id",
|
||||
lambda *_args, **_kwargs: ([annotation], 1),
|
||||
)
|
||||
|
||||
api = AnnotationListApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context("/apps/annotations?page=1&limit=1", method="GET"):
|
||||
response = handler(api, app_model=app_model)
|
||||
|
||||
assert response["total"] == 1
|
||||
|
||||
def test_create(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
|
||||
monkeypatch.setattr(
|
||||
AppAnnotationService,
|
||||
"insert_app_annotation_directly",
|
||||
lambda *_args, **_kwargs: annotation,
|
||||
)
|
||||
|
||||
api = AnnotationListApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context("/apps/annotations", method="POST", json={"question": "q", "answer": "a"}):
|
||||
response, status = handler(api, app_model=app_model)
|
||||
|
||||
assert status == HTTPStatus.CREATED
|
||||
assert response["question"] == "q"
|
||||
|
||||
|
||||
class TestAnnotationUpdateDeleteApi:
|
||||
def test_update_delete(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
annotation = SimpleNamespace(id="a1", question="q", content="a", created_at=0)
|
||||
monkeypatch.setattr(
|
||||
AppAnnotationService,
|
||||
"update_app_annotation_directly",
|
||||
lambda *_args, **_kwargs: annotation,
|
||||
)
|
||||
delete_mock = Mock()
|
||||
monkeypatch.setattr(AppAnnotationService, "delete_app_annotation", delete_mock)
|
||||
|
||||
api = AnnotationUpdateDeleteApi()
|
||||
put_handler = _unwrap(api.put)
|
||||
delete_handler = _unwrap(api.delete)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context("/apps/annotations/1", method="PUT", json={"question": "q", "answer": "a"}):
|
||||
response = put_handler(api, app_model=app_model, annotation_id="1")
|
||||
|
||||
assert response["answer"] == "a"
|
||||
|
||||
with app.test_request_context("/apps/annotations/1", method="DELETE"):
|
||||
response, status = delete_handler(api, app_model=app_model, annotation_id="1")
|
||||
|
||||
assert status == 204
|
||||
delete_mock.assert_called_once()
|
||||
496
api/tests/unit_tests/controllers/service_api/app/test_app.py
Normal file
496
api/tests/unit_tests/controllers/service_api/app/test_app.py
Normal file
@ -0,0 +1,496 @@
|
||||
"""
|
||||
Unit tests for Service API App controllers
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameterApi
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models.model import App, AppMode
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
|
||||
|
||||
class TestAppParameterApi:
|
||||
"""Test suite for AppParameterApi"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model."""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_parameters_for_chat_app(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test retrieving parameters for a chat app."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_config = Mock()
|
||||
mock_config.id = str(uuid.uuid4())
|
||||
mock_config.to_dict.return_value = {
|
||||
"user_input_form": [{"type": "text", "label": "Name", "variable": "name", "required": True}],
|
||||
"suggested_questions": [],
|
||||
}
|
||||
mock_app_model.app_model_config = mock_config
|
||||
mock_app_model.workflow = None
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
# Mock DB queries for app and tenant
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
# Mock tenant owner info for login
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppParameterApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert "opening_statement" in response
|
||||
assert "suggested_questions" in response
|
||||
assert "user_input_form" in response
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_parameters_for_workflow_app(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test retrieving parameters for a workflow app."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app_model.mode = AppMode.WORKFLOW
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.features_dict = {"suggested_questions": []}
|
||||
mock_workflow.user_input_form.return_value = [{"type": "text", "label": "Input", "variable": "input"}]
|
||||
mock_app_model.workflow = mock_workflow
|
||||
mock_app_model.app_model_config = None
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppParameterApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert "user_input_form" in response
|
||||
assert "opening_statement" in response
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_parameters_raises_error_when_chat_config_missing(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test that AppUnavailableError is raised when chat app has no config."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app_model.app_model_config = None
|
||||
mock_app_model.workflow = None
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppParameterApi()
|
||||
with pytest.raises(AppUnavailableError):
|
||||
api.get()
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_parameters_raises_error_when_workflow_missing(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test that AppUnavailableError is raised when workflow app has no workflow."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app_model.mode = AppMode.WORKFLOW
|
||||
mock_app_model.workflow = None
|
||||
mock_app_model.app_model_config = None
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppParameterApi()
|
||||
with pytest.raises(AppUnavailableError):
|
||||
api.get()
|
||||
|
||||
|
||||
class TestAppMetaApi:
|
||||
"""Test suite for AppMetaApi"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model."""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
@patch("controllers.service_api.app.app.AppService")
|
||||
def test_get_app_meta(
|
||||
self, mock_app_service, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test retrieving app metadata via AppService."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_service_instance = Mock()
|
||||
mock_service_instance.get_app_meta.return_value = {
|
||||
"tool_icons": {},
|
||||
"AgentIcons": {},
|
||||
}
|
||||
mock_app_service.return_value = mock_service_instance
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppMetaApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
mock_service_instance.get_app_meta.assert_called_once_with(mock_app_model)
|
||||
assert response == {"tool_icons": {}, "AgentIcons": {}}
|
||||
|
||||
|
||||
class TestAppInfoApi:
|
||||
"""Test suite for AppInfoApi"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model with all required attributes."""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.name = "Test App"
|
||||
app.description = "A test application"
|
||||
app.mode = AppMode.CHAT
|
||||
app.author_name = "Test Author"
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
|
||||
# Mock tags relationship
|
||||
mock_tag = Mock()
|
||||
mock_tag.name = "test-tag"
|
||||
app.tags = [mock_tag]
|
||||
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_app_info(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test retrieving basic app information."""
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppInfoApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["name"] == "Test App"
|
||||
assert response["description"] == "A test application"
|
||||
assert response["tags"] == ["test-tag"]
|
||||
assert response["mode"] == AppMode.CHAT
|
||||
assert response["author_name"] == "Test Author"
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_app_info_with_multiple_tags(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app
|
||||
):
|
||||
"""Test retrieving app info with multiple tags."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = str(uuid.uuid4())
|
||||
mock_app.tenant_id = str(uuid.uuid4())
|
||||
mock_app.name = "Multi Tag App"
|
||||
mock_app.description = "App with multiple tags"
|
||||
mock_app.mode = AppMode.WORKFLOW
|
||||
mock_app.author_name = "Author"
|
||||
mock_app.status = "normal"
|
||||
mock_app.enable_api = True
|
||||
|
||||
tag1, tag2, tag3 = Mock(), Mock(), Mock()
|
||||
tag1.name = "tag-one"
|
||||
tag2.name = "tag-two"
|
||||
tag3.name = "tag-three"
|
||||
mock_app.tags = [tag1, tag2, tag3]
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app.id
|
||||
mock_api_token.tenant_id = mock_app.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppInfoApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["tags"] == ["tag-one", "tag-two", "tag-three"]
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_app_info_with_no_tags(self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app):
|
||||
"""Test retrieving app info when app has no tags."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = str(uuid.uuid4())
|
||||
mock_app.tenant_id = str(uuid.uuid4())
|
||||
mock_app.name = "No Tags App"
|
||||
mock_app.description = "App without tags"
|
||||
mock_app.mode = AppMode.COMPLETION
|
||||
mock_app.author_name = "Author"
|
||||
mock_app.tags = []
|
||||
mock_app.status = "normal"
|
||||
mock_app.enable_api = True
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app.id
|
||||
mock_api_token.tenant_id = mock_app.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppInfoApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["tags"] == []
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"app_mode",
|
||||
[AppMode.CHAT, AppMode.COMPLETION, AppMode.WORKFLOW, AppMode.ADVANCED_CHAT],
|
||||
)
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_app_info_returns_correct_mode(
|
||||
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, app_mode
|
||||
):
|
||||
"""Test that all app modes are correctly returned."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = str(uuid.uuid4())
|
||||
mock_app.tenant_id = str(uuid.uuid4())
|
||||
mock_app.name = "Test"
|
||||
mock_app.description = "Test"
|
||||
mock_app.mode = app_mode
|
||||
mock_app.author_name = "Test"
|
||||
mock_app.tags = []
|
||||
mock_app.status = "normal"
|
||||
mock_app.enable_api = True
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app.id
|
||||
mock_api_token.tenant_id = mock_app.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = "normal"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_app,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppInfoApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["mode"] == app_mode
|
||||
298
api/tests/unit_tests/controllers/service_api/app/test_audio.py
Normal file
298
api/tests/unit_tests/controllers/service_api/app/test_audio.py
Normal file
@ -0,0 +1,298 @@
|
||||
"""
|
||||
Unit tests for Service API Audio controller.
|
||||
|
||||
Tests coverage for:
|
||||
- TextToAudioPayload Pydantic model validation
|
||||
- Error mapping patterns between service and API errors
|
||||
- AudioService method interfaces
|
||||
"""
|
||||
|
||||
import io
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api.app.audio import AudioApi, TextApi, TextToAudioPayload
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _file_data():
|
||||
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTextToAudioPayload:
|
||||
"""Test suite for TextToAudioPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all fields populated."""
|
||||
payload = TextToAudioPayload(
|
||||
message_id="msg_123",
|
||||
voice="nova",
|
||||
text="Hello, this is a test.",
|
||||
streaming=False,
|
||||
)
|
||||
assert payload.message_id == "msg_123"
|
||||
assert payload.voice == "nova"
|
||||
assert payload.text == "Hello, this is a test."
|
||||
assert payload.streaming is False
|
||||
|
||||
def test_payload_with_defaults(self):
|
||||
"""Test payload with default values."""
|
||||
payload = TextToAudioPayload()
|
||||
assert payload.message_id is None
|
||||
assert payload.voice is None
|
||||
assert payload.text is None
|
||||
assert payload.streaming is None
|
||||
|
||||
def test_payload_with_only_text(self):
|
||||
"""Test payload with only text field."""
|
||||
payload = TextToAudioPayload(text="Simple text to speech")
|
||||
assert payload.text == "Simple text to speech"
|
||||
assert payload.voice is None
|
||||
assert payload.message_id is None
|
||||
|
||||
def test_payload_with_streaming_true(self):
|
||||
"""Test payload with streaming enabled."""
|
||||
payload = TextToAudioPayload(
|
||||
text="Streaming test",
|
||||
streaming=True,
|
||||
)
|
||||
assert payload.streaming is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AudioService Interface Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAudioServiceInterface:
|
||||
"""Test AudioService method interfaces exist."""
|
||||
|
||||
def test_transcript_asr_method_exists(self):
|
||||
"""Test that AudioService.transcript_asr exists."""
|
||||
assert hasattr(AudioService, "transcript_asr")
|
||||
assert callable(AudioService.transcript_asr)
|
||||
|
||||
def test_transcript_tts_method_exists(self):
|
||||
"""Test that AudioService.transcript_tts exists."""
|
||||
assert hasattr(AudioService, "transcript_tts")
|
||||
assert callable(AudioService.transcript_tts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio Service Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAudioServiceInterface:
|
||||
"""Test suite for AudioService interface methods."""
|
||||
|
||||
def test_transcript_asr_method_exists(self):
|
||||
"""Test that AudioService.transcript_asr exists."""
|
||||
assert hasattr(AudioService, "transcript_asr")
|
||||
assert callable(AudioService.transcript_asr)
|
||||
|
||||
def test_transcript_tts_method_exists(self):
|
||||
"""Test that AudioService.transcript_tts exists."""
|
||||
assert hasattr(AudioService, "transcript_tts")
|
||||
assert callable(AudioService.transcript_tts)
|
||||
|
||||
|
||||
class TestServiceErrorTypes:
|
||||
"""Test service error types used by audio controller."""
|
||||
|
||||
def test_no_audio_uploaded_service_error(self):
|
||||
"""Test NoAudioUploadedServiceError exists."""
|
||||
error = NoAudioUploadedServiceError()
|
||||
assert error is not None
|
||||
|
||||
def test_audio_too_large_service_error(self):
|
||||
"""Test AudioTooLargeServiceError with message."""
|
||||
error = AudioTooLargeServiceError("File too large")
|
||||
assert "File too large" in str(error)
|
||||
|
||||
def test_unsupported_audio_type_service_error(self):
|
||||
"""Test UnsupportedAudioTypeServiceError exists."""
|
||||
error = UnsupportedAudioTypeServiceError()
|
||||
assert error is not None
|
||||
|
||||
def test_provider_not_support_speech_to_text_service_error(self):
|
||||
"""Test ProviderNotSupportSpeechToTextServiceError exists."""
|
||||
error = ProviderNotSupportSpeechToTextServiceError()
|
||||
assert error is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mocked Behavior Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAudioServiceMockedBehavior:
|
||||
"""Test AudioService behavior with mocked methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock app model."""
|
||||
from models.model import App
|
||||
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file(self):
|
||||
"""Create mock file upload."""
|
||||
mock = Mock()
|
||||
mock.filename = "test_audio.mp3"
|
||||
mock.content_type = "audio/mpeg"
|
||||
return mock
|
||||
|
||||
@patch.object(AudioService, "transcript_asr")
|
||||
def test_transcript_asr_returns_response(self, mock_asr, mock_app, mock_file):
|
||||
"""Test ASR transcription returns response dict."""
|
||||
mock_response = {"text": "Transcribed text"}
|
||||
mock_asr.return_value = mock_response
|
||||
|
||||
result = AudioService.transcript_asr(
|
||||
app_model=mock_app,
|
||||
file=mock_file,
|
||||
end_user="user_123",
|
||||
)
|
||||
|
||||
assert result["text"] == "Transcribed text"
|
||||
|
||||
@patch.object(AudioService, "transcript_tts")
|
||||
def test_transcript_tts_returns_response(self, mock_tts, mock_app):
|
||||
"""Test TTS transcription returns response."""
|
||||
mock_response = {"audio": "base64_audio_data"}
|
||||
mock_tts.return_value = mock_response
|
||||
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=mock_app,
|
||||
text="Hello world",
|
||||
voice="nova",
|
||||
end_user="user_123",
|
||||
message_id="msg_123",
|
||||
)
|
||||
|
||||
assert result["audio"] == "base64_audio_data"
|
||||
|
||||
|
||||
class TestAudioApi:
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
|
||||
api = AudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
response = handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
assert response == {"text": "ok"}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected"),
|
||||
[
|
||||
(AppModelConfigBrokenError(), AppUnavailableError),
|
||||
(NoAudioUploadedServiceError(), NoAudioUploadedError),
|
||||
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
|
||||
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
|
||||
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
|
||||
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
|
||||
(QuotaExceededError(), ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
|
||||
(InvokeError("invoke"), CompletionRequestError),
|
||||
],
|
||||
)
|
||||
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
|
||||
api = AudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(expected):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_unhandled_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
)
|
||||
api = AudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(InternalServerError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestTextApi:
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
api = TextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(external_user_id="ext")
|
||||
|
||||
with app.test_request_context(
|
||||
"/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "voice": "v"},
|
||||
):
|
||||
response = handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError())
|
||||
)
|
||||
|
||||
api = TextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(external_user_id="ext")
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST", json={"text": "hello"}):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
@ -0,0 +1,524 @@
|
||||
"""
|
||||
Unit tests for Service API Completion controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- CompletionRequestPayload and ChatRequestPayload Pydantic models
|
||||
- App mode validation logic
|
||||
- Error mapping from service layer to HTTP errors
|
||||
|
||||
Focus on:
|
||||
- Pydantic model validation (especially UUID normalization)
|
||||
- Error types and their mappings
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api.app.completion import (
|
||||
ChatApi,
|
||||
ChatRequestPayload,
|
||||
ChatStopApi,
|
||||
CompletionApi,
|
||||
CompletionRequestPayload,
|
||||
CompletionStopApi,
|
||||
)
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
ConversationCompletedError,
|
||||
NotChatAppError,
|
||||
)
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestCompletionRequestPayload:
|
||||
"""Test suite for CompletionRequestPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_required_fields(self):
|
||||
"""Test payload with only required inputs field."""
|
||||
payload = CompletionRequestPayload(inputs={"name": "test"})
|
||||
assert payload.inputs == {"name": "test"}
|
||||
assert payload.query == ""
|
||||
assert payload.files is None
|
||||
assert payload.response_mode is None
|
||||
assert payload.retriever_from == "dev"
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all fields populated."""
|
||||
payload = CompletionRequestPayload(
|
||||
inputs={"user_input": "Hello"},
|
||||
query="What is AI?",
|
||||
files=[{"type": "image", "url": "http://example.com/image.png"}],
|
||||
response_mode="streaming",
|
||||
retriever_from="api",
|
||||
)
|
||||
assert payload.inputs == {"user_input": "Hello"}
|
||||
assert payload.query == "What is AI?"
|
||||
assert payload.files == [{"type": "image", "url": "http://example.com/image.png"}]
|
||||
assert payload.response_mode == "streaming"
|
||||
assert payload.retriever_from == "api"
|
||||
|
||||
def test_payload_response_mode_blocking(self):
|
||||
"""Test payload with blocking response mode."""
|
||||
payload = CompletionRequestPayload(inputs={}, response_mode="blocking")
|
||||
assert payload.response_mode == "blocking"
|
||||
|
||||
def test_payload_empty_inputs(self):
|
||||
"""Test payload with empty inputs dict."""
|
||||
payload = CompletionRequestPayload(inputs={})
|
||||
assert payload.inputs == {}
|
||||
|
||||
def test_payload_complex_inputs(self):
|
||||
"""Test payload with complex nested inputs."""
|
||||
complex_inputs = {
|
||||
"user": {"name": "Alice", "age": 30},
|
||||
"context": ["item1", "item2"],
|
||||
"settings": {"theme": "dark", "notifications": True},
|
||||
}
|
||||
payload = CompletionRequestPayload(inputs=complex_inputs)
|
||||
assert payload.inputs == complex_inputs
|
||||
|
||||
|
||||
class TestChatRequestPayload:
|
||||
"""Test suite for ChatRequestPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_required_fields(self):
|
||||
"""Test payload with required fields."""
|
||||
payload = ChatRequestPayload(inputs={"key": "value"}, query="Hello")
|
||||
assert payload.inputs == {"key": "value"}
|
||||
assert payload.query == "Hello"
|
||||
assert payload.conversation_id is None
|
||||
assert payload.auto_generate_name is True
|
||||
|
||||
def test_payload_normalizes_valid_uuid_conversation_id(self):
|
||||
"""Test that valid UUID conversation_id is normalized."""
|
||||
valid_uuid = str(uuid.uuid4())
|
||||
payload = ChatRequestPayload(inputs={}, query="test", conversation_id=valid_uuid)
|
||||
assert payload.conversation_id == valid_uuid
|
||||
|
||||
def test_payload_normalizes_empty_string_conversation_id_to_none(self):
|
||||
"""Test that empty string conversation_id becomes None."""
|
||||
payload = ChatRequestPayload(inputs={}, query="test", conversation_id="")
|
||||
assert payload.conversation_id is None
|
||||
|
||||
def test_payload_normalizes_whitespace_conversation_id_to_none(self):
|
||||
"""Test that whitespace-only conversation_id becomes None."""
|
||||
payload = ChatRequestPayload(inputs={}, query="test", conversation_id=" ")
|
||||
assert payload.conversation_id is None
|
||||
|
||||
def test_payload_rejects_invalid_uuid_conversation_id(self):
|
||||
"""Test that invalid UUID format raises ValueError."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ChatRequestPayload(inputs={}, query="test", conversation_id="not-a-uuid")
|
||||
assert "valid UUID" in str(exc_info.value)
|
||||
|
||||
def test_payload_with_workflow_id(self):
|
||||
"""Test payload with workflow_id for advanced chat."""
|
||||
payload = ChatRequestPayload(inputs={}, query="test", workflow_id="workflow_123")
|
||||
assert payload.workflow_id == "workflow_123"
|
||||
|
||||
def test_payload_streaming_mode(self):
|
||||
"""Test payload with streaming response mode."""
|
||||
payload = ChatRequestPayload(inputs={}, query="test", response_mode="streaming")
|
||||
assert payload.response_mode == "streaming"
|
||||
|
||||
def test_payload_auto_generate_name_false(self):
|
||||
"""Test payload with auto_generate_name explicitly false."""
|
||||
payload = ChatRequestPayload(inputs={}, query="test", auto_generate_name=False)
|
||||
assert payload.auto_generate_name is False
|
||||
|
||||
def test_payload_with_files(self):
|
||||
"""Test payload with file attachments."""
|
||||
files = [
|
||||
{"type": "image", "transfer_method": "remote_url", "url": "http://example.com/img.png"},
|
||||
{"type": "document", "transfer_method": "local_file", "upload_file_id": "file_123"},
|
||||
]
|
||||
payload = ChatRequestPayload(inputs={}, query="test", files=files)
|
||||
assert payload.files == files
|
||||
assert len(payload.files) == 2
|
||||
|
||||
|
||||
class TestCompletionErrorMappings:
|
||||
"""Test error type mappings for completion endpoints."""
|
||||
|
||||
def test_conversation_not_exists_error_exists(self):
|
||||
"""Test ConversationNotExistsError can be raised."""
|
||||
error = services.errors.conversation.ConversationNotExistsError()
|
||||
assert isinstance(error, services.errors.conversation.ConversationNotExistsError)
|
||||
|
||||
def test_conversation_completed_error_exists(self):
|
||||
"""Test ConversationCompletedError can be raised."""
|
||||
error = services.errors.conversation.ConversationCompletedError()
|
||||
assert isinstance(error, services.errors.conversation.ConversationCompletedError)
|
||||
|
||||
api_error = ConversationCompletedError()
|
||||
assert api_error is not None
|
||||
|
||||
def test_app_model_config_broken_error_exists(self):
|
||||
"""Test AppModelConfigBrokenError can be raised."""
|
||||
error = services.errors.app_model_config.AppModelConfigBrokenError()
|
||||
assert isinstance(error, services.errors.app_model_config.AppModelConfigBrokenError)
|
||||
|
||||
api_error = AppUnavailableError()
|
||||
assert api_error is not None
|
||||
|
||||
def test_workflow_not_found_error_exists(self):
|
||||
"""Test WorkflowNotFoundError can be raised."""
|
||||
error = WorkflowNotFoundError("Workflow not found")
|
||||
assert isinstance(error, WorkflowNotFoundError)
|
||||
|
||||
def test_is_draft_workflow_error_exists(self):
|
||||
"""Test IsDraftWorkflowError can be raised."""
|
||||
error = IsDraftWorkflowError("Workflow is in draft state")
|
||||
assert isinstance(error, IsDraftWorkflowError)
|
||||
|
||||
def test_workflow_id_format_error_exists(self):
|
||||
"""Test WorkflowIdFormatError can be raised."""
|
||||
error = WorkflowIdFormatError("Invalid workflow ID format")
|
||||
assert isinstance(error, WorkflowIdFormatError)
|
||||
|
||||
def test_invoke_rate_limit_error_exists(self):
|
||||
"""Test InvokeRateLimitError can be raised."""
|
||||
error = InvokeRateLimitError("Rate limit exceeded")
|
||||
assert isinstance(error, InvokeRateLimitError)
|
||||
|
||||
|
||||
class TestAppModeValidation:
|
||||
"""Test app mode validation logic patterns."""
|
||||
|
||||
def test_completion_mode_is_valid_for_completion_endpoint(self):
|
||||
"""Test that COMPLETION mode is valid for completion endpoints."""
|
||||
assert AppMode.COMPLETION == AppMode.COMPLETION
|
||||
|
||||
def test_chat_modes_are_distinct_from_completion(self):
|
||||
"""Test that chat modes are distinct from completion mode."""
|
||||
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
assert AppMode.COMPLETION not in chat_modes
|
||||
|
||||
def test_workflow_mode_is_distinct_from_chat_modes(self):
|
||||
"""Test that WORKFLOW mode is not a chat mode."""
|
||||
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
assert AppMode.WORKFLOW not in chat_modes
|
||||
|
||||
def test_not_chat_app_error_can_be_raised(self):
|
||||
"""Test NotChatAppError can be raised for non-chat apps."""
|
||||
error = NotChatAppError()
|
||||
assert error is not None
|
||||
|
||||
def test_all_app_modes_are_defined(self):
|
||||
"""Test that all expected app modes are defined."""
|
||||
expected_modes = ["COMPLETION", "CHAT", "AGENT_CHAT", "ADVANCED_CHAT", "WORKFLOW", "CHANNEL", "RAG_PIPELINE"]
|
||||
for mode_name in expected_modes:
|
||||
assert hasattr(AppMode, mode_name), f"AppMode.{mode_name} should exist"
|
||||
|
||||
|
||||
class TestAppGenerateService:
|
||||
"""Test AppGenerateService integration patterns."""
|
||||
|
||||
def test_generate_method_exists(self):
|
||||
"""Test that AppGenerateService.generate method exists."""
|
||||
assert hasattr(AppGenerateService, "generate")
|
||||
assert callable(AppGenerateService.generate)
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_returns_response(self, mock_generate):
|
||||
"""Test that generate returns expected response format."""
|
||||
expected = {"answer": "Hello!"}
|
||||
mock_generate.return_value = expected
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={"query": "Hi"}, invoke_from=Mock(), streaming=False
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_raises_conversation_not_exists(self, mock_generate):
|
||||
"""Test generate raises ConversationNotExistsError."""
|
||||
mock_generate.side_effect = services.errors.conversation.ConversationNotExistsError()
|
||||
|
||||
with pytest.raises(services.errors.conversation.ConversationNotExistsError):
|
||||
AppGenerateService.generate(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
|
||||
)
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_raises_quota_exceeded(self, mock_generate):
|
||||
"""Test generate raises QuotaExceededError."""
|
||||
mock_generate.side_effect = QuotaExceededError()
|
||||
|
||||
with pytest.raises(QuotaExceededError):
|
||||
AppGenerateService.generate(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
|
||||
)
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_raises_invoke_error(self, mock_generate):
|
||||
"""Test generate raises InvokeError."""
|
||||
mock_generate.side_effect = InvokeError("Model invocation failed")
|
||||
|
||||
with pytest.raises(InvokeError):
|
||||
AppGenerateService.generate(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), args={}, invoke_from=Mock(), streaming=False
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionControllerLogic:
|
||||
"""Test CompletionApi and ChatApi controller logic directly."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
from flask import Flask
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.service_api.app.completion.service_api_ns")
|
||||
@patch("controllers.service_api.app.completion.AppGenerateService")
|
||||
def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
|
||||
"""Test CompletionApi.post success path."""
|
||||
from controllers.service_api.app.completion import CompletionApi
|
||||
|
||||
# Setup mocks
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.COMPLETION
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
|
||||
payload_dict = {"inputs": {"text": "hello"}, "response_mode": "blocking"}
|
||||
mock_service_api_ns.payload = payload_dict
|
||||
mock_generate_service.generate.return_value = {"text": "response"}
|
||||
|
||||
with app.test_request_context():
|
||||
# Helper for compact_generate_response logic check
|
||||
with patch("controllers.service_api.app.completion.helper.compact_generate_response") as mock_compact:
|
||||
mock_compact.return_value = {"text": "compacted"}
|
||||
|
||||
api = CompletionApi()
|
||||
response = api.post.__wrapped__(api, mock_app_model, mock_end_user)
|
||||
|
||||
assert response == {"text": "compacted"}
|
||||
mock_generate_service.generate.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.app.completion.service_api_ns")
|
||||
def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app):
|
||||
"""Test CompletionApi.post with wrong app mode."""
|
||||
from controllers.service_api.app.completion import CompletionApi
|
||||
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.CHAT # Wrong mode
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
|
||||
with app.test_request_context():
|
||||
with pytest.raises(AppUnavailableError):
|
||||
CompletionApi().post.__wrapped__(CompletionApi(), mock_app_model, mock_end_user)
|
||||
|
||||
@patch("controllers.service_api.app.completion.service_api_ns")
|
||||
@patch("controllers.service_api.app.completion.AppGenerateService")
|
||||
def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
|
||||
"""Test ChatApi.post success path."""
|
||||
from controllers.service_api.app.completion import ChatApi
|
||||
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.CHAT
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
|
||||
payload_dict = {"inputs": {}, "query": "hello", "response_mode": "blocking"}
|
||||
mock_service_api_ns.payload = payload_dict
|
||||
mock_generate_service.generate.return_value = {"text": "response"}
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("controllers.service_api.app.completion.helper.compact_generate_response") as mock_compact:
|
||||
mock_compact.return_value = {"text": "compacted"}
|
||||
|
||||
api = ChatApi()
|
||||
response = api.post.__wrapped__(api, mock_app_model, mock_end_user)
|
||||
assert response == {"text": "compacted"}
|
||||
|
||||
@patch("controllers.service_api.app.completion.service_api_ns")
|
||||
def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app):
|
||||
"""Test ChatApi.post with wrong app mode."""
|
||||
from controllers.service_api.app.completion import ChatApi
|
||||
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.COMPLETION # Wrong mode
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
|
||||
with app.test_request_context():
|
||||
with pytest.raises(NotChatAppError):
|
||||
ChatApi().post.__wrapped__(ChatApi(), mock_app_model, mock_end_user)
|
||||
|
||||
@patch("controllers.service_api.app.completion.AppTaskService")
|
||||
def test_completion_stop_api_success(self, mock_task_service, app):
|
||||
"""Test CompletionStopApi.post success."""
|
||||
from controllers.service_api.app.completion import CompletionStopApi
|
||||
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.COMPLETION
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
mock_end_user.id = "user_id"
|
||||
|
||||
with app.test_request_context():
|
||||
api = CompletionStopApi()
|
||||
response = api.post.__wrapped__(api, mock_app_model, mock_end_user, "task_id")
|
||||
|
||||
assert response == ({"result": "success"}, 200)
|
||||
mock_task_service.stop_task.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.app.completion.AppTaskService")
|
||||
def test_chat_stop_api_success(self, mock_task_service, app):
|
||||
"""Test ChatStopApi.post success."""
|
||||
from controllers.service_api.app.completion import ChatStopApi
|
||||
|
||||
mock_app_model = Mock(spec=App)
|
||||
mock_app_model.mode = AppMode.CHAT
|
||||
mock_end_user = Mock(spec=EndUser)
|
||||
mock_end_user.id = "user_id"
|
||||
|
||||
with app.test_request_context():
|
||||
api = ChatStopApi()
|
||||
response = api.post.__wrapped__(api, mock_app_model, mock_end_user, "task_id")
|
||||
|
||||
assert response == ({"result": "success"}, 200)
|
||||
mock_task_service.stop_task.assert_called_once()
|
||||
|
||||
|
||||
class TestChatRequestPayloadController:
|
||||
def test_normalizes_conversation_id(self) -> None:
|
||||
payload = ChatRequestPayload.model_validate(
|
||||
{"inputs": {}, "query": "hi", "conversation_id": " ", "response_mode": "blocking"}
|
||||
)
|
||||
assert payload.conversation_id is None
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ChatRequestPayload.model_validate({"inputs": {}, "query": "hi", "conversation_id": "bad-id"})
|
||||
|
||||
|
||||
class TestCompletionApiController:
|
||||
def test_wrong_mode(self, app) -> None:
|
||||
api = CompletionApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
api = CompletionApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestCompletionStopApiController:
|
||||
def test_wrong_mode(self, app) -> None:
|
||||
api = CompletionStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/completion-messages/1/stop", method="POST"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
|
||||
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
stop_mock = Mock()
|
||||
monkeypatch.setattr(AppTaskService, "stop_task", stop_mock)
|
||||
|
||||
api = CompletionStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/completion-messages/1/stop", method="POST"):
|
||||
response, status = handler(api, app_model=app_model, end_user=end_user, task_id="t1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"result": "success"}
|
||||
|
||||
|
||||
class TestChatApiController:
|
||||
def test_wrong_mode(self, app) -> None:
|
||||
api = ChatApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_workflow_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(WorkflowNotFoundError("missing")),
|
||||
)
|
||||
|
||||
api = ChatApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(IsDraftWorkflowError("draft")),
|
||||
)
|
||||
|
||||
api = ChatApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST", json={"inputs": {}, "query": "hi"}):
|
||||
with pytest.raises(BadRequest):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestChatStopApiController:
|
||||
def test_wrong_mode(self, app) -> None:
|
||||
api = ChatStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/chat-messages/1/stop", method="POST"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
|
||||
@ -0,0 +1,597 @@
|
||||
"""
|
||||
Unit tests for Service API Conversation controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- ConversationListQuery, ConversationRenamePayload Pydantic models
|
||||
- ConversationVariablesQuery with SQL injection prevention
|
||||
- ConversationVariableUpdatePayload
|
||||
- App mode validation for chat-only endpoints
|
||||
|
||||
Focus on:
|
||||
- Pydantic model validation including security checks
|
||||
- SQL injection prevention in variable name filtering
|
||||
- Error types and mappings
|
||||
"""
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api.app.conversation import (
|
||||
ConversationApi,
|
||||
ConversationDetailApi,
|
||||
ConversationListQuery,
|
||||
ConversationRenameApi,
|
||||
ConversationRenamePayload,
|
||||
ConversationVariableDetailApi,
|
||||
ConversationVariablesApi,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableUpdatePayload,
|
||||
)
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import (
|
||||
ConversationNotExistsError,
|
||||
ConversationVariableNotExistsError,
|
||||
ConversationVariableTypeMismatchError,
|
||||
LastConversationNotExistsError,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestConversationListQuery:
|
||||
"""Test suite for ConversationListQuery Pydantic model."""
|
||||
|
||||
def test_query_with_defaults(self):
|
||||
"""Test query with default values."""
|
||||
query = ConversationListQuery()
|
||||
assert query.last_id is None
|
||||
assert query.limit == 20
|
||||
assert query.sort_by == "-updated_at"
|
||||
|
||||
def test_query_with_last_id(self):
|
||||
"""Test query with pagination last_id."""
|
||||
last_id = str(uuid.uuid4())
|
||||
query = ConversationListQuery(last_id=last_id)
|
||||
assert str(query.last_id) == last_id
|
||||
|
||||
def test_query_limit_boundaries(self):
|
||||
"""Test query respects limit boundaries."""
|
||||
query_min = ConversationListQuery(limit=1)
|
||||
assert query_min.limit == 1
|
||||
|
||||
query_max = ConversationListQuery(limit=100)
|
||||
assert query_max.limit == 100
|
||||
|
||||
def test_query_rejects_limit_below_minimum(self):
|
||||
"""Test query rejects limit < 1."""
|
||||
with pytest.raises(ValueError):
|
||||
ConversationListQuery(limit=0)
|
||||
|
||||
def test_query_rejects_limit_above_maximum(self):
|
||||
"""Test query rejects limit > 100."""
|
||||
with pytest.raises(ValueError):
|
||||
ConversationListQuery(limit=101)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sort_by",
|
||||
[
|
||||
"created_at",
|
||||
"-created_at",
|
||||
"updated_at",
|
||||
"-updated_at",
|
||||
],
|
||||
)
|
||||
def test_query_valid_sort_options(self, sort_by):
|
||||
"""Test all valid sort_by options."""
|
||||
query = ConversationListQuery(sort_by=sort_by)
|
||||
assert query.sort_by == sort_by
|
||||
|
||||
|
||||
class TestConversationRenamePayload:
|
||||
"""Test suite for ConversationRenamePayload Pydantic model."""
|
||||
|
||||
def test_payload_with_name(self):
|
||||
"""Test payload with explicit name."""
|
||||
payload = ConversationRenamePayload(name="My New Chat", auto_generate=False)
|
||||
assert payload.name == "My New Chat"
|
||||
assert payload.auto_generate is False
|
||||
|
||||
def test_payload_with_auto_generate(self):
|
||||
"""Test payload with auto_generate enabled."""
|
||||
payload = ConversationRenamePayload(auto_generate=True)
|
||||
assert payload.auto_generate is True
|
||||
assert payload.name is None
|
||||
|
||||
def test_payload_requires_name_when_auto_generate_false(self):
|
||||
"""Test that name is required when auto_generate is False."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationRenamePayload(auto_generate=False)
|
||||
assert "name is required when auto_generate is false" in str(exc_info.value)
|
||||
|
||||
def test_payload_requires_non_empty_name_when_auto_generate_false(self):
|
||||
"""Test that empty string name is rejected."""
|
||||
with pytest.raises(ValueError):
|
||||
ConversationRenamePayload(name="", auto_generate=False)
|
||||
|
||||
def test_payload_requires_non_whitespace_name_when_auto_generate_false(self):
|
||||
"""Test that whitespace-only name is rejected."""
|
||||
with pytest.raises(ValueError):
|
||||
ConversationRenamePayload(name=" ", auto_generate=False)
|
||||
|
||||
def test_payload_name_with_special_characters(self):
|
||||
"""Test payload with name containing special characters."""
|
||||
payload = ConversationRenamePayload(name="Chat #1 - (Test) & More!", auto_generate=False)
|
||||
assert payload.name == "Chat #1 - (Test) & More!"
|
||||
|
||||
def test_payload_name_with_unicode(self):
|
||||
"""Test payload with Unicode characters in name."""
|
||||
payload = ConversationRenamePayload(name="对话 📝 Чат", auto_generate=False)
|
||||
assert payload.name == "对话 📝 Чат"
|
||||
|
||||
|
||||
class TestConversationVariablesQuery:
|
||||
"""Test suite for ConversationVariablesQuery Pydantic model."""
|
||||
|
||||
def test_query_with_defaults(self):
|
||||
"""Test query with default values."""
|
||||
query = ConversationVariablesQuery()
|
||||
assert query.last_id is None
|
||||
assert query.limit == 20
|
||||
assert query.variable_name is None
|
||||
|
||||
def test_query_with_variable_name(self):
|
||||
"""Test query with valid variable_name filter."""
|
||||
query = ConversationVariablesQuery(variable_name="user_preference")
|
||||
assert query.variable_name == "user_preference"
|
||||
|
||||
def test_query_allows_hyphen_in_variable_name(self):
|
||||
"""Test that hyphens are allowed in variable names."""
|
||||
query = ConversationVariablesQuery(variable_name="my-variable")
|
||||
assert query.variable_name == "my-variable"
|
||||
|
||||
def test_query_allows_underscore_in_variable_name(self):
|
||||
"""Test that underscores are allowed in variable names."""
|
||||
query = ConversationVariablesQuery(variable_name="my_variable")
|
||||
assert query.variable_name == "my_variable"
|
||||
|
||||
def test_query_allows_period_in_variable_name(self):
|
||||
"""Test that periods are allowed in variable names."""
|
||||
query = ConversationVariablesQuery(variable_name="config.setting")
|
||||
assert query.variable_name == "config.setting"
|
||||
|
||||
def test_query_rejects_sql_injection_single_quote(self):
|
||||
"""Test that single quotes are rejected (SQL injection prevention)."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name="'; DROP TABLE users;--")
|
||||
assert "can only contain" in str(exc_info.value)
|
||||
|
||||
def test_query_rejects_sql_injection_double_quote(self):
|
||||
"""Test that double quotes are rejected."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name='name"test')
|
||||
assert "can only contain" in str(exc_info.value)
|
||||
|
||||
def test_query_rejects_sql_injection_semicolon(self):
|
||||
"""Test that semicolons are rejected."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name="name;malicious")
|
||||
assert "can only contain" in str(exc_info.value)
|
||||
|
||||
def test_query_rejects_sql_injection_comment(self):
|
||||
"""Test that SQL comments are rejected."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name="name--comment")
|
||||
assert "invalid characters" in str(exc_info.value)
|
||||
|
||||
def test_query_rejects_special_characters(self):
|
||||
"""Test that special characters are rejected."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name="name@domain")
|
||||
assert "can only contain" in str(exc_info.value)
|
||||
|
||||
def test_query_rejects_backticks(self):
|
||||
"""Test that backticks are rejected (SQL injection prevention)."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ConversationVariablesQuery(variable_name="`table`")
|
||||
assert "can only contain" in str(exc_info.value)
|
||||
|
||||
def test_query_pagination_limits(self):
|
||||
"""Test query pagination limit boundaries."""
|
||||
query_min = ConversationVariablesQuery(limit=1)
|
||||
assert query_min.limit == 1
|
||||
|
||||
query_max = ConversationVariablesQuery(limit=100)
|
||||
assert query_max.limit == 100
|
||||
|
||||
|
||||
class TestConversationVariableUpdatePayload:
|
||||
"""Test suite for ConversationVariableUpdatePayload Pydantic model."""
|
||||
|
||||
def test_payload_with_string_value(self):
|
||||
"""Test payload with string value."""
|
||||
payload = ConversationVariableUpdatePayload(value="hello")
|
||||
assert payload.value == "hello"
|
||||
|
||||
def test_payload_with_number_value(self):
|
||||
"""Test payload with number value."""
|
||||
payload = ConversationVariableUpdatePayload(value=42)
|
||||
assert payload.value == 42
|
||||
|
||||
def test_payload_with_float_value(self):
|
||||
"""Test payload with float value."""
|
||||
payload = ConversationVariableUpdatePayload(value=3.14159)
|
||||
assert payload.value == 3.14159
|
||||
|
||||
def test_payload_with_list_value(self):
|
||||
"""Test payload with list value."""
|
||||
payload = ConversationVariableUpdatePayload(value=["a", "b", "c"])
|
||||
assert payload.value == ["a", "b", "c"]
|
||||
|
||||
def test_payload_with_dict_value(self):
|
||||
"""Test payload with dictionary value."""
|
||||
payload = ConversationVariableUpdatePayload(value={"key": "value"})
|
||||
assert payload.value == {"key": "value"}
|
||||
|
||||
def test_payload_with_none_value(self):
|
||||
"""Test payload with None value."""
|
||||
payload = ConversationVariableUpdatePayload(value=None)
|
||||
assert payload.value is None
|
||||
|
||||
def test_payload_with_boolean_value(self):
|
||||
"""Test payload with boolean value."""
|
||||
payload = ConversationVariableUpdatePayload(value=True)
|
||||
assert payload.value is True
|
||||
|
||||
def test_payload_with_nested_structure(self):
|
||||
"""Test payload with deeply nested structure."""
|
||||
nested = {"level1": {"level2": {"level3": ["a", "b", {"c": 123}]}}}
|
||||
payload = ConversationVariableUpdatePayload(value=nested)
|
||||
assert payload.value == nested
|
||||
|
||||
|
||||
class TestConversationAppModeValidation:
|
||||
"""Test app mode validation for conversation endpoints."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode",
|
||||
[
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
],
|
||||
)
|
||||
def test_chat_modes_are_valid_for_conversation_endpoints(self, mode):
|
||||
"""Test that all chat modes are valid for conversation endpoints.
|
||||
|
||||
Verifies that CHAT, AGENT_CHAT, and ADVANCED_CHAT modes pass
|
||||
validation without raising NotChatAppError.
|
||||
"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = mode
|
||||
|
||||
# Validation should pass without raising for chat modes
|
||||
app_mode = AppMode.value_of(app.mode)
|
||||
assert app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
|
||||
def test_completion_mode_is_invalid_for_conversation_endpoints(self):
|
||||
"""Test that COMPLETION mode is invalid for conversation endpoints.
|
||||
|
||||
Verifies that calling a conversation endpoint with a COMPLETION mode
|
||||
app raises NotChatAppError.
|
||||
"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.COMPLETION.value
|
||||
|
||||
app_mode = AppMode.value_of(app.mode)
|
||||
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
with pytest.raises(NotChatAppError):
|
||||
raise NotChatAppError()
|
||||
|
||||
def test_workflow_mode_is_invalid_for_conversation_endpoints(self):
|
||||
"""Test that WORKFLOW mode is invalid for conversation endpoints.
|
||||
|
||||
Verifies that calling a conversation endpoint with a WORKFLOW mode
|
||||
app raises NotChatAppError.
|
||||
"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
|
||||
app_mode = AppMode.value_of(app.mode)
|
||||
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
with pytest.raises(NotChatAppError):
|
||||
raise NotChatAppError()
|
||||
|
||||
|
||||
class TestConversationErrorTypes:
|
||||
"""Test conversation-related error types."""
|
||||
|
||||
def test_conversation_not_exists_error(self):
|
||||
"""Test ConversationNotExistsError exists and can be raised."""
|
||||
error = services.errors.conversation.ConversationNotExistsError()
|
||||
assert isinstance(error, services.errors.conversation.ConversationNotExistsError)
|
||||
|
||||
def test_conversation_completed_error(self):
|
||||
"""Test ConversationCompletedError exists."""
|
||||
error = services.errors.conversation.ConversationCompletedError()
|
||||
assert isinstance(error, services.errors.conversation.ConversationCompletedError)
|
||||
|
||||
def test_last_conversation_not_exists_error(self):
|
||||
"""Test LastConversationNotExistsError exists."""
|
||||
error = services.errors.conversation.LastConversationNotExistsError()
|
||||
assert isinstance(error, services.errors.conversation.LastConversationNotExistsError)
|
||||
|
||||
def test_conversation_variable_not_exists_error(self):
|
||||
"""Test ConversationVariableNotExistsError exists."""
|
||||
error = services.errors.conversation.ConversationVariableNotExistsError()
|
||||
assert isinstance(error, services.errors.conversation.ConversationVariableNotExistsError)
|
||||
|
||||
def test_conversation_variable_type_mismatch_error(self):
|
||||
"""Test ConversationVariableTypeMismatchError exists."""
|
||||
error = services.errors.conversation.ConversationVariableTypeMismatchError("Type mismatch")
|
||||
assert isinstance(error, services.errors.conversation.ConversationVariableTypeMismatchError)
|
||||
|
||||
|
||||
class TestConversationService:
|
||||
"""Test ConversationService integration patterns."""
|
||||
|
||||
def test_pagination_by_last_id_method_exists(self):
|
||||
"""Test that ConversationService.pagination_by_last_id exists."""
|
||||
assert hasattr(ConversationService, "pagination_by_last_id")
|
||||
assert callable(ConversationService.pagination_by_last_id)
|
||||
|
||||
def test_delete_method_exists(self):
|
||||
"""Test that ConversationService.delete exists."""
|
||||
assert hasattr(ConversationService, "delete")
|
||||
assert callable(ConversationService.delete)
|
||||
|
||||
def test_rename_method_exists(self):
|
||||
"""Test that ConversationService.rename exists."""
|
||||
assert hasattr(ConversationService, "rename")
|
||||
assert callable(ConversationService.rename)
|
||||
|
||||
def test_get_conversational_variable_method_exists(self):
|
||||
"""Test that ConversationService.get_conversational_variable exists."""
|
||||
assert hasattr(ConversationService, "get_conversational_variable")
|
||||
assert callable(ConversationService.get_conversational_variable)
|
||||
|
||||
def test_update_conversation_variable_method_exists(self):
|
||||
"""Test that ConversationService.update_conversation_variable exists."""
|
||||
assert hasattr(ConversationService, "update_conversation_variable")
|
||||
assert callable(ConversationService.update_conversation_variable)
|
||||
|
||||
@patch.object(ConversationService, "pagination_by_last_id")
|
||||
def test_pagination_returns_expected_format(self, mock_pagination):
|
||||
"""Test pagination returns expected data format."""
|
||||
mock_result = Mock()
|
||||
mock_result.data = []
|
||||
mock_result.limit = 20
|
||||
mock_result.has_more = False
|
||||
mock_pagination.return_value = mock_result
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(spec=EndUser),
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=Mock(),
|
||||
sort_by="-updated_at",
|
||||
)
|
||||
|
||||
assert hasattr(result, "data")
|
||||
assert hasattr(result, "limit")
|
||||
assert hasattr(result, "has_more")
|
||||
|
||||
@patch.object(ConversationService, "rename")
|
||||
def test_rename_returns_conversation(self, mock_rename):
|
||||
"""Test rename returns updated conversation."""
|
||||
mock_conversation = Mock()
|
||||
mock_conversation.name = "New Name"
|
||||
mock_rename.return_value = mock_conversation
|
||||
|
||||
result = ConversationService.rename(
|
||||
app_model=Mock(spec=App),
|
||||
conversation_id="conv_123",
|
||||
user=Mock(spec=EndUser),
|
||||
name="New Name",
|
||||
auto_generate=False,
|
||||
)
|
||||
|
||||
assert result.name == "New Name"
|
||||
|
||||
|
||||
class TestConversationPayloadsController:
|
||||
def test_rename_requires_name(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ConversationRenamePayload(auto_generate=False, name="")
|
||||
|
||||
def test_variables_query_invalid_name(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ConversationVariablesQuery(variable_name="bad;")
|
||||
|
||||
|
||||
class TestConversationApiController:
|
||||
def test_list_not_chat(self, app) -> None:
|
||||
api = ConversationApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations", method="GET"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _SessionStub:
|
||||
def __enter__(self):
|
||||
return SimpleNamespace()
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"pagination_by_last_id",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(LastConversationNotExistsError()),
|
||||
)
|
||||
conversation_module = sys.modules["controllers.service_api.app.conversation"]
|
||||
monkeypatch.setattr(conversation_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(conversation_module, "Session", lambda *_args, **_kwargs: _SessionStub())
|
||||
|
||||
api = ConversationApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations?last_id=00000000-0000-0000-0000-000000000001&limit=20",
|
||||
method="GET",
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestConversationDetailApiController:
|
||||
def test_delete_not_chat(self, app) -> None:
|
||||
api = ConversationDetailApi()
|
||||
handler = _unwrap(api.delete)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations/1", method="DELETE"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
def test_delete_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"delete",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
api = ConversationDetailApi()
|
||||
handler = _unwrap(api.delete)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations/1", method="DELETE"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
class TestConversationRenameApiController:
|
||||
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"rename",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
api = ConversationRenameApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations/1/name",
|
||||
method="POST",
|
||||
json={"auto_generate": True},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
class TestConversationVariablesApiController:
|
||||
def test_not_chat(self, app) -> None:
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations/1/variables", method="GET"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"get_conversational_variable",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations/1/variables?limit=20",
|
||||
method="GET",
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
class TestConversationVariableDetailApiController:
|
||||
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"update_conversation_variable",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationVariableTypeMismatchError("bad")),
|
||||
)
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations/1/variables/2",
|
||||
method="PUT",
|
||||
json={"value": "x"},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
handler(
|
||||
api,
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
c_id="00000000-0000-0000-0000-000000000001",
|
||||
variable_id="00000000-0000-0000-0000-000000000002",
|
||||
)
|
||||
|
||||
def test_update_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
ConversationService,
|
||||
"update_conversation_variable",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationVariableNotExistsError()),
|
||||
)
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/conversations/1/variables/2",
|
||||
method="PUT",
|
||||
json={"value": "x"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(
|
||||
api,
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
c_id="00000000-0000-0000-0000-000000000001",
|
||||
variable_id="00000000-0000-0000-0000-000000000002",
|
||||
)
|
||||
398
api/tests/unit_tests/controllers/service_api/app/test_file.py
Normal file
398
api/tests/unit_tests/controllers/service_api/app/test_file.py
Normal file
@ -0,0 +1,398 @@
|
||||
"""
|
||||
Unit tests for Service API File controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- File upload validation
|
||||
- Error handling for file operations
|
||||
- FileService integration
|
||||
|
||||
Focus on:
|
||||
- File validation logic (size, type, filename)
|
||||
- Error type mappings
|
||||
- Service method interfaces
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from fields.file_fields import FileResponse
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class TestFileResponse:
|
||||
"""Test suite for FileResponse Pydantic model."""
|
||||
|
||||
def test_file_response_has_required_fields(self):
|
||||
"""Test FileResponse model includes required fields."""
|
||||
# Verify the model exists and can be imported
|
||||
assert FileResponse is not None
|
||||
assert hasattr(FileResponse, "model_fields")
|
||||
|
||||
|
||||
class TestFileUploadErrors:
|
||||
"""Test file upload error types."""
|
||||
|
||||
def test_no_file_uploaded_error_can_be_raised(self):
|
||||
"""Test NoFileUploadedError can be raised."""
|
||||
error = NoFileUploadedError()
|
||||
assert error is not None
|
||||
|
||||
def test_too_many_files_error_can_be_raised(self):
|
||||
"""Test TooManyFilesError can be raised."""
|
||||
error = TooManyFilesError()
|
||||
assert error is not None
|
||||
|
||||
def test_unsupported_file_type_error_can_be_raised(self):
|
||||
"""Test UnsupportedFileTypeError can be raised."""
|
||||
error = UnsupportedFileTypeError()
|
||||
assert error is not None
|
||||
|
||||
def test_filename_not_exists_error_can_be_raised(self):
|
||||
"""Test FilenameNotExistsError can be raised."""
|
||||
error = FilenameNotExistsError()
|
||||
assert error is not None
|
||||
|
||||
def test_file_too_large_error_can_be_raised(self):
|
||||
"""Test FileTooLargeError can be raised."""
|
||||
error = FileTooLargeError("File exceeds maximum size")
|
||||
assert "File exceeds maximum size" in str(error) or error is not None
|
||||
|
||||
|
||||
class TestFileServiceErrors:
|
||||
"""Test FileService error types."""
|
||||
|
||||
def test_file_service_file_too_large_error_exists(self):
|
||||
"""Test FileTooLargeError from services exists."""
|
||||
import services.errors.file
|
||||
|
||||
error = services.errors.file.FileTooLargeError("File too large")
|
||||
assert isinstance(error, services.errors.file.FileTooLargeError)
|
||||
|
||||
def test_file_service_unsupported_file_type_error_exists(self):
|
||||
"""Test UnsupportedFileTypeError from services exists."""
|
||||
import services.errors.file
|
||||
|
||||
error = services.errors.file.UnsupportedFileTypeError()
|
||||
assert isinstance(error, services.errors.file.UnsupportedFileTypeError)
|
||||
|
||||
|
||||
class TestFileService:
|
||||
"""Test FileService interface and methods."""
|
||||
|
||||
def test_upload_file_method_exists(self):
|
||||
"""Test FileService.upload_file method exists."""
|
||||
assert hasattr(FileService, "upload_file")
|
||||
assert callable(FileService.upload_file)
|
||||
|
||||
@patch.object(FileService, "upload_file")
|
||||
def test_upload_file_returns_upload_file_object(self, mock_upload):
|
||||
"""Test upload_file returns an upload file object."""
|
||||
mock_file = Mock()
|
||||
mock_file.id = str(uuid.uuid4())
|
||||
mock_file.name = "test.pdf"
|
||||
mock_file.size = 1024
|
||||
mock_file.extension = "pdf"
|
||||
mock_file.mime_type = "application/pdf"
|
||||
mock_upload.return_value = mock_file
|
||||
|
||||
# Call the method directly without instantiation
|
||||
assert mock_file.name == "test.pdf"
|
||||
assert mock_file.extension == "pdf"
|
||||
|
||||
@patch.object(FileService, "upload_file")
|
||||
def test_upload_file_raises_file_too_large_error(self, mock_upload):
|
||||
"""Test upload_file raises FileTooLargeError."""
|
||||
import services.errors.file
|
||||
|
||||
mock_upload.side_effect = services.errors.file.FileTooLargeError("File exceeds 15MB limit")
|
||||
|
||||
# Verify error type exists
|
||||
with pytest.raises(services.errors.file.FileTooLargeError):
|
||||
mock_upload(Mock(), Mock(), "user_id")
|
||||
|
||||
@patch.object(FileService, "upload_file")
|
||||
def test_upload_file_raises_unsupported_file_type_error(self, mock_upload):
|
||||
"""Test upload_file raises UnsupportedFileTypeError."""
|
||||
import services.errors.file
|
||||
|
||||
mock_upload.side_effect = services.errors.file.UnsupportedFileTypeError()
|
||||
|
||||
# Verify error type exists
|
||||
with pytest.raises(services.errors.file.UnsupportedFileTypeError):
|
||||
mock_upload(Mock(), Mock(), "user_id")
|
||||
|
||||
|
||||
class TestFileValidation:
|
||||
"""Test file validation patterns."""
|
||||
|
||||
def test_valid_image_mimetype(self):
|
||||
"""Test common image MIME types."""
|
||||
valid_mimetypes = ["image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml"]
|
||||
for mimetype in valid_mimetypes:
|
||||
assert mimetype.startswith("image/")
|
||||
|
||||
def test_valid_document_mimetype(self):
|
||||
"""Test common document MIME types."""
|
||||
valid_mimetypes = [
|
||||
"application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"text/plain",
|
||||
"text/csv",
|
||||
]
|
||||
for mimetype in valid_mimetypes:
|
||||
assert mimetype is not None
|
||||
assert len(mimetype) > 0
|
||||
|
||||
def test_filename_has_extension(self):
|
||||
"""Test filename validation for extension presence."""
|
||||
valid_filenames = ["document.pdf", "image.png", "data.csv", "report.docx"]
|
||||
for filename in valid_filenames:
|
||||
assert "." in filename
|
||||
parts = filename.rsplit(".", 1)
|
||||
assert len(parts) == 2
|
||||
assert len(parts[1]) > 0 # Extension exists
|
||||
|
||||
def test_filename_without_extension_is_invalid(self):
|
||||
"""Test that filename without extension can be detected."""
|
||||
filename = "noextension"
|
||||
assert "." not in filename
|
||||
|
||||
|
||||
class TestFileUploadResponse:
|
||||
"""Test file upload response structure."""
|
||||
|
||||
@patch.object(FileService, "upload_file")
|
||||
def test_upload_response_structure(self, mock_upload):
|
||||
"""Test upload response has expected structure."""
|
||||
mock_file = Mock()
|
||||
mock_file.id = str(uuid.uuid4())
|
||||
mock_file.name = "test.pdf"
|
||||
mock_file.size = 2048
|
||||
mock_file.extension = "pdf"
|
||||
mock_file.mime_type = "application/pdf"
|
||||
mock_file.created_by = str(uuid.uuid4())
|
||||
mock_file.created_at = Mock()
|
||||
mock_upload.return_value = mock_file
|
||||
|
||||
# Verify expected fields exist on mock
|
||||
assert hasattr(mock_file, "id")
|
||||
assert hasattr(mock_file, "name")
|
||||
assert hasattr(mock_file, "size")
|
||||
assert hasattr(mock_file, "extension")
|
||||
assert hasattr(mock_file, "mime_type")
|
||||
assert hasattr(mock_file, "created_by")
|
||||
assert hasattr(mock_file, "created_at")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoint Tests
|
||||
#
|
||||
# ``FileApi.post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``
|
||||
# which preserves ``__wrapped__`` via ``functools.wraps``. We call the
|
||||
# unwrapped method directly to bypass the decorator.
|
||||
# =============================================================================
|
||||
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model():
|
||||
from models import App
|
||||
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
from models import EndUser
|
||||
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid.uuid4())
|
||||
return user
|
||||
|
||||
|
||||
class TestFileApiPost:
|
||||
"""Test suite for FileApi.post() endpoint.
|
||||
|
||||
``post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``
|
||||
which preserves ``__wrapped__``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.app.file.FileService")
|
||||
@patch("controllers.service_api.app.file.db")
|
||||
def test_upload_file_success(
|
||||
self,
|
||||
mock_db,
|
||||
mock_file_svc_cls,
|
||||
app,
|
||||
mock_app_model,
|
||||
mock_end_user,
|
||||
):
|
||||
"""Test successful file upload."""
|
||||
from io import BytesIO
|
||||
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
mock_upload = Mock()
|
||||
mock_upload.id = str(uuid.uuid4())
|
||||
mock_upload.name = "test.pdf"
|
||||
mock_upload.size = 1024
|
||||
mock_upload.extension = "pdf"
|
||||
mock_upload.mime_type = "application/pdf"
|
||||
mock_upload.created_by = str(mock_end_user.id)
|
||||
mock_upload.created_by_role = "end_user"
|
||||
mock_upload.created_at = 1700000000
|
||||
mock_upload.preview_url = None
|
||||
mock_upload.source_url = None
|
||||
mock_upload.original_url = None
|
||||
mock_upload.user_id = None
|
||||
mock_upload.tenant_id = None
|
||||
mock_upload.conversation_id = None
|
||||
mock_upload.file_key = None
|
||||
mock_file_svc_cls.return_value.upload_file.return_value = mock_upload
|
||||
|
||||
data = {"file": (BytesIO(b"file content"), "test.pdf", "application/pdf")}
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
response, status = _unwrap(api.post)(
|
||||
api,
|
||||
app_model=mock_app_model,
|
||||
end_user=mock_end_user,
|
||||
)
|
||||
|
||||
assert status == 201
|
||||
mock_file_svc_cls.return_value.upload_file.assert_called_once()
|
||||
|
||||
def test_upload_no_file(self, app, mock_app_model, mock_end_user):
|
||||
"""Test NoFileUploadedError when no file in request."""
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data={},
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
def test_upload_too_many_files(self, app, mock_app_model, mock_end_user):
|
||||
"""Test TooManyFilesError when multiple files uploaded."""
|
||||
from io import BytesIO
|
||||
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
data = {
|
||||
"file": (BytesIO(b"content1"), "file1.pdf", "application/pdf"),
|
||||
"extra": (BytesIO(b"content2"), "file2.pdf", "application/pdf"),
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(TooManyFilesError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
def test_upload_no_mimetype(self, app, mock_app_model, mock_end_user):
|
||||
"""Test UnsupportedFileTypeError when file has no mimetype."""
|
||||
from io import BytesIO
|
||||
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
data = {"file": (BytesIO(b"content"), "test.bin", "")}
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
@patch("controllers.service_api.app.file.FileService")
|
||||
@patch("controllers.service_api.app.file.db")
|
||||
def test_upload_file_too_large(
|
||||
self,
|
||||
mock_db,
|
||||
mock_file_svc_cls,
|
||||
app,
|
||||
mock_app_model,
|
||||
mock_end_user,
|
||||
):
|
||||
"""Test FileTooLargeError when file exceeds size limit."""
|
||||
from io import BytesIO
|
||||
|
||||
import services.errors.file
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError(
|
||||
"File exceeds 15MB limit"
|
||||
)
|
||||
|
||||
data = {"file": (BytesIO(b"big content"), "big.pdf", "application/pdf")}
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(FileTooLargeError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
@patch("controllers.service_api.app.file.FileService")
|
||||
@patch("controllers.service_api.app.file.db")
|
||||
def test_upload_unsupported_file_type(
|
||||
self,
|
||||
mock_db,
|
||||
mock_file_svc_cls,
|
||||
app,
|
||||
mock_app_model,
|
||||
mock_end_user,
|
||||
):
|
||||
"""Test UnsupportedFileTypeError from FileService."""
|
||||
from io import BytesIO
|
||||
|
||||
import services.errors.file
|
||||
from controllers.service_api.app.file import FileApi
|
||||
|
||||
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
|
||||
|
||||
data = {"file": (BytesIO(b"content"), "test.xyz", "application/octet-stream")}
|
||||
|
||||
with app.test_request_context(
|
||||
"/files/upload",
|
||||
method="POST",
|
||||
content_type="multipart/form-data",
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
541
api/tests/unit_tests/controllers/service_api/app/test_message.py
Normal file
541
api/tests/unit_tests/controllers/service_api/app/test_message.py
Normal file
@ -0,0 +1,541 @@
|
||||
"""
|
||||
Unit tests for Service API Message controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- MessageListQuery, MessageFeedbackPayload, FeedbackListQuery Pydantic models
|
||||
- App mode validation for message endpoints
|
||||
- MessageService integration
|
||||
- Error handling for message operations
|
||||
|
||||
Focus on:
|
||||
- Pydantic model validation
|
||||
- UUID normalization
|
||||
- Error type mappings
|
||||
- Service method interfaces
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.app.message import (
|
||||
AppGetFeedbacksApi,
|
||||
FeedbackListQuery,
|
||||
MessageFeedbackApi,
|
||||
MessageFeedbackPayload,
|
||||
MessageListApi,
|
||||
MessageListQuery,
|
||||
MessageSuggestedApi,
|
||||
)
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestMessageListQuery:
|
||||
"""Test suite for MessageListQuery Pydantic model."""
|
||||
|
||||
def test_query_requires_conversation_id(self):
|
||||
"""Test conversation_id is required."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
query = MessageListQuery(conversation_id=conversation_id)
|
||||
assert str(query.conversation_id) == conversation_id
|
||||
|
||||
def test_query_with_defaults(self):
|
||||
"""Test query with default values."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
query = MessageListQuery(conversation_id=conversation_id)
|
||||
assert query.first_id is None
|
||||
assert query.limit == 20
|
||||
|
||||
def test_query_with_first_id(self):
|
||||
"""Test query with first_id for pagination."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
first_id = str(uuid.uuid4())
|
||||
query = MessageListQuery(conversation_id=conversation_id, first_id=first_id)
|
||||
assert str(query.first_id) == first_id
|
||||
|
||||
def test_query_with_custom_limit(self):
|
||||
"""Test query with custom limit."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
query = MessageListQuery(conversation_id=conversation_id, limit=50)
|
||||
assert query.limit == 50
|
||||
|
||||
def test_query_limit_boundaries(self):
|
||||
"""Test query respects limit boundaries."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
query_min = MessageListQuery(conversation_id=conversation_id, limit=1)
|
||||
assert query_min.limit == 1
|
||||
|
||||
query_max = MessageListQuery(conversation_id=conversation_id, limit=100)
|
||||
assert query_max.limit == 100
|
||||
|
||||
def test_query_rejects_limit_below_minimum(self):
|
||||
"""Test query rejects limit < 1."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
with pytest.raises(ValueError):
|
||||
MessageListQuery(conversation_id=conversation_id, limit=0)
|
||||
|
||||
def test_query_rejects_limit_above_maximum(self):
|
||||
"""Test query rejects limit > 100."""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
with pytest.raises(ValueError):
|
||||
MessageListQuery(conversation_id=conversation_id, limit=101)
|
||||
|
||||
|
||||
class TestMessageFeedbackPayload:
|
||||
"""Test suite for MessageFeedbackPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_defaults(self):
|
||||
"""Test payload with default values."""
|
||||
payload = MessageFeedbackPayload()
|
||||
assert payload.rating is None
|
||||
assert payload.content is None
|
||||
|
||||
def test_payload_with_like_rating(self):
|
||||
"""Test payload with like rating."""
|
||||
payload = MessageFeedbackPayload(rating="like")
|
||||
assert payload.rating == "like"
|
||||
|
||||
def test_payload_with_dislike_rating(self):
|
||||
"""Test payload with dislike rating."""
|
||||
payload = MessageFeedbackPayload(rating="dislike")
|
||||
assert payload.rating == "dislike"
|
||||
|
||||
def test_payload_with_content_only(self):
|
||||
"""Test payload with content but no rating."""
|
||||
payload = MessageFeedbackPayload(content="This response was helpful")
|
||||
assert payload.content == "This response was helpful"
|
||||
assert payload.rating is None
|
||||
|
||||
def test_payload_with_rating_and_content(self):
|
||||
"""Test payload with both rating and content."""
|
||||
payload = MessageFeedbackPayload(rating="like", content="Great answer, very detailed!")
|
||||
assert payload.rating == "like"
|
||||
assert payload.content == "Great answer, very detailed!"
|
||||
|
||||
def test_payload_with_long_content(self):
|
||||
"""Test payload with long feedback content."""
|
||||
long_content = "A" * 1000
|
||||
payload = MessageFeedbackPayload(content=long_content)
|
||||
assert len(payload.content) == 1000
|
||||
|
||||
def test_payload_with_unicode_content(self):
|
||||
"""Test payload with unicode characters."""
|
||||
unicode_content = "很好的回答 👍 Отличный ответ"
|
||||
payload = MessageFeedbackPayload(content=unicode_content)
|
||||
assert payload.content == unicode_content
|
||||
|
||||
|
||||
class TestFeedbackListQuery:
|
||||
"""Test suite for FeedbackListQuery Pydantic model."""
|
||||
|
||||
def test_query_with_defaults(self):
|
||||
"""Test query with default values."""
|
||||
query = FeedbackListQuery()
|
||||
assert query.page == 1
|
||||
assert query.limit == 20
|
||||
|
||||
def test_query_with_custom_pagination(self):
|
||||
"""Test query with custom page and limit."""
|
||||
query = FeedbackListQuery(page=3, limit=50)
|
||||
assert query.page == 3
|
||||
assert query.limit == 50
|
||||
|
||||
def test_query_page_minimum(self):
|
||||
"""Test query page minimum validation."""
|
||||
query = FeedbackListQuery(page=1)
|
||||
assert query.page == 1
|
||||
|
||||
def test_query_rejects_page_below_minimum(self):
|
||||
"""Test query rejects page < 1."""
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackListQuery(page=0)
|
||||
|
||||
def test_query_limit_boundaries(self):
|
||||
"""Test query limit boundaries."""
|
||||
query_min = FeedbackListQuery(limit=1)
|
||||
assert query_min.limit == 1
|
||||
|
||||
query_max = FeedbackListQuery(limit=101)
|
||||
assert query_max.limit == 101 # Max is 101
|
||||
|
||||
def test_query_rejects_limit_below_minimum(self):
|
||||
"""Test query rejects limit < 1."""
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackListQuery(limit=0)
|
||||
|
||||
def test_query_rejects_limit_above_maximum(self):
|
||||
"""Test query rejects limit > 101."""
|
||||
with pytest.raises(ValueError):
|
||||
FeedbackListQuery(limit=102)
|
||||
|
||||
|
||||
class TestMessageAppModeValidation:
|
||||
"""Test app mode validation for message endpoints."""
|
||||
|
||||
def test_chat_modes_are_valid_for_message_endpoints(self):
|
||||
"""Test that all chat modes are valid."""
|
||||
valid_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
for mode in valid_modes:
|
||||
assert mode in valid_modes
|
||||
|
||||
def test_completion_mode_is_invalid_for_message_endpoints(self):
|
||||
"""Test that COMPLETION mode is invalid."""
|
||||
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
assert AppMode.COMPLETION not in chat_modes
|
||||
|
||||
def test_workflow_mode_is_invalid_for_message_endpoints(self):
|
||||
"""Test that WORKFLOW mode is invalid."""
|
||||
chat_modes = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
assert AppMode.WORKFLOW not in chat_modes
|
||||
|
||||
def test_not_chat_app_error_can_be_raised(self):
|
||||
"""Test NotChatAppError can be raised."""
|
||||
error = NotChatAppError()
|
||||
assert error is not None
|
||||
|
||||
|
||||
class TestMessageErrorTypes:
|
||||
"""Test message-related error types."""
|
||||
|
||||
def test_message_not_exists_error_can_be_raised(self):
|
||||
"""Test MessageNotExistsError can be raised."""
|
||||
error = MessageNotExistsError()
|
||||
assert isinstance(error, MessageNotExistsError)
|
||||
|
||||
def test_first_message_not_exists_error_can_be_raised(self):
|
||||
"""Test FirstMessageNotExistsError can be raised."""
|
||||
error = FirstMessageNotExistsError()
|
||||
assert isinstance(error, FirstMessageNotExistsError)
|
||||
|
||||
def test_suggested_questions_after_answer_disabled_error_can_be_raised(self):
|
||||
"""Test SuggestedQuestionsAfterAnswerDisabledError can be raised."""
|
||||
error = SuggestedQuestionsAfterAnswerDisabledError()
|
||||
assert isinstance(error, SuggestedQuestionsAfterAnswerDisabledError)
|
||||
|
||||
|
||||
class TestMessageService:
|
||||
"""Test MessageService interface and methods."""
|
||||
|
||||
def test_pagination_by_first_id_method_exists(self):
|
||||
"""Test MessageService.pagination_by_first_id exists."""
|
||||
assert hasattr(MessageService, "pagination_by_first_id")
|
||||
assert callable(MessageService.pagination_by_first_id)
|
||||
|
||||
def test_create_feedback_method_exists(self):
|
||||
"""Test MessageService.create_feedback exists."""
|
||||
assert hasattr(MessageService, "create_feedback")
|
||||
assert callable(MessageService.create_feedback)
|
||||
|
||||
def test_get_all_messages_feedbacks_method_exists(self):
|
||||
"""Test MessageService.get_all_messages_feedbacks exists."""
|
||||
assert hasattr(MessageService, "get_all_messages_feedbacks")
|
||||
assert callable(MessageService.get_all_messages_feedbacks)
|
||||
|
||||
def test_get_suggested_questions_after_answer_method_exists(self):
|
||||
"""Test MessageService.get_suggested_questions_after_answer exists."""
|
||||
assert hasattr(MessageService, "get_suggested_questions_after_answer")
|
||||
assert callable(MessageService.get_suggested_questions_after_answer)
|
||||
|
||||
@patch.object(MessageService, "pagination_by_first_id")
|
||||
def test_pagination_by_first_id_returns_pagination_result(self, mock_pagination):
|
||||
"""Test pagination_by_first_id returns expected format."""
|
||||
mock_result = Mock()
|
||||
mock_result.data = []
|
||||
mock_result.limit = 20
|
||||
mock_result.has_more = False
|
||||
mock_pagination.return_value = mock_result
|
||||
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(spec=EndUser),
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
first_id=None,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert hasattr(result, "data")
|
||||
assert hasattr(result, "limit")
|
||||
assert hasattr(result, "has_more")
|
||||
|
||||
@patch.object(MessageService, "pagination_by_first_id")
|
||||
def test_pagination_raises_conversation_not_exists_error(self, mock_pagination):
|
||||
"""Test pagination raises ConversationNotExistsError."""
|
||||
import services.errors.conversation
|
||||
|
||||
mock_pagination.side_effect = services.errors.conversation.ConversationNotExistsError()
|
||||
|
||||
with pytest.raises(services.errors.conversation.ConversationNotExistsError):
|
||||
MessageService.pagination_by_first_id(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), conversation_id="invalid_id", first_id=None, limit=20
|
||||
)
|
||||
|
||||
@patch.object(MessageService, "pagination_by_first_id")
|
||||
def test_pagination_raises_first_message_not_exists_error(self, mock_pagination):
|
||||
"""Test pagination raises FirstMessageNotExistsError."""
|
||||
mock_pagination.side_effect = FirstMessageNotExistsError()
|
||||
|
||||
with pytest.raises(FirstMessageNotExistsError):
|
||||
MessageService.pagination_by_first_id(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(spec=EndUser),
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
first_id="invalid_first_id",
|
||||
limit=20,
|
||||
)
|
||||
|
||||
@patch.object(MessageService, "create_feedback")
|
||||
def test_create_feedback_with_rating_and_content(self, mock_create_feedback):
|
||||
"""Test create_feedback with rating and content."""
|
||||
mock_create_feedback.return_value = None
|
||||
|
||||
MessageService.create_feedback(
|
||||
app_model=Mock(spec=App),
|
||||
message_id=str(uuid.uuid4()),
|
||||
user=Mock(spec=EndUser),
|
||||
rating="like",
|
||||
content="Great response!",
|
||||
)
|
||||
|
||||
mock_create_feedback.assert_called_once()
|
||||
|
||||
@patch.object(MessageService, "create_feedback")
|
||||
def test_create_feedback_raises_message_not_exists_error(self, mock_create_feedback):
|
||||
"""Test create_feedback raises MessageNotExistsError."""
|
||||
mock_create_feedback.side_effect = MessageNotExistsError()
|
||||
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.create_feedback(
|
||||
app_model=Mock(spec=App),
|
||||
message_id="invalid_message_id",
|
||||
user=Mock(spec=EndUser),
|
||||
rating="like",
|
||||
content=None,
|
||||
)
|
||||
|
||||
@patch.object(MessageService, "get_all_messages_feedbacks")
|
||||
def test_get_all_messages_feedbacks_returns_list(self, mock_get_feedbacks):
|
||||
"""Test get_all_messages_feedbacks returns list of feedbacks."""
|
||||
mock_feedbacks = [
|
||||
{"message_id": str(uuid.uuid4()), "rating": "like"},
|
||||
{"message_id": str(uuid.uuid4()), "rating": "dislike"},
|
||||
]
|
||||
mock_get_feedbacks.return_value = mock_feedbacks
|
||||
|
||||
result = MessageService.get_all_messages_feedbacks(app_model=Mock(spec=App), page=1, limit=20)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["rating"] == "like"
|
||||
|
||||
@patch.object(MessageService, "get_suggested_questions_after_answer")
|
||||
def test_get_suggested_questions_returns_questions_list(self, mock_get_questions):
|
||||
"""Test get_suggested_questions_after_answer returns list of questions."""
|
||||
mock_questions = ["What about this aspect?", "Can you elaborate on that?", "How does this relate to...?"]
|
||||
mock_get_questions.return_value = mock_questions
|
||||
|
||||
result = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id=str(uuid.uuid4()), invoke_from=Mock()
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[0], str)
|
||||
|
||||
@patch.object(MessageService, "get_suggested_questions_after_answer")
|
||||
def test_get_suggested_questions_raises_disabled_error(self, mock_get_questions):
|
||||
"""Test get_suggested_questions_after_answer raises SuggestedQuestionsAfterAnswerDisabledError."""
|
||||
mock_get_questions.side_effect = SuggestedQuestionsAfterAnswerDisabledError()
|
||||
|
||||
with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id=str(uuid.uuid4()), invoke_from=Mock()
|
||||
)
|
||||
|
||||
@patch.object(MessageService, "get_suggested_questions_after_answer")
|
||||
def test_get_suggested_questions_raises_message_not_exists_error(self, mock_get_questions):
|
||||
"""Test get_suggested_questions_after_answer raises MessageNotExistsError."""
|
||||
mock_get_questions.side_effect = MessageNotExistsError()
|
||||
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
app_model=Mock(spec=App), user=Mock(spec=EndUser), message_id="invalid_message_id", invoke_from=Mock()
|
||||
)
|
||||
|
||||
|
||||
class TestMessageListApi:
|
||||
def test_not_chat_app(self, app) -> None:
|
||||
api = MessageListApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages?conversation_id=cid", method="GET"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"pagination_by_first_id",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
api = MessageListApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/messages?conversation_id=00000000-0000-0000-0000-000000000001",
|
||||
method="GET",
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_first_message_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"pagination_by_first_id",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(FirstMessageNotExistsError()),
|
||||
)
|
||||
|
||||
api = MessageListApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/messages?conversation_id=00000000-0000-0000-0000-000000000001&first_id=00000000-0000-0000-0000-000000000002",
|
||||
method="GET",
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestMessageFeedbackApi:
|
||||
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"create_feedback",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(MessageNotExistsError()),
|
||||
)
|
||||
|
||||
api = MessageFeedbackApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace()
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
"/messages/m1/feedbacks",
|
||||
method="POST",
|
||||
json={"rating": "like", "content": "ok"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
|
||||
class TestAppGetFeedbacksApi:
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"])
|
||||
|
||||
api = AppGetFeedbacksApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/app/feedbacks?page=1&limit=20", method="GET"):
|
||||
response = handler(api, app_model=app_model)
|
||||
|
||||
assert response == {"data": ["f1"]}
|
||||
|
||||
|
||||
class TestMessageSuggestedApi:
|
||||
def test_not_chat(self, app) -> None:
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages/m1/suggested", method="GET"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(MessageNotExistsError()),
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages/m1/suggested", method="GET"):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
def test_disabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(SuggestedQuestionsAfterAnswerDisabledError()),
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages/m1/suggested", method="GET"):
|
||||
with pytest.raises(BadRequest):
|
||||
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
def test_internal_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages/m1/suggested", method="GET"):
|
||||
with pytest.raises(InternalServerError):
|
||||
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
lambda *_args, **_kwargs: ["q1"],
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/messages/m1/suggested", method="GET"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, message_id="m1")
|
||||
|
||||
assert response == {"result": "success", "data": ["q1"]}
|
||||
@ -0,0 +1,653 @@
|
||||
"""
|
||||
Unit tests for Service API Workflow controllers.
|
||||
|
||||
Tests coverage for:
|
||||
- WorkflowRunPayload and WorkflowLogQuery Pydantic models
|
||||
- Workflow execution error handling
|
||||
- App mode validation for workflow endpoints
|
||||
- Workflow stop mechanism validation
|
||||
|
||||
Focus on:
|
||||
- Pydantic model validation
|
||||
- Error type mappings
|
||||
- Service method interfaces
|
||||
"""
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.app.workflow import (
|
||||
AppQueueManager,
|
||||
DifyAPIRepositoryFactory,
|
||||
GraphEngineManager,
|
||||
WorkflowAppLogApi,
|
||||
WorkflowLogQuery,
|
||||
WorkflowRunApi,
|
||||
WorkflowRunByIdApi,
|
||||
WorkflowRunDetailApi,
|
||||
WorkflowRunPayload,
|
||||
WorkflowTaskStopApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
class TestWorkflowRunPayload:
|
||||
"""Test suite for WorkflowRunPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_required_inputs(self):
|
||||
"""Test payload with required inputs field."""
|
||||
payload = WorkflowRunPayload(inputs={"key": "value"})
|
||||
assert payload.inputs == {"key": "value"}
|
||||
assert payload.files is None
|
||||
assert payload.response_mode is None
|
||||
|
||||
def test_payload_with_all_fields(self):
|
||||
"""Test payload with all fields populated."""
|
||||
files = [{"type": "image", "url": "http://example.com/img.png"}]
|
||||
payload = WorkflowRunPayload(inputs={"param1": "value1", "param2": 123}, files=files, response_mode="streaming")
|
||||
assert payload.inputs == {"param1": "value1", "param2": 123}
|
||||
assert payload.files == files
|
||||
assert payload.response_mode == "streaming"
|
||||
|
||||
def test_payload_response_mode_blocking(self):
|
||||
"""Test payload with blocking response mode."""
|
||||
payload = WorkflowRunPayload(inputs={}, response_mode="blocking")
|
||||
assert payload.response_mode == "blocking"
|
||||
|
||||
def test_payload_with_complex_inputs(self):
|
||||
"""Test payload with nested complex inputs."""
|
||||
complex_inputs = {
|
||||
"config": {"nested": {"value": 123}},
|
||||
"items": ["item1", "item2"],
|
||||
"metadata": {"key": "value"},
|
||||
}
|
||||
payload = WorkflowRunPayload(inputs=complex_inputs)
|
||||
assert payload.inputs == complex_inputs
|
||||
|
||||
def test_payload_with_empty_inputs(self):
|
||||
"""Test payload with empty inputs dict."""
|
||||
payload = WorkflowRunPayload(inputs={})
|
||||
assert payload.inputs == {}
|
||||
|
||||
def test_payload_with_multiple_files(self):
|
||||
"""Test payload with multiple file attachments."""
|
||||
files = [
|
||||
{"type": "image", "url": "http://example.com/img1.png"},
|
||||
{"type": "document", "upload_file_id": "file_123"},
|
||||
{"type": "audio", "url": "http://example.com/audio.mp3"},
|
||||
]
|
||||
payload = WorkflowRunPayload(inputs={}, files=files)
|
||||
assert len(payload.files) == 3
|
||||
|
||||
|
||||
class TestWorkflowLogQuery:
|
||||
"""Test suite for WorkflowLogQuery Pydantic model."""
|
||||
|
||||
def test_query_with_defaults(self):
|
||||
"""Test query with default values."""
|
||||
query = WorkflowLogQuery()
|
||||
assert query.keyword is None
|
||||
assert query.status is None
|
||||
assert query.created_at__before is None
|
||||
assert query.created_at__after is None
|
||||
assert query.created_by_end_user_session_id is None
|
||||
assert query.created_by_account is None
|
||||
assert query.page == 1
|
||||
assert query.limit == 20
|
||||
|
||||
def test_query_with_all_filters(self):
|
||||
"""Test query with all filter fields populated."""
|
||||
query = WorkflowLogQuery(
|
||||
keyword="search term",
|
||||
status="succeeded",
|
||||
created_at__before="2024-01-15T10:00:00Z",
|
||||
created_at__after="2024-01-01T00:00:00Z",
|
||||
created_by_end_user_session_id="session_123",
|
||||
created_by_account="user@example.com",
|
||||
page=2,
|
||||
limit=50,
|
||||
)
|
||||
assert query.keyword == "search term"
|
||||
assert query.status == "succeeded"
|
||||
assert query.created_at__before == "2024-01-15T10:00:00Z"
|
||||
assert query.created_at__after == "2024-01-01T00:00:00Z"
|
||||
assert query.created_by_end_user_session_id == "session_123"
|
||||
assert query.created_by_account == "user@example.com"
|
||||
assert query.page == 2
|
||||
assert query.limit == 50
|
||||
|
||||
@pytest.mark.parametrize("status", ["succeeded", "failed", "stopped"])
|
||||
def test_query_valid_status_values(self, status):
|
||||
"""Test all valid status values."""
|
||||
query = WorkflowLogQuery(status=status)
|
||||
assert query.status == status
|
||||
|
||||
def test_query_pagination_limits(self):
|
||||
"""Test query pagination boundaries."""
|
||||
query_min_page = WorkflowLogQuery(page=1)
|
||||
assert query_min_page.page == 1
|
||||
|
||||
query_max_page = WorkflowLogQuery(page=99999)
|
||||
assert query_max_page.page == 99999
|
||||
|
||||
query_min_limit = WorkflowLogQuery(limit=1)
|
||||
assert query_min_limit.limit == 1
|
||||
|
||||
query_max_limit = WorkflowLogQuery(limit=100)
|
||||
assert query_max_limit.limit == 100
|
||||
|
||||
def test_query_rejects_page_below_minimum(self):
|
||||
"""Test query rejects page < 1."""
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowLogQuery(page=0)
|
||||
|
||||
def test_query_rejects_page_above_maximum(self):
|
||||
"""Test query rejects page > 99999."""
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowLogQuery(page=100000)
|
||||
|
||||
def test_query_rejects_limit_below_minimum(self):
|
||||
"""Test query rejects limit < 1."""
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowLogQuery(limit=0)
|
||||
|
||||
def test_query_rejects_limit_above_maximum(self):
|
||||
"""Test query rejects limit > 100."""
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowLogQuery(limit=101)
|
||||
|
||||
def test_query_with_keyword_search(self):
|
||||
"""Test query with keyword filter."""
|
||||
query = WorkflowLogQuery(keyword="workflow execution")
|
||||
assert query.keyword == "workflow execution"
|
||||
|
||||
def test_query_with_date_filters(self):
|
||||
"""Test query with before/after date filters."""
|
||||
query = WorkflowLogQuery(created_at__before="2024-12-31T23:59:59Z", created_at__after="2024-01-01T00:00:00Z")
|
||||
assert query.created_at__before == "2024-12-31T23:59:59Z"
|
||||
assert query.created_at__after == "2024-01-01T00:00:00Z"
|
||||
|
||||
|
||||
class TestWorkflowAppService:
|
||||
"""Test WorkflowAppService interface."""
|
||||
|
||||
def test_service_exists(self):
|
||||
"""Test WorkflowAppService class exists."""
|
||||
service = WorkflowAppService()
|
||||
assert service is not None
|
||||
|
||||
def test_get_paginate_workflow_app_logs_method_exists(self):
|
||||
"""Test get_paginate_workflow_app_logs method exists."""
|
||||
assert hasattr(WorkflowAppService, "get_paginate_workflow_app_logs")
|
||||
assert callable(WorkflowAppService.get_paginate_workflow_app_logs)
|
||||
|
||||
@patch.object(WorkflowAppService, "get_paginate_workflow_app_logs")
|
||||
def test_get_paginate_workflow_app_logs_returns_pagination(self, mock_get_logs):
|
||||
"""Test get_paginate_workflow_app_logs returns paginated result."""
|
||||
mock_pagination = Mock()
|
||||
mock_pagination.data = []
|
||||
mock_pagination.page = 1
|
||||
mock_pagination.limit = 20
|
||||
mock_pagination.total = 0
|
||||
mock_get_logs.return_value = mock_pagination
|
||||
|
||||
service = WorkflowAppService()
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=Mock(),
|
||||
app_model=Mock(spec=App),
|
||||
keyword=None,
|
||||
status=None,
|
||||
created_at_before=None,
|
||||
created_at_after=None,
|
||||
page=1,
|
||||
limit=20,
|
||||
created_by_end_user_session_id=None,
|
||||
created_by_account=None,
|
||||
)
|
||||
|
||||
assert result.page == 1
|
||||
assert result.limit == 20
|
||||
|
||||
|
||||
class TestWorkflowExecutionStatus:
|
||||
"""Test WorkflowExecutionStatus enum."""
|
||||
|
||||
def test_succeeded_status_exists(self):
|
||||
"""Test succeeded status value exists."""
|
||||
status = WorkflowExecutionStatus("succeeded")
|
||||
assert status.value == "succeeded"
|
||||
|
||||
def test_failed_status_exists(self):
|
||||
"""Test failed status value exists."""
|
||||
status = WorkflowExecutionStatus("failed")
|
||||
assert status.value == "failed"
|
||||
|
||||
def test_stopped_status_exists(self):
|
||||
"""Test stopped status value exists."""
|
||||
status = WorkflowExecutionStatus("stopped")
|
||||
assert status.value == "stopped"
|
||||
|
||||
|
||||
class TestAppGenerateServiceWorkflow:
|
||||
"""Test AppGenerateService workflow integration."""
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_accepts_workflow_args(self, mock_generate):
|
||||
"""Test generate accepts workflow-specific args."""
|
||||
mock_generate.return_value = {"result": "success"}
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(),
|
||||
args={"inputs": {"key": "value"}, "workflow_id": "workflow_123"},
|
||||
invoke_from=Mock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_raises_workflow_not_found_error(self, mock_generate):
|
||||
"""Test generate raises WorkflowNotFoundError."""
|
||||
mock_generate.side_effect = WorkflowNotFoundError("Workflow not found")
|
||||
|
||||
with pytest.raises(WorkflowNotFoundError):
|
||||
AppGenerateService.generate(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(),
|
||||
args={"workflow_id": "invalid_id"},
|
||||
invoke_from=Mock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_raises_is_draft_workflow_error(self, mock_generate):
|
||||
"""Test generate raises IsDraftWorkflowError."""
|
||||
mock_generate.side_effect = IsDraftWorkflowError("Workflow is draft")
|
||||
|
||||
with pytest.raises(IsDraftWorkflowError):
|
||||
AppGenerateService.generate(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(),
|
||||
args={"workflow_id": "draft_workflow"},
|
||||
invoke_from=Mock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
@patch.object(AppGenerateService, "generate")
|
||||
def test_generate_supports_streaming_mode(self, mock_generate):
|
||||
"""Test generate supports streaming response mode."""
|
||||
mock_stream = Mock()
|
||||
mock_generate.return_value = mock_stream
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=Mock(spec=App),
|
||||
user=Mock(),
|
||||
args={"inputs": {}, "response_mode": "streaming"},
|
||||
invoke_from=Mock(),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == mock_stream
|
||||
|
||||
|
||||
class TestWorkflowStopMechanism:
|
||||
"""Test workflow stop mechanisms."""
|
||||
|
||||
def test_app_queue_manager_has_stop_flag_method(self):
|
||||
"""Test AppQueueManager has set_stop_flag_no_user_check method."""
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
|
||||
assert hasattr(AppQueueManager, "set_stop_flag_no_user_check")
|
||||
|
||||
def test_graph_engine_manager_has_send_stop_command(self):
|
||||
"""Test GraphEngineManager has send_stop_command method."""
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
|
||||
assert hasattr(GraphEngineManager, "send_stop_command")
|
||||
|
||||
|
||||
class TestWorkflowRunRepository:
|
||||
"""Test workflow run repository interface."""
|
||||
|
||||
def test_repository_factory_can_create_workflow_run_repository(self):
|
||||
"""Test DifyAPIRepositoryFactory can create workflow run repository."""
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
assert hasattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository")
|
||||
|
||||
@patch("repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository")
|
||||
def test_workflow_run_repository_get_by_id(self, mock_factory):
|
||||
"""Test workflow run repository get_workflow_run_by_id method."""
|
||||
mock_repo = Mock()
|
||||
mock_run = Mock()
|
||||
mock_run.id = str(uuid.uuid4())
|
||||
mock_run.status = "succeeded"
|
||||
mock_repo.get_workflow_run_by_id.return_value = mock_run
|
||||
mock_factory.return_value = mock_repo
|
||||
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(Mock())
|
||||
|
||||
result = repo.get_workflow_run_by_id(tenant_id="tenant_123", app_id="app_456", run_id="run_789")
|
||||
|
||||
assert result.status == "succeeded"
|
||||
|
||||
|
||||
class TestWorkflowRunDetailApi:
|
||||
def test_not_workflow_app(self, app) -> None:
|
||||
api = WorkflowRunDetailApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
|
||||
with app.test_request_context("/workflows/run/1", method="GET"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
handler(api, app_model=app_model, workflow_run_id="run")
|
||||
|
||||
def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
run = SimpleNamespace(id="run")
|
||||
repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run)
|
||||
workflow_module = sys.modules["controllers.service_api.app.workflow"]
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: repo,
|
||||
)
|
||||
|
||||
api = WorkflowRunDetailApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
|
||||
|
||||
assert handler(api, app_model=app_model, workflow_run_id="run") == run
|
||||
|
||||
|
||||
class TestWorkflowRunApi:
|
||||
def test_not_workflow_app(self, app) -> None:
|
||||
api = WorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
def test_rate_limit(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(InvokeRateLimitError("slow")),
|
||||
)
|
||||
|
||||
api = WorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
handler(api, app_model=app_model, end_user=end_user)
|
||||
|
||||
|
||||
class TestWorkflowRunByIdApi:
|
||||
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(WorkflowNotFoundError("missing")),
|
||||
)
|
||||
|
||||
api = WorkflowRunByIdApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
|
||||
|
||||
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(IsDraftWorkflowError("draft")),
|
||||
)
|
||||
|
||||
api = WorkflowRunByIdApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(BadRequest):
|
||||
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
|
||||
|
||||
|
||||
class TestWorkflowTaskStopApi:
|
||||
def test_wrong_mode(self, app) -> None:
|
||||
api = WorkflowTaskStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/tasks/1/stop", method="POST"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
|
||||
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
stop_mock = Mock()
|
||||
send_mock = Mock()
|
||||
monkeypatch.setattr(AppQueueManager, "set_stop_flag_no_user_check", stop_mock)
|
||||
monkeypatch.setattr(GraphEngineManager, "send_stop_command", send_mock)
|
||||
|
||||
api = WorkflowTaskStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/workflows/tasks/1/stop", method="POST"):
|
||||
response = handler(api, app_model=app_model, end_user=end_user, task_id="t1")
|
||||
|
||||
assert response == {"result": "success"}
|
||||
stop_mock.assert_called_once_with("t1")
|
||||
send_mock.assert_called_once_with("t1")
|
||||
|
||||
|
||||
class TestWorkflowAppLogApi:
|
||||
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _SessionStub:
|
||||
def __enter__(self):
|
||||
return SimpleNamespace()
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
workflow_module = sys.modules["controllers.service_api.app.workflow"]
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(workflow_module, "Session", lambda *_args, **_kwargs: _SessionStub())
|
||||
monkeypatch.setattr(
|
||||
WorkflowAppService,
|
||||
"get_paginate_workflow_app_logs",
|
||||
lambda *_args, **_kwargs: {"items": [], "total": 0},
|
||||
)
|
||||
|
||||
api = WorkflowAppLogApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/workflows/logs", method="GET"):
|
||||
response = handler(api, app_model=app_model)
|
||||
|
||||
assert response == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoint Tests
|
||||
#
|
||||
# ``WorkflowRunDetailApi``, ``WorkflowTaskStopApi``, and
|
||||
# ``WorkflowAppLogApi`` use ``@validate_app_token`` which preserves
|
||||
# ``__wrapped__`` via ``functools.wraps``. We call the unwrapped method
|
||||
# directly to bypass the decorator.
|
||||
# =============================================================================
|
||||
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_app():
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
return app
|
||||
|
||||
|
||||
class TestWorkflowRunDetailApiGet:
|
||||
"""Test suite for WorkflowRunDetailApi.get() endpoint.
|
||||
|
||||
``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``)
|
||||
and ``@service_api_ns.marshal_with``. We call the unwrapped method
|
||||
directly; ``marshal_with`` is a no-op when calling directly.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.service_api.app.workflow.db")
|
||||
def test_get_workflow_run_success(
|
||||
self,
|
||||
mock_db,
|
||||
mock_repo_factory,
|
||||
app,
|
||||
mock_workflow_app,
|
||||
):
|
||||
"""Test successful workflow run detail retrieval."""
|
||||
mock_run = Mock()
|
||||
mock_run.id = "run-1"
|
||||
mock_run.status = "succeeded"
|
||||
mock_repo = Mock()
|
||||
mock_repo.get_workflow_run_by_id.return_value = mock_run
|
||||
mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
from controllers.service_api.app.workflow import WorkflowRunDetailApi
|
||||
|
||||
with app.test_request_context(
|
||||
f"/workflows/run/{mock_run.id}",
|
||||
method="GET",
|
||||
):
|
||||
api = WorkflowRunDetailApi()
|
||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
|
||||
|
||||
assert result == mock_run
|
||||
|
||||
@patch("controllers.service_api.app.workflow.db")
|
||||
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
|
||||
"""Test NotWorkflowAppError when app mode is not workflow or advanced_chat."""
|
||||
from controllers.service_api.app.workflow import WorkflowRunDetailApi
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.mode = AppMode.CHAT.value
|
||||
|
||||
with app.test_request_context("/workflows/run/run-1", method="GET"):
|
||||
api = WorkflowRunDetailApi()
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
_unwrap(api.get)(api, app_model=mock_app, workflow_run_id="run-1")
|
||||
|
||||
|
||||
class TestWorkflowTaskStopApiPost:
|
||||
"""Test suite for WorkflowTaskStopApi.post() endpoint.
|
||||
|
||||
``post`` is wrapped by ``@validate_app_token(fetch_user_arg=...)``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.app.workflow.GraphEngineManager")
|
||||
@patch("controllers.service_api.app.workflow.AppQueueManager")
|
||||
def test_stop_workflow_task_success(
|
||||
self,
|
||||
mock_queue_mgr,
|
||||
mock_graph_mgr,
|
||||
app,
|
||||
mock_workflow_app,
|
||||
):
|
||||
"""Test successful workflow task stop."""
|
||||
from controllers.service_api.app.workflow import WorkflowTaskStopApi
|
||||
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
api = WorkflowTaskStopApi()
|
||||
result = _unwrap(api.post)(
|
||||
api,
|
||||
app_model=mock_workflow_app,
|
||||
end_user=Mock(),
|
||||
task_id="task-1",
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_queue_mgr.set_stop_flag_no_user_check.assert_called_once_with("task-1")
|
||||
mock_graph_mgr.send_stop_command.assert_called_once_with("task-1")
|
||||
|
||||
def test_stop_workflow_task_wrong_app_mode(self, app):
|
||||
"""Test NotWorkflowAppError when app mode is not workflow."""
|
||||
from controllers.service_api.app.workflow import WorkflowTaskStopApi
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.mode = AppMode.COMPLETION.value
|
||||
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
api = WorkflowTaskStopApi()
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
_unwrap(api.post)(api, app_model=mock_app, end_user=Mock(), task_id="task-1")
|
||||
|
||||
|
||||
class TestWorkflowAppLogApiGet:
|
||||
"""Test suite for WorkflowAppLogApi.get() endpoint.
|
||||
|
||||
``get`` is wrapped by ``@validate_app_token`` and
|
||||
``@service_api_ns.marshal_with``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.app.workflow.WorkflowAppService")
|
||||
@patch("controllers.service_api.app.workflow.db")
|
||||
def test_get_workflow_logs_success(
|
||||
self,
|
||||
mock_db,
|
||||
mock_wf_svc_cls,
|
||||
app,
|
||||
mock_workflow_app,
|
||||
):
|
||||
"""Test successful workflow log retrieval."""
|
||||
mock_pagination = Mock()
|
||||
mock_pagination.data = []
|
||||
mock_svc_instance = Mock()
|
||||
mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination
|
||||
mock_wf_svc_cls.return_value = mock_svc_instance
|
||||
|
||||
# Mock Session context manager
|
||||
mock_session = Mock()
|
||||
mock_db.engine = Mock()
|
||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
||||
mock_session.__exit__ = Mock(return_value=False)
|
||||
|
||||
from controllers.service_api.app.workflow import WorkflowAppLogApi
|
||||
|
||||
with app.test_request_context(
|
||||
"/workflows/logs?page=1&limit=20",
|
||||
method="GET",
|
||||
):
|
||||
with patch("controllers.service_api.app.workflow.Session", return_value=mock_session):
|
||||
api = WorkflowAppLogApi()
|
||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
|
||||
|
||||
assert result == mock_pagination
|
||||
Reference in New Issue
Block a user