Merge branch 'main' into jzh

This commit is contained in:
JzoNg
2026-04-15 15:01:33 +08:00
12 changed files with 401 additions and 337 deletions

View File

@ -8,7 +8,7 @@ from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from controllers.common.helpers import FileInfo
@ -37,7 +37,7 @@ from models.model import IconType
from services.app_dsl_service import AppDslService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportMode
from services.entities.dsl_entities import ImportMode, ImportStatus
from services.entities.knowledge_entities.knowledge_entities import (
DataSource,
InfoList,
@ -623,7 +623,7 @@ class AppCopyApi(Resource):
args = CopyAppPayload.model_validate(console_ns.payload or {})
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
result = import_service.import_app(
@ -636,6 +636,13 @@ class AppCopyApi(Resource):
icon=args.icon,
icon_background=args.icon_background,
)
if result.status == ImportStatus.FAILED:
session.rollback()
return result.model_dump(mode="json"), 400
if result.status == ImportStatus.PENDING:
session.rollback()
return result.model_dump(mode="json"), 202
session.commit()
# Inherit web app permission from original app
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,6 +1,6 @@
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console.app.wraps import get_app_model
@ -52,8 +52,9 @@ class AppImportApi(Resource):
current_user, _ = current_account_with_tenant()
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with sessionmaker(db.engine).begin() as session:
# AppDslService performs internal commits for some creation paths, so use a plain
# Session here instead of nesting it inside sessionmaker(...).begin().
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
# Import app
account = current_user
@ -69,6 +70,10 @@ class AppImportApi(Resource):
icon_background=args.icon_background,
app_id=args.app_id,
)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
@ -95,12 +100,15 @@ class AppImportConfirmApi(Resource):
# Check user role first
current_user, _ = current_account_with_tenant()
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@ -117,7 +125,7 @@ class AppImportCheckDependenciesApi(Resource):
@account_initialization_required
@edit_permission_required
def get(self, app_model: App):
with sessionmaker(db.engine).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

View File

@ -4,7 +4,7 @@ from collections.abc import Sequence
from typing import Any
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource, fields, marshal, marshal_with
from graphon.enums import NodeType
from graphon.file import File
from graphon.graph_engine.manager import GraphEngineManager
@ -942,7 +942,6 @@ class PublishedAllWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_pagination_model)
@edit_permission_required
def get(self, app_model: App):
"""
@ -970,9 +969,10 @@ class PublishedAllWorkflowApi(Resource):
user_id=user_id,
named_only=named_only,
)
serialized_workflows = marshal(workflows, workflow_fields_copy)
return {
"items": workflows,
"items": serialized_workflows,
"page": page,
"limit": limit,
"has_more": has_more,

View File

@ -9,7 +9,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource):
account.set_tenant_id(workspace_id)
with sessionmaker(db.engine).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
dsl_service = AppDslService(session)
result = dsl_service.import_app(
account=account,
@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource):
name=args.name,
description=args.description,
)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400

View File

@ -1,11 +0,0 @@
import pytest
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
CODE_LANGUAGE = "unsupported_language"
def test_unsupported_with_code_template():
with pytest.raises(CodeExecutionError) as e:
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"

View File

@ -1,36 +0,0 @@
from textwrap import dedent
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
CODE_LANGUAGE = CodeLanguage.PYTHON3
def test_python3_plain():
code = 'print("Hello World")'
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code)
assert result == "Hello World\n"
def test_python3_json():
code = dedent("""
import json
print(json.dumps({'Hello': 'World'}))
""")
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code)
assert result == '{"Hello": "World"}\n'
def test_python3_with_code_template():
result = CodeExecutor.execute_workflow_code_template(
language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"}
)
assert result == {"result": "HelloWorld"}
def test_python3_get_runner_script():
runner_script = Python3TemplateTransformer.get_runner_script()
assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Python3TemplateTransformer._result_tag) == 2

View File

@ -96,6 +96,56 @@ class TestAppImportApi:
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
fake_session = MagicMock()
fake_session.__enter__.return_value = fake_session
fake_session.__exit__.return_value = None
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
fake_session.commit.assert_called_once_with()
fake_session.rollback.assert_not_called()
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
fake_session = MagicMock()
fake_session.__enter__.return_value = fake_session
fake_session.__exit__.return_value = None
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
fake_session.rollback.assert_called_once_with()
fake_session.commit.assert_not_called()
assert status == 400
assert response["status"] == ImportStatus.FAILED
class TestAppImportConfirmApi:
@pytest.fixture

View File

@ -0,0 +1,110 @@
"""
Testcontainers integration tests for Service API Site controller.
"""
from __future__ import annotations
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.service_api.app.site import AppSiteApi
from models.account import Tenant, TenantStatus
from models.model import App, AppMode, Site
@pytest.fixture
def app(flask_app_with_containers) -> Flask:
return flask_app_with_containers
def _unwrap(method):
fn = method
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
return fn
def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant:
tenant = Tenant(name="service-api-site-tenant", status=status)
db_session.add(tenant)
db_session.commit()
return tenant
def _create_app(db_session: Session, tenant_id: str) -> App:
app_model = App(
tenant_id=tenant_id,
mode=AppMode.CHAT,
name="service-api-site-app",
enable_site=True,
enable_api=True,
status="normal",
)
db_session.add(app_model)
db_session.commit()
return app_model
def _create_site(db_session: Session, app_id: str) -> Site:
site = Site(
app_id=app_id,
title="Service API Site",
icon_type="emoji",
icon="robot",
icon_background="#ffffff",
description="Service API test site",
default_language="en-US",
prompt_public=True,
show_workflow_steps=True,
customize_token_strategy="not_allow",
use_icon_as_answer_icon=False,
chat_color_theme="light",
chat_color_theme_inverted=False,
)
db_session.add(site)
db_session.commit()
return site
class TestAppSiteApi:
def test_get_site_success(self, app: Flask, db_session_with_containers: Session) -> None:
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
_create_site(db_session_with_containers, app_model.id)
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
api = AppSiteApi()
response = _unwrap(api.get)(api, app_model=app_model)
assert response["title"] == "Service API Site"
assert response["icon"] == "robot"
assert response["description"] == "Service API test site"
def test_get_site_not_found(self, app: Flask, db_session_with_containers: Session) -> None:
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
_unwrap(api.get)(api, app_model=app_model)
def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None:
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
_create_site(db_session_with_containers, app_model.id)
archived_tenant = db_session_with_containers.get(Tenant, tenant.id)
assert archived_tenant is not None
archived_tenant.status = TenantStatus.ARCHIVE
db_session_with_containers.commit()
app_model = db_session_with_containers.get(App, app_model.id)
assert app_model is not None
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
_unwrap(api.get)(api, app_model=app_model)

View File

@ -0,0 +1,139 @@
"""Unit tests for console app import endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _Result:
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
self.status = status
self.app_id = app_id
def model_dump(self, mode: str = "json"):
return {"status": self.status, "app_id": self.app_id}
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
def _mock_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
fake_session = MagicMock()
fake_session.__enter__.return_value = fake_session
fake_session.__exit__.return_value = None
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
return fake_session
class TestAppImportApi:
@pytest.fixture
def api(self):
return app_import_module.AppImportApi()
def test_import_post_returns_failed_status_and_rolls_back(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=False)
session = _mock_session(monkeypatch)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.rollback.assert_called_once_with()
session.commit.assert_not_called()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_post_returns_pending_status_and_commits(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=False)
session = _mock_session(monkeypatch)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once_with()
session.rollback.assert_not_called()
assert status == 202
assert response["status"] == ImportStatus.PENDING
def test_import_post_updates_webapp_auth_when_enabled(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
method = _unwrap(api.post)
_install_features(monkeypatch, enabled=True)
session = _mock_session(monkeypatch)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
)
update_access = MagicMock()
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once_with()
session.rollback.assert_not_called()
update_access.assert_called_once_with("app-123", "private")
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
class TestAppImportConfirmApi:
@pytest.fixture
def api(self):
return app_import_module.AppImportConfirmApi()
def test_import_confirm_returns_failed_status_and_rolls_back(
self, api, app, monkeypatch: pytest.MonkeyPatch
) -> None:
method = _unwrap(api.post)
session = _mock_session(monkeypatch)
monkeypatch.setattr(
app_import_module.AppDslService,
"confirm_import",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(import_id="import-1")
session.rollback.assert_called_once_with()
session.commit.assert_not_called()
assert status == 400
assert response["status"] == ImportStatus.FAILED

View File

@ -258,6 +258,63 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
assert exc.value.description == "invalid workflow graph"
def test_get_published_workflows_marshals_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = workflow_module.PublishedAllWorkflowApi()
handler = _unwrap(api.get)
session_state = {"open": False}
class _SessionContext:
def __enter__(self):
session_state["open"] = True
return object()
def __exit__(self, exc_type, exc, tb):
session_state["open"] = False
return False
class _SessionMaker:
def begin(self):
return _SessionContext()
class _Workflow:
@property
def id(self):
assert session_state["open"] is True
return "w1"
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker())
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
lambda: SimpleNamespace(
get_all_published_workflow=lambda **_kwargs: ([_Workflow()], False),
),
)
def _fake_marshal(items, fields):
assert session_state["open"] is True
return [{"id": item.id} for item in items]
monkeypatch.setattr(workflow_module, "marshal", _fake_marshal)
with app.test_request_context(
"/apps/app/workflows",
method="GET",
query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"},
):
response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1"))
assert response == {
"items": [{"id": "w1"}],
"page": 1,
"limit": 10,
"has_more": False,
}
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)

View File

@ -102,16 +102,16 @@ class TestEnterpriseAppDSLImport:
@pytest.fixture
def _mock_import_deps(self):
"""Patch db, sessionmaker, and AppDslService for import handler tests."""
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock())
mock_session_ctx.__exit__ = MagicMock(return_value=False)
mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx)))
"""Patch db, Session, and AppDslService for import handler tests."""
mock_session = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=False)
with (
patch("controllers.inner_api.app.dsl.db"),
patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker),
patch("controllers.inner_api.app.dsl.Session", return_value=mock_session),
patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls,
):
self._mock_session = mock_session
self._mock_dsl = MagicMock()
mock_dsl_cls.return_value = self._mock_dsl
yield
@ -147,6 +147,8 @@ class TestEnterpriseAppDSLImport:
assert status_code == 200
assert body["status"] == "completed"
mock_account.set_tenant_id.assert_called_once_with("ws-123")
self._mock_session.commit.assert_called_once_with()
self._mock_session.rollback.assert_not_called()
@pytest.mark.usefixtures("_mock_import_deps")
@patch("controllers.inner_api.app.dsl._get_active_account")
@ -162,6 +164,8 @@ class TestEnterpriseAppDSLImport:
assert status_code == 202
assert body["status"] == "pending"
self._mock_session.commit.assert_called_once_with()
self._mock_session.rollback.assert_not_called()
@pytest.mark.usefixtures("_mock_import_deps")
@patch("controllers.inner_api.app.dsl._get_active_account")
@ -177,6 +181,8 @@ class TestEnterpriseAppDSLImport:
assert status_code == 400
assert body["status"] == "failed"
self._mock_session.rollback.assert_called_once_with()
self._mock_session.commit.assert_not_called()
@patch("controllers.inner_api.app.dsl._get_active_account")
def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask):

View File

@ -1,270 +0,0 @@
"""
Unit tests for Service API Site controller
"""
import uuid
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.service_api.app.site import AppSiteApi
from models.account import TenantStatus
from models.model import App, Site
from tests.unit_tests.conftest import setup_mock_tenant_account_query
class TestAppSiteApi:
"""Test suite for AppSiteApi"""
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model with tenant."""
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.enable_api = True
mock_tenant = Mock()
mock_tenant.id = app.tenant_id
mock_tenant.status = TenantStatus.NORMAL
app.tenant = mock_tenant
return app
@pytest.fixture
def mock_site(self):
"""Create a mock Site model."""
site = Mock(spec=Site)
site.id = str(uuid.uuid4())
site.app_id = str(uuid.uuid4())
site.title = "Test Site"
site.icon = "icon-url"
site.icon_background = "#ffffff"
site.description = "Site description"
site.copyright = "Copyright 2024"
site.privacy_policy = "Privacy policy text"
site.custom_disclaimer = "Custom disclaimer"
site.default_language = "en-US"
site.prompt_public = True
site.show_workflow_steps = True
site.use_icon_as_answer_icon = False
site.chat_color_theme = "light"
site.chat_color_theme_inverted = False
site.icon_type = "image"
site.created_at = "2024-01-01T00:00:00"
site.updated_at = "2024-01-01T00:00:00"
return site
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_success(
self,
mock_wraps_db,
mock_validate_token,
mock_current_app,
mock_db,
mock_user_logged_in,
app,
mock_app_model,
mock_site,
):
"""Test successful retrieval of site configuration."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
# Mock wraps.db for authentication
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site.db for site query
mock_db.session.scalar.return_value = mock_site
# Act
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
response = api.get()
# Assert
assert response["title"] == "Test Site"
assert response["icon"] == "icon-url"
assert response["description"] == "Site description"
mock_db.session.scalar.assert_called_once()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_not_found(
self,
mock_wraps_db,
mock_validate_token,
mock_current_app,
mock_db,
mock_user_logged_in,
app,
mock_app_model,
):
"""Test that Forbidden is raised when site is not found."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site query to return None
mock_db.session.scalar.return_value = None
# Act & Assert
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
api.get()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_tenant_archived(
self,
mock_wraps_db,
mock_validate_token,
mock_current_app,
mock_db,
mock_user_logged_in,
app,
mock_app_model,
mock_site,
):
"""Test that Forbidden is raised when tenant is archived."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
# Mock site query
mock_db.session.scalar.return_value = mock_site
# Set tenant status to archived AFTER authentication
mock_app_model.tenant.status = TenantStatus.ARCHIVE
# Act & Assert
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
api.get()
@patch("controllers.service_api.wraps.user_logged_in")
@patch("controllers.service_api.app.site.db")
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_site_queries_by_app_id(
self, mock_wraps_db, mock_validate_token, mock_current_app, mock_db, mock_user_logged_in, app, mock_app_model
):
"""Test that site is queried using the app model's id."""
# Arrange
mock_current_app.login_manager = Mock()
# Mock authentication
mock_api_token = Mock()
mock_api_token.app_id = mock_app_model.id
mock_api_token.tenant_id = mock_app_model.tenant_id
mock_validate_token.return_value = mock_api_token
mock_tenant = Mock()
mock_tenant.status = TenantStatus.NORMAL
mock_app_model.tenant = mock_tenant
mock_wraps_db.session.get.side_effect = [
mock_app_model,
mock_tenant,
]
mock_account = Mock()
mock_account.current_tenant = mock_tenant
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
mock_site = Mock(spec=Site)
mock_site.id = str(uuid.uuid4())
mock_site.app_id = mock_app_model.id
mock_site.title = "Test Site"
mock_site.icon = "icon-url"
mock_site.icon_background = "#ffffff"
mock_site.description = "Site description"
mock_site.copyright = "Copyright 2024"
mock_site.privacy_policy = "Privacy policy text"
mock_site.custom_disclaimer = "Custom disclaimer"
mock_site.default_language = "en-US"
mock_site.prompt_public = True
mock_site.show_workflow_steps = True
mock_site.use_icon_as_answer_icon = False
mock_site.chat_color_theme = "light"
mock_site.chat_color_theme_inverted = False
mock_site.icon_type = "image"
mock_site.created_at = "2024-01-01T00:00:00"
mock_site.updated_at = "2024-01-01T00:00:00"
mock_db.session.scalar.return_value = mock_site
# Act
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
api = AppSiteApi()
api.get()
# Assert
# The query was executed successfully (site returned), which validates the correct query was made
mock_db.session.scalar.assert_called_once()