mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 11:45:05 +08:00
feat: Add Clickzetta Lakehouse vector database integration (#22551)
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@ -69,6 +69,19 @@ class Storage:
|
||||
from extensions.storage.supabase_storage import SupabaseStorage
|
||||
|
||||
return SupabaseStorage
|
||||
case StorageType.CLICKZETTA_VOLUME:
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
def create_clickzetta_volume_storage():
|
||||
# ClickZettaVolumeConfig will automatically read from environment variables
|
||||
# and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set
|
||||
volume_config = ClickZettaVolumeConfig()
|
||||
return ClickZettaVolumeStorage(volume_config)
|
||||
|
||||
return create_clickzetta_volume_storage
|
||||
case _:
|
||||
raise ValueError(f"unsupported storage type {storage_type}")
|
||||
|
||||
|
||||
5
api/extensions/storage/clickzetta_volume/__init__.py
Normal file
5
api/extensions/storage/clickzetta_volume/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
from .clickzetta_volume_storage import ClickZettaVolumeStorage
|
||||
|
||||
__all__ = ["ClickZettaVolumeStorage"]
|
||||
@ -0,0 +1,530 @@
|
||||
"""ClickZetta Volume Storage Implementation
|
||||
|
||||
This module provides storage backend using ClickZetta Volume functionality.
|
||||
Supports Table Volume, User Volume, and External Volume types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import clickzetta # type: ignore[import]
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
from .volume_permissions import VolumePermissionManager, check_volume_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClickZettaVolumeConfig(BaseModel):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
instance: str = ""
|
||||
service: str = "api.clickzetta.com"
|
||||
workspace: str = "quick_start"
|
||||
vcluster: str = "default_ap"
|
||||
schema_name: str = "dify"
|
||||
volume_type: str = "table" # table|user|external
|
||||
volume_name: Optional[str] = None # For external volumes
|
||||
table_prefix: str = "dataset_" # Prefix for table volume names
|
||||
dify_prefix: str = "dify_km" # Directory prefix for User Volume
|
||||
permission_check: bool = True # Enable/disable permission checking
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
then fall back to CLICKZETTA_* environment variables (for vector DB config).
|
||||
"""
|
||||
import os
|
||||
|
||||
# Helper function to get environment variable with fallback
|
||||
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
|
||||
# First try CLICKZETTA_VOLUME_* specific config
|
||||
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
|
||||
if volume_value:
|
||||
return str(volume_value)
|
||||
|
||||
# Then try environment variables
|
||||
volume_env = os.getenv(volume_key)
|
||||
if volume_env:
|
||||
return volume_env
|
||||
|
||||
# Fall back to existing CLICKZETTA_* config
|
||||
fallback_env = os.getenv(fallback_key)
|
||||
if fallback_env:
|
||||
return fallback_env
|
||||
|
||||
return default or ""
|
||||
|
||||
# Apply environment variables with fallback to existing CLICKZETTA_* config
|
||||
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
|
||||
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
|
||||
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
|
||||
values.setdefault(
|
||||
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
|
||||
)
|
||||
values.setdefault(
|
||||
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
|
||||
)
|
||||
values.setdefault(
|
||||
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
|
||||
)
|
||||
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
|
||||
|
||||
# Volume-specific configurations (no fallback to vector DB config)
|
||||
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
|
||||
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
|
||||
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
|
||||
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
|
||||
# 暂时禁用权限检查功能,直接设置为false
|
||||
values.setdefault("permission_check", False)
|
||||
|
||||
# Validate required fields
|
||||
if not values.get("username"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
|
||||
if not values.get("instance"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
|
||||
|
||||
# Validate volume type
|
||||
volume_type = values["volume_type"]
|
||||
if volume_type not in ["table", "user", "external"]:
|
||||
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
|
||||
|
||||
if volume_type == "external" and not values.get("volume_name"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class ClickZettaVolumeStorage(BaseStorage):
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
def __init__(self, config: ClickZettaVolumeConfig):
|
||||
"""Initialize ClickZetta Volume storage.
|
||||
|
||||
Args:
|
||||
config: ClickZetta Volume configuration
|
||||
"""
|
||||
self._config = config
|
||||
self._connection = None
|
||||
self._permission_manager: VolumePermissionManager | None = None
|
||||
self._init_connection()
|
||||
self._init_permission_manager()
|
||||
|
||||
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize ClickZetta connection."""
|
||||
try:
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name,
|
||||
)
|
||||
logger.debug("ClickZetta connection established")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to connect to ClickZetta")
|
||||
raise
|
||||
|
||||
def _init_permission_manager(self):
|
||||
"""Initialize permission manager."""
|
||||
try:
|
||||
self._permission_manager = VolumePermissionManager(
|
||||
self._connection, self._config.volume_type, self._config.volume_name
|
||||
)
|
||||
logger.debug("Permission manager initialized")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to initialize permission manager")
|
||||
raise
|
||||
|
||||
def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str:
|
||||
"""Get the appropriate volume path based on volume type."""
|
||||
if self._config.volume_type == "user":
|
||||
# Add dify prefix for User Volume to organize files
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
elif self._config.volume_type == "table":
|
||||
# Check if this should use User Volume (special directories)
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
# Use User Volume with dify prefix for special directories
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
|
||||
if dataset_id:
|
||||
return f"{self._config.table_prefix}{dataset_id}/{filename}"
|
||||
else:
|
||||
# Extract dataset_id from filename if not provided
|
||||
# Format: dataset_id/filename
|
||||
if "/" in filename:
|
||||
return filename
|
||||
else:
|
||||
raise ValueError("dataset_id is required for table volume or filename must include dataset_id/")
|
||||
elif self._config.volume_type == "external":
|
||||
return filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str:
|
||||
"""Get SQL prefix for volume operations."""
|
||||
if self._config.volume_type == "user":
|
||||
return "USER VOLUME"
|
||||
elif self._config.volume_type == "table":
|
||||
# For Dify's current file storage pattern, most files are stored in
|
||||
# paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext"
|
||||
# These should use USER VOLUME for better compatibility
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return "USER VOLUME"
|
||||
|
||||
# Only use TABLE VOLUME for actual dataset-specific paths
|
||||
# like "dataset_12345/file.pdf" or paths with dataset_ prefix
|
||||
if dataset_id:
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
else:
|
||||
# Default table name for generic operations
|
||||
table_name = "default_dataset"
|
||||
return f"TABLE VOLUME {table_name}"
|
||||
elif self._config.volume_type == "external":
|
||||
return f"VOLUME {self._config.volume_name}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _execute_sql(self, sql: str, fetch: bool = False):
|
||||
"""Execute SQL command."""
|
||||
try:
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Connection not initialized")
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
if fetch:
|
||||
return cursor.fetchall()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("SQL execution failed: %s", sql)
|
||||
raise
|
||||
|
||||
def _ensure_table_volume_exists(self, dataset_id: str) -> None:
|
||||
"""Ensure table volume exists for the given dataset_id."""
|
||||
if self._config.volume_type != "table" or not dataset_id:
|
||||
return
|
||||
|
||||
# Skip for upload_files and other special directories that use USER VOLUME
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return
|
||||
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
|
||||
try:
|
||||
# Check if table exists
|
||||
check_sql = f"SHOW TABLES LIKE '{table_name}'"
|
||||
result = self._execute_sql(check_sql, fetch=True)
|
||||
|
||||
if not result:
|
||||
# Create table with volume
|
||||
create_sql = f"""
|
||||
CREATE TABLE {table_name} (
|
||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||
filename VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_filename (filename)
|
||||
) WITH VOLUME
|
||||
"""
|
||||
self._execute_sql(create_sql)
|
||||
logger.info("Created table volume: %s", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create table volume %s: %s", table_name, e)
|
||||
# Don't raise exception, let the operation continue
|
||||
# The table might exist but not be visible due to permissions
|
||||
|
||||
def save(self, filename: str, data: bytes) -> None:
|
||||
"""Save data to ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
data: File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Ensure table volume exists (for table volumes)
|
||||
if dataset_id:
|
||||
self._ensure_table_volume_exists(dataset_id)
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "save", dataset_id)
|
||||
|
||||
# Write data to temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_file.write(data)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Upload to volume
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
Path(temp_file_path).unlink(missing_ok=True)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
"""Load file content from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "load_once", dataset_id)
|
||||
|
||||
# Download to temporary directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
|
||||
else:
|
||||
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
# Find the downloaded file (may be in subdirectories)
|
||||
downloaded_file = None
|
||||
for root, dirs, files in os.walk(temp_dir):
|
||||
for file in files:
|
||||
if file == filename or file == os.path.basename(filename):
|
||||
downloaded_file = Path(root) / file
|
||||
break
|
||||
if downloaded_file:
|
||||
break
|
||||
|
||||
if not downloaded_file or not downloaded_file.exists():
|
||||
raise FileNotFoundError(f"Downloaded file not found: {filename}")
|
||||
|
||||
content = downloaded_file.read_bytes()
|
||||
|
||||
logger.debug("File %s loaded from ClickZetta Volume", filename)
|
||||
return content
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
"""Load file as stream from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Yields:
|
||||
File content chunks
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
batch_size = 4096
|
||||
stream = BytesIO(content)
|
||||
|
||||
while chunk := stream.read(batch_size):
|
||||
yield chunk
|
||||
|
||||
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
||||
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
"""Download file from ClickZetta Volume to local path.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
target_filepath: Local target file path
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
|
||||
with Path(target_filepath).open("wb") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
|
||||
|
||||
def exists(self, filename: str) -> bool:
|
||||
"""Check if file exists in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
True if file exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
exists = len(rows) > 0
|
||||
logger.debug("File %s exists check: %s", filename, exists)
|
||||
return exists
|
||||
except Exception as e:
|
||||
logger.warning("Error checking file existence for %s: %s", filename, e)
|
||||
return False
|
||||
|
||||
def delete(self, filename: str):
|
||||
"""Delete file from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
"""
|
||||
if not self.exists(filename):
|
||||
logger.debug("File %s not found, skip delete", filename)
|
||||
return
|
||||
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"REMOVE {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
logger.debug("File %s deleted from ClickZetta Volume", filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
"""Scan files and directories in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
path: Path to scan (dataset_id for table volumes)
|
||||
files: Include files in results
|
||||
directories: Include directories in results
|
||||
|
||||
Returns:
|
||||
List of file/directory paths
|
||||
"""
|
||||
try:
|
||||
# For table volumes, path is treated as dataset_id
|
||||
dataset_id = None
|
||||
if self._config.volume_type == "table":
|
||||
dataset_id = path
|
||||
path = "" # Root of the table volume
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# For User Volume, add dify prefix to path
|
||||
if volume_prefix == "USER VOLUME":
|
||||
if path:
|
||||
scan_path = f"{self._config.dify_prefix}/{path}"
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'"
|
||||
else:
|
||||
if path:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix}"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
file_path = row[0] # relative_path column
|
||||
|
||||
# For User Volume, remove dify prefix from results
|
||||
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
|
||||
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
|
||||
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
|
||||
|
||||
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
|
||||
result.append(file_path)
|
||||
|
||||
logger.debug("Scanned %d items in path %s", len(result), path)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error scanning path %s", path)
|
||||
return []
|
||||
516
api/extensions/storage/clickzetta_volume/file_lifecycle.py
Normal file
516
api/extensions/storage/clickzetta_volume/file_lifecycle.py
Normal file
@ -0,0 +1,516 @@
|
||||
"""ClickZetta Volume文件生命周期管理
|
||||
|
||||
该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。
|
||||
支持知识库文件的完整生命周期管理。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStatus(Enum):
|
||||
"""文件状态枚举"""
|
||||
|
||||
ACTIVE = "active" # 活跃状态
|
||||
ARCHIVED = "archived" # 已归档
|
||||
DELETED = "deleted" # 已删除(软删除)
|
||||
BACKUP = "backup" # 备份文件
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMetadata:
|
||||
"""文件元数据"""
|
||||
|
||||
filename: str
|
||||
size: int | None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
version: int | None
|
||||
status: FileStatus
|
||||
checksum: Optional[str] = None
|
||||
tags: Optional[dict[str, str]] = None
|
||||
parent_version: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
data = asdict(self)
|
||||
data["created_at"] = self.created_at.isoformat()
|
||||
data["modified_at"] = self.modified_at.isoformat()
|
||||
data["status"] = self.status.value
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FileMetadata":
|
||||
"""从字典创建实例"""
|
||||
data = data.copy()
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
|
||||
data["status"] = FileStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class FileLifecycleManager:
|
||||
"""文件生命周期管理器"""
|
||||
|
||||
def __init__(self, storage, dataset_id: Optional[str] = None):
|
||||
"""初始化生命周期管理器
|
||||
|
||||
Args:
|
||||
storage: ClickZetta Volume存储实例
|
||||
dataset_id: 数据集ID(用于Table Volume)
|
||||
"""
|
||||
self._storage = storage
|
||||
self._dataset_id = dataset_id
|
||||
self._metadata_file = ".dify_file_metadata.json"
|
||||
self._version_prefix = ".versions/"
|
||||
self._backup_prefix = ".backups/"
|
||||
self._deleted_prefix = ".deleted/"
|
||||
|
||||
# 获取权限管理器(如果存在)
|
||||
self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None)
|
||||
|
||||
def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
|
||||
"""保存文件并管理生命周期
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
data: 文件内容
|
||||
tags: 文件标签
|
||||
|
||||
Returns:
|
||||
文件元数据
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "save"):
|
||||
from .volume_permissions import VolumePermissionError
|
||||
|
||||
raise VolumePermissionError(
|
||||
f"Permission denied for lifecycle save operation on file: {filename}",
|
||||
operation="save",
|
||||
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
|
||||
dataset_id=self._dataset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. 检查是否存在旧版本
|
||||
metadata_dict = self._load_metadata()
|
||||
current_metadata = metadata_dict.get(filename)
|
||||
|
||||
# 2. 如果存在旧版本,创建版本备份
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata)
|
||||
|
||||
# 3. 计算文件信息
|
||||
now = datetime.now()
|
||||
checksum = self._calculate_checksum(data)
|
||||
new_version = (current_metadata["version"] + 1) if current_metadata else 1
|
||||
|
||||
# 4. 保存新文件
|
||||
self._storage.save(filename, data)
|
||||
|
||||
# 5. 创建元数据
|
||||
created_at = now
|
||||
parent_version = None
|
||||
|
||||
if current_metadata:
|
||||
# 如果created_at是字符串,转换为datetime
|
||||
if isinstance(current_metadata["created_at"], str):
|
||||
created_at = datetime.fromisoformat(current_metadata["created_at"])
|
||||
else:
|
||||
created_at = current_metadata["created_at"]
|
||||
parent_version = current_metadata["version"]
|
||||
|
||||
file_metadata = FileMetadata(
|
||||
filename=filename,
|
||||
size=len(data),
|
||||
created_at=created_at,
|
||||
modified_at=now,
|
||||
version=new_version,
|
||||
status=FileStatus.ACTIVE,
|
||||
checksum=checksum,
|
||||
tags=tags or {},
|
||||
parent_version=parent_version,
|
||||
)
|
||||
|
||||
# 6. 更新元数据
|
||||
metadata_dict[filename] = file_metadata.to_dict()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s saved with lifecycle management, version %s", filename, new_version)
|
||||
return file_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save file with lifecycle")
|
||||
raise
|
||||
|
||||
def get_file_metadata(self, filename: str) -> Optional[FileMetadata]:
|
||||
"""获取文件元数据
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
文件元数据,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
return FileMetadata.from_dict(metadata_dict[filename])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get file metadata for %s", filename)
|
||||
return None
|
||||
|
||||
def list_file_versions(self, filename: str) -> list[FileMetadata]:
|
||||
"""列出文件的所有版本
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
文件版本列表,按版本号排序
|
||||
"""
|
||||
try:
|
||||
versions = []
|
||||
|
||||
# 获取当前版本
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
versions.append(current_metadata)
|
||||
|
||||
# 获取历史版本
|
||||
version_pattern = f"{self._version_prefix}{filename}.v*"
|
||||
try:
|
||||
version_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
for file_path in version_files:
|
||||
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
|
||||
# 解析版本号
|
||||
version_str = file_path.split(".v")[-1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_str)
|
||||
# 这里简化处理,实际应该从版本文件中读取元数据
|
||||
# 暂时创建基本的元数据信息
|
||||
except ValueError:
|
||||
continue
|
||||
except:
|
||||
# 如果无法扫描版本文件,只返回当前版本
|
||||
pass
|
||||
|
||||
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list file versions for %s", filename)
|
||||
return []
|
||||
|
||||
def restore_version(self, filename: str, version: int) -> bool:
|
||||
"""恢复文件到指定版本
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
version: 要恢复的版本号
|
||||
|
||||
Returns:
|
||||
恢复是否成功
|
||||
"""
|
||||
try:
|
||||
version_filename = f"{self._version_prefix}{filename}.v{version}"
|
||||
|
||||
# 检查版本文件是否存在
|
||||
if not self._storage.exists(version_filename):
|
||||
logger.warning("Version %s of %s not found", version, filename)
|
||||
return False
|
||||
|
||||
# 读取版本文件内容
|
||||
version_data = self._storage.load_once(version_filename)
|
||||
|
||||
# 保存当前版本为备份
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata.to_dict())
|
||||
|
||||
# 恢复文件
|
||||
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to restore %s to version %s", filename, version)
|
||||
return False
|
||||
|
||||
def archive_file(self, filename: str) -> bool:
|
||||
"""归档文件
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
归档是否成功
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "archive"):
|
||||
logger.warning("Permission denied for archive operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# 更新文件状态为归档
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename not in metadata_dict:
|
||||
logger.warning("File %s not found in metadata", filename)
|
||||
return False
|
||||
|
||||
metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s archived successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to archive file %s", filename)
|
||||
return False
|
||||
|
||||
def soft_delete_file(self, filename: str) -> bool:
|
||||
"""软删除文件(移动到删除目录)
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "delete"):
|
||||
logger.warning("Permission denied for soft delete operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查文件是否存在
|
||||
if not self._storage.exists(filename):
|
||||
logger.warning("File %s not found", filename)
|
||||
return False
|
||||
|
||||
# 读取文件内容
|
||||
file_data = self._storage.load_once(filename)
|
||||
|
||||
# 移动到删除目录
|
||||
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self._storage.save(deleted_filename, file_data)
|
||||
|
||||
# 删除原文件
|
||||
self._storage.delete(filename)
|
||||
|
||||
# 更新元数据
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
metadata_dict[filename]["status"] = FileStatus.DELETED.value
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s soft deleted successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to soft delete file %s", filename)
|
||||
return False
|
||||
|
||||
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
|
||||
"""清理旧版本文件
|
||||
|
||||
Args:
|
||||
max_versions: 保留的最大版本数
|
||||
max_age_days: 版本文件的最大保留天数
|
||||
|
||||
Returns:
|
||||
清理的文件数量
|
||||
"""
|
||||
try:
|
||||
cleaned_count = 0
|
||||
cutoff_date = datetime.now() - timedelta(days=max_age_days)
|
||||
|
||||
# 获取所有版本文件
|
||||
try:
|
||||
all_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
|
||||
|
||||
# 按文件分组
|
||||
file_versions: dict[str, list[tuple[int, str]]] = {}
|
||||
for version_file in version_files:
|
||||
# 解析文件名和版本
|
||||
parts = version_file[len(self._version_prefix) :].split(".v")
|
||||
if len(parts) >= 2:
|
||||
base_filename = parts[0]
|
||||
version_part = parts[1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_part)
|
||||
if base_filename not in file_versions:
|
||||
file_versions[base_filename] = []
|
||||
file_versions[base_filename].append((version_num, version_file))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# 清理每个文件的旧版本
|
||||
for base_filename, versions in file_versions.items():
|
||||
# 按版本号排序
|
||||
versions.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 保留最新的max_versions个版本,删除其余的
|
||||
if len(versions) > max_versions:
|
||||
to_delete = versions[max_versions:]
|
||||
for version_num, version_file in to_delete:
|
||||
self._storage.delete(version_file)
|
||||
cleaned_count += 1
|
||||
logger.debug("Cleaned old version: %s", version_file)
|
||||
|
||||
logger.info("Cleaned %d old version files", cleaned_count)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not scan for version files: %s", e)
|
||||
|
||||
return cleaned_count
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to cleanup old versions")
|
||||
return 0
|
||||
|
||||
def get_storage_statistics(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息
|
||||
|
||||
Returns:
|
||||
存储统计字典
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
|
||||
stats: dict[str, Any] = {
|
||||
"total_files": len(metadata_dict),
|
||||
"active_files": 0,
|
||||
"archived_files": 0,
|
||||
"deleted_files": 0,
|
||||
"total_size": 0,
|
||||
"versions_count": 0,
|
||||
"oldest_file": None,
|
||||
"newest_file": None,
|
||||
}
|
||||
|
||||
oldest_date = None
|
||||
newest_date = None
|
||||
|
||||
for filename, metadata in metadata_dict.items():
|
||||
file_meta = FileMetadata.from_dict(metadata)
|
||||
|
||||
# 统计文件状态
|
||||
if file_meta.status == FileStatus.ACTIVE:
|
||||
stats["active_files"] = (stats["active_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.ARCHIVED:
|
||||
stats["archived_files"] = (stats["archived_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.DELETED:
|
||||
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
|
||||
|
||||
# 统计大小
|
||||
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
|
||||
|
||||
# 统计版本
|
||||
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
|
||||
|
||||
# 找出最新和最旧的文件
|
||||
if oldest_date is None or file_meta.created_at < oldest_date:
|
||||
oldest_date = file_meta.created_at
|
||||
stats["oldest_file"] = filename
|
||||
|
||||
if newest_date is None or file_meta.modified_at > newest_date:
|
||||
newest_date = file_meta.modified_at
|
||||
stats["newest_file"] = filename
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get storage statistics")
|
||||
return {}
|
||||
|
||||
def _create_version_backup(self, filename: str, metadata: dict):
|
||||
"""创建版本备份"""
|
||||
try:
|
||||
# 读取当前文件内容
|
||||
current_data = self._storage.load_once(filename)
|
||||
|
||||
# 保存为版本文件
|
||||
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
|
||||
self._storage.save(version_filename, current_data)
|
||||
|
||||
logger.debug("Created version backup: %s", version_filename)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create version backup for %s: %s", filename, e)
|
||||
|
||||
def _load_metadata(self) -> dict[str, Any]:
|
||||
"""加载元数据文件"""
|
||||
try:
|
||||
if self._storage.exists(self._metadata_file):
|
||||
metadata_content = self._storage.load_once(self._metadata_file)
|
||||
result = json.loads(metadata_content.decode("utf-8"))
|
||||
return dict(result) if result else {}
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load metadata: %s", e)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, metadata_dict: dict):
|
||||
"""保存元数据文件"""
|
||||
try:
|
||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
|
||||
logger.debug("Metadata saved successfully")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save metadata")
|
||||
raise
|
||||
|
||||
def _calculate_checksum(self, data: bytes) -> str:
|
||||
"""计算文件校验和"""
|
||||
import hashlib
|
||||
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
def _check_permission(self, filename: str, operation: str) -> bool:
|
||||
"""检查文件操作权限
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
operation: 操作类型
|
||||
|
||||
Returns:
|
||||
True if permission granted, False otherwise
|
||||
"""
|
||||
# 如果没有权限管理器,默认允许
|
||||
if not self._permission_manager:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 根据操作类型映射到权限
|
||||
operation_mapping = {
|
||||
"save": "save",
|
||||
"load": "load_once",
|
||||
"delete": "delete",
|
||||
"archive": "delete", # 归档需要删除权限
|
||||
"restore": "save", # 恢复需要写权限
|
||||
"cleanup": "delete", # 清理需要删除权限
|
||||
"read": "load_once",
|
||||
"write": "save",
|
||||
}
|
||||
|
||||
mapped_operation = operation_mapping.get(operation, operation)
|
||||
|
||||
# 检查权限
|
||||
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Permission check failed for %s operation %s", filename, operation)
|
||||
# 安全默认:权限检查失败时拒绝访问
|
||||
return False
|
||||
646
api/extensions/storage/clickzetta_volume/volume_permissions.py
Normal file
646
api/extensions/storage/clickzetta_volume/volume_permissions.py
Normal file
@ -0,0 +1,646 @@
|
||||
"""ClickZetta Volume权限管理机制
|
||||
|
||||
该模块提供Volume权限检查、验证和管理功能。
|
||||
根据ClickZetta的权限模型,不同Volume类型有不同的权限要求。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VolumePermission(Enum):
|
||||
"""Volume权限类型枚举"""
|
||||
|
||||
READ = "SELECT" # 对应ClickZetta的SELECT权限
|
||||
WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限
|
||||
LIST = "SELECT" # 列出文件需要SELECT权限
|
||||
DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限
|
||||
USAGE = "USAGE" # External Volume需要的基本权限
|
||||
|
||||
|
||||
class VolumePermissionManager:
|
||||
"""Volume权限管理器"""
|
||||
|
||||
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None):
|
||||
"""初始化权限管理器
|
||||
|
||||
Args:
|
||||
connection_or_config: ClickZetta连接对象或配置字典
|
||||
volume_type: Volume类型 (user|table|external)
|
||||
volume_name: Volume名称 (用于external volume)
|
||||
"""
|
||||
# 支持两种初始化方式:连接对象或配置字典
|
||||
if isinstance(connection_or_config, dict):
|
||||
# 从配置字典创建连接
|
||||
import clickzetta # type: ignore[import-untyped]
|
||||
|
||||
config = connection_or_config
|
||||
self._connection = clickzetta.connect(
|
||||
username=config.get("username"),
|
||||
password=config.get("password"),
|
||||
instance=config.get("instance"),
|
||||
service=config.get("service"),
|
||||
workspace=config.get("workspace"),
|
||||
vcluster=config.get("vcluster"),
|
||||
schema=config.get("schema") or config.get("database"),
|
||||
)
|
||||
self._volume_type = config.get("volume_type", volume_type)
|
||||
self._volume_name = config.get("volume_name", volume_name)
|
||||
else:
|
||||
# 直接使用连接对象
|
||||
self._connection = connection_or_config
|
||||
self._volume_type = volume_type
|
||||
self._volume_name = volume_name
|
||||
|
||||
if not self._connection:
|
||||
raise ValueError("Valid connection or config is required")
|
||||
if not self._volume_type:
|
||||
raise ValueError("volume_type is required")
|
||||
|
||||
self._permission_cache: dict[str, set[str]] = {}
|
||||
self._current_username = None # 将从连接中获取当前用户名
|
||||
|
||||
def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
|
||||
"""检查用户是否有执行特定操作的权限
|
||||
|
||||
Args:
|
||||
operation: 要执行的操作类型
|
||||
dataset_id: 数据集ID (用于table volume)
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self._volume_type == "user":
|
||||
return self._check_user_volume_permission(operation)
|
||||
elif self._volume_type == "table":
|
||||
return self._check_table_volume_permission(operation, dataset_id)
|
||||
elif self._volume_type == "external":
|
||||
return self._check_external_volume_permission(operation)
|
||||
else:
|
||||
logger.warning("Unknown volume type: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Permission check failed")
|
||||
return False
|
||||
|
||||
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""检查User Volume权限
|
||||
|
||||
User Volume权限规则:
|
||||
- 用户对自己的User Volume有全部权限
|
||||
- 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限
|
||||
- 更注重连接身份验证,而不是复杂的权限检查
|
||||
"""
|
||||
try:
|
||||
# 获取当前用户名
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# 检查基本连接状态
|
||||
with self._connection.cursor() as cursor:
|
||||
# 简单的连接测试,如果能执行查询说明用户有基本权限
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
logger.debug(
|
||||
"User Volume permission check for %s, operation %s: granted (basic connection verified)",
|
||||
current_user,
|
||||
operation.name,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"User Volume permission check failed: cannot verify basic connection for %s", current_user
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("User Volume permission check failed")
|
||||
# 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示
|
||||
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
|
||||
"""检查Table Volume权限
|
||||
|
||||
Table Volume权限规则:
|
||||
- Table Volume权限继承对应表的权限
|
||||
- SELECT权限 -> 可以READ/LIST文件
|
||||
- INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件
|
||||
"""
|
||||
if not dataset_id:
|
||||
logger.warning("dataset_id is required for table volume permission check")
|
||||
return False
|
||||
|
||||
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
|
||||
|
||||
try:
|
||||
# 检查表权限
|
||||
permissions = self._get_table_permissions(table_name)
|
||||
required_permissions = set(operation.value.split(","))
|
||||
|
||||
# 检查是否有所需的所有权限
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
table_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Table volume permission check failed for %s", table_name)
|
||||
return False
|
||||
|
||||
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""检查External Volume权限
|
||||
|
||||
External Volume权限规则:
|
||||
- 尝试获取对External Volume的权限
|
||||
- 如果权限检查失败,进行备选验证
|
||||
- 对于开发环境,提供更宽松的权限检查
|
||||
"""
|
||||
if not self._volume_name:
|
||||
logger.warning("volume_name is required for external volume permission check")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查External Volume权限
|
||||
permissions = self._get_external_volume_permissions(self._volume_name)
|
||||
|
||||
# External Volume权限映射:根据操作类型确定所需权限
|
||||
required_permissions = set()
|
||||
|
||||
if operation in [VolumePermission.READ, VolumePermission.LIST]:
|
||||
required_permissions.add("read")
|
||||
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
|
||||
required_permissions.add("write")
|
||||
|
||||
# 检查是否有所需的所有权限
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
self._volume_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
# 如果权限检查失败,尝试备选验证
|
||||
if not has_permission:
|
||||
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)
|
||||
|
||||
# 备选验证:尝试列出Volume来验证基本访问权限
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == self._volume_name:
|
||||
logger.info("Fallback verification successful for %s", self._volume_name)
|
||||
return True
|
||||
except Exception as fallback_e:
|
||||
logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("External volume permission check failed for %s", self._volume_name)
|
||||
logger.info("External Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _get_table_permissions(self, table_name: str) -> set[str]:
|
||||
"""获取用户对指定表的权限
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
用户对该表的权限集合
|
||||
"""
|
||||
cache_key = f"table:{table_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# 使用正确的ClickZetta语法检查当前用户权限
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# 解析权限结果,查找对该表的权限
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||
object_name = grant[2] if len(grant) > 2 else ""
|
||||
|
||||
# 检查是否是对该表的权限
|
||||
if (
|
||||
object_type == "TABLE"
|
||||
and object_name == table_name
|
||||
or object_type == "SCHEMA"
|
||||
and object_name in table_name
|
||||
):
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
# 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
|
||||
permissions.add("SELECT")
|
||||
except Exception:
|
||||
logger.debug("Cannot query table %s, no SELECT permission", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check table permissions for %s: %s", table_name, e)
|
||||
# 安全默认:权限检查失败时拒绝访问
|
||||
pass
|
||||
|
||||
# 缓存权限信息
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_current_username(self) -> str:
|
||||
"""获取当前用户名"""
|
||||
if self._current_username:
|
||||
return self._current_username
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
self._current_username = result[0]
|
||||
return str(self._current_username)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get current username")
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _get_user_permissions(self, username: str) -> set[str]:
|
||||
"""获取用户的基本权限集合"""
|
||||
cache_key = f"user_permissions:{username}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# 使用正确的ClickZetta语法检查当前用户权限
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# 解析权限结果,查找用户的基本权限
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||
|
||||
# 收集所有相关权限
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check user permissions for %s: %s", username, e)
|
||||
# 安全默认:权限检查失败时拒绝访问
|
||||
pass
|
||||
|
||||
# 缓存权限信息
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
|
||||
"""获取用户对指定External Volume的权限
|
||||
|
||||
Args:
|
||||
volume_name: External Volume名称
|
||||
|
||||
Returns:
|
||||
用户对该Volume的权限集合
|
||||
"""
|
||||
cache_key = f"external_volume:{volume_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# 使用正确的ClickZetta语法检查Volume权限
|
||||
logger.info("Checking permissions for volume: %s", volume_name)
|
||||
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
logger.info("Raw grants result for %s: %s", volume_name, grants)
|
||||
|
||||
# 解析权限结果
|
||||
# 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
|
||||
# grantee_name, grantor_name, grant_option, granted_time)
|
||||
for grant in grants:
|
||||
logger.info("Processing grant: %s", grant)
|
||||
if len(grant) >= 5:
|
||||
granted_type = grant[0]
|
||||
privilege = grant[1].upper()
|
||||
granted_on = grant[3]
|
||||
object_name = grant[4]
|
||||
|
||||
logger.info(
|
||||
"Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s",
|
||||
granted_type,
|
||||
privilege,
|
||||
granted_on,
|
||||
object_name,
|
||||
)
|
||||
|
||||
# 检查是否是对该Volume的权限或者是层级权限
|
||||
if (
|
||||
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
|
||||
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
|
||||
logger.info("Matching grant found for %s", volume_name)
|
||||
|
||||
if "READ" in privilege:
|
||||
permissions.add("read")
|
||||
logger.info("Added READ permission for %s", volume_name)
|
||||
if "WRITE" in privilege:
|
||||
permissions.add("write")
|
||||
logger.info("Added WRITE permission for %s", volume_name)
|
||||
if "ALTER" in privilege:
|
||||
permissions.add("alter")
|
||||
logger.info("Added ALTER permission for %s", volume_name)
|
||||
if privilege == "ALL":
|
||||
permissions.update(["read", "write", "alter"])
|
||||
logger.info("Added ALL permissions for %s", volume_name)
|
||||
|
||||
logger.info("Final permissions for %s: %s", volume_name, permissions)
|
||||
|
||||
# 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
permissions.add("read") # 至少有读权限
|
||||
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Cannot access volume %s, no basic permission", volume_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
|
||||
# 在权限检查失败时,尝试基本的Volume访问验证
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
logger.info("Basic volume access verified for %s", volume_name)
|
||||
permissions.add("read")
|
||||
permissions.add("write") # 假设有写权限
|
||||
break
|
||||
except Exception as basic_e:
|
||||
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
|
||||
# 最后的备选方案:假设有基本权限
|
||||
permissions.add("read")
|
||||
|
||||
# 缓存权限信息
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def clear_permission_cache(self):
|
||||
"""清空权限缓存"""
|
||||
self._permission_cache.clear()
|
||||
logger.debug("Permission cache cleared")
|
||||
|
||||
def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]:
|
||||
"""获取权限摘要
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID (用于table volume)
|
||||
|
||||
Returns:
|
||||
权限摘要字典
|
||||
"""
|
||||
summary = {}
|
||||
|
||||
for operation in VolumePermission:
|
||||
summary[operation.name.lower()] = self.check_permission(operation, dataset_id)
|
||||
|
||||
return summary
|
||||
|
||||
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
|
||||
"""检查文件路径的权限继承
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
operation: 要执行的操作
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
# 解析文件路径
|
||||
path_parts = file_path.strip("/").split("/")
|
||||
|
||||
if not path_parts:
|
||||
logger.warning("Invalid file path for permission inheritance check")
|
||||
return False
|
||||
|
||||
# 对于Table Volume,第一层是dataset_id
|
||||
if self._volume_type == "table":
|
||||
if len(path_parts) < 1:
|
||||
return False
|
||||
|
||||
dataset_id = path_parts[0]
|
||||
|
||||
# 检查对dataset的权限
|
||||
has_dataset_permission = self.check_permission(operation, dataset_id)
|
||||
|
||||
if not has_dataset_permission:
|
||||
logger.debug("Permission denied for dataset %s", dataset_id)
|
||||
return False
|
||||
|
||||
# 检查路径遍历攻击
|
||||
if self._contains_path_traversal(file_path):
|
||||
logger.warning("Path traversal attack detected: %s", file_path)
|
||||
return False
|
||||
|
||||
# 检查是否访问敏感目录
|
||||
if self._is_sensitive_path(file_path):
|
||||
logger.warning("Access to sensitive path denied: %s", file_path)
|
||||
return False
|
||||
|
||||
logger.debug("Permission inherited for path %s", file_path)
|
||||
return True
|
||||
|
||||
elif self._volume_type == "user":
|
||||
# User Volume的权限继承
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# 检查是否试图访问其他用户的目录
|
||||
if len(path_parts) > 1 and path_parts[0] != current_user:
|
||||
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
|
||||
return False
|
||||
|
||||
# 检查基本权限
|
||||
return self.check_permission(operation)
|
||||
|
||||
elif self._volume_type == "external":
|
||||
# External Volume的权限继承
|
||||
# 检查对External Volume的权限
|
||||
return self.check_permission(operation)
|
||||
|
||||
else:
|
||||
logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Permission inheritance check failed")
|
||||
return False
|
||||
|
||||
def _contains_path_traversal(self, file_path: str) -> bool:
|
||||
"""检查路径是否包含路径遍历攻击"""
|
||||
# 检查常见的路径遍历模式
|
||||
traversal_patterns = [
|
||||
"../",
|
||||
"..\\",
|
||||
"..%2f",
|
||||
"..%2F",
|
||||
"..%5c",
|
||||
"..%5C",
|
||||
"%2e%2e%2f",
|
||||
"%2e%2e%5c",
|
||||
"....//",
|
||||
"....\\\\",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
for pattern in traversal_patterns:
|
||||
if pattern in file_path_lower:
|
||||
return True
|
||||
|
||||
# 检查绝对路径
|
||||
if file_path.startswith("/") or file_path.startswith("\\"):
|
||||
return True
|
||||
|
||||
# 检查Windows驱动器路径
|
||||
if len(file_path) >= 2 and file_path[1] == ":":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_sensitive_path(self, file_path: str) -> bool:
|
||||
"""检查路径是否为敏感路径"""
|
||||
sensitive_patterns = [
|
||||
"passwd",
|
||||
"shadow",
|
||||
"hosts",
|
||||
"config",
|
||||
"secrets",
|
||||
"private",
|
||||
"key",
|
||||
"certificate",
|
||||
"cert",
|
||||
"ssl",
|
||||
"database",
|
||||
"backup",
|
||||
"dump",
|
||||
"log",
|
||||
"tmp",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
return any(pattern in file_path_lower for pattern in sensitive_patterns)
|
||||
|
||||
def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool:
|
||||
"""验证操作权限
|
||||
|
||||
Args:
|
||||
operation: 操作名称 (save|load|exists|delete|scan)
|
||||
dataset_id: 数据集ID
|
||||
|
||||
Returns:
|
||||
True if operation is allowed, False otherwise
|
||||
"""
|
||||
operation_mapping = {
|
||||
"save": VolumePermission.WRITE,
|
||||
"load": VolumePermission.READ,
|
||||
"load_once": VolumePermission.READ,
|
||||
"load_stream": VolumePermission.READ,
|
||||
"download": VolumePermission.READ,
|
||||
"exists": VolumePermission.READ,
|
||||
"delete": VolumePermission.DELETE,
|
||||
"scan": VolumePermission.LIST,
|
||||
}
|
||||
|
||||
if operation not in operation_mapping:
|
||||
logger.warning("Unknown operation: %s", operation)
|
||||
return False
|
||||
|
||||
volume_permission = operation_mapping[operation]
|
||||
return self.check_permission(volume_permission, dataset_id)
|
||||
|
||||
|
||||
class VolumePermissionError(Exception):
|
||||
"""Volume权限错误异常"""
|
||||
|
||||
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None):
|
||||
self.operation = operation
|
||||
self.volume_type = volume_type
|
||||
self.dataset_id = dataset_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_volume_permission(
|
||||
permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""权限检查装饰器函数
|
||||
|
||||
Args:
|
||||
permission_manager: 权限管理器
|
||||
operation: 操作名称
|
||||
dataset_id: 数据集ID
|
||||
|
||||
Raises:
|
||||
VolumePermissionError: 如果没有权限
|
||||
"""
|
||||
if not permission_manager.validate_operation(operation, dataset_id):
|
||||
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"
|
||||
if dataset_id:
|
||||
error_message += f" (dataset: {dataset_id})"
|
||||
|
||||
raise VolumePermissionError(
|
||||
error_message,
|
||||
operation=operation,
|
||||
volume_type=permission_manager._volume_type or "unknown",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
@ -5,6 +5,7 @@ class StorageType(StrEnum):
|
||||
ALIYUN_OSS = "aliyun-oss"
|
||||
AZURE_BLOB = "azure-blob"
|
||||
BAIDU_OBS = "baidu-obs"
|
||||
CLICKZETTA_VOLUME = "clickzetta-volume"
|
||||
GOOGLE_STORAGE = "google-storage"
|
||||
HUAWEI_OBS = "huawei-obs"
|
||||
LOCAL = "local"
|
||||
|
||||
Reference in New Issue
Block a user