mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
feat(sandbox): add AppAssetsInitializer and refactor VMFactory to VMBuilder
- Add AppAssetsInitializer to load published app assets into sandbox - Refactor VMFactory.create() to VMBuilder with builder pattern - Extract SandboxInitializer base class and DifyCliInitializer - Simplify SandboxLayer constructor (remove options/environments params) - Fix circular import in sandbox module by removing eager SandboxBashTool export - Update SandboxProviderService to return VMBuilder instead of VirtualEnvironment
This commit is contained in:
@ -1,6 +1,4 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.sandbox.manager import SandboxManager
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
@ -15,16 +13,9 @@ class SandboxInitializationError(Exception):
|
||||
|
||||
|
||||
class SandboxLayer(GraphEngineLayer):
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
options: Mapping[str, Any] | None = None,
|
||||
environments: Mapping[str, str] | None = None,
|
||||
) -> None:
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
super().__init__()
|
||||
self._tenant_id = tenant_id
|
||||
self._options: Mapping[str, Any] = options or {}
|
||||
self._environments: Mapping[str, str] = environments or {}
|
||||
self._workflow_execution_id: str | None = None
|
||||
|
||||
def _get_workflow_execution_id(self) -> str:
|
||||
@ -46,13 +37,16 @@ class SandboxLayer(GraphEngineLayer):
|
||||
self._workflow_execution_id = self._get_workflow_execution_id()
|
||||
|
||||
try:
|
||||
from core.sandbox.initializer import AppAssetsInitializer, DifyCliInitializer
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
|
||||
logger.info("Initializing sandbox for tenant_id=%s", self._tenant_id)
|
||||
sandbox = SandboxProviderService.create_sandbox(
|
||||
tenant_id=self._tenant_id,
|
||||
environments=self._environments,
|
||||
)
|
||||
app_id = self.graph_runtime_state.system_variable.app_id
|
||||
logger.info("Initializing sandbox for tenant_id=%s, app_id=%s", self._tenant_id, app_id)
|
||||
|
||||
builder = SandboxProviderService.create_sandbox_builder(self._tenant_id).initializer(DifyCliInitializer())
|
||||
if app_id:
|
||||
builder.initializer(AppAssetsInitializer(self._tenant_id, app_id))
|
||||
sandbox = builder.build()
|
||||
|
||||
SandboxManager.register(self._workflow_execution_id, sandbox)
|
||||
logger.info(
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from core.sandbox.bash.bash_tool import SandboxBashTool
|
||||
from core.sandbox.bash.dify_cli import (
|
||||
DifyCliBinary,
|
||||
DifyCliConfig,
|
||||
@ -7,24 +6,26 @@ from core.sandbox.bash.dify_cli import (
|
||||
DifyCliToolConfig,
|
||||
)
|
||||
from core.sandbox.constants import (
|
||||
APP_ASSETS_PATH,
|
||||
APP_ASSETS_ZIP_PATH,
|
||||
DIFY_CLI_CONFIG_PATH,
|
||||
DIFY_CLI_PATH,
|
||||
DIFY_CLI_PATH_PATTERN,
|
||||
)
|
||||
from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer
|
||||
from core.sandbox.session import SandboxSession
|
||||
from core.sandbox.initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer
|
||||
|
||||
__all__ = [
|
||||
"APP_ASSETS_PATH",
|
||||
"APP_ASSETS_ZIP_PATH",
|
||||
"DIFY_CLI_CONFIG_PATH",
|
||||
"DIFY_CLI_PATH",
|
||||
"DIFY_CLI_PATH_PATTERN",
|
||||
"AppAssetsInitializer",
|
||||
"DifyCliBinary",
|
||||
"DifyCliConfig",
|
||||
"DifyCliEnvConfig",
|
||||
"DifyCliInitializer",
|
||||
"DifyCliLocator",
|
||||
"DifyCliToolConfig",
|
||||
"SandboxBashTool",
|
||||
"SandboxInitializer",
|
||||
"SandboxSession",
|
||||
]
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from core.sandbox.bash.bash_tool import SandboxBashTool
|
||||
from core.sandbox.bash.dify_cli import (
|
||||
DifyCliBinary,
|
||||
DifyCliConfig,
|
||||
@ -13,5 +12,4 @@ __all__ = [
|
||||
"DifyCliEnvConfig",
|
||||
"DifyCliLocator",
|
||||
"DifyCliToolConfig",
|
||||
"SandboxBashTool",
|
||||
]
|
||||
|
||||
@ -5,3 +5,7 @@ DIFY_CLI_PATH: Final[str] = ".dify/bin/dify"
|
||||
DIFY_CLI_PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}"
|
||||
|
||||
DIFY_CLI_CONFIG_PATH: Final[str] = ".dify_cli.json"
|
||||
|
||||
# App Assets
|
||||
APP_ASSETS_PATH: Final[str] = "assets"
|
||||
APP_ASSETS_ZIP_PATH: Final[str] = ".dify/tmp/assets.zip"
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -14,48 +16,66 @@ class VMType(StrEnum):
|
||||
LOCAL = "local"
|
||||
|
||||
|
||||
class VMFactory:
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
vm_type: VMType,
|
||||
options: Mapping[str, Any] | None = None,
|
||||
environments: Mapping[str, str] | None = None,
|
||||
user_id: str | None = None,
|
||||
initializers: Sequence["SandboxInitializer"] | None = None,
|
||||
) -> VirtualEnvironment:
|
||||
options = options or {}
|
||||
environments = environments or {}
|
||||
def _get_vm_class(vm_type: VMType) -> type[VirtualEnvironment]:
|
||||
match vm_type:
|
||||
case VMType.DOCKER:
|
||||
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
|
||||
|
||||
vm_class = cls._get_vm_class(vm_type)
|
||||
vm = vm_class(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id)
|
||||
return DockerDaemonEnvironment
|
||||
case VMType.E2B:
|
||||
from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment
|
||||
|
||||
if initializers:
|
||||
for initializer in initializers:
|
||||
initializer.initialize(vm)
|
||||
return E2BEnvironment
|
||||
case VMType.LOCAL:
|
||||
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||
|
||||
return LocalVirtualEnvironment
|
||||
case _:
|
||||
raise ValueError(f"Unsupported VM type: {vm_type}")
|
||||
|
||||
|
||||
class VMBuilder:
|
||||
def __init__(self, tenant_id: str, vm_type: VMType) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._vm_type = vm_type
|
||||
self._user_id: str | None = None
|
||||
self._options: dict[str, Any] = {}
|
||||
self._environments: dict[str, str] = {}
|
||||
self._initializers: list[SandboxInitializer] = []
|
||||
|
||||
def user(self, user_id: str) -> VMBuilder:
|
||||
self._user_id = user_id
|
||||
return self
|
||||
|
||||
def options(self, options: Mapping[str, Any]) -> VMBuilder:
|
||||
self._options = dict(options)
|
||||
return self
|
||||
|
||||
def environments(self, environments: Mapping[str, str]) -> VMBuilder:
|
||||
self._environments = dict(environments)
|
||||
return self
|
||||
|
||||
def initializer(self, initializer: SandboxInitializer) -> VMBuilder:
|
||||
self._initializers.append(initializer)
|
||||
return self
|
||||
|
||||
def initializers(self, initializers: Sequence[SandboxInitializer]) -> VMBuilder:
|
||||
self._initializers.extend(initializers)
|
||||
return self
|
||||
|
||||
def build(self) -> VirtualEnvironment:
|
||||
vm_class = _get_vm_class(self._vm_type)
|
||||
vm = vm_class(
|
||||
tenant_id=self._tenant_id,
|
||||
options=self._options,
|
||||
environments=self._environments,
|
||||
user_id=self._user_id,
|
||||
)
|
||||
for init in self._initializers:
|
||||
init.initialize(vm)
|
||||
return vm
|
||||
|
||||
@classmethod
|
||||
def _get_vm_class(cls, vm_type: VMType) -> type[VirtualEnvironment]:
|
||||
match vm_type:
|
||||
case VMType.DOCKER:
|
||||
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
|
||||
|
||||
return DockerDaemonEnvironment
|
||||
case VMType.E2B:
|
||||
from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment
|
||||
|
||||
return E2BEnvironment
|
||||
case VMType.LOCAL:
|
||||
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||
|
||||
return LocalVirtualEnvironment
|
||||
case _:
|
||||
raise ValueError(f"Unsupported VM type: {vm_type}")
|
||||
|
||||
@classmethod
|
||||
def validate(cls, vm_type: VMType, options: Mapping[str, Any]) -> None:
|
||||
vm_class = cls._get_vm_class(vm_type)
|
||||
@staticmethod
|
||||
def validate(vm_type: VMType, options: Mapping[str, Any]) -> None:
|
||||
vm_class = _get_vm_class(vm_type)
|
||||
vm_class.validate(options)
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer
|
||||
from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer
|
||||
from core.sandbox.initializer.base import SandboxInitializer
|
||||
from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer
|
||||
|
||||
__all__ = [
|
||||
"AppAssetsInitializer",
|
||||
"DifyCliInitializer",
|
||||
"SandboxInitializer",
|
||||
]
|
||||
|
||||
86
api/core/sandbox/initializer/app_assets_initializer.py
Normal file
86
api/core/sandbox/initializer/app_assets_initializer.py
Normal file
@ -0,0 +1,86 @@
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.sandbox.constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH
|
||||
from core.sandbox.initializer.base import SandboxInitializer
|
||||
from core.virtual_environment.__base.helpers import execute, with_connection
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.app_asset import AppAssetDraft
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppAssetsInitializer(SandboxInitializer):
|
||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def initialize(self, env: VirtualEnvironment) -> None:
|
||||
published = self._get_latest_published()
|
||||
if not published:
|
||||
logger.debug("No published assets for app_id=%s, skipping", self._app_id)
|
||||
return
|
||||
|
||||
zip_key = AppAssetDraft.get_published_storage_key(self._tenant_id, self._app_id, published.id)
|
||||
try:
|
||||
zip_data = storage.load_once(zip_key)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to load assets zip for app_id=%s, key=%s",
|
||||
self._app_id,
|
||||
zip_key,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
env.upload_file(APP_ASSETS_ZIP_PATH, BytesIO(zip_data))
|
||||
|
||||
with with_connection(env) as conn:
|
||||
execute(
|
||||
env,
|
||||
["mkdir", "-p", ".dify/tmp"],
|
||||
connection=conn,
|
||||
error_message="Failed to create temp directory",
|
||||
)
|
||||
execute(
|
||||
env,
|
||||
["mkdir", "-p", APP_ASSETS_PATH],
|
||||
connection=conn,
|
||||
error_message="Failed to create assets directory",
|
||||
)
|
||||
execute(
|
||||
env,
|
||||
["unzip", "-o", APP_ASSETS_ZIP_PATH, "-d", APP_ASSETS_PATH],
|
||||
connection=conn,
|
||||
timeout=60,
|
||||
error_message="Failed to unzip assets",
|
||||
)
|
||||
execute(
|
||||
env,
|
||||
["rm", "-f", APP_ASSETS_ZIP_PATH],
|
||||
connection=conn,
|
||||
error_message="Failed to cleanup temp zip file",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"App assets initialized for app_id=%s, published_id=%s",
|
||||
self._app_id,
|
||||
published.id,
|
||||
)
|
||||
|
||||
def _get_latest_published(self) -> AppAssetDraft | None:
|
||||
with Session(db.engine) as session:
|
||||
return (
|
||||
session.query(AppAssetDraft)
|
||||
.filter(
|
||||
AppAssetDraft.tenant_id == self._tenant_id,
|
||||
AppAssetDraft.app_id == self._app_id,
|
||||
AppAssetDraft.version != AppAssetDraft.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(AppAssetDraft.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
8
api/core/sandbox/initializer/base.py
Normal file
8
api/core/sandbox/initializer/base.py
Normal file
@ -0,0 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
class SandboxInitializer(ABC):
|
||||
@abstractmethod
|
||||
def initialize(self, env: VirtualEnvironment) -> None: ...
|
||||
@ -1,21 +1,16 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from core.sandbox.bash.dify_cli import DifyCliLocator
|
||||
from core.sandbox.constants import DIFY_CLI_PATH
|
||||
from core.sandbox.initializer.base import SandboxInitializer
|
||||
from core.virtual_environment.__base.helpers import execute
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxInitializer(ABC):
|
||||
@abstractmethod
|
||||
def initialize(self, env: VirtualEnvironment) -> None: ...
|
||||
|
||||
|
||||
class DifyCliInitializer(SandboxInitializer):
|
||||
def __init__(self, cli_root: str | Path | None = None) -> None:
|
||||
self._locator = DifyCliLocator(root=cli_root)
|
||||
@ -1,8 +1,9 @@
|
||||
"""Sandbox debug utilities. TODO: Remove this module when sandbox debugging is complete."""
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.callback_handler.agent_tool_callback_handler import print_text
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
SANDBOX_DEBUG_ENABLED = True
|
||||
|
||||
@ -11,6 +12,9 @@ def sandbox_debug(tag: str, message: str, data: Any = None) -> None:
|
||||
if not SANDBOX_DEBUG_ENABLED:
|
||||
return
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
from core.callback_handler.agent_tool_callback_handler import print_text
|
||||
|
||||
print_text(f"\n[{tag}]\n", color="blue")
|
||||
if data is not None:
|
||||
print_text(f"{message}: {data}\n", color="blue")
|
||||
|
||||
@ -19,13 +19,11 @@ from sqlalchemy.orm import Session
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.sandbox.factory import VMFactory, VMType
|
||||
from core.sandbox.initializer import DifyCliInitializer
|
||||
from core.sandbox.factory import VMBuilder, VMType
|
||||
from core.sandbox.utils.encryption import create_sandbox_config_encrypter, masked_config
|
||||
from core.tools.utils.system_encryption import (
|
||||
decrypt_system_params,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from extensions.ext_database import db
|
||||
from models.sandbox import SandboxProvider, SandboxProviderSystemConfig
|
||||
|
||||
@ -175,7 +173,7 @@ class SandboxProviderService:
|
||||
if model_class:
|
||||
model_class.model_validate(config)
|
||||
|
||||
VMFactory.validate(VMType(provider_type), config)
|
||||
VMBuilder.validate(VMType(provider_type), config)
|
||||
|
||||
@classmethod
|
||||
def save_config(
|
||||
@ -306,13 +304,8 @@ class SandboxProviderService:
|
||||
return config.provider_type if config else None
|
||||
|
||||
@classmethod
|
||||
def create_sandbox(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
environments: Mapping[str, str] | None = None,
|
||||
) -> VirtualEnvironment:
|
||||
def create_sandbox_builder(cls, tenant_id: str) -> VMBuilder:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get config: tenant config > system default > raise error
|
||||
tenant_config = (
|
||||
session.query(SandboxProvider)
|
||||
.filter(
|
||||
@ -337,10 +330,4 @@ class SandboxProviderService:
|
||||
if not config or not provider_type:
|
||||
raise ValueError(f"No active sandbox provider for tenant {tenant_id} or system default")
|
||||
|
||||
return VMFactory.create(
|
||||
tenant_id=tenant_id,
|
||||
vm_type=VMType(provider_type),
|
||||
options=dict(config),
|
||||
environments=environments or {},
|
||||
initializers=[DifyCliInitializer()],
|
||||
)
|
||||
return VMBuilder(tenant_id, VMType(provider_type)).options(config)
|
||||
|
||||
@ -701,7 +701,14 @@ class WorkflowService:
|
||||
sandbox = None
|
||||
single_step_execution_id: str | None = None
|
||||
if draft_workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
sandbox = SandboxProviderService.create_sandbox(tenant_id=draft_workflow.tenant_id)
|
||||
from core.sandbox.initializer import AppAssetsInitializer, DifyCliInitializer
|
||||
|
||||
sandbox = (
|
||||
SandboxProviderService.create_sandbox_builder(draft_workflow.tenant_id)
|
||||
.initializer(DifyCliInitializer())
|
||||
.initializer(AppAssetsInitializer(draft_workflow.tenant_id, app_model.id))
|
||||
.build()
|
||||
)
|
||||
single_step_execution_id = f"single-step-{uuid.uuid4()}"
|
||||
|
||||
SandboxManager.register(single_step_execution_id, sandbox)
|
||||
|
||||
@ -30,23 +30,50 @@ class MockVirtualEnvironment:
|
||||
|
||||
|
||||
class MockSystemVariableView:
|
||||
def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"):
|
||||
def __init__(
|
||||
self,
|
||||
workflow_execution_id: str | None = "test-workflow-exec-id",
|
||||
app_id: str | None = "test-app-id",
|
||||
):
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
self._app_id = app_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._workflow_execution_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str | None:
|
||||
return self._app_id
|
||||
|
||||
|
||||
class MockReadOnlyGraphRuntimeStateWrapper:
|
||||
def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"):
|
||||
self._system_variable = MockSystemVariableView(workflow_execution_id)
|
||||
def __init__(
|
||||
self,
|
||||
workflow_execution_id: str | None = "test-workflow-exec-id",
|
||||
app_id: str | None = "test-app-id",
|
||||
):
|
||||
self._system_variable = MockSystemVariableView(workflow_execution_id, app_id)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> MockSystemVariableView:
|
||||
return self._system_variable
|
||||
|
||||
|
||||
class MockVMBuilder:
|
||||
def __init__(self, sandbox: VirtualEnvironment):
|
||||
self._sandbox = sandbox
|
||||
|
||||
def environments(self, _):
|
||||
return self
|
||||
|
||||
def initializer(self, _):
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self._sandbox
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_sandbox_manager():
|
||||
SandboxManager.clear()
|
||||
@ -54,17 +81,15 @@ def clean_sandbox_manager():
|
||||
SandboxManager.clear()
|
||||
|
||||
|
||||
def create_mock_builder(sandbox):
|
||||
return MockVMBuilder(sandbox)
|
||||
|
||||
|
||||
class TestSandboxLayer:
|
||||
def test_init_with_parameters(self):
|
||||
layer = SandboxLayer(
|
||||
tenant_id="test-tenant",
|
||||
options={"base_working_path": "/tmp/sandbox"},
|
||||
environments={"PYTHONUNBUFFERED": "1"},
|
||||
)
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
|
||||
assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._options == {"base_working_path": "/tmp/sandbox"} # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._environments == {"PYTHONUNBUFFERED": "1"} # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_sandbox_property_raises_when_not_initialized(self):
|
||||
@ -82,32 +107,25 @@ class TestSandboxLayer:
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
assert layer.sandbox is mock_sandbox
|
||||
|
||||
def test_on_graph_start_creates_sandbox_and_registers_with_manager(self):
|
||||
layer = SandboxLayer(
|
||||
tenant_id="test-tenant-123",
|
||||
environments={"PATH": "/usr/bin"},
|
||||
)
|
||||
layer = SandboxLayer(tenant_id="test-tenant-123")
|
||||
mock_sandbox = MockVirtualEnvironment()
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-123")
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-123", "test-app-123")
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
) as mock_create:
|
||||
layer.on_graph_start()
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
tenant_id="test-tenant-123",
|
||||
environments={"PATH": "/usr/bin"},
|
||||
)
|
||||
mock_create.assert_called_once_with("test-tenant-123")
|
||||
|
||||
assert SandboxManager.get("test-exec-123") is mock_sandbox
|
||||
|
||||
@ -117,7 +135,7 @@ class TestSandboxLayer:
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
side_effect=Exception("Sandbox provider not available"),
|
||||
):
|
||||
with pytest.raises(SandboxInitializationError) as exc_info:
|
||||
@ -152,8 +170,8 @@ class TestSandboxLayer:
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
@ -174,8 +192,8 @@ class TestSandboxLayer:
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
@ -195,8 +213,8 @@ class TestSandboxLayer:
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
@ -221,8 +239,8 @@ class TestSandboxLayer:
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
@ -250,8 +268,8 @@ class TestSandboxLayerIntegration:
|
||||
mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox")
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
@ -274,8 +292,8 @@ class TestSandboxLayerIntegration:
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
|
||||
@ -1,148 +1,116 @@
|
||||
"""
|
||||
Unit tests for the SandboxFactory.
|
||||
|
||||
This module tests the factory pattern implementation for creating VirtualEnvironment instances
|
||||
based on sandbox type, including error handling for unsupported types.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.sandbox.factory import VMFactory, VMType
|
||||
from core.sandbox.factory import VMBuilder, VMType
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
class TestSandboxType:
|
||||
"""Test cases for SandboxType enum."""
|
||||
|
||||
def test_sandbox_type_values(self):
|
||||
"""Test that SandboxType enum has expected values."""
|
||||
class TestVMType:
|
||||
def test_values(self):
|
||||
assert VMType.DOCKER == "docker"
|
||||
assert VMType.E2B == "e2b"
|
||||
assert VMType.LOCAL == "local"
|
||||
|
||||
def test_sandbox_type_is_string_enum(self):
|
||||
"""Test that SandboxType values are strings."""
|
||||
def test_is_string_enum(self):
|
||||
assert isinstance(VMType.DOCKER.value, str)
|
||||
assert isinstance(VMType.E2B.value, str)
|
||||
assert isinstance(VMType.LOCAL.value, str)
|
||||
|
||||
|
||||
class TestSandboxFactory:
|
||||
"""Test cases for SandboxFactory."""
|
||||
class TestVMBuilder:
|
||||
def test_build_docker(self):
|
||||
mock_instance = MagicMock(spec=VirtualEnvironment)
|
||||
mock_class = MagicMock(return_value=mock_instance)
|
||||
|
||||
def test_create_docker_sandbox_success(self):
|
||||
"""Test successful Docker sandbox creation."""
|
||||
mock_sandbox_instance = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox_class = MagicMock(return_value=mock_sandbox_instance)
|
||||
|
||||
with patch.object(VMFactory, "_get_sandbox_class", return_value=mock_sandbox_class):
|
||||
result = VMFactory.create(
|
||||
tenant_id="test-tenant",
|
||||
vm_type=VMType.DOCKER,
|
||||
options={"docker_image": "python:3.11-slim"},
|
||||
environments={"PYTHONUNBUFFERED": "1"},
|
||||
with patch(
|
||||
"core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment",
|
||||
mock_class,
|
||||
):
|
||||
result = (
|
||||
VMBuilder("test-tenant", VMType.DOCKER)
|
||||
.options({"docker_image": "python:3.11-slim"})
|
||||
.environments({"PYTHONUNBUFFERED": "1"})
|
||||
.build()
|
||||
)
|
||||
|
||||
mock_sandbox_class.assert_called_once_with(
|
||||
mock_class.assert_called_once_with(
|
||||
tenant_id="test-tenant",
|
||||
options={"docker_image": "python:3.11-slim"},
|
||||
environments={"PYTHONUNBUFFERED": "1"},
|
||||
user_id=None,
|
||||
)
|
||||
assert result is mock_sandbox_instance
|
||||
assert result is mock_instance
|
||||
|
||||
def test_create_with_none_options_uses_empty_dict(self):
|
||||
"""Test that None options are converted to empty dict."""
|
||||
mock_sandbox_instance = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox_class = MagicMock(return_value=mock_sandbox_instance)
|
||||
|
||||
with patch.object(VMFactory, "_get_sandbox_class", return_value=mock_sandbox_class):
|
||||
VMFactory.create(tenant_id="test-tenant", vm_type=VMType.DOCKER, options=None, environments=None)
|
||||
|
||||
mock_sandbox_class.assert_called_once_with(
|
||||
tenant_id="test-tenant", options={}, environments={}, user_id=None
|
||||
)
|
||||
|
||||
def test_create_with_default_parameters(self):
|
||||
"""Test sandbox creation with default parameters."""
|
||||
mock_sandbox_instance = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox_class = MagicMock(return_value=mock_sandbox_instance)
|
||||
|
||||
with patch.object(VMFactory, "_get_sandbox_class", return_value=mock_sandbox_class):
|
||||
result = VMFactory.create(tenant_id="test-tenant", vm_type=VMType.DOCKER)
|
||||
|
||||
mock_sandbox_class.assert_called_once_with(
|
||||
tenant_id="test-tenant", options={}, environments={}, user_id=None
|
||||
)
|
||||
assert result is mock_sandbox_instance
|
||||
|
||||
def test_get_sandbox_class_docker_returns_correct_class(self):
|
||||
"""Test that DOCKER type returns DockerDaemonEnvironment class."""
|
||||
# Test by creating with mock to verify the class lookup works
|
||||
def test_build_with_user(self):
|
||||
mock_instance = MagicMock(spec=VirtualEnvironment)
|
||||
mock_class = MagicMock(return_value=mock_instance)
|
||||
|
||||
with patch(
|
||||
"core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment",
|
||||
return_value=mock_instance,
|
||||
) as mock_docker_class:
|
||||
VMFactory.create(tenant_id="test-tenant", vm_type=VMType.DOCKER)
|
||||
mock_docker_class.assert_called_once()
|
||||
mock_class,
|
||||
):
|
||||
VMBuilder("test-tenant", VMType.DOCKER).user("user-123").build()
|
||||
|
||||
def test_get_sandbox_class_local_returns_correct_class(self):
|
||||
"""Test that LOCAL type returns LocalVirtualEnvironment class."""
|
||||
mock_class.assert_called_once_with(
|
||||
tenant_id="test-tenant",
|
||||
options={},
|
||||
environments={},
|
||||
user_id="user-123",
|
||||
)
|
||||
|
||||
def test_build_with_initializers(self):
|
||||
mock_instance = MagicMock(spec=VirtualEnvironment)
|
||||
mock_class = MagicMock(return_value=mock_instance)
|
||||
mock_initializer = MagicMock()
|
||||
|
||||
with patch(
|
||||
"core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment",
|
||||
mock_class,
|
||||
):
|
||||
VMBuilder("test-tenant", VMType.DOCKER).initializer(mock_initializer).build()
|
||||
|
||||
mock_initializer.initialize.assert_called_once_with(mock_instance)
|
||||
|
||||
def test_build_local(self):
|
||||
mock_instance = MagicMock(spec=VirtualEnvironment)
|
||||
|
||||
with patch(
|
||||
"core.virtual_environment.providers.local_without_isolation.LocalVirtualEnvironment",
|
||||
return_value=mock_instance,
|
||||
) as mock_local_class:
|
||||
VMFactory.create(tenant_id="test-tenant", vm_type=VMType.LOCAL)
|
||||
mock_local_class.assert_called_once()
|
||||
) as mock_class:
|
||||
VMBuilder("test-tenant", VMType.LOCAL).build()
|
||||
mock_class.assert_called_once()
|
||||
|
||||
def test_get_sandbox_class_e2b_returns_correct_class(self):
|
||||
"""Test that E2B type returns E2BEnvironment class."""
|
||||
def test_build_e2b(self):
|
||||
mock_instance = MagicMock(spec=VirtualEnvironment)
|
||||
|
||||
with patch(
|
||||
"core.virtual_environment.providers.e2b_sandbox.E2BEnvironment",
|
||||
return_value=mock_instance,
|
||||
) as mock_e2b_class:
|
||||
VMFactory.create(tenant_id="test-tenant", vm_type=VMType.E2B)
|
||||
mock_e2b_class.assert_called_once()
|
||||
) as mock_class:
|
||||
VMBuilder("test-tenant", VMType.E2B).build()
|
||||
mock_class.assert_called_once()
|
||||
|
||||
def test_create_with_unsupported_type_raises_value_error(self):
|
||||
"""Test that unsupported sandbox type raises ValueError."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
VMFactory.create(tenant_id="test-tenant", vm_type="unsupported_type") # type: ignore[arg-type]
|
||||
def test_build_unsupported_type_raises(self):
|
||||
with pytest.raises(ValueError, match="Unsupported VM type"):
|
||||
VMBuilder("test-tenant", "unsupported").build() # type: ignore[arg-type]
|
||||
|
||||
assert "Unsupported sandbox type: unsupported_type" in str(exc_info.value)
|
||||
def test_validate(self):
|
||||
mock_class = MagicMock()
|
||||
|
||||
def test_create_propagates_instantiation_error(self):
|
||||
"""Test that sandbox instantiation errors are propagated."""
|
||||
mock_sandbox_class = MagicMock()
|
||||
mock_sandbox_class.side_effect = Exception("Docker daemon not available")
|
||||
|
||||
with patch.object(VMFactory, "_get_sandbox_class", return_value=mock_sandbox_class):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
VMFactory.create(tenant_id="test-tenant", vm_type=VMType.DOCKER)
|
||||
|
||||
assert "Docker daemon not available" in str(exc_info.value)
|
||||
with patch(
|
||||
"core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment",
|
||||
mock_class,
|
||||
):
|
||||
VMBuilder.validate(VMType.DOCKER, {"key": "value"})
|
||||
mock_class.validate.assert_called_once_with({"key": "value"})
|
||||
|
||||
|
||||
class TestSandboxFactoryIntegration:
|
||||
"""Integration tests for SandboxFactory with real providers (using LOCAL type)."""
|
||||
|
||||
def test_create_local_sandbox_integration(self, tmp_path: Path):
|
||||
"""Test creating a real local sandbox."""
|
||||
sandbox = VMFactory.create(
|
||||
tenant_id="test-tenant",
|
||||
vm_type=VMType.LOCAL,
|
||||
options={"base_working_path": str(tmp_path)},
|
||||
environments={},
|
||||
)
|
||||
class TestVMBuilderIntegration:
|
||||
def test_local_sandbox(self, tmp_path: Path):
|
||||
sandbox = VMBuilder("test-tenant", VMType.LOCAL).options({"base_working_path": str(tmp_path)}).build()
|
||||
|
||||
try:
|
||||
assert sandbox is not None
|
||||
|
||||
Reference in New Issue
Block a user