mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
Merge remote-tracking branch 'upstream/main' into feat/human-input-merge-again
This commit is contained in:
286
api/tests/unit_tests/libs/test_archive_storage.py
Normal file
286
api/tests/unit_tests/libs/test_archive_storage.py
Normal file
@ -0,0 +1,286 @@
|
||||
import base64
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from unittest.mock import ANY, MagicMock
|
||||
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from libs import archive_storage as storage_module
|
||||
from libs.archive_storage import (
|
||||
ArchiveStorage,
|
||||
ArchiveStorageError,
|
||||
ArchiveStorageNotConfiguredError,
|
||||
)
|
||||
|
||||
BUCKET_NAME = "archive-bucket"
|
||||
|
||||
|
||||
def _configure_storage(monkeypatch, **overrides):
|
||||
defaults = {
|
||||
"ARCHIVE_STORAGE_ENABLED": True,
|
||||
"ARCHIVE_STORAGE_ENDPOINT": "https://storage.example.com",
|
||||
"ARCHIVE_STORAGE_ARCHIVE_BUCKET": BUCKET_NAME,
|
||||
"ARCHIVE_STORAGE_ACCESS_KEY": "access",
|
||||
"ARCHIVE_STORAGE_SECRET_KEY": "secret",
|
||||
"ARCHIVE_STORAGE_REGION": "auto",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
for key, value in defaults.items():
|
||||
monkeypatch.setattr(storage_module.dify_config, key, value, raising=False)
|
||||
|
||||
|
||||
def _client_error(code: str) -> ClientError:
|
||||
return ClientError({"Error": {"Code": code}}, "Operation")
|
||||
|
||||
|
||||
def _mock_client(monkeypatch):
|
||||
client = MagicMock()
|
||||
client.head_bucket.return_value = None
|
||||
# Configure put_object to return a proper ETag that matches the MD5 hash
|
||||
# The ETag format is typically the MD5 hash wrapped in quotes
|
||||
|
||||
def mock_put_object(**kwargs):
|
||||
md5_hash = kwargs.get("Body", b"")
|
||||
if isinstance(md5_hash, bytes):
|
||||
md5_hash = hashlib.md5(md5_hash).hexdigest()
|
||||
else:
|
||||
md5_hash = hashlib.md5(md5_hash.encode()).hexdigest()
|
||||
response = MagicMock()
|
||||
response.get.return_value = f'"{md5_hash}"'
|
||||
return response
|
||||
|
||||
client.put_object.side_effect = mock_put_object
|
||||
boto_client = MagicMock(return_value=client)
|
||||
monkeypatch.setattr(storage_module.boto3, "client", boto_client)
|
||||
return client, boto_client
|
||||
|
||||
|
||||
def test_init_disabled(monkeypatch):
|
||||
_configure_storage(monkeypatch, ARCHIVE_STORAGE_ENABLED=False)
|
||||
with pytest.raises(ArchiveStorageNotConfiguredError, match="not enabled"):
|
||||
ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
|
||||
def test_init_missing_config(monkeypatch):
|
||||
_configure_storage(monkeypatch, ARCHIVE_STORAGE_ENDPOINT=None)
|
||||
with pytest.raises(ArchiveStorageNotConfiguredError, match="incomplete"):
|
||||
ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
|
||||
def test_init_bucket_not_found(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.head_bucket.side_effect = _client_error("404")
|
||||
|
||||
with pytest.raises(ArchiveStorageNotConfiguredError, match="does not exist"):
|
||||
ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
|
||||
def test_init_bucket_access_denied(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.head_bucket.side_effect = _client_error("403")
|
||||
|
||||
with pytest.raises(ArchiveStorageNotConfiguredError, match="Access denied"):
|
||||
ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
|
||||
def test_init_bucket_other_error(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.head_bucket.side_effect = _client_error("500")
|
||||
|
||||
with pytest.raises(ArchiveStorageError, match="Failed to access archive bucket"):
|
||||
ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
|
||||
def test_init_sets_client(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, boto_client = _mock_client(monkeypatch)
|
||||
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
boto_client.assert_called_once_with(
|
||||
"s3",
|
||||
endpoint_url="https://storage.example.com",
|
||||
aws_access_key_id="access",
|
||||
aws_secret_access_key="secret",
|
||||
region_name="auto",
|
||||
config=ANY,
|
||||
)
|
||||
assert storage.client is client
|
||||
assert storage.bucket == BUCKET_NAME
|
||||
|
||||
|
||||
def test_put_object_returns_checksum(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
data = b"hello"
|
||||
checksum = storage.put_object("key", data)
|
||||
|
||||
expected_md5 = hashlib.md5(data).hexdigest()
|
||||
expected_content_md5 = base64.b64encode(hashlib.md5(data).digest()).decode()
|
||||
client.put_object.assert_called_once_with(
|
||||
Bucket="archive-bucket",
|
||||
Key="key",
|
||||
Body=data,
|
||||
ContentMD5=expected_content_md5,
|
||||
)
|
||||
assert checksum == expected_md5
|
||||
|
||||
|
||||
def test_put_object_raises_on_error(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
client.put_object.side_effect = _client_error("500")
|
||||
|
||||
with pytest.raises(ArchiveStorageError, match="Failed to upload object"):
|
||||
storage.put_object("key", b"data")
|
||||
|
||||
|
||||
def test_get_object_returns_bytes(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
body = MagicMock()
|
||||
body.read.return_value = b"payload"
|
||||
client.get_object.return_value = {"Body": body}
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
assert storage.get_object("key") == b"payload"
|
||||
|
||||
|
||||
def test_get_object_missing(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.get_object.side_effect = _client_error("NoSuchKey")
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="Archive object not found"):
|
||||
storage.get_object("missing")
|
||||
|
||||
|
||||
def test_get_object_stream(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
body = MagicMock()
|
||||
body.iter_chunks.return_value = [b"a", b"b"]
|
||||
client.get_object.return_value = {"Body": body}
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
assert list(storage.get_object_stream("key")) == [b"a", b"b"]
|
||||
|
||||
|
||||
def test_get_object_stream_missing(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.get_object.side_effect = _client_error("NoSuchKey")
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="Archive object not found"):
|
||||
list(storage.get_object_stream("missing"))
|
||||
|
||||
|
||||
def test_object_exists(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
assert storage.object_exists("key") is True
|
||||
client.head_object.side_effect = _client_error("404")
|
||||
assert storage.object_exists("missing") is False
|
||||
|
||||
|
||||
def test_delete_object_error(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.delete_object.side_effect = _client_error("500")
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
with pytest.raises(ArchiveStorageError, match="Failed to delete object"):
|
||||
storage.delete_object("key")
|
||||
|
||||
|
||||
def test_list_objects(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
paginator = MagicMock()
|
||||
paginator.paginate.return_value = [
|
||||
{"Contents": [{"Key": "a"}, {"Key": "b"}]},
|
||||
{"Contents": [{"Key": "c"}]},
|
||||
]
|
||||
client.get_paginator.return_value = paginator
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
assert storage.list_objects("prefix") == ["a", "b", "c"]
|
||||
paginator.paginate.assert_called_once_with(Bucket="archive-bucket", Prefix="prefix")
|
||||
|
||||
|
||||
def test_list_objects_error(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
paginator = MagicMock()
|
||||
paginator.paginate.side_effect = _client_error("500")
|
||||
client.get_paginator.return_value = paginator
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
with pytest.raises(ArchiveStorageError, match="Failed to list objects"):
|
||||
storage.list_objects("prefix")
|
||||
|
||||
|
||||
def test_generate_presigned_url(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.generate_presigned_url.return_value = "http://signed-url"
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
url = storage.generate_presigned_url("key", expires_in=123)
|
||||
|
||||
client.generate_presigned_url.assert_called_once_with(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": "archive-bucket", "Key": "key"},
|
||||
ExpiresIn=123,
|
||||
)
|
||||
assert url == "http://signed-url"
|
||||
|
||||
|
||||
def test_generate_presigned_url_error(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.generate_presigned_url.side_effect = _client_error("500")
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
with pytest.raises(ArchiveStorageError, match="Failed to generate pre-signed URL"):
|
||||
storage.generate_presigned_url("key")
|
||||
|
||||
|
||||
def test_serialization_roundtrip():
|
||||
records = [
|
||||
{
|
||||
"id": "1",
|
||||
"created_at": datetime(2024, 1, 1, 12, 0, 0),
|
||||
"payload": {"nested": "value"},
|
||||
"items": [{"name": "a"}],
|
||||
},
|
||||
{"id": "2", "value": 123},
|
||||
]
|
||||
|
||||
data = ArchiveStorage.serialize_to_jsonl(records)
|
||||
decoded = ArchiveStorage.deserialize_from_jsonl(data)
|
||||
|
||||
assert decoded[0]["id"] == "1"
|
||||
assert decoded[0]["payload"]["nested"] == "value"
|
||||
assert decoded[0]["items"][0]["name"] == "a"
|
||||
assert "2024-01-01T12:00:00" in decoded[0]["created_at"]
|
||||
assert decoded[1]["value"] == 123
|
||||
|
||||
|
||||
def test_content_md5_matches_checksum():
|
||||
data = b"checksum"
|
||||
expected = base64.b64encode(hashlib.md5(data).digest()).decode()
|
||||
|
||||
assert ArchiveStorage._content_md5(data) == expected
|
||||
assert ArchiveStorage.compute_checksum(data) == hashlib.md5(data).hexdigest()
|
||||
@ -99,29 +99,20 @@ def test_external_api_json_message_and_bad_request_rewrite():
|
||||
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
|
||||
|
||||
|
||||
def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||
# Force exc_info() to return (None,None,None) only during request
|
||||
import libs.external_api as ext
|
||||
def test_external_api_param_mapping_and_quota():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
orig_exc_info = ext.sys.exc_info
|
||||
try:
|
||||
ext.sys.exc_info = lambda: (None, None, None)
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
finally:
|
||||
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
|
||||
|
||||
def test_unauthorized_and_force_logout_clears_cookies():
|
||||
|
||||
@ -2,7 +2,7 @@ from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.helper import OptionalTimestampField, extract_tenant_id
|
||||
from libs.helper import OptionalTimestampField, escape_like_pattern, extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
@ -78,3 +78,51 @@ class TestOptionalTimestampField:
|
||||
value = datetime(2024, 1, 2, 3, 4, 5)
|
||||
|
||||
assert field.format(value) == int(value.timestamp())
|
||||
|
||||
|
||||
class TestEscapeLikePattern:
|
||||
"""Test cases for the escape_like_pattern utility function."""
|
||||
|
||||
def test_escape_percent_character(self):
|
||||
"""Test escaping percent character."""
|
||||
result = escape_like_pattern("50% discount")
|
||||
assert result == "50\\% discount"
|
||||
|
||||
def test_escape_underscore_character(self):
|
||||
"""Test escaping underscore character."""
|
||||
result = escape_like_pattern("test_data")
|
||||
assert result == "test\\_data"
|
||||
|
||||
def test_escape_backslash_character(self):
|
||||
"""Test escaping backslash character."""
|
||||
result = escape_like_pattern("path\\to\\file")
|
||||
assert result == "path\\\\to\\\\file"
|
||||
|
||||
def test_escape_combined_special_characters(self):
|
||||
"""Test escaping multiple special characters together."""
|
||||
result = escape_like_pattern("file_50%\\path")
|
||||
assert result == "file\\_50\\%\\\\path"
|
||||
|
||||
def test_escape_empty_string(self):
|
||||
"""Test escaping empty string returns empty string."""
|
||||
result = escape_like_pattern("")
|
||||
assert result == ""
|
||||
|
||||
def test_escape_none_handling(self):
|
||||
"""Test escaping None returns None (falsy check handles it)."""
|
||||
# The function checks `if not pattern`, so None is falsy and returns as-is
|
||||
result = escape_like_pattern(None)
|
||||
assert result is None
|
||||
|
||||
def test_escape_normal_string_no_change(self):
|
||||
"""Test that normal strings without special characters are unchanged."""
|
||||
result = escape_like_pattern("normal text")
|
||||
assert result == "normal text"
|
||||
|
||||
def test_escape_order_matters(self):
|
||||
"""Test that backslash is escaped first to prevent double escaping."""
|
||||
# If we escape % first, then escape \, we might get wrong results
|
||||
# This test ensures the order is correct: \ first, then % and _
|
||||
result = escape_like_pattern("test\\%_value")
|
||||
# Should be: test\\\%\_value
|
||||
assert result == "test\\\\\\%\\_value"
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -17,7 +17,7 @@ def test_smtp_plain_success(mock_smtp_cls: MagicMock):
|
||||
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
|
||||
client.send(_mail())
|
||||
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10)
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10, local_hostname=ANY)
|
||||
mock_smtp.sendmail.assert_called_once()
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
@ -38,7 +38,7 @@ def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
|
||||
)
|
||||
client.send(_mail())
|
||||
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10)
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10, local_hostname=ANY)
|
||||
assert mock_smtp.ehlo.call_count == 2
|
||||
mock_smtp.starttls.assert_called_once()
|
||||
mock_smtp.login.assert_called_once_with("user", "pass")
|
||||
|
||||
142
api/tests/unit_tests/libs/test_workspace_permission.py
Normal file
142
api/tests/unit_tests/libs/test_workspace_permission.py
Normal file
@ -0,0 +1,142 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.workspace_permission import (
|
||||
check_workspace_member_invite_permission,
|
||||
check_workspace_owner_transfer_permission,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkspacePermissionHelper:
|
||||
"""Test workspace permission helper functions."""
|
||||
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
def test_community_edition_allows_invite(self, mock_enterprise_service, mock_config):
|
||||
"""Community edition should always allow invitations without calling any service."""
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
# Should not raise
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
# EnterpriseService should NOT be called in community edition
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
|
||||
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_community_edition_allows_transfer(self, mock_feature_service, mock_config):
|
||||
"""Community edition should check billing plan but not call enterprise service."""
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = True
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
# Should not raise
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
mock_feature_service.get_features.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
def test_enterprise_blocks_invite_when_disabled(self, mock_config, mock_enterprise_service):
|
||||
"""Enterprise edition should block invitations when workspace policy is False."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_member_invite = False
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
with pytest.raises(Forbidden, match="Workspace policy prohibits member invitations"):
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
def test_enterprise_allows_invite_when_enabled(self, mock_config, mock_enterprise_service):
|
||||
"""Enterprise edition should allow invitations when workspace policy is True."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_member_invite = True
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
# Should not raise
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_billing_plan_blocks_transfer(self, mock_feature_service, mock_config, mock_enterprise_service):
|
||||
"""SANDBOX billing plan should block owner transfer before checking enterprise policy."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = False # SANDBOX plan
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
with pytest.raises(Forbidden, match="Your current plan does not allow workspace ownership transfer"):
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
# Enterprise service should NOT be called since billing plan already blocks
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_enterprise_blocks_transfer_when_disabled(self, mock_feature_service, mock_config, mock_enterprise_service):
|
||||
"""Enterprise edition should block transfer when workspace policy is False."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = True # Billing plan allows
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_owner_transfer = False # Workspace policy blocks
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
with pytest.raises(Forbidden, match="Workspace policy prohibits ownership transfer"):
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_enterprise_allows_transfer_when_both_enabled(
|
||||
self, mock_feature_service, mock_config, mock_enterprise_service
|
||||
):
|
||||
"""Enterprise edition should allow transfer when both billing and workspace policy allow."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = True # Billing plan allows
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_owner_transfer = True # Workspace policy allows
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
# Should not raise
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.logger")
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
def test_enterprise_service_error_fails_open(self, mock_config, mock_enterprise_service, mock_logger):
|
||||
"""On enterprise service error, should fail-open (allow) and log error."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Simulate enterprise service error
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.side_effect = Exception("Service unavailable")
|
||||
|
||||
# Should not raise (fail-open)
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
# Should log the error
|
||||
mock_logger.exception.assert_called_once()
|
||||
assert "Failed to check workspace invite permission" in str(mock_logger.exception.call_args)
|
||||
Reference in New Issue
Block a user