refactor(sandbox): sandbox provider system default configuration

This commit is contained in:
Harry
2026-01-16 18:21:53 +08:00
parent 8b42435f7a
commit 0bd17c6d0f
19 changed files with 382 additions and 457 deletions

View File

@ -13,45 +13,18 @@ logger = logging.getLogger(__name__)
@console_ns.route("/workspaces/current/sandbox-providers")
class SandboxProviderListApi(Resource):
"""List all sandbox providers for the current tenant."""
@console_ns.doc("list_sandbox_providers")
@console_ns.doc(description="Get list of available sandbox providers with configuration status")
@console_ns.response(
200,
"Success",
fields.List(fields.Raw(description="Sandbox provider information")),
)
@console_ns.response(200, "Success", fields.List(fields.Raw(description="Sandbox provider information")))
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List all sandbox providers."""
_, current_tenant_id = current_account_with_tenant()
providers = SandboxProviderService.list_providers(current_tenant_id)
return jsonable_encoder([p.model_dump() for p in providers])
@console_ns.route("/workspaces/current/sandbox-provider/<string:provider_type>")
class SandboxProviderApi(Resource):
"""Get specific sandbox provider details."""
@console_ns.doc("get_sandbox_provider")
@console_ns.doc(description="Get specific sandbox provider details")
@console_ns.doc(params={"provider_type": "Sandbox provider type (e2b, docker, local)"})
@console_ns.response(200, "Success", fields.Raw(description="Sandbox provider details"))
@setup_required
@login_required
@account_initialization_required
def get(self, provider_type: str):
"""Get a specific sandbox provider."""
_, current_tenant_id = current_account_with_tenant()
provider = SandboxProviderService.get_provider(current_tenant_id, provider_type)
if not provider:
return {"message": f"Provider {provider_type} not found"}, 404
return jsonable_encoder(provider.model_dump())
config_parser = reqparse.RequestParser()
config_parser.add_argument("config", type=dict, required=True, location="json")
@ -120,20 +93,3 @@ class SandboxProviderActivateApi(Resource):
return result
except ValueError as e:
return {"message": str(e)}, 400
@console_ns.route("/workspaces/current/sandbox-provider/active")
class SandboxProviderActiveApi(Resource):
"""Get the currently active sandbox provider."""
@console_ns.doc("get_active_sandbox_provider")
@console_ns.doc(description="Get the currently active sandbox provider for the workspace")
@console_ns.response(200, "Success")
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Get the active sandbox provider."""
_, current_tenant_id = current_account_with_tenant()
active_provider = SandboxProviderService.get_active_provider(current_tenant_id)
return {"provider_type": active_provider}

View File

@ -12,13 +12,13 @@ from .constants import (
DIFY_CLI_PATH,
DIFY_CLI_PATH_PATTERN,
)
from .factory import VMBuilder, VMType
from .initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer
from .manager import SandboxManager
from .session import SandboxSession
from .storage import ArchiveSandboxStorage, SandboxStorage
from .utils.debug import sandbox_debug
from .utils.encryption import create_sandbox_config_encrypter, masked_config
from .vm import SandboxBuilder, SandboxType, VMConfig
__all__ = [
"APP_ASSETS_PATH",
@ -34,12 +34,13 @@ __all__ = [
"DifyCliInitializer",
"DifyCliLocator",
"DifyCliToolConfig",
"SandboxBuilder",
"SandboxInitializer",
"SandboxManager",
"SandboxSession",
"SandboxStorage",
"VMBuilder",
"VMType",
"SandboxType",
"VMConfig",
"create_sandbox_config_encrypter",
"masked_config",
"sandbox_debug",

View File

@ -0,0 +1,3 @@
from .providers import SandboxProviderApiEntity
__all__ = ["SandboxProviderApiEntity"]

View File

@ -0,0 +1,21 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
class SandboxProviderApiEntity(BaseModel):
provider_type: str = Field(..., description="Provider type identifier")
is_system_configured: bool = Field(default=False)
is_tenant_configured: bool = Field(default=False)
is_active: bool = Field(default=False)
config: Mapping[str, Any] = Field(default_factory=dict)
config_schema: list[dict[str, Any]] = Field(default_factory=list)
class SandboxProviderEntity(BaseModel):
id: str = Field(..., description="Provider identifier")
provider_type: str = Field(..., description="Provider type identifier")
is_active: bool = Field(default=False)
config: Mapping[str, Any] = Field(default_factory=dict)
config_schema: list[dict[str, Any]] = Field(default_factory=list)

View File

@ -1,81 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
if TYPE_CHECKING:
from .initializer import SandboxInitializer
class VMType(StrEnum):
DOCKER = "docker"
E2B = "e2b"
LOCAL = "local"
def _get_vm_class(vm_type: VMType) -> type[VirtualEnvironment]:
match vm_type:
case VMType.DOCKER:
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
return DockerDaemonEnvironment
case VMType.E2B:
from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment
return E2BEnvironment
case VMType.LOCAL:
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
return LocalVirtualEnvironment
case _:
raise ValueError(f"Unsupported VM type: {vm_type}")
class VMBuilder:
def __init__(self, tenant_id: str, vm_type: VMType) -> None:
self._tenant_id = tenant_id
self._vm_type = vm_type
self._user_id: str | None = None
self._options: dict[str, Any] = {}
self._environments: dict[str, str] = {}
self._initializers: list[SandboxInitializer] = []
def user(self, user_id: str) -> VMBuilder:
self._user_id = user_id
return self
def options(self, options: Mapping[str, Any]) -> VMBuilder:
self._options = dict(options)
return self
def environments(self, environments: Mapping[str, str]) -> VMBuilder:
self._environments = dict(environments)
return self
def initializer(self, initializer: SandboxInitializer) -> VMBuilder:
self._initializers.append(initializer)
return self
def initializers(self, initializers: Sequence[SandboxInitializer]) -> VMBuilder:
self._initializers.extend(initializers)
return self
def build(self) -> VirtualEnvironment:
vm_class = _get_vm_class(self._vm_type)
vm = vm_class(
tenant_id=self._tenant_id,
options=self._options,
environments=self._environments,
user_id=self._user_id,
)
for init in self._initializers:
init.initialize(vm)
return vm
@staticmethod
def validate(vm_type: VMType, options: Mapping[str, Any]) -> None:
vm_class = _get_vm_class(vm_type)
vm_class.validate(options)

109
api/core/sandbox/vm.py Normal file
View File

@ -0,0 +1,109 @@
"""
Facade module for virtual machine providers.
Provides unified interfaces to access different VM provider implementations
(E2B, Docker, Local) through VMType, VMBuilder, and VMConfig.
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from .initializer import SandboxInitializer
class SandboxType(StrEnum):
"""
Sandbox types.
"""
DOCKER = "docker"
E2B = "e2b"
LOCAL = "local"
@classmethod
def get_all(cls) -> list[str]:
"""
Get all available sandbox types.
"""
if dify_config.EDITION == "SELF_HOSTED":
return [p.value for p in cls]
else:
return [p.value for p in cls if p != SandboxType.LOCAL]
def _get_sandbox_class(sandbox_type: SandboxType) -> type[VirtualEnvironment]:
match sandbox_type:
case SandboxType.DOCKER:
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
return DockerDaemonEnvironment
case SandboxType.E2B:
from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment
return E2BEnvironment
case SandboxType.LOCAL:
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
return LocalVirtualEnvironment
case _:
raise ValueError(f"Unsupported sandbox type: {sandbox_type}")
class SandboxBuilder:
def __init__(self, tenant_id: str, sandbox_type: SandboxType) -> None:
self._tenant_id = tenant_id
self._sandbox_type = sandbox_type
self._user_id: str | None = None
self._options: dict[str, Any] = {}
self._environments: dict[str, str] = {}
self._initializers: list[SandboxInitializer] = []
def user(self, user_id: str) -> SandboxBuilder:
self._user_id = user_id
return self
def options(self, options: Mapping[str, Any]) -> SandboxBuilder:
self._options = dict(options)
return self
def environments(self, environments: Mapping[str, str]) -> SandboxBuilder:
self._environments = dict(environments)
return self
def initializer(self, initializer: SandboxInitializer) -> SandboxBuilder:
self._initializers.append(initializer)
return self
def initializers(self, initializers: Sequence[SandboxInitializer]) -> SandboxBuilder:
self._initializers.extend(initializers)
return self
def build(self) -> VirtualEnvironment:
vm_class = _get_sandbox_class(self._sandbox_type)
vm = vm_class(
tenant_id=self._tenant_id,
options=self._options,
environments=self._environments,
user_id=self._user_id,
)
for init in self._initializers:
init.initialize(vm)
return vm
@staticmethod
def validate(vm_type: SandboxType, options: Mapping[str, Any]) -> None:
vm_class = _get_sandbox_class(vm_type)
vm_class.validate(options)
class VMConfig:
@staticmethod
def get_schema(vm_type: SandboxType) -> list[BasicProviderConfig]:
return _get_sandbox_class(vm_type).get_config_schema()

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from io import BytesIO
from typing import Any
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
@ -174,3 +175,8 @@ class VirtualEnvironment(ABC):
Returns:
CommandStatus: The status of the command execution.
"""
@classmethod
@abstractmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
pass

View File

@ -15,6 +15,7 @@ import docker.errors
from docker.models.containers import Container
import docker
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
@ -256,6 +257,13 @@ class DockerDaemonEnvironment(VirtualEnvironment):
DOCKER_IMAGE = "docker_image"
DOCKER_COMMAND = "docker_command"
@classmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
return [
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.DOCKER_SOCK),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.DOCKER_IMAGE),
]
@classmethod
def validate(cls, options: Mapping[str, Any]) -> None:
docker_sock = options.get(cls.OptionsKey.DOCKER_SOCK, cls._DEFAULT_DOCKER_SOCK)

View File

@ -10,6 +10,7 @@ from uuid import uuid4
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
@ -96,6 +97,14 @@ class E2BEnvironment(VirtualEnvironment):
class StoreKey(StrEnum):
SANDBOX = "sandbox"
@classmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
return [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.API_KEY),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.E2B_API_URL),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.E2B_DEFAULT_TEMPLATE),
]
@classmethod
def validate(cls, options: Mapping[str, Any]) -> None:
from e2b.exceptions import AuthenticationException # type: ignore[import-untyped]

View File

@ -8,6 +8,7 @@ from platform import machine, system
from typing import Any
from uuid import uuid4
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
@ -72,6 +73,12 @@ class LocalVirtualEnvironment(VirtualEnvironment):
NEVER USE IT IN PRODUCTION ENVIRONMENTS.
"""
@classmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
return [
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="base_working_path"),
]
@classmethod
def validate(cls, options: Mapping[str, Any]) -> None:
pass

View File

@ -0,0 +1,68 @@
"""sandbox_provider_configure_type
Revision ID: 45471e916693
Revises: d88f3edbd99d
Create Date: 2026-01-16 17:28:46.691473
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '45471e916693'
down_revision = 'd88f3edbd99d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
sa.Column('quota_used', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
)
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
with op.batch_alter_table('sandbox_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('configure_type', sa.String(length=20), server_default='user', nullable=False))
batch_op.drop_constraint(batch_op.f('unique_sandbox_provider_tenant_type'), type_='unique')
batch_op.create_unique_constraint('unique_sandbox_provider_tenant_type', ['tenant_id', 'provider_type', 'configure_type'])
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.drop_index('workflow_run_created_at_id_idx')
with op.batch_alter_table('sandbox_providers', schema=None) as batch_op:
batch_op.drop_constraint('unique_sandbox_provider_tenant_type', type_='unique')
batch_op.create_unique_constraint(batch_op.f('unique_sandbox_provider_tenant_type'), ['tenant_id', 'provider_type'], postgresql_nulls_not_distinct=False)
batch_op.drop_column('configure_type')
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_index('message_created_at_id_idx')
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.drop_index('tenant_credit_pool_tenant_id_idx')
batch_op.drop_index('tenant_credit_pool_pool_type_idx')
op.drop_table('tenant_credit_pools')
# ### end Alembic commands ###

View File

@ -51,7 +51,7 @@ class SandboxProvider(TypeBase):
__tablename__ = "sandbox_providers"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="sandbox_provider_pkey"),
sa.UniqueConstraint("tenant_id", "provider_type", name="unique_sandbox_provider_tenant_type"),
sa.UniqueConstraint("tenant_id", "provider_type", "configure_type", name="unique_sandbox_provider_tenant_type"),
sa.Index("idx_sandbox_providers_tenant_id", "tenant_id"),
sa.Index("idx_sandbox_providers_tenant_active", "tenant_id", "is_active"),
)
@ -62,6 +62,7 @@ class SandboxProvider(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="e2b, docker, local")
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False, comment="Encrypted config JSON")
configure_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default="user", default="user")
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False

View File

@ -1,3 +0,0 @@
from .sandbox_provider_service import SandboxProviderService
__all__ = ["SandboxProviderService"]

View File

@ -1,332 +1,191 @@
"""
Sandbox Provider Service for managing sandbox configurations.
Supports three provider types:
- e2b: Cloud-based sandbox (requires API key)
- docker: Local Docker-based sandbox (self-hosted)
- local: Local execution without isolation (self-hosted only)
"""
import json
import logging
from collections.abc import Mapping
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE
from core.entities.provider_entities import BasicProviderConfig
from core.sandbox import VMBuilder, VMType, create_sandbox_config_encrypter, masked_config
from core.tools.utils.system_encryption import (
decrypt_system_params,
)
from core.sandbox import SandboxBuilder, SandboxType, VMConfig, create_sandbox_config_encrypter, masked_config
from core.sandbox.entities import SandboxProviderApiEntity
from core.sandbox.entities.providers import SandboxProviderEntity
from core.tools.utils.system_encryption import decrypt_system_params
from extensions.ext_database import db
from models.sandbox import SandboxProvider, SandboxProviderSystemConfig
logger = logging.getLogger(__name__)
class SandboxProviderType(StrEnum):
E2B = "e2b"
DOCKER = "docker"
LOCAL = "local"
def _get_encrypter(tenant_id: str, provider_type: str):
return create_sandbox_config_encrypter(tenant_id, VMConfig.get_schema(SandboxType(provider_type)), provider_type)[0]
class E2BConfig(BaseModel):
api_key: str = ""
e2b_api_url: str = "https://api.e2b.app"
e2b_default_template: str = "code-interpreter-v1"
@model_validator(mode="before")
@classmethod
def check_required(cls, values: dict[str, Any]) -> dict[str, Any]:
if not values.get("api_key"):
raise ValueError("api_key is required")
return values
class DockerConfig(BaseModel):
docker_sock: str = "unix:///var/run/docker.sock"
docker_image: str = "ubuntu:latest"
class LocalConfig(BaseModel):
pass
PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
SandboxProviderType.E2B: E2BConfig,
SandboxProviderType.DOCKER: DockerConfig,
SandboxProviderType.LOCAL: LocalConfig,
}
PROVIDER_CONFIG_SCHEMAS: dict[str, list[BasicProviderConfig]] = {
SandboxProviderType.E2B: [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name="api_key"),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="e2b_api_url"),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="e2b_default_template"),
],
SandboxProviderType.DOCKER: [
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="docker_sock"),
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="docker_image"),
],
SandboxProviderType.LOCAL: [
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name="base_working_path"),
],
}
class SandboxProviderInfo(BaseModel):
provider_type: str = Field(..., description="Provider type identifier")
label: str = Field(..., description="Display name")
description: str = Field(..., description="Provider description")
icon: str = Field(..., description="Icon identifier")
is_system_configured: bool = Field(default=False, description="Whether system default is configured")
is_tenant_configured: bool = Field(default=False, description="Whether tenant has custom config")
is_active: bool = Field(default=False, description="Whether this provider is active for the tenant")
config: Mapping[str, Any] = Field(default_factory=dict, description="Masked config")
config_schema: list[dict[str, Any]] = Field(default_factory=list, description="Config form schema")
PROVIDER_METADATA: dict[str, dict[str, str]] = {
SandboxProviderType.E2B: {
"label": "E2B",
"description": "Cloud-based sandbox powered by E2B. Secure, scalable, and managed.",
"icon": "e2b",
},
SandboxProviderType.DOCKER: {
"label": "Docker",
"description": "Local Docker-based sandbox. Requires Docker daemon running on the host.",
"icon": "docker",
},
SandboxProviderType.LOCAL: {
"label": "Local",
"description": "Local execution without isolation. Only for development/testing.",
"icon": "local",
},
}
def _query_tenant_config(session: Session, tenant_id: str, provider_type: str) -> SandboxProvider | None:
return (
session.query(SandboxProvider)
.filter(SandboxProvider.tenant_id == tenant_id, SandboxProvider.provider_type == provider_type)
.first()
)
class SandboxProviderService:
@classmethod
def get_available_provider_types(cls) -> list[str]:
providers = [SandboxProviderType.E2B, SandboxProviderType.DOCKER]
if dify_config.EDITION == "SELF_HOSTED":
providers.append(SandboxProviderType.LOCAL)
return [provider.value for provider in providers]
@classmethod
def list_providers(cls, tenant_id: str) -> list[SandboxProviderInfo]:
result: list[SandboxProviderInfo] = []
def list_providers(cls, tenant_id: str) -> list[SandboxProviderApiEntity]:
with Session(db.engine, expire_on_commit=False) as session:
provider_types = SandboxType.get_all()
tenant_configs = {
cfg.provider_type: cfg
for cfg in session.query(SandboxProvider).filter(SandboxProvider.tenant_id == tenant_id).all()
config.provider_type: config
for config in session.query(SandboxProvider).filter(SandboxProvider.tenant_id == tenant_id).all()
}
system_configs = {
config.provider_type: config
for config in session.query(SandboxProviderSystemConfig)
.filter(SandboxProviderSystemConfig.provider_type.in_(provider_types))
.all()
}
system_defaults = {cfg.provider_type for cfg in session.query(SandboxProviderSystemConfig).all()}
for provider_type in cls.get_available_provider_types():
providers: list[SandboxProviderApiEntity] = []
current_provider = cls.get_active_sandbox_config(session, tenant_id)
for provider_type in SandboxType.get_all():
tenant_config = tenant_configs.get(provider_type)
schema = PROVIDER_CONFIG_SCHEMAS.get(provider_type, [])
metadata = PROVIDER_METADATA.get(provider_type, {})
config: Mapping[str, Any] = {}
if tenant_config and tenant_config.config:
encrypter, _ = create_sandbox_config_encrypter(tenant_id, schema, provider_type)
config = masked_config(schema, encrypter.decrypt(tenant_config.config))
result.append(
SandboxProviderInfo(
provider_type=provider_type,
label=metadata.get("label", provider_type),
description=metadata.get("description", ""),
icon=metadata.get("icon", provider_type),
is_system_configured=provider_type in system_defaults and tenant_config is None,
is_tenant_configured=tenant_config is not None,
is_active=tenant_config.is_active if tenant_config else False,
config=config,
config_schema=[{"name": c.name, "type": c.type.value} for c in schema],
schema = VMConfig.get_schema(SandboxType(provider_type))
if tenant_config:
is_tenant_configured = tenant_config.configure_type == "user"
if is_tenant_configured:
decrypted_config = _get_encrypter(tenant_id, provider_type).decrypt(data=tenant_config.config)
config = masked_config(schemas=schema, config=decrypted_config)
else:
config = {}
providers.append(
SandboxProviderApiEntity(
provider_type=provider_type,
is_system_configured=system_configs.get(provider_type) is not None,
is_tenant_configured=is_tenant_configured,
is_active=current_provider.id == tenant_config.id,
config=config,
config_schema=[c.model_dump() for c in schema],
)
)
)
return result
@classmethod
def get_provider(cls, tenant_id: str, provider_type: str) -> SandboxProviderInfo | None:
if provider_type not in cls.get_available_provider_types():
return None
providers = cls.list_providers(tenant_id)
for provider in providers:
if provider.provider_type == provider_type:
return provider
return None
else:
system_config = system_configs.get(provider_type)
providers.append(
SandboxProviderApiEntity(
provider_type=provider_type,
is_active=system_config is not None and system_config.id == current_provider.id,
is_system_configured=system_config is not None,
config_schema=[c.model_dump() for c in schema],
)
)
return providers
@classmethod
def validate_config(cls, provider_type: str, config: Mapping[str, Any]) -> None:
model_class = PROVIDER_CONFIG_MODELS.get(provider_type)
if model_class:
model_class.model_validate(config)
VMBuilder.validate(VMType(provider_type), config)
SandboxBuilder.validate(SandboxType(provider_type), config)
@classmethod
def save_config(
cls,
tenant_id: str,
provider_type: str,
config: Mapping[str, Any],
) -> dict[str, Any]:
if provider_type not in cls.get_available_provider_types():
def save_config(cls, tenant_id: str, provider_type: str, config: Mapping[str, Any]) -> dict[str, Any]:
if provider_type not in SandboxType.get_all():
raise ValueError(f"Invalid provider type: {provider_type}")
with Session(db.engine) as session:
existing = (
session.query(SandboxProvider)
.filter(
SandboxProvider.tenant_id == tenant_id,
SandboxProvider.provider_type == provider_type,
)
.first()
)
schema = PROVIDER_CONFIG_SCHEMAS.get(provider_type, [])
encrypter, _ = create_sandbox_config_encrypter(tenant_id, schema, provider_type)
final_config = dict(config)
if existing and existing.config:
existing_config = encrypter.decrypt(existing.config)
for key, value in final_config.items():
if value == HIDDEN_VALUE:
final_config[key] = existing_config.get(key, "")
cls.validate_config(provider_type, final_config)
encrypted = encrypter.encrypt(final_config)
if existing:
existing.encrypted_config = json.dumps(encrypted)
else:
new_config = SandboxProvider(
provider = _query_tenant_config(session, tenant_id, provider_type)
encrypter = _get_encrypter(tenant_id, provider_type)
if not provider:
provider = SandboxProvider(
tenant_id=tenant_id,
provider_type=provider_type,
encrypted_config=json.dumps(encrypted),
is_active=False,
encrypted_config=json.dumps({}),
)
session.add(new_config)
session.add(provider)
new_config = dict(config)
old_config = encrypter.decrypt(provider.config)
for key, value in new_config.items():
if value == HIDDEN_VALUE:
new_config[key] = old_config.get(key, "")
cls.validate_config(provider_type, new_config)
provider.encrypted_config = json.dumps(encrypter.encrypt(new_config))
provider.is_active = provider.is_active or cls.is_system_default_config(session, tenant_id)
provider.configure_type = "user"
session.commit()
return {"result": "success"}
@classmethod
def delete_config(cls, tenant_id: str, provider_type: str) -> dict[str, Any]:
with Session(db.engine) as session:
config = (
session.query(SandboxProvider)
.filter(
SandboxProvider.tenant_id == tenant_id,
SandboxProvider.provider_type == provider_type,
)
.first()
)
if not config:
return {"result": "success"}
session.delete(config)
session.commit()
if config := _query_tenant_config(session, tenant_id, provider_type):
session.delete(config)
session.commit()
return {"result": "success"}
@classmethod
def is_system_default_config(cls, session: Session, tenant_id: str) -> bool:
system_configed: SandboxProviderSystemConfig | None = session.query(SandboxProviderSystemConfig).first()
if not system_configed:
return False
active_config = cls.get_active_sandbox_config(session, tenant_id)
return active_config.id == system_configed.id
@classmethod
def activate_provider(cls, tenant_id: str, provider_type: str) -> dict[str, Any]:
if provider_type not in cls.get_available_provider_types():
if provider_type not in SandboxType.get_all():
raise ValueError(f"Invalid provider type: {provider_type}")
with Session(db.engine) as session:
tenant_config = (
session.query(SandboxProvider)
.filter(
SandboxProvider.tenant_id == tenant_id,
SandboxProvider.provider_type == provider_type,
)
.first()
)
tenant_config = _query_tenant_config(session, tenant_id, provider_type)
system_config = session.query(SandboxProviderSystemConfig).filter_by(provider_type=provider_type).first()
system_default = (
session.query(SandboxProviderSystemConfig)
.filter(SandboxProviderSystemConfig.provider_type == provider_type)
.first()
)
config_schema = PROVIDER_CONFIG_SCHEMAS.get(provider_type, [])
needs_config = len(config_schema) > 0
if needs_config and not tenant_config and not system_default:
raise ValueError(f"Provider {provider_type} is not configured. Please add configuration first.")
session.query(SandboxProvider).filter(
SandboxProvider.tenant_id == tenant_id,
).update({"is_active": False})
session.query(SandboxProvider).filter(SandboxProvider.tenant_id == tenant_id).update({"is_active": False})
# using tenant config
if tenant_config:
tenant_config.is_active = True
else:
new_config = SandboxProvider(
tenant_id=tenant_id,
provider_type=provider_type,
encrypted_config=json.dumps({}),
is_active=True,
session.commit()
return {"result": "success"}
# using system config
if system_config:
session.add(
SandboxProvider(
is_active=True,
tenant_id=tenant_id,
configure_type="system",
provider_type=provider_type,
encrypted_config=json.dumps({}),
)
)
session.add(new_config)
session.commit()
return {"result": "success"}
session.commit()
return {"result": "success"}
raise ValueError(f"No sandbox provider configured for tenant {tenant_id} and provider type {provider_type}")
@classmethod
def get_active_provider(cls, tenant_id: str) -> str | None:
with Session(db.engine, expire_on_commit=False) as session:
config = (
session.query(SandboxProvider)
.filter(
SandboxProvider.tenant_id == tenant_id,
SandboxProvider.is_active.is_(True),
)
.first()
def get_active_sandbox_config(cls, session: Session, tenant_id: str) -> SandboxProviderEntity:
tenant_configed = (
session.query(SandboxProvider)
.filter(SandboxProvider.tenant_id == tenant_id, SandboxProvider.is_active.is_(True))
.first()
)
if tenant_configed:
config = _get_encrypter(tenant_id, tenant_configed.provider_type).decrypt(tenant_configed.config)
return SandboxProviderEntity(
id=tenant_configed.id, provider_type=tenant_configed.provider_type, config=config
)
return config.provider_type if config else None
system_configed: SandboxProviderSystemConfig | None = session.query(SandboxProviderSystemConfig).first()
if system_configed:
return SandboxProviderEntity(
id=system_configed.id,
provider_type=system_configed.provider_type,
config=decrypt_system_params(system_configed.encrypted_config),
)
raise ValueError(f"No sandbox provider configured for tenant {tenant_id}")
@classmethod
def create_sandbox_builder(cls, tenant_id: str) -> VMBuilder:
def create_sandbox_builder(cls, tenant_id: str) -> SandboxBuilder:
with Session(db.engine, expire_on_commit=False) as session:
tenant_config = (
session.query(SandboxProvider)
.filter(
SandboxProvider.tenant_id == tenant_id,
SandboxProvider.is_active.is_(True),
)
.first()
)
config: Mapping[str, Any] = {}
provider_type = None
if tenant_config:
schema = PROVIDER_CONFIG_SCHEMAS.get(tenant_config.provider_type, [])
encrypter, _ = create_sandbox_config_encrypter(tenant_id, schema, tenant_config.provider_type)
config = encrypter.decrypt(tenant_config.config)
provider_type = tenant_config.provider_type
else:
system_default = session.query(SandboxProviderSystemConfig).first()
if system_default:
config = decrypt_system_params(system_default.encrypted_config)
provider_type = system_default.provider_type
if not config or not provider_type:
raise ValueError(f"No active sandbox provider for tenant {tenant_id} or system default")
return VMBuilder(tenant_id, VMType(provider_type)).options(config)
provider_type, config = cls.get_active_sandbox_config(session, tenant_id)
return SandboxBuilder(tenant_id, SandboxType(provider_type)).options(config)

View File

@ -3,20 +3,20 @@ from unittest.mock import MagicMock, patch
import pytest
from core.sandbox import VMBuilder, VMType
from core.sandbox import SandboxBuilder, SandboxType
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
class TestVMType:
def test_values(self):
assert VMType.DOCKER == "docker"
assert VMType.E2B == "e2b"
assert VMType.LOCAL == "local"
assert SandboxType.DOCKER == "docker"
assert SandboxType.E2B == "e2b"
assert SandboxType.LOCAL == "local"
def test_is_string_enum(self):
assert isinstance(VMType.DOCKER.value, str)
assert isinstance(VMType.E2B.value, str)
assert isinstance(VMType.LOCAL.value, str)
assert isinstance(SandboxType.DOCKER.value, str)
assert isinstance(SandboxType.E2B.value, str)
assert isinstance(SandboxType.LOCAL.value, str)
class TestVMBuilder:
@ -29,7 +29,7 @@ class TestVMBuilder:
mock_class,
):
result = (
VMBuilder("test-tenant", VMType.DOCKER)
SandboxBuilder("test-tenant", SandboxType.DOCKER)
.options({"docker_image": "python:3.11-slim"})
.environments({"PYTHONUNBUFFERED": "1"})
.build()
@ -51,7 +51,7 @@ class TestVMBuilder:
"core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment",
mock_class,
):
VMBuilder("test-tenant", VMType.DOCKER).user("user-123").build()
SandboxBuilder("test-tenant", SandboxType.DOCKER).user("user-123").build()
mock_class.assert_called_once_with(
tenant_id="test-tenant",
@ -69,7 +69,7 @@ class TestVMBuilder:
"core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment",
mock_class,
):
VMBuilder("test-tenant", VMType.DOCKER).initializer(mock_initializer).build()
SandboxBuilder("test-tenant", SandboxType.DOCKER).initializer(mock_initializer).build()
mock_initializer.initialize.assert_called_once_with(mock_instance)
@ -80,7 +80,7 @@ class TestVMBuilder:
"core.virtual_environment.providers.local_without_isolation.LocalVirtualEnvironment",
return_value=mock_instance,
) as mock_class:
VMBuilder("test-tenant", VMType.LOCAL).build()
SandboxBuilder("test-tenant", SandboxType.LOCAL).build()
mock_class.assert_called_once()
def test_build_e2b(self):
@ -90,12 +90,12 @@ class TestVMBuilder:
"core.virtual_environment.providers.e2b_sandbox.E2BEnvironment",
return_value=mock_instance,
) as mock_class:
VMBuilder("test-tenant", VMType.E2B).build()
SandboxBuilder("test-tenant", SandboxType.E2B).build()
mock_class.assert_called_once()
def test_build_unsupported_type_raises(self):
with pytest.raises(ValueError, match="Unsupported VM type"):
VMBuilder("test-tenant", "unsupported").build() # type: ignore[arg-type]
SandboxBuilder("test-tenant", "unsupported").build() # type: ignore[arg-type]
def test_validate(self):
mock_class = MagicMock()
@ -104,13 +104,13 @@ class TestVMBuilder:
"core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment",
mock_class,
):
VMBuilder.validate(VMType.DOCKER, {"key": "value"})
SandboxBuilder.validate(SandboxType.DOCKER, {"key": "value"})
mock_class.validate.assert_called_once_with({"key": "value"})
class TestVMBuilderIntegration:
def test_local_sandbox(self, tmp_path: Path):
sandbox = VMBuilder("test-tenant", VMType.LOCAL).options({"base_working_path": str(tmp_path)}).build()
sandbox = SandboxBuilder("test-tenant", SandboxType.LOCAL).options({"base_working_path": str(tmp_path)}).build()
try:
assert sandbox is not None

View File

@ -47,7 +47,7 @@ const ProviderCard = ({
<span className="system-md-semibold text-text-primary">
{provider.label}
</span>
{provider.is_system_configured && (
{provider.is_system_configured && !provider.is_tenant_configured && (
<span className="system-2xs-medium rounded-[5px] border border-divider-deep px-[5px] py-[3px] text-text-tertiary">
{t('sandboxProvider.managedBySaas', { ns: 'common' })}
</span>

View File

@ -10,18 +10,6 @@ export const getSandboxProviderListContract = base
.input(type<unknown>())
.output(type<SandboxProvider[]>())
export const getSandboxProviderContract = base
.route({
path: '/workspaces/current/sandbox-provider/{providerType}',
method: 'GET',
})
.input(type<{
params: {
providerType: string
}
}>())
.output(type<SandboxProvider>())
export const saveSandboxProviderConfigContract = base
.route({
path: '/workspaces/current/sandbox-provider/{providerType}/config',
@ -60,11 +48,3 @@ export const activateSandboxProviderContract = base
}
}>())
.output(type<{ result: string }>())
export const getActiveSandboxProviderContract = base
.route({
path: '/workspaces/current/sandbox-provider/active',
method: 'GET',
})
.input(type<unknown>())
.output(type<{ provider_type: string | null }>())

View File

@ -16,8 +16,6 @@ import { bindPartnerStackContract, invoicesContract } from './console/billing'
import {
activateSandboxProviderContract,
deleteSandboxProviderConfigContract,
getActiveSandboxProviderContract,
getSandboxProviderContract,
getSandboxProviderListContract,
saveSandboxProviderConfigContract,
} from './console/sandbox-provider'
@ -40,11 +38,9 @@ export const consoleRouterContract = {
},
sandboxProvider: {
getSandboxProviderList: getSandboxProviderListContract,
getSandboxProvider: getSandboxProviderContract,
saveSandboxProviderConfig: saveSandboxProviderConfigContract,
deleteSandboxProviderConfig: deleteSandboxProviderConfigContract,
activateSandboxProvider: activateSandboxProviderContract,
getActiveSandboxProvider: getActiveSandboxProviderContract,
},
appAsset: {
tree: treeContract,

View File

@ -12,14 +12,6 @@ export const useGetSandboxProviderList = () => {
})
}
export const useGetSandboxProvider = (providerType: string) => {
return useQuery({
queryKey: consoleQuery.sandboxProvider.getSandboxProvider.queryKey({ input: { params: { providerType } } }),
queryFn: () => consoleClient.sandboxProvider.getSandboxProvider({ params: { providerType } }),
enabled: !!providerType,
})
}
export const useSaveSandboxProviderConfig = () => {
const queryClient = useQueryClient()
return useMutation({
@ -65,10 +57,3 @@ export const useActivateSandboxProvider = () => {
},
})
}
export const useGetActiveSandboxProvider = () => {
return useQuery({
queryKey: consoleQuery.sandboxProvider.getActiveSandboxProvider.queryKey(),
queryFn: () => consoleClient.sandboxProvider.getActiveSandboxProvider(),
})
}