mirror of
https://github.com/langgenius/dify.git
synced 2026-06-08 09:27:39 +08:00
Merge branch 'main' into 4-27-app-deploy
This commit is contained in:
@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.data_source_bearer_auth import (
|
||||
ApiKeyAuthDataSource,
|
||||
ApiKeyAuthDataSourceBinding,
|
||||
ApiKeyAuthDataSourceBindingDelete,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _payload_patch(payload: dict):
|
||||
return patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
|
||||
def test_list_data_source_auth_uses_injected_tenant_id() -> None:
|
||||
api = ApiKeyAuthDataSource()
|
||||
method = _unwrap(api.get)
|
||||
binding = SimpleNamespace(
|
||||
id="binding-1",
|
||||
category="api_key",
|
||||
provider="custom",
|
||||
disabled=False,
|
||||
created_at=datetime(2026, 1, 1, tzinfo=UTC),
|
||||
updated_at=datetime(2026, 1, 2, tzinfo=UTC),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list",
|
||||
return_value=[binding],
|
||||
) as get_provider_auth_list:
|
||||
result = method(api, "tenant-1")
|
||||
|
||||
get_provider_auth_list.assert_called_once_with("tenant-1")
|
||||
assert result["sources"][0]["id"] == "binding-1"
|
||||
assert result["sources"][0]["provider"] == "custom"
|
||||
|
||||
|
||||
def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None:
|
||||
api = ApiKeyAuthDataSourceBinding()
|
||||
method = _unwrap(api.post)
|
||||
payload = {
|
||||
"category": "api_key",
|
||||
"provider": "custom",
|
||||
"credentials": {"auth_type": "api_key", "config": {"api_key": "secret"}},
|
||||
}
|
||||
|
||||
with (
|
||||
_payload_patch(payload),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth,
|
||||
):
|
||||
result, status = method(api, "tenant-1")
|
||||
|
||||
create_auth.assert_called_once_with("tenant-1", payload)
|
||||
assert result == {"result": "success"}
|
||||
assert status == 200
|
||||
|
||||
|
||||
def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None:
|
||||
api = ApiKeyAuthDataSourceBindingDelete()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth"
|
||||
) as delete_provider_auth:
|
||||
result, status = method(api, "tenant-1", "binding-1")
|
||||
|
||||
delete_provider_auth.assert_called_once_with("tenant-1", "binding-1")
|
||||
assert result == ""
|
||||
assert status == 204
|
||||
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.console.auth.oauth_server import OAuthServerUserAuthorizeApi
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = "account-1"
|
||||
account.role = TenantAccountRole.OWNER
|
||||
return account
|
||||
|
||||
|
||||
def _make_oauth_provider_app() -> OAuthProviderApp:
|
||||
return OAuthProviderApp(
|
||||
app_icon="icon",
|
||||
client_id="client-1",
|
||||
client_secret="secret",
|
||||
app_label={"en-US": "Test App"},
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
scope="read",
|
||||
)
|
||||
|
||||
|
||||
def test_oauth_authorize_uses_injected_current_user() -> None:
|
||||
api = OAuthServerUserAuthorizeApi()
|
||||
method = _unwrap(api.post)
|
||||
account = _make_account()
|
||||
oauth_provider_app = _make_oauth_provider_app()
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code",
|
||||
return_value="authorization-code",
|
||||
) as sign_oauth_authorization_code:
|
||||
response = method(api, oauth_provider_app, account)
|
||||
|
||||
sign_oauth_authorization_code.assert_called_once_with("client-1", "account-1")
|
||||
assert response == {"code": "authorization-code"}
|
||||
@ -70,13 +70,14 @@ class TestExternalApiTemplateListApi:
|
||||
ExternalDatasetService,
|
||||
"get_external_knowledge_apis",
|
||||
return_value=([api_item], 1),
|
||||
),
|
||||
) as get_external_knowledge_apis,
|
||||
):
|
||||
resp, status = method(api, "id")
|
||||
resp, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 1
|
||||
assert resp["data"][0]["id"] == "1"
|
||||
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
|
||||
|
||||
def test_post_forbidden(self, app: Flask, current_user):
|
||||
current_user.is_dataset_editor = False
|
||||
@ -321,13 +322,14 @@ class TestExternalApiTemplateListApiAdvanced:
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
|
||||
return_value=(templates, 25),
|
||||
),
|
||||
) as get_external_knowledge_apis,
|
||||
):
|
||||
resp, status = method(api, "id")
|
||||
resp, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 25
|
||||
assert len(resp["data"]) == 3
|
||||
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApiAdvanced:
|
||||
|
||||
@ -13,6 +13,8 @@ from controllers.console.tag.tags import (
|
||||
TagListApi,
|
||||
TagUpdateDeleteApi,
|
||||
)
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.enums import TagType
|
||||
from services.tag_service import UpdateTagPayload
|
||||
|
||||
@ -35,20 +37,26 @@ def app():
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user():
|
||||
return MagicMock(
|
||||
id="user-1",
|
||||
has_edit_permission=True,
|
||||
is_dataset_editor=True,
|
||||
account = Account(
|
||||
name="Admin User",
|
||||
email="admin@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = "user-1"
|
||||
account.role = TenantAccountRole.OWNER
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def readonly_user():
|
||||
return MagicMock(
|
||||
id="user-2",
|
||||
has_edit_permission=False,
|
||||
is_dataset_editor=False,
|
||||
account = Account(
|
||||
name="Readonly User",
|
||||
email="readonly@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = "user-2"
|
||||
account.role = TenantAccountRole.NORMAL
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -80,10 +88,6 @@ class TestTagListApi:
|
||||
|
||||
with app.test_request_context("/?type=knowledge"):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.get_tags",
|
||||
return_value=[
|
||||
@ -96,7 +100,7 @@ class TestTagListApi:
|
||||
],
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}]
|
||||
@ -109,17 +113,13 @@ class TestTagListApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.save_tags",
|
||||
return_value=tag,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, admin_user)
|
||||
|
||||
assert status == 200
|
||||
assert result["name"] == "test-tag"
|
||||
@ -133,14 +133,10 @@ class TestTagListApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, readonly_user)
|
||||
|
||||
|
||||
class TestTagUpdateDeleteApi:
|
||||
@ -152,10 +148,6 @@ class TestTagUpdateDeleteApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.update_tags",
|
||||
@ -166,7 +158,7 @@ class TestTagUpdateDeleteApi:
|
||||
return_value=3,
|
||||
),
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
result, status = method(api, admin_user, "tag-1")
|
||||
|
||||
assert status == 200
|
||||
update_payload, tag_id = update_tags_mock.call_args.args
|
||||
@ -182,14 +174,10 @@ class TestTagUpdateDeleteApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "tag-1")
|
||||
method(api, readonly_user, "tag-1")
|
||||
|
||||
def test_delete_success(self, app: Flask, admin_user):
|
||||
api = TagUpdateDeleteApi()
|
||||
@ -197,10 +185,6 @@ class TestTagUpdateDeleteApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock,
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
@ -222,14 +206,10 @@ class TestTagBindingCollectionApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, admin_user)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert status == 200
|
||||
@ -241,14 +221,10 @@ class TestTagBindingCollectionApi:
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch({}),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, readonly_user)
|
||||
|
||||
|
||||
class TestTagBindingRemoveApi:
|
||||
@ -264,14 +240,10 @@ class TestTagBindingRemoveApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, admin_user)
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
delete_payload = delete_mock.call_args.args[0]
|
||||
@ -285,14 +257,10 @@ class TestTagBindingRemoveApi:
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch({}),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, readonly_user)
|
||||
|
||||
|
||||
class TestTagResponseModel:
|
||||
|
||||
141
api/tests/unit_tests/controllers/console/test_apikey.py
Normal file
141
api/tests/unit_tests/controllers/console/test_apikey.py
Normal file
@ -0,0 +1,141 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.apikey import BaseApiKeyListResource, BaseApiKeyResource
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
|
||||
|
||||
def _make_list_resource() -> BaseApiKeyListResource:
|
||||
resource = BaseApiKeyListResource()
|
||||
resource.resource_type = ApiTokenType.APP
|
||||
resource.resource_model = App
|
||||
resource.resource_id_field = "app_id"
|
||||
resource.token_prefix = "app-"
|
||||
return resource
|
||||
|
||||
|
||||
def _make_key_resource() -> BaseApiKeyResource:
|
||||
resource = BaseApiKeyResource()
|
||||
resource.resource_type = ApiTokenType.APP
|
||||
resource.resource_model = App
|
||||
resource.resource_id_field = "app_id"
|
||||
return resource
|
||||
|
||||
|
||||
def _make_account(role: TenantAccountRole) -> Account:
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email=f"{role.value}@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = f"{role.value}-user"
|
||||
account.role = role
|
||||
return account
|
||||
|
||||
|
||||
def test_list_api_keys_uses_injected_tenant_id() -> None:
|
||||
resource = _make_list_resource()
|
||||
api_key = SimpleNamespace(
|
||||
id="key-1",
|
||||
type=ApiTokenType.APP,
|
||||
token="app-token",
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource") as get_resource,
|
||||
patch("controllers.console.apikey.db") as db_mock,
|
||||
):
|
||||
db_mock.session.scalars.return_value.all.return_value = [api_key]
|
||||
|
||||
result = resource.get("app-1", "tenant-1")
|
||||
|
||||
get_resource.assert_called_once_with("app-1", "tenant-1", App)
|
||||
assert result == {
|
||||
"data": [
|
||||
{
|
||||
"id": "key-1",
|
||||
"type": "app",
|
||||
"token": "app-token",
|
||||
"last_used_at": None,
|
||||
"created_at": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_create_api_key_uses_injected_tenant_id() -> None:
|
||||
resource = _make_list_resource()
|
||||
raw_post = cast(
|
||||
Callable[[BaseApiKeyListResource, str, str], tuple[dict[str, object], int]],
|
||||
inspect.unwrap(BaseApiKeyListResource.post),
|
||||
)
|
||||
|
||||
def add_api_token(api_token: ApiToken) -> None:
|
||||
api_token.id = "key-1"
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource") as get_resource,
|
||||
patch("controllers.console.apikey.db") as db_mock,
|
||||
patch("controllers.console.apikey.ApiToken.generate_api_key", return_value="app-generated-token"),
|
||||
):
|
||||
db_mock.session.scalar.return_value = 0
|
||||
db_mock.session.add.side_effect = add_api_token
|
||||
|
||||
result, status = raw_post(resource, "app-1", "tenant-1")
|
||||
|
||||
get_resource.assert_called_once_with("app-1", "tenant-1", App)
|
||||
assert status == 201
|
||||
assert result["token"] == "app-generated-token"
|
||||
api_token = db_mock.session.add.call_args.args[0]
|
||||
assert api_token.app_id == "app-1"
|
||||
assert api_token.tenant_id == "tenant-1"
|
||||
assert api_token.type == ApiTokenType.APP
|
||||
db_mock.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_api_key_rejects_non_admin_account() -> None:
|
||||
resource = _make_key_resource()
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource") as get_resource,
|
||||
patch("controllers.console.apikey.db") as db_mock,
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.NORMAL))
|
||||
|
||||
get_resource.assert_called_once_with("app-1", "tenant-1", App)
|
||||
db_mock.session.scalar.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_api_key_uses_injected_user_and_tenant() -> None:
|
||||
resource = _make_key_resource()
|
||||
api_key = SimpleNamespace(token="app-token", type=ApiTokenType.APP)
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource") as get_resource,
|
||||
patch("controllers.console.apikey.db") as db_mock,
|
||||
patch("controllers.console.apikey.ApiTokenCache.delete") as delete_cache,
|
||||
):
|
||||
db_mock.session.scalar.return_value = api_key
|
||||
|
||||
result, status = resource.delete("app-1", "key-1", "tenant-1", _make_account(TenantAccountRole.OWNER))
|
||||
|
||||
get_resource.assert_called_once_with("app-1", "tenant-1", App)
|
||||
delete_cache.assert_called_once_with("app-token", ApiTokenType.APP)
|
||||
db_mock.session.execute.assert_called_once()
|
||||
db_mock.session.commit.assert_called_once()
|
||||
assert result == ""
|
||||
assert status == 204
|
||||
@ -19,6 +19,8 @@ from controllers.console.files import (
|
||||
FilePreviewApi,
|
||||
FileSupportTypeApi,
|
||||
)
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@ -53,14 +55,15 @@ def mock_decorators():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_user():
|
||||
user = MagicMock()
|
||||
user.is_dataset_editor = True
|
||||
user = Account(name="Test User", email="user-1@example.com", status=AccountStatus.ACTIVE)
|
||||
user.id = "user-1"
|
||||
user.role = TenantAccountRole.OWNER
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_tenant_id():
|
||||
return "tenant-123"
|
||||
def mock_account_context(mock_current_user):
|
||||
return mock_current_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -91,15 +94,15 @@ class TestFileApiGet:
|
||||
|
||||
|
||||
class TestFileApiPost:
|
||||
def test_no_file_uploaded(self, app: Flask, mock_current_user):
|
||||
def test_no_file_uploaded(self, app: Flask, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(method="POST", data={}):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
def test_too_many_files(self, app: Flask, mock_current_user):
|
||||
def test_too_many_files(self, app: Flask, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -114,9 +117,9 @@ class TestFileApiPost:
|
||||
mock_request.form.get.return_value = None
|
||||
|
||||
with pytest.raises(TooManyFilesError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
def test_filename_missing(self, app: Flask, mock_current_user):
|
||||
def test_filename_missing(self, app: Flask, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -126,10 +129,10 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
def test_dataset_upload_without_permission(self, app: Flask, mock_current_user):
|
||||
mock_current_user.is_dataset_editor = False
|
||||
mock_current_user.role = TenantAccountRole.NORMAL
|
||||
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
@ -143,7 +146,7 @@ class TestFileApiPost:
|
||||
with pytest.raises(Forbidden):
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_successful_upload(self, app: Flask, mock_current_user, mock_file_service):
|
||||
def test_successful_upload(self, app: Flask, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -171,13 +174,13 @@ class TestFileApiPost:
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api, mock_current_user)
|
||||
response, status = post_method(api, mock_account_context)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-123"
|
||||
assert response["name"] == "test.txt"
|
||||
|
||||
def test_upload_with_invalid_source(self, app: Flask, mock_current_user, mock_file_service):
|
||||
def test_upload_with_invalid_source(self, app: Flask, mock_account_context, mock_file_service):
|
||||
"""Test that invalid source parameter gets normalized to None"""
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
@ -208,7 +211,7 @@ class TestFileApiPost:
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api, mock_current_user)
|
||||
response, status = post_method(api, mock_account_context)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-456"
|
||||
@ -217,7 +220,7 @@ class TestFileApiPost:
|
||||
call_kwargs = mock_file_service.upload_file.call_args[1]
|
||||
assert call_kwargs["source"] is None
|
||||
|
||||
def test_file_too_large_error(self, app: Flask, mock_current_user, mock_file_service):
|
||||
def test_file_too_large_error(self, app: Flask, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -232,9 +235,9 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
def test_unsupported_file_type(self, app: Flask, mock_current_user, mock_file_service):
|
||||
def test_unsupported_file_type(self, app: Flask, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -249,9 +252,9 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
def test_blocked_extension(self, app: Flask, mock_current_user, mock_file_service):
|
||||
def test_blocked_extension(self, app: Flask, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -266,17 +269,17 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
|
||||
class TestFilePreviewApi:
|
||||
def test_get_preview(self, app: Flask, mock_current_tenant_id, mock_file_service):
|
||||
def test_get_preview(self, app: Flask, mock_account_context, mock_file_service):
|
||||
api = FilePreviewApi()
|
||||
get_method = unwrap(api.get)
|
||||
mock_file_service.get_file_preview.return_value = "preview text"
|
||||
|
||||
with app.test_request_context():
|
||||
result = get_method(api, mock_current_tenant_id, "1234")
|
||||
result = get_method(api, "tenant-123", "1234")
|
||||
|
||||
assert result == {"content": "preview text"}
|
||||
|
||||
|
||||
@ -10,6 +10,8 @@ import pytest
|
||||
|
||||
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
|
||||
from controllers.console import remote_files as remote_files_module
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
@ -20,6 +22,17 @@ def _unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
def _make_account(account_id: str = "u1") -> Account:
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email=f"{account_id}@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = account_id
|
||||
account.role = TenantAccountRole.OWNER
|
||||
return account
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(
|
||||
self,
|
||||
@ -48,7 +61,6 @@ def _mock_upload_dependencies(
|
||||
*,
|
||||
file_size_within_limit: bool = True,
|
||||
):
|
||||
current_user = SimpleNamespace(id="u1")
|
||||
file_info = SimpleNamespace(
|
||||
filename="report.txt",
|
||||
extension=".txt",
|
||||
@ -64,6 +76,7 @@ def _mock_upload_dependencies(
|
||||
file_service_cls = MagicMock()
|
||||
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
|
||||
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
|
||||
current_user = _make_account()
|
||||
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.file_helpers,
|
||||
@ -226,7 +239,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
|
||||
handler(api, SimpleNamespace(id="u1"))
|
||||
handler(api, _make_account())
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -243,7 +256,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
|
||||
handler(api, SimpleNamespace(id="u1"))
|
||||
handler(api, _make_account())
|
||||
|
||||
|
||||
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_login import LoginManager, UserMixin
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
@ -17,8 +18,11 @@ from controllers.console.wraps import (
|
||||
only_edition_enterprise,
|
||||
only_edition_self_hosted,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from models.account import AccountStatus
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
|
||||
@ -33,6 +37,17 @@ class MockUser(UserMixin):
|
||||
return self.id
|
||||
|
||||
|
||||
def make_account(account_id: str = "account-1") -> Account:
|
||||
account = Account(
|
||||
name="Test Account",
|
||||
email=f"{account_id}@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = account_id
|
||||
account.role = TenantAccountRole.OWNER
|
||||
return account
|
||||
|
||||
|
||||
def create_app_with_login():
|
||||
"""Create a Flask app with LoginManager configured."""
|
||||
app = Flask(__name__)
|
||||
@ -84,6 +99,42 @@ class TestAccountInitialization:
|
||||
protected_view()
|
||||
|
||||
|
||||
class TestCurrentContextInjection:
|
||||
"""Test request context injection decorators."""
|
||||
|
||||
def test_with_current_tenant_id_injects_tenant_id(self):
|
||||
class Handler:
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
return current_tenant_id
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-123")):
|
||||
assert Handler().get() == "tenant-123"
|
||||
|
||||
def test_with_current_user_injects_account(self):
|
||||
current_user = make_account()
|
||||
|
||||
class Handler:
|
||||
@with_current_user
|
||||
def get(self, injected_user):
|
||||
return injected_user
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")):
|
||||
assert Handler().get() is current_user
|
||||
|
||||
def test_stacked_current_context_injectors_preserve_argument_order(self):
|
||||
current_user = make_account()
|
||||
|
||||
class Handler:
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, injected_user):
|
||||
return current_tenant_id, injected_user
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")):
|
||||
assert Handler().get() == ("tenant-123", current_user)
|
||||
|
||||
|
||||
class TestEditionChecks:
|
||||
"""Test edition-specific decorators"""
|
||||
|
||||
@ -114,7 +165,7 @@ class TestEditionChecks:
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
cloud_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
@ -177,7 +228,7 @@ class TestBillingEnabled:
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.dify_config.BILLING_ENABLED", False):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features") as get_features:
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
billing_view()
|
||||
|
||||
assert exc_info.value.code == 403
|
||||
@ -204,11 +255,43 @@ class TestBillingResourceLimits:
|
||||
with patch(
|
||||
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
|
||||
):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_features", return_value=mock_features
|
||||
) as get_features:
|
||||
result = add_member()
|
||||
|
||||
# Assert
|
||||
assert result == "member_added"
|
||||
get_features.assert_called_once_with("tenant123", exclude_vector_space=True)
|
||||
|
||||
def test_should_load_vector_space_from_dedicated_quota_api(self):
|
||||
"""Test vector-space limit checks avoid loading the full feature payload."""
|
||||
# Arrange
|
||||
mock_vector_space = MagicMock()
|
||||
mock_vector_space.limit = 10
|
||||
mock_vector_space.size = 5
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
def add_segment():
|
||||
return "segment_added"
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
|
||||
):
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.BILLING_ENABLED", True),
|
||||
patch(
|
||||
"controllers.console.wraps.FeatureService.get_vector_space", return_value=mock_vector_space
|
||||
) as get_vector_space,
|
||||
patch("controllers.console.wraps.FeatureService.get_features") as get_features,
|
||||
):
|
||||
result = add_segment()
|
||||
|
||||
# Assert
|
||||
assert result == "segment_added"
|
||||
get_vector_space.assert_called_once_with("tenant123")
|
||||
get_features.assert_not_called()
|
||||
|
||||
def test_should_reject_when_over_resource_limit(self):
|
||||
"""Test that requests are rejected when over resource limits"""
|
||||
@ -230,7 +313,7 @@ class TestBillingResourceLimits:
|
||||
return_value=(MockUser("test_user"), "tenant123"),
|
||||
):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
add_member()
|
||||
assert exc_info.value.code == 403
|
||||
assert "members has reached the limit" in str(exc_info.value.description)
|
||||
@ -255,7 +338,7 @@ class TestBillingResourceLimits:
|
||||
return_value=(MockUser("test_user"), "tenant123"),
|
||||
):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
upload_document()
|
||||
assert exc_info.value.code == 403
|
||||
|
||||
@ -329,7 +412,7 @@ class TestRateLimiting:
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
knowledge_request()
|
||||
|
||||
# Verify error
|
||||
|
||||
@ -139,7 +139,7 @@ class TestTenantListApi:
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
|
||||
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
|
||||
get_features_mock.assert_called_once_with("t2")
|
||||
get_features_mock.assert_called_once_with("t2", exclude_vector_space=True)
|
||||
|
||||
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app: Flask):
|
||||
"""Test fallback to FeatureService when bulk billing returns empty result.
|
||||
@ -235,7 +235,7 @@ class TestTenantListApi:
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
get_features_mock.assert_called_once_with("t1")
|
||||
get_features_mock.assert_called_once_with("t1", exclude_vector_space=True)
|
||||
|
||||
def test_get_enterprise_only_skips_feature_service(self, app: Flask):
|
||||
api = TenantListApi()
|
||||
|
||||
@ -872,6 +872,11 @@ class TestSegmentApiPost:
|
||||
mock_features.billing.enabled = False
|
||||
mock_feature_svc.get_features.return_value = mock_features
|
||||
|
||||
mock_vector_space = Mock()
|
||||
mock_vector_space.limit = 10
|
||||
mock_vector_space.size = 0
|
||||
mock_feature_svc.get_vector_space.return_value = mock_vector_space
|
||||
|
||||
mock_rate_limit = Mock()
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
@ -1209,6 +1214,10 @@ class TestDatasetSegmentApiUpdate:
|
||||
mock_features = Mock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_feature_svc.get_features.return_value = mock_features
|
||||
mock_vector_space = Mock()
|
||||
mock_vector_space.limit = 10
|
||||
mock_vector_space.size = 0
|
||||
mock_feature_svc.get_vector_space.return_value = mock_vector_space
|
||||
mock_rate_limit = Mock()
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
@ -1710,6 +1719,10 @@ class TestChildChunkApiPost:
|
||||
mock_features = Mock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_feature_svc.get_features.return_value = mock_features
|
||||
mock_vector_space = Mock()
|
||||
mock_vector_space.limit = 10
|
||||
mock_vector_space.size = 0
|
||||
mock_feature_svc.get_vector_space.return_value = mock_vector_space
|
||||
mock_rate_limit = Mock()
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
@ -950,7 +950,8 @@ class TestDocumentAddByTextApi:
|
||||
"""Configure mocks to neutralise billing/auth decorators.
|
||||
|
||||
``cloud_edition_billing_resource_check`` calls
|
||||
``FeatureService.get_features`` and
|
||||
``FeatureService.get_vector_space`` for vector-space checks and
|
||||
``FeatureService.get_features`` for other resource checks.
|
||||
``cloud_edition_billing_rate_limit_check`` calls
|
||||
``FeatureService.get_knowledge_rate_limit``.
|
||||
Both call ``validate_and_get_api_token`` first.
|
||||
@ -963,6 +964,11 @@ class TestDocumentAddByTextApi:
|
||||
mock_features.billing.enabled = False
|
||||
mock_feature_svc.get_features.return_value = mock_features
|
||||
|
||||
mock_vector_space = Mock()
|
||||
mock_vector_space.limit = 10
|
||||
mock_vector_space.size = 0
|
||||
mock_feature_svc.get_vector_space.return_value = mock_vector_space
|
||||
|
||||
mock_rate_limit = Mock()
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
@ -1140,6 +1146,10 @@ def _setup_billing_mocks(mock_validate_token, mock_feature_svc, tenant_id: str):
|
||||
mock_features = Mock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_feature_svc.get_features.return_value = mock_features
|
||||
mock_vector_space = Mock()
|
||||
mock_vector_space.limit = 10
|
||||
mock_vector_space.size = 0
|
||||
mock_feature_svc.get_vector_space.return_value = mock_vector_space
|
||||
mock_rate_limit = Mock()
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
@ -265,6 +265,65 @@ class TestCloudEditionBillingResourceCheck:
|
||||
|
||||
# Assert
|
||||
assert result == "member_added"
|
||||
mock_get_features.assert_called_once_with("tenant123", exclude_vector_space=True)
|
||||
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_features")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_vector_space")
|
||||
def test_loads_vector_space_from_dedicated_quota_api(
|
||||
self, mock_get_vector_space, mock_get_features, mock_validate_token, app: Flask
|
||||
):
|
||||
"""Test vector-space resource checks avoid loading the full feature payload."""
|
||||
# Arrange
|
||||
mock_validate_token.return_value = Mock(tenant_id="tenant123")
|
||||
|
||||
mock_vector_space = Mock()
|
||||
mock_vector_space.limit = 10
|
||||
mock_vector_space.size = 5
|
||||
mock_get_vector_space.return_value = mock_vector_space
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
def add_segment():
|
||||
return "segment_added"
|
||||
|
||||
# Act
|
||||
with (
|
||||
app.test_request_context("/", method="GET"),
|
||||
patch("controllers.service_api.wraps.dify_config.BILLING_ENABLED", True),
|
||||
):
|
||||
result = add_segment()
|
||||
|
||||
# Assert
|
||||
assert result == "segment_added"
|
||||
mock_get_vector_space.assert_called_once_with("tenant123")
|
||||
mock_get_features.assert_not_called()
|
||||
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_features")
|
||||
def test_loads_features_when_checking_non_vector_space_limit(
|
||||
self, mock_get_features, mock_validate_token, app: Flask
|
||||
):
|
||||
"""Test non-vector-space resource checks keep using the light feature payload."""
|
||||
# Arrange
|
||||
mock_validate_token.return_value = Mock(tenant_id="tenant123")
|
||||
|
||||
mock_features = Mock()
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.documents_upload_quota.limit = 10
|
||||
mock_features.documents_upload_quota.size = 5
|
||||
mock_get_features.return_value = mock_features
|
||||
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
def upload_document():
|
||||
return "document_uploaded"
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/", method="GET"):
|
||||
result = upload_document()
|
||||
|
||||
# Assert
|
||||
assert result == "document_uploaded"
|
||||
mock_get_features.assert_called_once_with("tenant123", exclude_vector_space=True)
|
||||
|
||||
@patch("controllers.service_api.wraps.validate_and_get_api_token")
|
||||
@patch("controllers.service_api.wraps.FeatureService.get_features")
|
||||
|
||||
@ -126,7 +126,7 @@ def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
monkeypatch.setattr(
|
||||
site_module.FeatureService,
|
||||
"get_features",
|
||||
lambda tenant_id: SimpleNamespace(can_replace_logo=True),
|
||||
lambda tenant_id, **_kwargs: SimpleNamespace(can_replace_logo=True),
|
||||
)
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
@ -245,7 +245,7 @@ def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: F
|
||||
monkeypatch.setattr(
|
||||
site_module.FeatureService,
|
||||
"get_features",
|
||||
lambda tenant_id: SimpleNamespace(can_replace_logo=True),
|
||||
lambda tenant_id, **_kwargs: SimpleNamespace(can_replace_logo=True),
|
||||
)
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
|
||||
Reference in New Issue
Block a user