mirror of
https://github.com/langgenius/dify.git
synced 2026-05-28 21:03:22 +08:00
Compare commits
5 Commits
feat/evalu
...
feat/snipp
| Author | SHA1 | Date | |
|---|---|---|---|
| c46a313d78 | |||
| 57b02e341c | |||
| b94ff65e9f | |||
| 678260e34e | |||
| 739e34d08a |
@ -27,7 +27,7 @@ COPY api/providers ./providers
|
||||
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
|
||||
COPY dify-agent/src /app/dify-agent/src
|
||||
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
|
||||
RUN uv sync --frozen --no-dev
|
||||
RUN uv sync --frozen --no-dev --no-editable
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import posixpath
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
import oss2 as aliyun_s3
|
||||
|
||||
@ -29,9 +30,11 @@ class AliyunOssStorage(BaseStorage):
|
||||
cloudbox_id=dify_config.ALIYUN_CLOUDBOX_ID,
|
||||
)
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(self.__wrapper_folder_filename(filename), data)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
||||
data = obj.read()
|
||||
@ -39,17 +42,21 @@ class AliyunOssStorage(BaseStorage):
|
||||
return b""
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
||||
while chunk := obj.read(4096):
|
||||
yield chunk
|
||||
|
||||
@override
|
||||
def download(self, filename: str, target_filepath):
|
||||
self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename: str):
|
||||
return self.client.object_exists(self.__wrapper_folder_filename(filename))
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
self.client.delete_object(self.__wrapper_folder_filename(filename))
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
@ -48,9 +49,11 @@ class AwsS3Storage(BaseStorage):
|
||||
# other error, raise exception
|
||||
raise
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
try:
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
@ -61,6 +64,7 @@ class AwsS3Storage(BaseStorage):
|
||||
raise
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
@ -73,9 +77,11 @@ class AwsS3Storage(BaseStorage):
|
||||
else:
|
||||
raise
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
self.client.download_file(self.bucket_name, filename, target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
try:
|
||||
self.client.head_object(Bucket=self.bucket_name, Key=filename)
|
||||
@ -83,5 +89,6 @@ class AwsS3Storage(BaseStorage):
|
||||
except:
|
||||
return False
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from typing import override
|
||||
|
||||
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
|
||||
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
|
||||
@ -26,6 +27,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
else:
|
||||
self.credential = None
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
@ -34,6 +36,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
blob_container = client.get_container_client(container=self.bucket_name)
|
||||
blob_container.upload_blob(filename, data)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("Azure bucket name is not configured.")
|
||||
@ -46,6 +49,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
raise TypeError(f"Expected bytes from blob.readall(), got {type(data).__name__}")
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("Azure bucket name is not configured.")
|
||||
@ -55,6 +59,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
blob_data = blob.download_blob()
|
||||
yield from blob_data.chunks()
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
@ -66,6 +71,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
blob_data = blob.download_blob()
|
||||
blob_data.readinto(my_blob)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
if not self.bucket_name:
|
||||
return False
|
||||
@ -75,6 +81,7 @@ class AzureBlobStorage(BaseStorage):
|
||||
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
return blob.exists()
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import hashlib
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
from baidubce.auth.bce_credentials import BceCredentials
|
||||
from baidubce.bce_client_configuration import BceClientConfiguration
|
||||
@ -26,6 +27,7 @@ class BaiduObsStorage(BaseStorage):
|
||||
|
||||
self.client = BosClient(config=client_config)
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
md5 = hashlib.md5()
|
||||
md5.update(data)
|
||||
@ -34,24 +36,29 @@ class BaiduObsStorage(BaseStorage):
|
||||
bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5
|
||||
)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
response = self.client.get_object(bucket_name=self.bucket_name, key=filename)
|
||||
data: bytes = response.data.read()
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
|
||||
while chunk := response.read(4096):
|
||||
yield chunk
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename)
|
||||
if res is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
self.client.delete_object(bucket_name=self.bucket_name, key=filename)
|
||||
|
||||
@ -10,7 +10,7 @@ import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -251,6 +251,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||
# Don't raise exception, let the operation continue
|
||||
# The table might exist but not be visible due to permissions
|
||||
|
||||
@override
|
||||
def save(self, filename: str, data: bytes):
|
||||
"""Save data to ClickZetta Volume.
|
||||
|
||||
@ -304,6 +305,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||
# Clean up temporary file
|
||||
Path(temp_file_path).unlink(missing_ok=True)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
"""Load file content from ClickZetta Volume.
|
||||
|
||||
@ -364,6 +366,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||
logger.debug("File %s loaded from ClickZetta Volume", filename)
|
||||
return content
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
"""Load file as stream from ClickZetta Volume.
|
||||
|
||||
@ -382,6 +385,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||
|
||||
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
||||
|
||||
@override
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
"""Download file from ClickZetta Volume to local path.
|
||||
|
||||
@ -395,6 +399,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||
|
||||
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename: str) -> bool:
|
||||
"""Check if file exists in ClickZetta Volume.
|
||||
|
||||
@ -436,6 +441,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||
logger.warning("Error checking file existence for %s: %s", filename, e)
|
||||
return False
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
"""Delete file from ClickZetta Volume.
|
||||
|
||||
@ -472,6 +478,7 @@ class ClickZettaVolumeStorage(BaseStorage):
|
||||
|
||||
logger.debug("File %s deleted from ClickZetta Volume", filename)
|
||||
|
||||
@override
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
"""Scan files and directories in ClickZetta Volume.
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from google.cloud import storage as google_cloud_storage # type: ignore
|
||||
from pydantic import TypeAdapter
|
||||
@ -29,12 +29,14 @@ class GoogleCloudStorage(BaseStorage):
|
||||
else:
|
||||
self.client = google_cloud_storage.Client()
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.blob(filename)
|
||||
with io.BytesIO(data) as stream:
|
||||
blob.upload_from_file(stream)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
@ -43,6 +45,7 @@ class GoogleCloudStorage(BaseStorage):
|
||||
data: bytes = blob.download_as_bytes()
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
@ -52,6 +55,7 @@ class GoogleCloudStorage(BaseStorage):
|
||||
while chunk := blob_stream.read(4096):
|
||||
yield chunk
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.get_blob(filename)
|
||||
@ -59,11 +63,13 @@ class GoogleCloudStorage(BaseStorage):
|
||||
raise FileNotFoundError("File not found")
|
||||
blob.download_to_filename(target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
blob = bucket.blob(filename)
|
||||
return blob.exists()
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
bucket.delete_blob(filename)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
from obs import ObsClient
|
||||
|
||||
@ -20,27 +21,33 @@ class HuaweiObsStorage(BaseStorage):
|
||||
path_style=dify_config.HUAWEI_OBS_PATH_STYLE,
|
||||
)
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
|
||||
while chunk := response.read(4096):
|
||||
yield chunk
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
res = self._get_meta(filename)
|
||||
if res is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
import opendal
|
||||
from dotenv import dotenv_values
|
||||
@ -41,10 +41,12 @@ class OpenDALStorage(BaseStorage):
|
||||
logger.debug("opendal operator created with scheme %s", scheme)
|
||||
logger.debug("added retry layer to opendal operator")
|
||||
|
||||
@override
|
||||
def save(self, filename: str, data: bytes):
|
||||
self.op.write(path=filename, bs=data)
|
||||
logger.debug("file %s saved", filename)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
@ -53,6 +55,7 @@ class OpenDALStorage(BaseStorage):
|
||||
logger.debug("file %s loaded", filename)
|
||||
return content
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
if not self.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
@ -67,6 +70,7 @@ class OpenDALStorage(BaseStorage):
|
||||
yield chunk
|
||||
logger.debug("file %s loaded as stream", filename)
|
||||
|
||||
@override
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
if not self.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
@ -74,9 +78,11 @@ class OpenDALStorage(BaseStorage):
|
||||
Path(target_filepath).write_bytes(self.op.read(path=filename))
|
||||
logger.debug("file %s downloaded to %s", filename, target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename: str) -> bool:
|
||||
return self.op.exists(path=filename)
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
if self.exists(filename):
|
||||
self.op.delete(path=filename)
|
||||
@ -84,6 +90,7 @@ class OpenDALStorage(BaseStorage):
|
||||
return
|
||||
logger.debug("file %s not found, skip delete", filename)
|
||||
|
||||
@override
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
if not self.exists(path):
|
||||
raise FileNotFoundError("Path not found")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
@ -22,9 +23,11 @@ class OracleOCIStorage(BaseStorage):
|
||||
region_name=dify_config.OCI_REGION,
|
||||
)
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
try:
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||
@ -35,6 +38,7 @@ class OracleOCIStorage(BaseStorage):
|
||||
raise
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
@ -45,9 +49,11 @@ class OracleOCIStorage(BaseStorage):
|
||||
else:
|
||||
raise
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
self.client.download_file(self.bucket_name, filename, target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
try:
|
||||
self.client.head_object(Bucket=self.bucket_name, Key=filename)
|
||||
@ -55,5 +61,6 @@ class OracleOCIStorage(BaseStorage):
|
||||
except:
|
||||
return False
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import io
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import override
|
||||
|
||||
from supabase import Client
|
||||
|
||||
@ -28,29 +29,35 @@ class SupabaseStorage(BaseStorage):
|
||||
if not self.bucket_exists():
|
||||
self.client.storage.create_bucket(id=id, name=bucket_name)
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
self.client.storage.from_(self.bucket_name).upload(filename, data)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
content: bytes = self.client.storage.from_(self.bucket_name).download(filename)
|
||||
return content
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
result = self.client.storage.from_(self.bucket_name).download(filename)
|
||||
byte_stream = io.BytesIO(result)
|
||||
while chunk := byte_stream.read(4096): # Read in chunks of 4KB
|
||||
yield chunk
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
result = self.client.storage.from_(self.bucket_name).download(filename)
|
||||
Path(target_filepath).write_bytes(result)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
result = self.client.storage.from_(self.bucket_name).list(path=filename)
|
||||
if len(result) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
self.client.storage.from_(self.bucket_name).remove([filename])
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
|
||||
@ -29,23 +30,29 @@ class TencentCosStorage(BaseStorage):
|
||||
)
|
||||
self.client = CosS3Client(config)
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read()
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
yield from response["Body"].get_stream(chunk_size=4096)
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
response["Body"].get_stream_to_file(target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
return self.client.object_exists(Bucket=self.bucket_name, Key=filename)
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
import tos
|
||||
|
||||
@ -27,11 +28,13 @@ class VolcengineTosStorage(BaseStorage):
|
||||
region=dify_config.VOLCENGINE_TOS_REGION,
|
||||
)
|
||||
|
||||
@override
|
||||
def save(self, filename, data):
|
||||
if not self.bucket_name:
|
||||
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
self.client.put_object(bucket=self.bucket_name, key=filename, content=data)
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
@ -40,6 +43,7 @@ class VolcengineTosStorage(BaseStorage):
|
||||
raise TypeError(f"Expected bytes, got {type(data).__name__}")
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
if not self.bucket_name:
|
||||
raise FileNotFoundError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
@ -47,11 +51,13 @@ class VolcengineTosStorage(BaseStorage):
|
||||
while chunk := response.read(4096):
|
||||
yield chunk
|
||||
|
||||
@override
|
||||
def download(self, filename, target_filepath):
|
||||
if not self.bucket_name:
|
||||
raise ValueError("VOLCENGINE_TOS_BUCKET_NAME is not set")
|
||||
self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath)
|
||||
|
||||
@override
|
||||
def exists(self, filename):
|
||||
if not self.bucket_name:
|
||||
return False
|
||||
@ -60,6 +66,7 @@ class VolcengineTosStorage(BaseStorage):
|
||||
return False
|
||||
return True
|
||||
|
||||
@override
|
||||
def delete(self, filename: str):
|
||||
if not self.bucket_name:
|
||||
return
|
||||
|
||||
@ -0,0 +1,298 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import MemberNotInTenantError
|
||||
from controllers.console.workspace import members as members_module
|
||||
from controllers.console.workspace.members import MemberCancelInviteApi, MemberUpdateRoleApi, OwnerTransfer
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class WorkspaceMembersIntegrationFactory:
|
||||
@staticmethod
|
||||
def create_tenant(db_session_with_containers) -> Tenant:
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status=TenantStatus.NORMAL)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def create_account(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
email_prefix: str,
|
||||
tenant: Tenant | None = None,
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
current: bool = False,
|
||||
) -> Account:
|
||||
account = Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"{email_prefix}-{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
if tenant is not None:
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role,
|
||||
current=current,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_owner_workspace(db_session_with_containers) -> tuple[Tenant, Account]:
|
||||
tenant = WorkspaceMembersIntegrationFactory.create_tenant(db_session_with_containers)
|
||||
owner = WorkspaceMembersIntegrationFactory.create_account(
|
||||
db_session_with_containers,
|
||||
email_prefix="owner",
|
||||
tenant=tenant,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
return tenant, owner
|
||||
|
||||
@staticmethod
|
||||
def create_owner_transfer_token(account: Account) -> str:
|
||||
_, token = members_module.AccountService.generate_owner_transfer_token(
|
||||
account.email,
|
||||
account=account,
|
||||
code="123456",
|
||||
additional_data={},
|
||||
)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def get_join(db_session_with_containers, *, tenant: Tenant, account: Account) -> TenantAccountJoin:
|
||||
tenant_id = tenant.id
|
||||
account_id = account.id
|
||||
db_session_with_containers.expire_all()
|
||||
join = (
|
||||
db_session_with_containers.query(TenantAccountJoin)
|
||||
.filter_by(tenant_id=tenant_id, account_id=account_id)
|
||||
.one()
|
||||
)
|
||||
return join
|
||||
|
||||
|
||||
class TestMemberCancelInviteApiWithContainers:
|
||||
def test_cancel_success(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
member = factory.create_account(db_session_with_containers, email_prefix="member")
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
patch.object(members_module.TenantService, "remove_member_from_tenant") as mock_remove_member,
|
||||
):
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
mock_remove_member.assert_called_once()
|
||||
called_tenant, called_member, called_current_user = mock_remove_member.call_args.args
|
||||
assert called_tenant.id == tenant.id
|
||||
assert called_member.id == member.id
|
||||
assert called_current_user.id == current_user.id
|
||||
|
||||
def test_cancel_not_found(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
):
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, str(uuid4()))
|
||||
|
||||
def test_cancel_cannot_operate_self(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
member = factory.create_account(db_session_with_containers, email_prefix="member")
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
patch.object(
|
||||
members_module.TenantService,
|
||||
"remove_member_from_tenant",
|
||||
side_effect=services.errors.account.CannotOperateSelfError("x"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 400
|
||||
assert result["code"] == "cannot-operate-self"
|
||||
|
||||
def test_cancel_no_permission(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
member = factory.create_account(db_session_with_containers, email_prefix="member")
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
patch.object(
|
||||
members_module.TenantService,
|
||||
"remove_member_from_tenant",
|
||||
side_effect=services.errors.account.NoPermissionError("x"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 403
|
||||
assert result["code"] == "forbidden"
|
||||
|
||||
def test_cancel_member_not_in_tenant(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
member = factory.create_account(db_session_with_containers, email_prefix="member")
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
patch.object(
|
||||
members_module.TenantService,
|
||||
"remove_member_from_tenant",
|
||||
side_effect=services.errors.account.MemberNotInTenantError(),
|
||||
),
|
||||
):
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 404
|
||||
assert result["code"] == "member-not-found"
|
||||
|
||||
|
||||
class TestMemberUpdateRoleApiWithContainers:
|
||||
def test_update_success(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
member = factory.create_account(
|
||||
db_session_with_containers,
|
||||
email_prefix="member",
|
||||
tenant=tenant,
|
||||
role=TenantAccountRole.EDITOR,
|
||||
)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/", json={"role": "normal"}),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
):
|
||||
result = method(api, member.id)
|
||||
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert (
|
||||
factory.get_join(db_session_with_containers, tenant=tenant, account=member).role == TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
def test_update_member_not_found(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/", json={"role": "normal"}),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
):
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, str(uuid4()))
|
||||
|
||||
|
||||
class TestOwnerTransferApiWithContainers:
|
||||
def test_member_not_in_tenant(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
member = factory.create_account(db_session_with_containers, email_prefix="member")
|
||||
token = factory.create_owner_transfer_token(current_user)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/", json={"token": token}),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
):
|
||||
with pytest.raises(MemberNotInTenantError):
|
||||
method(api, member.id)
|
||||
|
||||
def test_member_not_found(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
token = factory.create_owner_transfer_token(current_user)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/", json={"token": token}),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
):
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, str(uuid4()))
|
||||
|
||||
def test_transfer_success(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
member = factory.create_account(
|
||||
db_session_with_containers,
|
||||
email_prefix="member",
|
||||
tenant=tenant,
|
||||
role=TenantAccountRole.NORMAL,
|
||||
)
|
||||
token = factory.create_owner_transfer_token(current_user)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/", json={"token": token}),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
patch.object(members_module.AccountService, "send_new_owner_transfer_notify_email") as mock_new_owner_email,
|
||||
patch.object(members_module.AccountService, "send_old_owner_transfer_notify_email") as mock_old_owner_email,
|
||||
):
|
||||
result = method(api, member.id)
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert (
|
||||
factory.get_join(db_session_with_containers, tenant=tenant, account=member).role == TenantAccountRole.OWNER
|
||||
)
|
||||
assert (
|
||||
factory.get_join(db_session_with_containers, tenant=tenant, account=current_user).role
|
||||
== TenantAccountRole.ADMIN
|
||||
)
|
||||
mock_new_owner_email.assert_called_once()
|
||||
mock_old_owner_email.assert_called_once()
|
||||
@ -3,22 +3,18 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
EmailCodeError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
MemberNotInTenantError,
|
||||
NotOwnerError,
|
||||
OwnerTransferLimitError,
|
||||
)
|
||||
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
|
||||
from controllers.console.workspace.members import (
|
||||
DatasetOperatorMemberListApi,
|
||||
MemberCancelInviteApi,
|
||||
MemberInviteEmailApi,
|
||||
MemberListApi,
|
||||
MemberUpdateRoleApi,
|
||||
@ -251,135 +247,7 @@ class TestMemberInviteEmailApi:
|
||||
assert result["invitation_results"][0]["status"] == "failed"
|
||||
|
||||
|
||||
class TestMemberCancelInviteApi:
|
||||
def test_cancel_success(self, app: Flask):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
|
||||
):
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_cancel_not_found(self, app: Flask):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
):
|
||||
get_mock.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "x")
|
||||
|
||||
def test_cancel_cannot_operate_self(self, app: Flask):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.CannotOperateSelfError("x"),
|
||||
),
|
||||
):
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 400
|
||||
|
||||
def test_cancel_no_permission(self, app: Flask):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.NoPermissionError("x"),
|
||||
),
|
||||
):
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 403
|
||||
|
||||
def test_cancel_member_not_in_tenant(self, app: Flask):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.get") as get_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.MemberNotInTenantError(),
|
||||
),
|
||||
):
|
||||
get_mock.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 404
|
||||
|
||||
|
||||
class TestMemberUpdateRoleApi:
|
||||
def test_update_success(self, app: Flask):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
payload = {"role": "normal"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=member),
|
||||
patch("controllers.console.workspace.members.TenantService.update_member_role"),
|
||||
):
|
||||
result = method(api, "id")
|
||||
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_update_invalid_role(self, app: Flask):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
@ -391,23 +259,6 @@ class TestMemberUpdateRoleApi:
|
||||
|
||||
assert status == 400
|
||||
|
||||
def test_update_member_not_found(self, app: Flask):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"role": "normal"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.members.current_account_with_tenant",
|
||||
return_value=(MagicMock(current_tenant=MagicMock()), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=None),
|
||||
):
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "id")
|
||||
|
||||
|
||||
class TestDatasetOperatorMemberListApi:
|
||||
def test_get_success(self, app: Flask):
|
||||
@ -637,27 +488,3 @@ class TestOwnerTransferApi:
|
||||
):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
method(api, "2")
|
||||
|
||||
def test_member_not_in_tenant(self, app: Flask):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
payload = {"token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
|
||||
return_value={"email": "a@test.com"},
|
||||
),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=member),
|
||||
patch("controllers.console.workspace.members.TenantService.is_member", return_value=False),
|
||||
):
|
||||
with pytest.raises(MemberNotInTenantError):
|
||||
method(api, "2")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# base image
|
||||
FROM node:22-alpine AS base
|
||||
FROM node:22.22.1-alpine AS base
|
||||
LABEL maintainer="takatost@gmail.com"
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
|
||||
Reference in New Issue
Block a user