mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
refactor(sandbox): sandbox provider system default configuration
This commit is contained in:
@ -1,3 +0,0 @@
|
||||
from .sandbox_provider_service import SandboxProviderService
|
||||
|
||||
__all__ = ["SandboxProviderService"]
|
||||
@ -1,332 +1,191 @@
|
||||
"""
|
||||
Sandbox Provider Service for managing sandbox configurations.
|
||||
|
||||
Supports three provider types:
|
||||
- e2b: Cloud-based sandbox (requires API key)
|
||||
- docker: Local Docker-based sandbox (self-hosted)
|
||||
- local: Local execution without isolation (self-hosted only)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.sandbox import VMBuilder, VMType, create_sandbox_config_encrypter, masked_config
|
||||
from core.tools.utils.system_encryption import (
|
||||
decrypt_system_params,
|
||||
)
|
||||
from core.sandbox import SandboxBuilder, SandboxType, VMConfig, create_sandbox_config_encrypter, masked_config
|
||||
from core.sandbox.entities import SandboxProviderApiEntity
|
||||
from core.sandbox.entities.providers import SandboxProviderEntity
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models.sandbox import SandboxProvider, SandboxProviderSystemConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxProviderType(StrEnum):
|
||||
E2B = "e2b"
|
||||
DOCKER = "docker"
|
||||
LOCAL = "local"
|
||||
def _get_encrypter(tenant_id: str, provider_type: str):
|
||||
return create_sandbox_config_encrypter(tenant_id, VMConfig.get_schema(SandboxType(provider_type)), provider_type)[0]
|
||||
|
||||
|
||||
class E2BConfig(BaseModel):
|
||||
api_key: str = ""
|
||||
e2b_api_url: str = "https://api.e2b.app"
|
||||
e2b_default_template: str = "code-interpreter-v1"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_required(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
if not values.get("api_key"):
|
||||
raise ValueError("api_key is required")
|
||||
return values
|
||||
|
||||
|
||||
class DockerConfig(BaseModel):
|
||||
docker_sock: str = "unix:///var/run/docker.sock"
|
||||
docker_image: str = "ubuntu:latest"
|
||||
|
||||
|
||||
class LocalConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
|
||||
SandboxProviderType.E2B: E2BConfig,
|
||||
SandboxProviderType.DOCKER: DockerConfig,
|
||||
SandboxProviderType.LOCAL: LocalConfig,
|
||||
}
|
||||
|
||||
PROVIDER_CONFIG_SCHEMAS: dict[str, list[BasicProviderConfig]] = {
|
||||
SandboxProviderType.E2B: [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name="api_key"),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="e2b_api_url"),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="e2b_default_template"),
|
||||
],
|
||||
SandboxProviderType.DOCKER: [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="docker_sock"),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="docker_image"),
|
||||
],
|
||||
SandboxProviderType.LOCAL: [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="base_working_path"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class SandboxProviderInfo(BaseModel):
|
||||
provider_type: str = Field(..., description="Provider type identifier")
|
||||
label: str = Field(..., description="Display name")
|
||||
description: str = Field(..., description="Provider description")
|
||||
icon: str = Field(..., description="Icon identifier")
|
||||
is_system_configured: bool = Field(default=False, description="Whether system default is configured")
|
||||
is_tenant_configured: bool = Field(default=False, description="Whether tenant has custom config")
|
||||
is_active: bool = Field(default=False, description="Whether this provider is active for the tenant")
|
||||
config: Mapping[str, Any] = Field(default_factory=dict, description="Masked config")
|
||||
config_schema: list[dict[str, Any]] = Field(default_factory=list, description="Config form schema")
|
||||
|
||||
|
||||
PROVIDER_METADATA: dict[str, dict[str, str]] = {
|
||||
SandboxProviderType.E2B: {
|
||||
"label": "E2B",
|
||||
"description": "Cloud-based sandbox powered by E2B. Secure, scalable, and managed.",
|
||||
"icon": "e2b",
|
||||
},
|
||||
SandboxProviderType.DOCKER: {
|
||||
"label": "Docker",
|
||||
"description": "Local Docker-based sandbox. Requires Docker daemon running on the host.",
|
||||
"icon": "docker",
|
||||
},
|
||||
SandboxProviderType.LOCAL: {
|
||||
"label": "Local",
|
||||
"description": "Local execution without isolation. Only for development/testing.",
|
||||
"icon": "local",
|
||||
},
|
||||
}
|
||||
def _query_tenant_config(session: Session, tenant_id: str, provider_type: str) -> SandboxProvider | None:
|
||||
return (
|
||||
session.query(SandboxProvider)
|
||||
.filter(SandboxProvider.tenant_id == tenant_id, SandboxProvider.provider_type == provider_type)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
class SandboxProviderService:
|
||||
@classmethod
|
||||
def get_available_provider_types(cls) -> list[str]:
|
||||
providers = [SandboxProviderType.E2B, SandboxProviderType.DOCKER]
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
providers.append(SandboxProviderType.LOCAL)
|
||||
return [provider.value for provider in providers]
|
||||
|
||||
@classmethod
|
||||
def list_providers(cls, tenant_id: str) -> list[SandboxProviderInfo]:
|
||||
result: list[SandboxProviderInfo] = []
|
||||
|
||||
def list_providers(cls, tenant_id: str) -> list[SandboxProviderApiEntity]:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
provider_types = SandboxType.get_all()
|
||||
tenant_configs = {
|
||||
cfg.provider_type: cfg
|
||||
for cfg in session.query(SandboxProvider).filter(SandboxProvider.tenant_id == tenant_id).all()
|
||||
config.provider_type: config
|
||||
for config in session.query(SandboxProvider).filter(SandboxProvider.tenant_id == tenant_id).all()
|
||||
}
|
||||
system_configs = {
|
||||
config.provider_type: config
|
||||
for config in session.query(SandboxProviderSystemConfig)
|
||||
.filter(SandboxProviderSystemConfig.provider_type.in_(provider_types))
|
||||
.all()
|
||||
}
|
||||
system_defaults = {cfg.provider_type for cfg in session.query(SandboxProviderSystemConfig).all()}
|
||||
|
||||
for provider_type in cls.get_available_provider_types():
|
||||
providers: list[SandboxProviderApiEntity] = []
|
||||
current_provider = cls.get_active_sandbox_config(session, tenant_id)
|
||||
for provider_type in SandboxType.get_all():
|
||||
tenant_config = tenant_configs.get(provider_type)
|
||||
schema = PROVIDER_CONFIG_SCHEMAS.get(provider_type, [])
|
||||
metadata = PROVIDER_METADATA.get(provider_type, {})
|
||||
|
||||
config: Mapping[str, Any] = {}
|
||||
if tenant_config and tenant_config.config:
|
||||
encrypter, _ = create_sandbox_config_encrypter(tenant_id, schema, provider_type)
|
||||
config = masked_config(schema, encrypter.decrypt(tenant_config.config))
|
||||
|
||||
result.append(
|
||||
SandboxProviderInfo(
|
||||
provider_type=provider_type,
|
||||
label=metadata.get("label", provider_type),
|
||||
description=metadata.get("description", ""),
|
||||
icon=metadata.get("icon", provider_type),
|
||||
is_system_configured=provider_type in system_defaults and tenant_config is None,
|
||||
is_tenant_configured=tenant_config is not None,
|
||||
is_active=tenant_config.is_active if tenant_config else False,
|
||||
config=config,
|
||||
config_schema=[{"name": c.name, "type": c.type.value} for c in schema],
|
||||
schema = VMConfig.get_schema(SandboxType(provider_type))
|
||||
if tenant_config:
|
||||
is_tenant_configured = tenant_config.configure_type == "user"
|
||||
if is_tenant_configured:
|
||||
decrypted_config = _get_encrypter(tenant_id, provider_type).decrypt(data=tenant_config.config)
|
||||
config = masked_config(schemas=schema, config=decrypted_config)
|
||||
else:
|
||||
config = {}
|
||||
providers.append(
|
||||
SandboxProviderApiEntity(
|
||||
provider_type=provider_type,
|
||||
is_system_configured=system_configs.get(provider_type) is not None,
|
||||
is_tenant_configured=is_tenant_configured,
|
||||
is_active=current_provider.id == tenant_config.id,
|
||||
config=config,
|
||||
config_schema=[c.model_dump() for c in schema],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, tenant_id: str, provider_type: str) -> SandboxProviderInfo | None:
|
||||
if provider_type not in cls.get_available_provider_types():
|
||||
return None
|
||||
|
||||
providers = cls.list_providers(tenant_id)
|
||||
for provider in providers:
|
||||
if provider.provider_type == provider_type:
|
||||
return provider
|
||||
return None
|
||||
else:
|
||||
system_config = system_configs.get(provider_type)
|
||||
providers.append(
|
||||
SandboxProviderApiEntity(
|
||||
provider_type=provider_type,
|
||||
is_active=system_config is not None and system_config.id == current_provider.id,
|
||||
is_system_configured=system_config is not None,
|
||||
config_schema=[c.model_dump() for c in schema],
|
||||
)
|
||||
)
|
||||
return providers
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, provider_type: str, config: Mapping[str, Any]) -> None:
|
||||
model_class = PROVIDER_CONFIG_MODELS.get(provider_type)
|
||||
if model_class:
|
||||
model_class.model_validate(config)
|
||||
|
||||
VMBuilder.validate(VMType(provider_type), config)
|
||||
SandboxBuilder.validate(SandboxType(provider_type), config)
|
||||
|
||||
@classmethod
|
||||
def save_config(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_type: str,
|
||||
config: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if provider_type not in cls.get_available_provider_types():
|
||||
def save_config(cls, tenant_id: str, provider_type: str, config: Mapping[str, Any]) -> dict[str, Any]:
|
||||
if provider_type not in SandboxType.get_all():
|
||||
raise ValueError(f"Invalid provider type: {provider_type}")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
existing = (
|
||||
session.query(SandboxProvider)
|
||||
.filter(
|
||||
SandboxProvider.tenant_id == tenant_id,
|
||||
SandboxProvider.provider_type == provider_type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
schema = PROVIDER_CONFIG_SCHEMAS.get(provider_type, [])
|
||||
encrypter, _ = create_sandbox_config_encrypter(tenant_id, schema, provider_type)
|
||||
|
||||
final_config = dict(config)
|
||||
if existing and existing.config:
|
||||
existing_config = encrypter.decrypt(existing.config)
|
||||
for key, value in final_config.items():
|
||||
if value == HIDDEN_VALUE:
|
||||
final_config[key] = existing_config.get(key, "")
|
||||
|
||||
cls.validate_config(provider_type, final_config)
|
||||
|
||||
encrypted = encrypter.encrypt(final_config)
|
||||
|
||||
if existing:
|
||||
existing.encrypted_config = json.dumps(encrypted)
|
||||
else:
|
||||
new_config = SandboxProvider(
|
||||
provider = _query_tenant_config(session, tenant_id, provider_type)
|
||||
encrypter = _get_encrypter(tenant_id, provider_type)
|
||||
if not provider:
|
||||
provider = SandboxProvider(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=provider_type,
|
||||
encrypted_config=json.dumps(encrypted),
|
||||
is_active=False,
|
||||
encrypted_config=json.dumps({}),
|
||||
)
|
||||
session.add(new_config)
|
||||
session.add(provider)
|
||||
|
||||
new_config = dict(config)
|
||||
old_config = encrypter.decrypt(provider.config)
|
||||
for key, value in new_config.items():
|
||||
if value == HIDDEN_VALUE:
|
||||
new_config[key] = old_config.get(key, "")
|
||||
|
||||
cls.validate_config(provider_type, new_config)
|
||||
|
||||
provider.encrypted_config = json.dumps(encrypter.encrypt(new_config))
|
||||
provider.is_active = provider.is_active or cls.is_system_default_config(session, tenant_id)
|
||||
provider.configure_type = "user"
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def delete_config(cls, tenant_id: str, provider_type: str) -> dict[str, Any]:
|
||||
with Session(db.engine) as session:
|
||||
config = (
|
||||
session.query(SandboxProvider)
|
||||
.filter(
|
||||
SandboxProvider.tenant_id == tenant_id,
|
||||
SandboxProvider.provider_type == provider_type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not config:
|
||||
return {"result": "success"}
|
||||
|
||||
session.delete(config)
|
||||
session.commit()
|
||||
|
||||
if config := _query_tenant_config(session, tenant_id, provider_type):
|
||||
session.delete(config)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def is_system_default_config(cls, session: Session, tenant_id: str) -> bool:
|
||||
system_configed: SandboxProviderSystemConfig | None = session.query(SandboxProviderSystemConfig).first()
|
||||
if not system_configed:
|
||||
return False
|
||||
active_config = cls.get_active_sandbox_config(session, tenant_id)
|
||||
return active_config.id == system_configed.id
|
||||
|
||||
@classmethod
|
||||
def activate_provider(cls, tenant_id: str, provider_type: str) -> dict[str, Any]:
|
||||
if provider_type not in cls.get_available_provider_types():
|
||||
if provider_type not in SandboxType.get_all():
|
||||
raise ValueError(f"Invalid provider type: {provider_type}")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
tenant_config = (
|
||||
session.query(SandboxProvider)
|
||||
.filter(
|
||||
SandboxProvider.tenant_id == tenant_id,
|
||||
SandboxProvider.provider_type == provider_type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tenant_config = _query_tenant_config(session, tenant_id, provider_type)
|
||||
system_config = session.query(SandboxProviderSystemConfig).filter_by(provider_type=provider_type).first()
|
||||
|
||||
system_default = (
|
||||
session.query(SandboxProviderSystemConfig)
|
||||
.filter(SandboxProviderSystemConfig.provider_type == provider_type)
|
||||
.first()
|
||||
)
|
||||
|
||||
config_schema = PROVIDER_CONFIG_SCHEMAS.get(provider_type, [])
|
||||
needs_config = len(config_schema) > 0
|
||||
|
||||
if needs_config and not tenant_config and not system_default:
|
||||
raise ValueError(f"Provider {provider_type} is not configured. Please add configuration first.")
|
||||
|
||||
session.query(SandboxProvider).filter(
|
||||
SandboxProvider.tenant_id == tenant_id,
|
||||
).update({"is_active": False})
|
||||
session.query(SandboxProvider).filter(SandboxProvider.tenant_id == tenant_id).update({"is_active": False})
|
||||
|
||||
# using tenant config
|
||||
if tenant_config:
|
||||
tenant_config.is_active = True
|
||||
else:
|
||||
new_config = SandboxProvider(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=provider_type,
|
||||
encrypted_config=json.dumps({}),
|
||||
is_active=True,
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
# using system config
|
||||
if system_config:
|
||||
session.add(
|
||||
SandboxProvider(
|
||||
is_active=True,
|
||||
tenant_id=tenant_id,
|
||||
configure_type="system",
|
||||
provider_type=provider_type,
|
||||
encrypted_config=json.dumps({}),
|
||||
)
|
||||
)
|
||||
session.add(new_config)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
raise ValueError(f"No sandbox provider configured for tenant {tenant_id} and provider type {provider_type}")
|
||||
|
||||
@classmethod
|
||||
def get_active_provider(cls, tenant_id: str) -> str | None:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = (
|
||||
session.query(SandboxProvider)
|
||||
.filter(
|
||||
SandboxProvider.tenant_id == tenant_id,
|
||||
SandboxProvider.is_active.is_(True),
|
||||
)
|
||||
.first()
|
||||
def get_active_sandbox_config(cls, session: Session, tenant_id: str) -> SandboxProviderEntity:
|
||||
tenant_configed = (
|
||||
session.query(SandboxProvider)
|
||||
.filter(SandboxProvider.tenant_id == tenant_id, SandboxProvider.is_active.is_(True))
|
||||
.first()
|
||||
)
|
||||
if tenant_configed:
|
||||
config = _get_encrypter(tenant_id, tenant_configed.provider_type).decrypt(tenant_configed.config)
|
||||
return SandboxProviderEntity(
|
||||
id=tenant_configed.id, provider_type=tenant_configed.provider_type, config=config
|
||||
)
|
||||
return config.provider_type if config else None
|
||||
|
||||
system_configed: SandboxProviderSystemConfig | None = session.query(SandboxProviderSystemConfig).first()
|
||||
if system_configed:
|
||||
return SandboxProviderEntity(
|
||||
id=system_configed.id,
|
||||
provider_type=system_configed.provider_type,
|
||||
config=decrypt_system_params(system_configed.encrypted_config),
|
||||
)
|
||||
|
||||
raise ValueError(f"No sandbox provider configured for tenant {tenant_id}")
|
||||
|
||||
@classmethod
|
||||
def create_sandbox_builder(cls, tenant_id: str) -> VMBuilder:
|
||||
def create_sandbox_builder(cls, tenant_id: str) -> SandboxBuilder:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
tenant_config = (
|
||||
session.query(SandboxProvider)
|
||||
.filter(
|
||||
SandboxProvider.tenant_id == tenant_id,
|
||||
SandboxProvider.is_active.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
config: Mapping[str, Any] = {}
|
||||
provider_type = None
|
||||
if tenant_config:
|
||||
schema = PROVIDER_CONFIG_SCHEMAS.get(tenant_config.provider_type, [])
|
||||
encrypter, _ = create_sandbox_config_encrypter(tenant_id, schema, tenant_config.provider_type)
|
||||
config = encrypter.decrypt(tenant_config.config)
|
||||
provider_type = tenant_config.provider_type
|
||||
else:
|
||||
system_default = session.query(SandboxProviderSystemConfig).first()
|
||||
if system_default:
|
||||
config = decrypt_system_params(system_default.encrypted_config)
|
||||
provider_type = system_default.provider_type
|
||||
|
||||
if not config or not provider_type:
|
||||
raise ValueError(f"No active sandbox provider for tenant {tenant_id} or system default")
|
||||
|
||||
return VMBuilder(tenant_id, VMType(provider_type)).options(config)
|
||||
provider_type, config = cls.get_active_sandbox_config(session, tenant_id)
|
||||
return SandboxBuilder(tenant_id, SandboxType(provider_type)).options(config)
|
||||
|
||||
Reference in New Issue
Block a user