Merge branch 'main' into 4-27-app-deploy

This commit is contained in:
Stephen Zhou
2026-05-27 10:25:12 +08:00
91 changed files with 1784 additions and 475 deletions

View File

@ -0,0 +1,85 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import PropertyMock, patch
from controllers.console import console_ns
from controllers.console.auth.data_source_bearer_auth import (
ApiKeyAuthDataSource,
ApiKeyAuthDataSourceBinding,
ApiKeyAuthDataSourceBindingDelete,
)
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _payload_patch(payload: dict):
return patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
)
def test_list_data_source_auth_uses_injected_tenant_id() -> None:
api = ApiKeyAuthDataSource()
method = _unwrap(api.get)
binding = SimpleNamespace(
id="binding-1",
category="api_key",
provider="custom",
disabled=False,
created_at=datetime(2026, 1, 1, tzinfo=UTC),
updated_at=datetime(2026, 1, 2, tzinfo=UTC),
)
with patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list",
return_value=[binding],
) as get_provider_auth_list:
result = method(api, "tenant-1")
get_provider_auth_list.assert_called_once_with("tenant-1")
assert result["sources"][0]["id"] == "binding-1"
assert result["sources"][0]["provider"] == "custom"
def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None:
api = ApiKeyAuthDataSourceBinding()
method = _unwrap(api.post)
payload = {
"category": "api_key",
"provider": "custom",
"credentials": {"auth_type": "api_key", "config": {"api_key": "secret"}},
}
with (
_payload_patch(payload),
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth,
):
result, status = method(api, "tenant-1")
create_auth.assert_called_once_with("tenant-1", payload)
assert result == {"result": "success"}
assert status == 200
def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None:
api = ApiKeyAuthDataSourceBindingDelete()
method = _unwrap(api.delete)
with patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth"
) as delete_provider_auth:
result, status = method(api, "tenant-1", "binding-1")
delete_provider_auth.assert_called_once_with("tenant-1", "binding-1")
assert result == ""
assert status == 204

View File

@ -0,0 +1,52 @@
from __future__ import annotations
from unittest.mock import patch
from controllers.console.auth.oauth_server import OAuthServerUserAuthorizeApi
from models import Account
from models.account import AccountStatus, TenantAccountRole
from models.model import OAuthProviderApp
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _make_account() -> Account:
account = Account(
name="Test User",
email="test@example.com",
status=AccountStatus.ACTIVE,
)
account.id = "account-1"
account.role = TenantAccountRole.OWNER
return account
def _make_oauth_provider_app() -> OAuthProviderApp:
return OAuthProviderApp(
app_icon="icon",
client_id="client-1",
client_secret="secret",
app_label={"en-US": "Test App"},
redirect_uris=["https://example.com/callback"],
scope="read",
)
def test_oauth_authorize_uses_injected_current_user() -> None:
api = OAuthServerUserAuthorizeApi()
method = _unwrap(api.post)
account = _make_account()
oauth_provider_app = _make_oauth_provider_app()
with patch(
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code",
return_value="authorization-code",
) as sign_oauth_authorization_code:
response = method(api, oauth_provider_app, account)
sign_oauth_authorization_code.assert_called_once_with("client-1", "account-1")
assert response == {"code": "authorization-code"}

View File

@ -70,13 +70,14 @@ class TestExternalApiTemplateListApi:
ExternalDatasetService,
"get_external_knowledge_apis",
return_value=([api_item], 1),
),
) as get_external_knowledge_apis,
):
resp, status = method(api, "id")
resp, status = method(api, "tenant-1")
assert status == 200
assert resp["total"] == 1
assert resp["data"][0]["id"] == "1"
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
def test_post_forbidden(self, app: Flask, current_user):
current_user.is_dataset_editor = False
@ -321,13 +322,14 @@ class TestExternalApiTemplateListApiAdvanced:
patch(
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
return_value=(templates, 25),
),
) as get_external_knowledge_apis,
):
resp, status = method(api, "id")
resp, status = method(api, "tenant-1")
assert status == 200
assert resp["total"] == 25
assert len(resp["data"]) == 3
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
class TestExternalDatasetCreateApiAdvanced:

View File

@ -13,6 +13,8 @@ from controllers.console.tag.tags import (
TagListApi,
TagUpdateDeleteApi,
)
from models import Account
from models.account import AccountStatus, TenantAccountRole
from models.enums import TagType
from services.tag_service import UpdateTagPayload
@ -35,20 +37,26 @@ def app():
@pytest.fixture
def admin_user():
return MagicMock(
id="user-1",
has_edit_permission=True,
is_dataset_editor=True,
account = Account(
name="Admin User",
email="admin@example.com",
status=AccountStatus.ACTIVE,
)
account.id = "user-1"
account.role = TenantAccountRole.OWNER
return account
@pytest.fixture
def readonly_user():
return MagicMock(
id="user-2",
has_edit_permission=False,
is_dataset_editor=False,
account = Account(
name="Readonly User",
email="readonly@example.com",
status=AccountStatus.ACTIVE,
)
account.id = "user-2"
account.role = TenantAccountRole.NORMAL
return account
@pytest.fixture
@ -80,10 +88,6 @@ class TestTagListApi:
with app.test_request_context("/?type=knowledge"):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.tag.tags.TagService.get_tags",
return_value=[
@ -96,7 +100,7 @@ class TestTagListApi:
],
),
):
result, status = method(api)
result, status = method(api, "tenant-1")
assert status == 200
assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}]
@ -109,17 +113,13 @@ class TestTagListApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch(
"controllers.console.tag.tags.TagService.save_tags",
return_value=tag,
),
):
result, status = method(api)
result, status = method(api, admin_user)
assert status == 200
assert result["name"] == "test-tag"
@ -133,14 +133,10 @@ class TestTagListApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch(payload),
):
with pytest.raises(Forbidden):
method(api)
method(api, readonly_user)
class TestTagUpdateDeleteApi:
@ -152,10 +148,6 @@ class TestTagUpdateDeleteApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch(
"controllers.console.tag.tags.TagService.update_tags",
@ -166,7 +158,7 @@ class TestTagUpdateDeleteApi:
return_value=3,
),
):
result, status = method(api, "tag-1")
result, status = method(api, admin_user, "tag-1")
assert status == 200
update_payload, tag_id = update_tags_mock.call_args.args
@ -182,14 +174,10 @@ class TestTagUpdateDeleteApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch(payload),
):
with pytest.raises(Forbidden):
method(api, "tag-1")
method(api, readonly_user, "tag-1")
def test_delete_success(self, app: Flask, admin_user):
api = TagUpdateDeleteApi()
@ -197,10 +185,6 @@ class TestTagUpdateDeleteApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, "tenant-1"),
),
patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock,
):
result, status = method(api, "tag-1")
@ -222,14 +206,10 @@ class TestTagBindingCollectionApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
):
result, status = method(api)
result, status = method(api, admin_user)
save_mock.assert_called_once()
assert status == 200
@ -241,14 +221,10 @@ class TestTagBindingCollectionApi:
with app.test_request_context("/", json={}):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch({}),
):
with pytest.raises(Forbidden):
method(api)
method(api, readonly_user)
class TestTagBindingRemoveApi:
@ -264,14 +240,10 @@ class TestTagBindingRemoveApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
):
result, status = method(api)
result, status = method(api, admin_user)
delete_mock.assert_called_once()
delete_payload = delete_mock.call_args.args[0]
@ -285,14 +257,10 @@ class TestTagBindingRemoveApi:
with app.test_request_context("/", json={}):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch({}),
):
with pytest.raises(Forbidden):
method(api)
method(api, readonly_user)
class TestTagResponseModel:

View File

@ -0,0 +1,141 @@
from __future__ import annotations
import inspect
from collections.abc import Callable
from types import SimpleNamespace
from typing import cast
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console.apikey import BaseApiKeyListResource, BaseApiKeyResource
from models import Account
from models.account import AccountStatus, TenantAccountRole
from models.enums import ApiTokenType
from models.model import ApiToken, App
def _make_list_resource() -> BaseApiKeyListResource:
resource = BaseApiKeyListResource()
resource.resource_type = ApiTokenType.APP
resource.resource_model = App
resource.resource_id_field = "app_id"
resource.token_prefix = "app-"
return resource
def _make_key_resource() -> BaseApiKeyResource:
resource = BaseApiKeyResource()
resource.resource_type = ApiTokenType.APP
resource.resource_model = App
resource.resource_id_field = "app_id"
return resource
def _make_account(role: TenantAccountRole) -> Account:
account = Account(
name="Test User",
email=f"{role.value}@example.com",
status=AccountStatus.ACTIVE,
)
account.id = f"{role.value}-user"
account.role = role
return account
def test_list_api_keys_uses_injected_tenant_id() -> None:
resource = _make_list_resource()
api_key = SimpleNamespace(
id="key-1",
type=ApiTokenType.APP,
token="app-token",
last_used_at=None,
created_at=None,
)
with (
patch("controllers.console.apikey._get_resource") as get_resource,
patch("controllers.console.apikey.db") as db_mock,
):
db_mock.session.scalars.return_value.all.return_value = [api_key]
result = resource.get("app-1", "tenant-1")
get_resource.assert_called_once_with("app-1", "tenant-1", App)
assert result == {
"data": [
{
"id": "key-1",
"type": "app",
"token": "app-token",
"last_used_at": None,
"created_at": None,
}
]
}
def test_create_api_key_uses_injected_tenant_id() -> None:
resource = _make_list_resource()
raw_post = cast(
Callable[[BaseApiKeyListResource, str, str], tuple[dict[str, object], int]],
inspect.unwrap(BaseApiKeyListResource.post),
)
def add_api_token(api_token: ApiToken) -> None:
api_token.id = "key-1"
with (
patch("controllers.console.apikey._get_resource") as get_resource,
patch("controllers.console.apikey.db") as db_mock,
patch("controllers.console.apikey.ApiToken.generate_api_key", return_value="app-generated-token"),
):
db_mock.session.scalar.return_value = 0
db_mock.session.add.side_effect = add_api_token
result, status = raw_post(resource, "app-1", "tenant-1")
get_resource.assert_called_once_with("app-1", "tenant-1", App)
assert status == 201
assert result["token"] == "app-generated-token"
api_token = db_mock.session.add.call_args.args[0]
assert api_token.app_id == "app-1"
assert api_token.tenant_id == "tenant-1"
assert api_token.type == ApiTokenType.APP
db_mock.session.commit.assert_called_once()
def test_delete_api_key_rejects_non_admin_account() -> None:
resource = _make_key_resource()
with (
patch("controllers.console.apikey._get_resource") as get_resource,
patch("controllers.console.apikey.db") as db_mock,
):
with pytest.raises(Forbidden):
resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.NORMAL))
get_resource.assert_called_once_with("app-1", "tenant-1", App)
db_mock.session.scalar.assert_not_called()
def test_delete_api_key_uses_injected_user_and_tenant() -> None:
resource = _make_key_resource()
api_key = SimpleNamespace(token="app-token", type=ApiTokenType.APP)
with (
patch("controllers.console.apikey._get_resource") as get_resource,
patch("controllers.console.apikey.db") as db_mock,
patch("controllers.console.apikey.ApiTokenCache.delete") as delete_cache,
):
db_mock.session.scalar.return_value = api_key
result, status = resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.OWNER))
get_resource.assert_called_once_with("app-1", "tenant-1", App)
delete_cache.assert_called_once_with("app-token", ApiTokenType.APP)
db_mock.session.execute.assert_called_once()
db_mock.session.commit.assert_called_once()
assert result == ""
assert status == 204

View File

@ -19,6 +19,8 @@ from controllers.console.files import (
FilePreviewApi,
FileSupportTypeApi,
)
from models import Account
from models.account import AccountStatus, TenantAccountRole
def unwrap(func):
@ -53,14 +55,15 @@ def mock_decorators():
@pytest.fixture
def mock_current_user():
user = MagicMock()
user.is_dataset_editor = True
user = Account(name="Test User", email="user-1@example.com", status=AccountStatus.ACTIVE)
user.id = "user-1"
user.role = TenantAccountRole.OWNER
return user
@pytest.fixture
def mock_current_tenant_id():
return "tenant-123"
def mock_account_context(mock_current_user):
return mock_current_user
@pytest.fixture
@ -91,15 +94,15 @@ class TestFileApiGet:
class TestFileApiPost:
def test_no_file_uploaded(self, app: Flask, mock_current_user):
def test_no_file_uploaded(self, app: Flask, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
with app.test_request_context(method="POST", data={}):
with pytest.raises(NoFileUploadedError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
def test_too_many_files(self, app: Flask, mock_current_user):
def test_too_many_files(self, app: Flask, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
@ -114,9 +117,9 @@ class TestFileApiPost:
mock_request.form.get.return_value = None
with pytest.raises(TooManyFilesError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
def test_filename_missing(self, app: Flask, mock_current_user):
def test_filename_missing(self, app: Flask, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
@ -126,10 +129,10 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(FilenameNotExistsError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
def test_dataset_upload_without_permission(self, app: Flask, mock_current_user):
mock_current_user.is_dataset_editor = False
mock_current_user.role = TenantAccountRole.NORMAL
api = FileApi()
post_method = unwrap(api.post)
@ -143,7 +146,7 @@ class TestFileApiPost:
with pytest.raises(Forbidden):
post_method(api, mock_current_user)
def test_successful_upload(self, app: Flask, mock_current_user, mock_file_service):
def test_successful_upload(self, app: Flask, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -171,13 +174,13 @@ class TestFileApiPost:
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api, mock_current_user)
response, status = post_method(api, mock_account_context)
assert status == 201
assert response["id"] == "file-id-123"
assert response["name"] == "test.txt"
def test_upload_with_invalid_source(self, app: Flask, mock_current_user, mock_file_service):
def test_upload_with_invalid_source(self, app: Flask, mock_account_context, mock_file_service):
"""Test that invalid source parameter gets normalized to None"""
api = FileApi()
post_method = unwrap(api.post)
@ -208,7 +211,7 @@ class TestFileApiPost:
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api, mock_current_user)
response, status = post_method(api, mock_account_context)
assert status == 201
assert response["id"] == "file-id-456"
@ -217,7 +220,7 @@ class TestFileApiPost:
call_kwargs = mock_file_service.upload_file.call_args[1]
assert call_kwargs["source"] is None
def test_file_too_large_error(self, app: Flask, mock_current_user, mock_file_service):
def test_file_too_large_error(self, app: Flask, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -232,9 +235,9 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(FileTooLargeError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
def test_unsupported_file_type(self, app: Flask, mock_current_user, mock_file_service):
def test_unsupported_file_type(self, app: Flask, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -249,9 +252,9 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(UnsupportedFileTypeError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
def test_blocked_extension(self, app: Flask, mock_current_user, mock_file_service):
def test_blocked_extension(self, app: Flask, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -266,17 +269,17 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(BlockedFileExtensionError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
class TestFilePreviewApi:
def test_get_preview(self, app: Flask, mock_current_tenant_id, mock_file_service):
def test_get_preview(self, app: Flask, mock_account_context, mock_file_service):
api = FilePreviewApi()
get_method = unwrap(api.get)
mock_file_service.get_file_preview.return_value = "preview text"
with app.test_request_context():
result = get_method(api, mock_current_tenant_id, "1234")
result = get_method(api, "tenant-123", "1234")
assert result == {"content": "preview text"}

View File

@ -10,6 +10,8 @@ import pytest
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
from controllers.console import remote_files as remote_files_module
from models import Account
from models.account import AccountStatus, TenantAccountRole
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
@ -20,6 +22,17 @@ def _unwrap(func):
return func
def _make_account(account_id: str = "u1") -> Account:
account = Account(
name="Test User",
email=f"{account_id}@example.com",
status=AccountStatus.ACTIVE,
)
account.id = account_id
account.role = TenantAccountRole.OWNER
return account
class _FakeResponse:
def __init__(
self,
@ -48,7 +61,6 @@ def _mock_upload_dependencies(
*,
file_size_within_limit: bool = True,
):
current_user = SimpleNamespace(id="u1")
file_info = SimpleNamespace(
filename="report.txt",
extension=".txt",
@ -64,6 +76,7 @@ def _mock_upload_dependencies(
file_service_cls = MagicMock()
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
current_user = _make_account()
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
remote_files_module.file_helpers,
@ -226,7 +239,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
handler(api, SimpleNamespace(id="u1"))
handler(api, _make_account())
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -243,7 +256,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
handler(api, SimpleNamespace(id="u1"))
handler(api, _make_account())
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_login import LoginManager, UserMixin
from werkzeug.exceptions import HTTPException
from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
from controllers.console.workspace.error import AccountNotInitializedError
@ -17,8 +18,11 @@ from controllers.console.wraps import (
only_edition_enterprise,
only_edition_self_hosted,
setup_required,
with_current_tenant_id,
with_current_user,
)
from models.account import AccountStatus
from models import Account
from models.account import AccountStatus, TenantAccountRole
from services.feature_service import LicenseStatus
@ -33,6 +37,17 @@ class MockUser(UserMixin):
return self.id
def make_account(account_id: str = "account-1") -> Account:
account = Account(
name="Test Account",
email=f"{account_id}@example.com",
status=AccountStatus.ACTIVE,
)
account.id = account_id
account.role = TenantAccountRole.OWNER
return account
def create_app_with_login():
"""Create a Flask app with LoginManager configured."""
app = Flask(__name__)
@ -84,6 +99,42 @@ class TestAccountInitialization:
protected_view()
class TestCurrentContextInjection:
"""Test request context injection decorators."""
def test_with_current_tenant_id_injects_tenant_id(self):
class Handler:
@with_current_tenant_id
def get(self, current_tenant_id: str):
return current_tenant_id
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-123")):
assert Handler().get() == "tenant-123"
def test_with_current_user_injects_account(self):
current_user = make_account()
class Handler:
@with_current_user
def get(self, injected_user):
return injected_user
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")):
assert Handler().get() is current_user
def test_stacked_current_context_injectors_preserve_argument_order(self):
current_user = make_account()
class Handler:
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, injected_user):
return current_tenant_id, injected_user
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")):
assert Handler().get() == ("tenant-123", current_user)
class TestEditionChecks:
"""Test edition-specific decorators"""
@ -114,7 +165,7 @@ class TestEditionChecks:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
with pytest.raises(Exception) as exc_info:
with pytest.raises(HTTPException) as exc_info:
cloud_view()
assert exc_info.value.code == 404
@ -177,7 +228,7 @@ class TestBillingEnabled:
with app.test_request_context():
with patch("controllers.console.wraps.dify_config.BILLING_ENABLED", False):
with patch("controllers.console.wraps.FeatureService.get_features") as get_features:
with pytest.raises(Exception) as exc_info:
with pytest.raises(HTTPException) as exc_info:
billing_view()
assert exc_info.value.code == 403
@ -204,11 +255,43 @@ class TestBillingResourceLimits:
with patch(
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with patch(
"controllers.console.wraps.FeatureService.get_features", return_value=mock_features
) as get_features:
result = add_member()
# Assert
assert result == "member_added"
get_features.assert_called_once_with("tenant123", exclude_vector_space=True)
def test_should_load_vector_space_from_dedicated_quota_api(self):
"""Test vector-space limit checks avoid loading the full feature payload."""
# Arrange
mock_vector_space = MagicMock()
mock_vector_space.limit = 10
mock_vector_space.size = 5
@cloud_edition_billing_resource_check("vector_space")
def add_segment():
return "segment_added"
# Act
with patch(
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
):
with (
patch("controllers.console.wraps.dify_config.BILLING_ENABLED", True),
patch(
"controllers.console.wraps.FeatureService.get_vector_space", return_value=mock_vector_space
) as get_vector_space,
patch("controllers.console.wraps.FeatureService.get_features") as get_features,
):
result = add_segment()
# Assert
assert result == "segment_added"
get_vector_space.assert_called_once_with("tenant123")
get_features.assert_not_called()
def test_should_reject_when_over_resource_limit(self):
"""Test that requests are rejected when over resource limits"""
@ -230,7 +313,7 @@ class TestBillingResourceLimits:
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
with pytest.raises(HTTPException) as exc_info:
add_member()
assert exc_info.value.code == 403
assert "members has reached the limit" in str(exc_info.value.description)
@ -255,7 +338,7 @@ class TestBillingResourceLimits:
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
with pytest.raises(HTTPException) as exc_info:
upload_document()
assert exc_info.value.code == 403
@ -329,7 +412,7 @@ class TestRateLimiting:
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
with pytest.raises(Exception) as exc_info:
with pytest.raises(HTTPException) as exc_info:
knowledge_request()
# Verify error

View File

@ -139,7 +139,7 @@ class TestTenantListApi:
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_called_once_with("t2")
get_features_mock.assert_called_once_with("t2", exclude_vector_space=True)
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app: Flask):
"""Test fallback to FeatureService when bulk billing returns empty result.
@ -235,7 +235,7 @@ class TestTenantListApi:
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
get_features_mock.assert_called_once_with("t1")
get_features_mock.assert_called_once_with("t1", exclude_vector_space=True)
def test_get_enterprise_only_skips_feature_service(self, app: Flask):
api = TenantListApi()

View File

@ -872,6 +872,11 @@ class TestSegmentApiPost:
mock_features.billing.enabled = False
mock_feature_svc.get_features.return_value = mock_features
mock_vector_space = Mock()
mock_vector_space.limit = 10
mock_vector_space.size = 0
mock_feature_svc.get_vector_space.return_value = mock_vector_space
mock_rate_limit = Mock()
mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
@ -1209,6 +1214,10 @@ class TestDatasetSegmentApiUpdate:
mock_features = Mock()
mock_features.billing.enabled = False
mock_feature_svc.get_features.return_value = mock_features
mock_vector_space = Mock()
mock_vector_space.limit = 10
mock_vector_space.size = 0
mock_feature_svc.get_vector_space.return_value = mock_vector_space
mock_rate_limit = Mock()
mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
@ -1710,6 +1719,10 @@ class TestChildChunkApiPost:
mock_features = Mock()
mock_features.billing.enabled = False
mock_feature_svc.get_features.return_value = mock_features
mock_vector_space = Mock()
mock_vector_space.limit = 10
mock_vector_space.size = 0
mock_feature_svc.get_vector_space.return_value = mock_vector_space
mock_rate_limit = Mock()
mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit

View File

@ -950,7 +950,8 @@ class TestDocumentAddByTextApi:
"""Configure mocks to neutralise billing/auth decorators.
``cloud_edition_billing_resource_check`` calls
``FeatureService.get_features`` and
``FeatureService.get_vector_space`` for vector-space checks and
``FeatureService.get_features`` for other resource checks.
``cloud_edition_billing_rate_limit_check`` calls
``FeatureService.get_knowledge_rate_limit``.
Both call ``validate_and_get_api_token`` first.
@ -963,6 +964,11 @@ class TestDocumentAddByTextApi:
mock_features.billing.enabled = False
mock_feature_svc.get_features.return_value = mock_features
mock_vector_space = Mock()
mock_vector_space.limit = 10
mock_vector_space.size = 0
mock_feature_svc.get_vector_space.return_value = mock_vector_space
mock_rate_limit = Mock()
mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
@ -1140,6 +1146,10 @@ def _setup_billing_mocks(mock_validate_token, mock_feature_svc, tenant_id: str):
mock_features = Mock()
mock_features.billing.enabled = False
mock_feature_svc.get_features.return_value = mock_features
mock_vector_space = Mock()
mock_vector_space.limit = 10
mock_vector_space.size = 0
mock_feature_svc.get_vector_space.return_value = mock_vector_space
mock_rate_limit = Mock()
mock_rate_limit.enabled = False
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit

View File

@ -265,6 +265,65 @@ class TestCloudEditionBillingResourceCheck:
# Assert
assert result == "member_added"
mock_get_features.assert_called_once_with("tenant123", exclude_vector_space=True)
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
@patch("controllers.service_api.wraps.FeatureService.get_vector_space")
def test_loads_vector_space_from_dedicated_quota_api(
self, mock_get_vector_space, mock_get_features, mock_validate_token, app: Flask
):
"""Test vector-space resource checks avoid loading the full feature payload."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
mock_vector_space = Mock()
mock_vector_space.limit = 10
mock_vector_space.size = 5
mock_get_vector_space.return_value = mock_vector_space
@cloud_edition_billing_resource_check("vector_space", "dataset")
def add_segment():
return "segment_added"
# Act
with (
app.test_request_context("/", method="GET"),
patch("controllers.service_api.wraps.dify_config.BILLING_ENABLED", True),
):
result = add_segment()
# Assert
assert result == "segment_added"
mock_get_vector_space.assert_called_once_with("tenant123")
mock_get_features.assert_not_called()
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_loads_features_when_checking_non_vector_space_limit(
self, mock_get_features, mock_validate_token, app: Flask
):
"""Test non-vector-space resource checks keep using the light feature payload."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
mock_features = Mock()
mock_features.billing.enabled = True
mock_features.documents_upload_quota.limit = 10
mock_features.documents_upload_quota.size = 5
mock_get_features.return_value = mock_features
@cloud_edition_billing_resource_check("documents", "dataset")
def upload_document():
return "document_uploaded"
# Act
with app.test_request_context("/", method="GET"):
result = upload_document()
# Assert
assert result == "document_uploaded"
mock_get_features.assert_called_once_with("tenant123", exclude_vector_space=True)
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")

View File

@ -126,7 +126,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
monkeypatch.setattr(
site_module.FeatureService,
"get_features",
lambda tenant_id: SimpleNamespace(can_replace_logo=True),
lambda tenant_id, **_kwargs: SimpleNamespace(can_replace_logo=True),
)
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
@ -245,7 +245,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F
monkeypatch.setattr(
site_module.FeatureService,
"get_features",
lambda tenant_id: SimpleNamespace(can_replace_logo=True),
lambda tenant_id, **_kwargs: SimpleNamespace(can_replace_logo=True),
)
with app.test_request_context("/api/form/human_input/token-1", method="GET"):