Compare commits

..

5 Commits

16 changed files with 386 additions and 178 deletions

View File

@ -27,7 +27,7 @@ COPY api/providers ./providers
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
COPY dify-agent/src /app/dify-agent/src
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
RUN uv sync --frozen --no-dev
RUN uv sync --frozen --no-dev --no-editable
# production stage
FROM base AS production

View File

@ -1,5 +1,6 @@
import posixpath
from collections.abc import Generator
from typing import override
import oss2 as aliyun_s3
@ -29,9 +30,11 @@ class AliyunOssStorage(BaseStorage):
cloudbox_id=dify_config.ALIYUN_CLOUDBOX_ID,
)
@override
def save(self, filename, data):
self.client.put_object(self.__wrapper_folder_filename(filename), data)
@override
def load_once(self, filename: str) -> bytes:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
data = obj.read()
@ -39,17 +42,21 @@ class AliyunOssStorage(BaseStorage):
return b""
return data
@override
def load_stream(self, filename: str) -> Generator:
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
while chunk := obj.read(4096):
yield chunk
@override
def download(self, filename: str, target_filepath):
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
@override
def exists(self, filename: str):
return self.client.object_exists(self.__wrapper_folder_filename(filename))
@override
def delete(self, filename: str):
self.client.delete_object(self.__wrapper_folder_filename(filename))

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Generator
from typing import override
import boto3
from botocore.client import Config
@ -48,9 +49,11 @@ class AwsS3Storage(BaseStorage):
# other error, raise exception
raise
@override
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
@override
def load_once(self, filename: str) -> bytes:
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
@ -61,6 +64,7 @@ class AwsS3Storage(BaseStorage):
raise
return data
@override
def load_stream(self, filename: str) -> Generator:
try:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
@ -73,9 +77,11 @@ class AwsS3Storage(BaseStorage):
else:
raise
@override
def download(self, filename, target_filepath):
self.client.download_file(self.bucket_name, filename, target_filepath)
@override
def exists(self, filename):
try:
self.client.head_object(Bucket=self.bucket_name, Key=filename)
@ -83,5 +89,6 @@ class AwsS3Storage(BaseStorage):
except:
return False
@override
def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -1,5 +1,6 @@
from collections.abc import Generator
from datetime import timedelta
from typing import override
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
@ -26,6 +27,7 @@ class AzureBlobStorage(BaseStorage):
else:
self.credential = None
@override
def save(self, filename, data):
if not self.bucket_name:
return
@ -34,6 +36,7 @@ class AzureBlobStorage(BaseStorage):
blob_container = client.get_container_client(container=self.bucket_name)
blob_container.upload_blob(filename, data)
@override
def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise FileNotFoundError("Azure bucket name is not configured.")
@ -46,6 +49,7 @@ class AzureBlobStorage(BaseStorage):
raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}")
return data
@override
def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise FileNotFoundError("Azure bucket name is not configured.")
@ -55,6 +59,7 @@ class AzureBlobStorage(BaseStorage):
blob_data = blob.download_blob()
yield from blob_data.chunks()
@override
def download(self, filename, target_filepath):
if not self.bucket_name:
return
@ -66,6 +71,7 @@ class AzureBlobStorage(BaseStorage):
blob_data = blob.download_blob()
blob_data.readinto(my_blob)
@override
def exists(self, filename):
if not self.bucket_name:
return False
@ -75,6 +81,7 @@ class AzureBlobStorage(BaseStorage):
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
return blob.exists()
@override
def delete(self, filename: str):
if not self.bucket_name:
return

View File

@ -1,6 +1,7 @@
import base64
import hashlib
from collections.abc import Generator
from typing import override
from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration
@ -26,6 +27,7 @@ class BaiduObsStorage(BaseStorage):
self.client = BosClient(config=client_config)
@override
def save(self, filename, data):
md5 = hashlib.md5()
md5.update(data)
@ -34,24 +36,29 @@ class BaiduObsStorage(BaseStorage):
bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5
)
@override
def load_once(self, filename: str) -> bytes:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
data: bytes = response.data.read()
return data
@override
def load_stream(self, filename: str) -> Generator:
response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
while chunk := response.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath)
@override
def exists(self, filename):
res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename)
if res is None:
return False
return True
@override
def delete(self, filename: str):
self.client.delete_object(bucket_name=self.bucket_name, key=filename)

View File

@ -10,7 +10,7 @@ import tempfile
from collections.abc import Generator
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import Any, override
import clickzetta
from pydantic import BaseModel, model_validator
@ -251,6 +251,7 @@ class ClickZettaVolumeStorage(BaseStorage):
# Don't raise exception, let the operation continue
# The table might exist but not be visible due to permissions
@override
def save(self, filename: str, data: bytes):
"""Save data to ClickZetta Volume.
@ -304,6 +305,7 @@ class ClickZettaVolumeStorage(BaseStorage):
# Clean up temporary file
Path(temp_file_path).unlink(missing_ok=True)
@override
def load_once(self, filename: str) -> bytes:
"""Load file content from ClickZetta Volume.
@ -364,6 +366,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s loaded from ClickZetta Volume", filename)
return content
@override
def load_stream(self, filename: str) -> Generator:
"""Load file as stream from ClickZetta Volume.
@ -382,6 +385,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
@override
def download(self, filename: str, target_filepath: str):
"""Download file from ClickZetta Volume to local path.
@ -395,6 +399,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
@override
def exists(self, filename: str) -> bool:
"""Check if file exists in ClickZetta Volume.
@ -436,6 +441,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.warning("Error checking file existence for %s: %s", filename, e)
return False
@override
def delete(self, filename: str):
"""Delete file from ClickZetta Volume.
@ -472,6 +478,7 @@ class ClickZettaVolumeStorage(BaseStorage):
logger.debug("File %s deleted from ClickZetta Volume", filename)
@override
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
"""Scan files and directories in ClickZetta Volume.

View File

@ -1,7 +1,7 @@
import base64
import io
from collections.abc import Generator
from typing import Any
from typing import Any, override
from google.cloud import storage as google_cloud_storage # type: ignore
from pydantic import TypeAdapter
@ -29,12 +29,14 @@ class GoogleCloudStorage(BaseStorage):
else:
self.client = google_cloud_storage.Client()
@override
def save(self, filename, data):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
with io.BytesIO(data) as stream:
blob.upload_from_file(stream)
@override
def load_once(self, filename: str) -> bytes:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
@ -43,6 +45,7 @@ class GoogleCloudStorage(BaseStorage):
data: bytes = blob.download_as_bytes()
return data
@override
def load_stream(self, filename: str) -> Generator:
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
@ -52,6 +55,7 @@ class GoogleCloudStorage(BaseStorage):
while chunk := blob_stream.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename)
@ -59,11 +63,13 @@ class GoogleCloudStorage(BaseStorage):
raise FileNotFoundError("File not found")
blob.download_to_filename(target_filepath)
@override
def exists(self, filename):
bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.blob(filename)
return blob.exists()
@override
def delete(self, filename: str):
bucket = self.client.get_bucket(self.bucket_name)
bucket.delete_blob(filename)

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import override
from obs import ObsClient
@ -20,27 +21,33 @@ class HuaweiObsStorage(BaseStorage):
path_style=dify_config.HUAWEI_OBS_PATH_STYLE,
)
@override
def save(self, filename, data):
self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
@override
def load_once(self, filename: str) -> bytes:
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
return data
@override
def load_stream(self, filename: str) -> Generator:
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
while chunk := response.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath)
@override
def exists(self, filename):
res = self._get_meta(filename)
if res is None:
return False
return True
@override
def delete(self, filename: str):
self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename)

View File

@ -2,7 +2,7 @@ import logging
import os
from collections.abc import Generator
from pathlib import Path
from typing import Any
from typing import Any, override
import opendal
from dotenv import dotenv_values
@ -41,10 +41,12 @@ class OpenDALStorage(BaseStorage):
logger.debug("opendal operator created with scheme %s", scheme)
logger.debug("added retry layer to opendal operator")
@override
def save(self, filename: str, data: bytes):
self.op.write(path=filename, bs=data)
logger.debug("file %s saved", filename)
@override
def load_once(self, filename: str) -> bytes:
if not self.exists(filename):
raise FileNotFoundError("File not found")
@ -53,6 +55,7 @@ class OpenDALStorage(BaseStorage):
logger.debug("file %s loaded", filename)
return content
@override
def load_stream(self, filename: str) -> Generator:
if not self.exists(filename):
raise FileNotFoundError("File not found")
@ -67,6 +70,7 @@ class OpenDALStorage(BaseStorage):
yield chunk
logger.debug("file %s loaded as stream", filename)
@override
def download(self, filename: str, target_filepath: str):
if not self.exists(filename):
raise FileNotFoundError("File not found")
@ -74,9 +78,11 @@ class OpenDALStorage(BaseStorage):
Path(target_filepath).write_bytes(self.op.read(path=filename))
logger.debug("file %s downloaded to %s", filename, target_filepath)
@override
def exists(self, filename: str) -> bool:
return self.op.exists(path=filename)
@override
def delete(self, filename: str):
if self.exists(filename):
self.op.delete(path=filename)
@ -84,6 +90,7 @@ class OpenDALStorage(BaseStorage):
return
logger.debug("file %s not found, skip delete", filename)
@override
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
if not self.exists(path):
raise FileNotFoundError("Path not found")

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import override
import boto3
from botocore.exceptions import ClientError
@ -22,9 +23,11 @@ class OracleOCIStorage(BaseStorage):
region_name=dify_config.OCI_REGION,
)
@override
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
@override
def load_once(self, filename: str) -> bytes:
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
@ -35,6 +38,7 @@ class OracleOCIStorage(BaseStorage):
raise
return data
@override
def load_stream(self, filename: str) -> Generator:
try:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
@ -45,9 +49,11 @@ class OracleOCIStorage(BaseStorage):
else:
raise
@override
def download(self, filename, target_filepath):
self.client.download_file(self.bucket_name, filename, target_filepath)
@override
def exists(self, filename):
try:
self.client.head_object(Bucket=self.bucket_name, Key=filename)
@ -55,5 +61,6 @@ class OracleOCIStorage(BaseStorage):
except:
return False
@override
def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -1,6 +1,7 @@
import io
from collections.abc import Generator
from pathlib import Path
from typing import override
from supabase import Client
@ -28,29 +29,35 @@ class SupabaseStorage(BaseStorage):
if not self.bucket_exists():
self.client.storage.create_bucket(id=id, name=bucket_name)
@override
def save(self, filename, data):
self.client.storage.from_(self.bucket_name).upload(filename, data)
@override
def load_once(self, filename: str) -> bytes:
content: bytes = self.client.storage.from_(self.bucket_name).download(filename)
return content
@override
def load_stream(self, filename: str) -> Generator:
result = self.client.storage.from_(self.bucket_name).download(filename)
byte_stream = io.BytesIO(result)
while chunk := byte_stream.read(4096): # Read in chunks of 4KB
yield chunk
@override
def download(self, filename, target_filepath):
result = self.client.storage.from_(self.bucket_name).download(filename)
Path(target_filepath).write_bytes(result)
@override
def exists(self, filename):
result = self.client.storage.from_(self.bucket_name).list(path=filename)
if len(result) > 0:
return True
return False
@override
def delete(self, filename: str):
self.client.storage.from_(self.bucket_name).remove([filename])

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import override
from qcloud_cos import CosConfig, CosS3Client
@ -29,23 +30,29 @@ class TencentCosStorage(BaseStorage):
)
self.client = CosS3Client(config)
@override
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
@override
def load_once(self, filename: str) -> bytes:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
return data
@override
def load_stream(self, filename: str) -> Generator:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response["Body"].get_stream(chunk_size=4096)
@override
def download(self, filename, target_filepath):
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
response["Body"].get_stream_to_file(target_filepath)
@override
def exists(self, filename):
return self.client.object_exists(Bucket=self.bucket_name, Key=filename)
@override
def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import override
import tos
@ -27,11 +28,13 @@ class VolcengineTosStorage(BaseStorage):
region=dify_config.VOLCENGINE_TOS_REGION,
)
@override
def save(self, filename, data):
if not self.bucket_name:
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
self.client.put_object(bucket=self.bucket_name, key=filename, content=data)
@override
def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
@ -40,6 +43,7 @@ class VolcengineTosStorage(BaseStorage):
raise TypeError(f"Expected bytes, got {type(data).__name__}")
return data
@override
def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
@ -47,11 +51,13 @@ class VolcengineTosStorage(BaseStorage):
while chunk := response.read(4096):
yield chunk
@override
def download(self, filename, target_filepath):
if not self.bucket_name:
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath)
@override
def exists(self, filename):
if not self.bucket_name:
return False
@ -60,6 +66,7 @@ class VolcengineTosStorage(BaseStorage):
return False
return True
@override
def delete(self, filename: str):
if not self.bucket_name:
return

View File

@ -0,0 +1,298 @@
from __future__ import annotations
from unittest.mock import patch
from uuid import uuid4
import pytest
from werkzeug.exceptions import HTTPException
import services
from controllers.console.auth.error import MemberNotInTenantError
from controllers.console.workspace import members as members_module
from controllers.console.workspace.members import MemberCancelInviteApi, MemberUpdateRoleApi, OwnerTransfer
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class WorkspaceMembersIntegrationFactory:
@staticmethod
def create_tenant(db_session_with_containers) -> Tenant:
tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status=TenantStatus.NORMAL)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
return tenant
@staticmethod
def create_account(
db_session_with_containers,
*,
email_prefix: str,
tenant: Tenant | None = None,
role: TenantAccountRole = TenantAccountRole.NORMAL,
current: bool = False,
) -> Account:
account = Account(
name=f"Account {uuid4()}",
email=f"{email_prefix}-{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
if tenant is not None:
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=current,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account
@staticmethod
def create_owner_workspace(db_session_with_containers) -> tuple[Tenant, Account]:
tenant = WorkspaceMembersIntegrationFactory.create_tenant(db_session_with_containers)
owner = WorkspaceMembersIntegrationFactory.create_account(
db_session_with_containers,
email_prefix="owner",
tenant=tenant,
role=TenantAccountRole.OWNER,
current=True,
)
return tenant, owner
@staticmethod
def create_owner_transfer_token(account: Account) -> str:
_, token = members_module.AccountService.generate_owner_transfer_token(
account.email,
account=account,
code="123456",
additional_data={},
)
return token
@staticmethod
def get_join(db_session_with_containers, *, tenant: Tenant, account: Account) -> TenantAccountJoin:
tenant_id = tenant.id
account_id = account.id
db_session_with_containers.expire_all()
join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant_id, account_id=account_id)
.one()
)
return join
class TestMemberCancelInviteApiWithContainers:
def test_cancel_success(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(db_session_with_containers, email_prefix="member")
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(members_module.TenantService, "remove_member_from_tenant") as mock_remove_member,
):
result, status = method(api, member.id)
assert status == 200
assert result["result"] == "success"
mock_remove_member.assert_called_once()
called_tenant, called_member, called_current_user = mock_remove_member.call_args.args
assert called_tenant.id == tenant.id
assert called_member.id == member.id
assert called_current_user.id == current_user.id
def test_cancel_not_found(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with pytest.raises(HTTPException):
method(api, str(uuid4()))
def test_cancel_cannot_operate_self(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(db_session_with_containers, email_prefix="member")
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(
members_module.TenantService,
"remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
result, status = method(api, member.id)
assert status == 400
assert result["code"] == "cannot-operate-self"
def test_cancel_no_permission(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(db_session_with_containers, email_prefix="member")
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(
members_module.TenantService,
"remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
result, status = method(api, member.id)
assert status == 403
assert result["code"] == "forbidden"
def test_cancel_member_not_in_tenant(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(db_session_with_containers, email_prefix="member")
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(
members_module.TenantService,
"remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
result, status = method(api, member.id)
assert status == 404
assert result["code"] == "member-not-found"
class TestMemberUpdateRoleApiWithContainers:
def test_update_success(self, flask_app_with_containers, db_session_with_containers):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(
db_session_with_containers,
email_prefix="member",
tenant=tenant,
role=TenantAccountRole.EDITOR,
)
with (
flask_app_with_containers.test_request_context("/", json={"role": "normal"}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
result = method(api, member.id)
if isinstance(result, tuple):
result = result[0]
assert result["result"] == "success"
assert (
factory.get_join(db_session_with_containers, tenant=tenant, account=member).role == TenantAccountRole.NORMAL
)
def test_update_member_not_found(self, flask_app_with_containers, db_session_with_containers):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
with (
flask_app_with_containers.test_request_context("/", json={"role": "normal"}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with pytest.raises(HTTPException):
method(api, str(uuid4()))
class TestOwnerTransferApiWithContainers:
def test_member_not_in_tenant(self, flask_app_with_containers, db_session_with_containers):
api = OwnerTransfer()
method = unwrap(api.post)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(db_session_with_containers, email_prefix="member")
token = factory.create_owner_transfer_token(current_user)
with (
flask_app_with_containers.test_request_context("/", json={"token": token}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with pytest.raises(MemberNotInTenantError):
method(api, member.id)
def test_member_not_found(self, flask_app_with_containers, db_session_with_containers):
api = OwnerTransfer()
method = unwrap(api.post)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
token = factory.create_owner_transfer_token(current_user)
with (
flask_app_with_containers.test_request_context("/", json={"token": token}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with pytest.raises(HTTPException):
method(api, str(uuid4()))
def test_transfer_success(self, flask_app_with_containers, db_session_with_containers):
api = OwnerTransfer()
method = unwrap(api.post)
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
member = factory.create_account(
db_session_with_containers,
email_prefix="member",
tenant=tenant,
role=TenantAccountRole.NORMAL,
)
token = factory.create_owner_transfer_token(current_user)
with (
flask_app_with_containers.test_request_context("/", json={"token": token}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(members_module.AccountService, "send_new_owner_transfer_notify_email") as mock_new_owner_email,
patch.object(members_module.AccountService, "send_old_owner_transfer_notify_email") as mock_old_owner_email,
):
result = method(api, member.id)
assert result["result"] == "success"
assert (
factory.get_join(db_session_with_containers, tenant=tenant, account=member).role == TenantAccountRole.OWNER
)
assert (
factory.get_join(db_session_with_containers, tenant=tenant, account=current_user).role
== TenantAccountRole.ADMIN
)
mock_new_owner_email.assert_called_once()
mock_old_owner_email.assert_called_once()

View File

@ -3,22 +3,18 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import HTTPException
import services
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
InvalidEmailError,
InvalidTokenError,
MemberNotInTenantError,
NotOwnerError,
OwnerTransferLimitError,
)
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.workspace.members import (
DatasetOperatorMemberListApi,
MemberCancelInviteApi,
MemberInviteEmailApi,
MemberListApi,
MemberUpdateRoleApi,
@ -251,135 +247,7 @@ class TestMemberInviteEmailApi:
assert result["invitation_results"][0]["status"] == "failed"
class TestMemberCancelInviteApi:
def test_cancel_success(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
):
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 200
assert result["result"] == "success"
def test_cancel_not_found(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get") as get_mock,
):
get_mock.return_value = None
with pytest.raises(HTTPException):
method(api, "x")
def test_cancel_cannot_operate_self(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 400
def test_cancel_no_permission(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 403
def test_cancel_member_not_in_tenant(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get") as get_mock,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
get_mock.return_value = member
result, status = method(api, member.id)
assert status == 404
class TestMemberUpdateRoleApi:
def test_update_success(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.update_member_role"),
):
result = method(api, "id")
if isinstance(result, tuple):
result = result[0]
assert result["result"] == "success"
def test_update_invalid_role(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -391,23 +259,6 @@ class TestMemberUpdateRoleApi:
assert status == 400
def test_update_member_not_found(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.members.current_account_with_tenant",
return_value=(MagicMock(current_tenant=MagicMock()), "t1"),
),
patch("controllers.console.workspace.members.db.session.get", return_value=None),
):
with pytest.raises(HTTPException):
method(api, "id")
class TestDatasetOperatorMemberListApi:
def test_get_success(self, app: Flask):
@ -637,27 +488,3 @@ class TestOwnerTransferApi:
):
with pytest.raises(InvalidTokenError):
method(api, "2")
def test_member_not_in_tenant(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
member = MagicMock()
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "a@test.com"},
),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.is_member", return_value=False),
):
with pytest.raises(MemberNotInTenantError):
method(api, "2")

View File

@ -1,5 +1,5 @@
# base image
FROM node:22-alpine AS base
FROM node:22.22.1-alpine AS base
LABEL maintainer="takatost@gmail.com"
# if you located in China, you can use aliyun mirror to speed up