mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 18:27:27 +08:00
Merge branch 'main' into jzh
This commit is contained in:
@ -8,7 +8,7 @@ from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.helpers import FileInfo
|
||||
@ -37,7 +37,7 @@ from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportMode
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
InfoList,
|
||||
@ -623,7 +623,7 @@ class AppCopyApi(Resource):
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
result = import_service.import_app(
|
||||
@ -636,6 +636,13 @@ class AppCopyApi(Resource):
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
return result.model_dump(mode="json"), 400
|
||||
if result.status == ImportStatus.PENDING:
|
||||
session.rollback()
|
||||
return result.model_dump(mode="json"), 202
|
||||
session.commit()
|
||||
|
||||
# Inherit web app permission from original app
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@ -52,8 +52,9 @@ class AppImportApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# AppDslService performs internal commits for some creation paths, so use a plain
|
||||
# Session here instead of nesting it inside sessionmaker(...).begin().
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
@ -69,6 +70,10 @@ class AppImportApi(Resource):
|
||||
icon_background=args.icon_background,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
# update web app setting as private
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
@ -95,12 +100,15 @@ class AppImportConfirmApi(Resource):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
@ -117,7 +125,7 @@ class AppImportCheckDependenciesApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
@ -942,7 +942,6 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@ -970,9 +969,10 @@ class PublishedAllWorkflowApi(Resource):
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
)
|
||||
serialized_workflows = marshal(workflows, workflow_fields_copy)
|
||||
|
||||
return {
|
||||
"items": workflows,
|
||||
"items": serialized_workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
|
||||
@ -9,7 +9,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource):
|
||||
|
||||
account.set_tenant_id(workspace_id)
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
dsl_service = AppDslService(session)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource):
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
|
||||
@ -1,11 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
|
||||
CODE_LANGUAGE = "unsupported_language"
|
||||
|
||||
|
||||
def test_unsupported_with_code_template():
|
||||
with pytest.raises(CodeExecutionError) as e:
|
||||
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
|
||||
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"
|
||||
@ -1,36 +0,0 @@
|
||||
from textwrap import dedent
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
|
||||
CODE_LANGUAGE = CodeLanguage.PYTHON3
|
||||
|
||||
|
||||
def test_python3_plain():
|
||||
code = 'print("Hello World")'
|
||||
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code)
|
||||
assert result == "Hello World\n"
|
||||
|
||||
|
||||
def test_python3_json():
|
||||
code = dedent("""
|
||||
import json
|
||||
print(json.dumps({'Hello': 'World'}))
|
||||
""")
|
||||
result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code)
|
||||
assert result == '{"Hello": "World"}\n'
|
||||
|
||||
|
||||
def test_python3_with_code_template():
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"}
|
||||
)
|
||||
assert result == {"result": "HelloWorld"}
|
||||
|
||||
|
||||
def test_python3_get_runner_script():
|
||||
runner_script = Python3TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Python3TemplateTransformer._result_tag) == 2
|
||||
@ -96,6 +96,56 @@ class TestAppImportApi:
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.__enter__.return_value = fake_session
|
||||
fake_session.__exit__.return_value = None
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
fake_session.commit.assert_called_once_with()
|
||||
fake_session.rollback.assert_not_called()
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.__enter__.return_value = fake_session
|
||||
fake_session.__exit__.return_value = None
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
fake_session.rollback.assert_called_once_with()
|
||||
fake_session.commit.assert_not_called()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
class TestAppImportConfirmApi:
|
||||
@pytest.fixture
|
||||
|
||||
@ -0,0 +1,110 @@
|
||||
"""
|
||||
Testcontainers integration tests for Service API Site controller.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.service_api.app.site import AppSiteApi
|
||||
from models.account import Tenant, TenantStatus
|
||||
from models.model import App, AppMode, Site
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(flask_app_with_containers) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
def _unwrap(method):
|
||||
fn = method
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__
|
||||
return fn
|
||||
|
||||
|
||||
def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant:
|
||||
tenant = Tenant(name="service-api-site-tenant", status=status)
|
||||
db_session.add(tenant)
|
||||
db_session.commit()
|
||||
return tenant
|
||||
|
||||
|
||||
def _create_app(db_session: Session, tenant_id: str) -> App:
|
||||
app_model = App(
|
||||
tenant_id=tenant_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="service-api-site-app",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
status="normal",
|
||||
)
|
||||
db_session.add(app_model)
|
||||
db_session.commit()
|
||||
return app_model
|
||||
|
||||
|
||||
def _create_site(db_session: Session, app_id: str) -> Site:
|
||||
site = Site(
|
||||
app_id=app_id,
|
||||
title="Service API Site",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#ffffff",
|
||||
description="Service API test site",
|
||||
default_language="en-US",
|
||||
prompt_public=True,
|
||||
show_workflow_steps=True,
|
||||
customize_token_strategy="not_allow",
|
||||
use_icon_as_answer_icon=False,
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
)
|
||||
db_session.add(site)
|
||||
db_session.commit()
|
||||
return site
|
||||
|
||||
|
||||
class TestAppSiteApi:
|
||||
def test_get_site_success(self, app: Flask, db_session_with_containers: Session) -> None:
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
app_model = _create_app(db_session_with_containers, tenant.id)
|
||||
_create_site(db_session_with_containers, app_model.id)
|
||||
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
|
||||
api = AppSiteApi()
|
||||
response = _unwrap(api.get)(api, app_model=app_model)
|
||||
|
||||
assert response["title"] == "Service API Site"
|
||||
assert response["icon"] == "robot"
|
||||
assert response["description"] == "Service API test site"
|
||||
|
||||
def test_get_site_not_found(self, app: Flask, db_session_with_containers: Session) -> None:
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
app_model = _create_app(db_session_with_containers, tenant.id)
|
||||
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
|
||||
api = AppSiteApi()
|
||||
with pytest.raises(Forbidden):
|
||||
_unwrap(api.get)(api, app_model=app_model)
|
||||
|
||||
def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None:
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
app_model = _create_app(db_session_with_containers, tenant.id)
|
||||
_create_site(db_session_with_containers, app_model.id)
|
||||
|
||||
archived_tenant = db_session_with_containers.get(Tenant, tenant.id)
|
||||
assert archived_tenant is not None
|
||||
archived_tenant.status = TenantStatus.ARCHIVE
|
||||
db_session_with_containers.commit()
|
||||
|
||||
app_model = db_session_with_containers.get(App, app_model.id)
|
||||
assert app_model is not None
|
||||
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
|
||||
api = AppSiteApi()
|
||||
with pytest.raises(Forbidden):
|
||||
_unwrap(api.get)(api, app_model=app_model)
|
||||
@ -0,0 +1,139 @@
|
||||
"""Unit tests for console app import endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
|
||||
def model_dump(self, mode: str = "json"):
|
||||
return {"status": self.status, "app_id": self.app_id}
|
||||
|
||||
|
||||
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
|
||||
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
|
||||
|
||||
|
||||
def _mock_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
fake_session = MagicMock()
|
||||
fake_session.__enter__.return_value = fake_session
|
||||
fake_session.__exit__.return_value = None
|
||||
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
|
||||
return fake_session
|
||||
|
||||
|
||||
class TestAppImportApi:
|
||||
@pytest.fixture
|
||||
def api(self):
|
||||
return app_import_module.AppImportApi()
|
||||
|
||||
def test_import_post_returns_failed_status_and_rolls_back(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
session = _mock_session(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.rollback.assert_called_once_with()
|
||||
session.commit.assert_not_called()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
def test_import_post_returns_pending_status_and_commits(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
session = _mock_session(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once_with()
|
||||
session.rollback.assert_not_called()
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
session = _mock_session(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once_with()
|
||||
session.rollback.assert_not_called()
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
class TestAppImportConfirmApi:
|
||||
@pytest.fixture
|
||||
def api(self):
|
||||
return app_import_module.AppImportConfirmApi()
|
||||
|
||||
def test_import_confirm_returns_failed_status_and_rolls_back(
|
||||
self, api, app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = _mock_session(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
|
||||
session.rollback.assert_called_once_with()
|
||||
session.commit.assert_not_called()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
@ -258,6 +258,63 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
|
||||
assert exc.value.description == "invalid workflow graph"
|
||||
|
||||
|
||||
def test_get_published_workflows_marshals_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.PublishedAllWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
session_state = {"open": False}
|
||||
|
||||
class _SessionContext:
|
||||
def __enter__(self):
|
||||
session_state["open"] = True
|
||||
return object()
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
session_state["open"] = False
|
||||
return False
|
||||
|
||||
class _SessionMaker:
|
||||
def begin(self):
|
||||
return _SessionContext()
|
||||
|
||||
class _Workflow:
|
||||
@property
|
||||
def id(self):
|
||||
assert session_state["open"] is True
|
||||
return "w1"
|
||||
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker())
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
lambda: SimpleNamespace(
|
||||
get_all_published_workflow=lambda **_kwargs: ([_Workflow()], False),
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_marshal(items, fields):
|
||||
assert session_state["open"] is True
|
||||
return [{"id": item.id} for item in items]
|
||||
|
||||
monkeypatch.setattr(workflow_module, "marshal", _fake_marshal)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows",
|
||||
method="GET",
|
||||
query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"},
|
||||
):
|
||||
response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1"))
|
||||
|
||||
assert response == {
|
||||
"items": [{"id": "w1"}],
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"has_more": False,
|
||||
}
|
||||
|
||||
|
||||
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)
|
||||
|
||||
@ -102,16 +102,16 @@ class TestEnterpriseAppDSLImport:
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_import_deps(self):
|
||||
"""Patch db, sessionmaker, and AppDslService for import handler tests."""
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock())
|
||||
mock_session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx)))
|
||||
"""Patch db, Session, and AppDslService for import handler tests."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=False)
|
||||
with (
|
||||
patch("controllers.inner_api.app.dsl.db"),
|
||||
patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker),
|
||||
patch("controllers.inner_api.app.dsl.Session", return_value=mock_session),
|
||||
patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls,
|
||||
):
|
||||
self._mock_session = mock_session
|
||||
self._mock_dsl = MagicMock()
|
||||
mock_dsl_cls.return_value = self._mock_dsl
|
||||
yield
|
||||
@ -147,6 +147,8 @@ class TestEnterpriseAppDSLImport:
|
||||
assert status_code == 200
|
||||
assert body["status"] == "completed"
|
||||
mock_account.set_tenant_id.assert_called_once_with("ws-123")
|
||||
self._mock_session.commit.assert_called_once_with()
|
||||
self._mock_session.rollback.assert_not_called()
|
||||
|
||||
@pytest.mark.usefixtures("_mock_import_deps")
|
||||
@patch("controllers.inner_api.app.dsl._get_active_account")
|
||||
@ -162,6 +164,8 @@ class TestEnterpriseAppDSLImport:
|
||||
|
||||
assert status_code == 202
|
||||
assert body["status"] == "pending"
|
||||
self._mock_session.commit.assert_called_once_with()
|
||||
self._mock_session.rollback.assert_not_called()
|
||||
|
||||
@pytest.mark.usefixtures("_mock_import_deps")
|
||||
@patch("controllers.inner_api.app.dsl._get_active_account")
|
||||
@ -177,6 +181,8 @@ class TestEnterpriseAppDSLImport:
|
||||
|
||||
assert status_code == 400
|
||||
assert body["status"] == "failed"
|
||||
self._mock_session.rollback.assert_called_once_with()
|
||||
self._mock_session.commit.assert_not_called()
|
||||
|
||||
@patch("controllers.inner_api.app.dsl._get_active_account")
|
||||
def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask):
|
||||
|
||||
@ -1,270 +0,0 @@
|
||||
"""
|
||||
Unit tests for Service API Site controller
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.service_api.app.site import AppSiteApi
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from tests.unit_tests.conftest import setup_mock_tenant_account_query
|
||||
|
||||
|
||||
class TestAppSiteApi:
|
||||
"""Test suite for AppSiteApi"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model with tenant."""
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.enable_api = True
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.id = app.tenant_id
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
app.tenant = mock_tenant
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_site(self):
|
||||
"""Create a mock Site model."""
|
||||
site = Mock(spec=Site)
|
||||
site.id = str(uuid.uuid4())
|
||||
site.app_id = str(uuid.uuid4())
|
||||
site.title = "Test Site"
|
||||
site.icon = "icon-url"
|
||||
site.icon_background = "#ffffff"
|
||||
site.description = "Site description"
|
||||
site.copyright = "Copyright 2024"
|
||||
site.privacy_policy = "Privacy policy text"
|
||||
site.custom_disclaimer = "Custom disclaimer"
|
||||
site.default_language = "en-US"
|
||||
site.prompt_public = True
|
||||
site.show_workflow_steps = True
|
||||
site.use_icon_as_answer_icon = False
|
||||
site.chat_color_theme = "light"
|
||||
site.chat_color_theme_inverted = False
|
||||
site.icon_type = "image"
|
||||
site.created_at = "2024-01-01T00:00:00"
|
||||
site.updated_at = "2024-01-01T00:00:00"
|
||||
return site
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.app.site.db")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_site_success(
|
||||
self,
|
||||
mock_wraps_db,
|
||||
mock_validate_token,
|
||||
mock_current_app,
|
||||
mock_db,
|
||||
mock_user_logged_in,
|
||||
app,
|
||||
mock_app_model,
|
||||
mock_site,
|
||||
):
|
||||
"""Test successful retrieval of site configuration."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
mock_app_model.tenant = mock_tenant
|
||||
|
||||
# Mock wraps.db for authentication
|
||||
mock_wraps_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
# Mock site.db for site query
|
||||
mock_db.session.scalar.return_value = mock_site
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppSiteApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["title"] == "Test Site"
|
||||
assert response["icon"] == "icon-url"
|
||||
assert response["description"] == "Site description"
|
||||
mock_db.session.scalar.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.app.site.db")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_site_not_found(
|
||||
self,
|
||||
mock_wraps_db,
|
||||
mock_validate_token,
|
||||
mock_current_app,
|
||||
mock_db,
|
||||
mock_user_logged_in,
|
||||
app,
|
||||
mock_app_model,
|
||||
):
|
||||
"""Test that Forbidden is raised when site is not found."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
mock_app_model.tenant = mock_tenant
|
||||
|
||||
mock_wraps_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
# Mock site query to return None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppSiteApi()
|
||||
with pytest.raises(Forbidden):
|
||||
api.get()
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.app.site.db")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_site_tenant_archived(
|
||||
self,
|
||||
mock_wraps_db,
|
||||
mock_validate_token,
|
||||
mock_current_app,
|
||||
mock_db,
|
||||
mock_user_logged_in,
|
||||
app,
|
||||
mock_app_model,
|
||||
mock_site,
|
||||
):
|
||||
"""Test that Forbidden is raised when tenant is archived."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
|
||||
mock_wraps_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
# Mock site query
|
||||
mock_db.session.scalar.return_value = mock_site
|
||||
|
||||
# Set tenant status to archived AFTER authentication
|
||||
mock_app_model.tenant.status = TenantStatus.ARCHIVE
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppSiteApi()
|
||||
with pytest.raises(Forbidden):
|
||||
api.get()
|
||||
|
||||
@patch("controllers.service_api.wraps.user_logged_in")
|
||||
@patch("controllers.service_api.app.site.db")
|
||||
@patch("controllers.service_api.wraps.current_app")
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.db")
|
||||
def test_get_site_queries_by_app_id(
|
||||
self, mock_wraps_db, mock_validate_token, mock_current_app, mock_db, mock_user_logged_in, app, mock_app_model
|
||||
):
|
||||
"""Test that site is queried using the app model's id."""
|
||||
# Arrange
|
||||
mock_current_app.login_manager = Mock()
|
||||
|
||||
# Mock authentication
|
||||
mock_api_token = Mock()
|
||||
mock_api_token.app_id = mock_app_model.id
|
||||
mock_api_token.tenant_id = mock_app_model.tenant_id
|
||||
mock_validate_token.return_value = mock_api_token
|
||||
|
||||
mock_tenant = Mock()
|
||||
mock_tenant.status = TenantStatus.NORMAL
|
||||
mock_app_model.tenant = mock_tenant
|
||||
|
||||
mock_wraps_db.session.get.side_effect = [
|
||||
mock_app_model,
|
||||
mock_tenant,
|
||||
]
|
||||
|
||||
mock_account = Mock()
|
||||
mock_account.current_tenant = mock_tenant
|
||||
setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account)
|
||||
|
||||
mock_site = Mock(spec=Site)
|
||||
mock_site.id = str(uuid.uuid4())
|
||||
mock_site.app_id = mock_app_model.id
|
||||
mock_site.title = "Test Site"
|
||||
mock_site.icon = "icon-url"
|
||||
mock_site.icon_background = "#ffffff"
|
||||
mock_site.description = "Site description"
|
||||
mock_site.copyright = "Copyright 2024"
|
||||
mock_site.privacy_policy = "Privacy policy text"
|
||||
mock_site.custom_disclaimer = "Custom disclaimer"
|
||||
mock_site.default_language = "en-US"
|
||||
mock_site.prompt_public = True
|
||||
mock_site.show_workflow_steps = True
|
||||
mock_site.use_icon_as_answer_icon = False
|
||||
mock_site.chat_color_theme = "light"
|
||||
mock_site.chat_color_theme_inverted = False
|
||||
mock_site.icon_type = "image"
|
||||
mock_site.created_at = "2024-01-01T00:00:00"
|
||||
mock_site.updated_at = "2024-01-01T00:00:00"
|
||||
mock_db.session.scalar.return_value = mock_site
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}):
|
||||
api = AppSiteApi()
|
||||
api.get()
|
||||
|
||||
# Assert
|
||||
# The query was executed successfully (site returned), which validates the correct query was made
|
||||
mock_db.session.scalar.assert_called_once()
|
||||
Reference in New Issue
Block a user