feat(sandbox): enhance sandbox initialization with draft support and asset management

- Introduced DraftAppAssetsInitializer for handling draft assets.
- Updated SandboxLayer to conditionally set sandbox ID and storage based on workflow version.
- Improved asset initialization logging and error handling.
- Refactored ArchiveSandboxStorage to support exclusion patterns during archiving.
- Modified command and LLM nodes to retrieve sandbox from workflow context, supporting draft workflows.
This commit is contained in:
Harry
2026-01-20 19:44:20 +08:00
parent da6fdc963c
commit 18a589003e
8 changed files with 114 additions and 369 deletions

View File

@ -1,6 +1,8 @@
import logging
from core.sandbox import AppAssetsInitializer, DifyCliInitializer, SandboxManager
from core.sandbox.constants import APP_ASSETS_PATH
from core.sandbox.initializer.app_assets_initializer import DraftAppAssetsInitializer
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
from core.sandbox.vm import SandboxBuilder
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@ -33,12 +35,11 @@ class SandboxLayer(GraphEngineLayer):
self._user_id = user_id
self._workflow_version = workflow_version
self._workflow_execution_id = workflow_execution_id
self._sandbox_id = (
self._workflow_execution_id
if self._workflow_version == Workflow.VERSION_DRAFT
else SandboxBuilder.draft_id(self._user_id)
is_draft = self._workflow_version == Workflow.VERSION_DRAFT
self._sandbox_id = SandboxBuilder.draft_id(self._user_id) if is_draft else self._workflow_execution_id
self._sandbox_storage = ArchiveSandboxStorage(
self._tenant_id, self._sandbox_id, exclude_patterns=[APP_ASSETS_PATH] if is_draft else None
)
self._sandbox_storage = ArchiveSandboxStorage(self._tenant_id, self._sandbox_id)
def on_graph_start(self) -> None:
try:
@ -61,9 +62,15 @@ class SandboxLayer(GraphEngineLayer):
)
AppAssetService.build_assets(self._tenant_id, self._app_id, assets)
assets_initializer = (
DraftAppAssetsInitializer(self._tenant_id, self._app_id, assets.id)
if is_draft
else AppAssetsInitializer(self._tenant_id, self._app_id, assets.id)
)
builder = (
SandboxProviderService.create_sandbox_builder(self._tenant_id)
.initializer(AppAssetsInitializer(self._tenant_id, self._app_id, assets.id))
.initializer(assets_initializer)
.initializer(DifyCliInitializer(self._tenant_id, self._user_id, self._app_id, assets.id))
)
try:
@ -78,12 +85,6 @@ class SandboxLayer(GraphEngineLayer):
raise SandboxInitializationError(f"Failed to build sandbox: {e}") from e
SandboxManager.register(self._sandbox_id, sandbox)
logger.info(
"Sandbox initialized, workflow_execution_id=%s, sandbox_id=%s, sandbox_arch=%s",
self._sandbox_id,
sandbox.metadata.id,
sandbox.metadata.arch,
)
# mount sandbox files from storage
mounted = self._sandbox_storage.mount(sandbox)

View File

@ -18,7 +18,6 @@ from ..constants import (
DIFY_CLI_PATH,
DIFY_CLI_TOOLS_ROOT,
)
from ..manager import SandboxManager
from .bash_tool import SandboxBashTool
logger = logging.getLogger(__name__)
@ -28,7 +27,7 @@ class SandboxBashSession:
def __init__(
self,
*,
workflow_execution_id: str,
sandbox: VirtualEnvironment,
tenant_id: str,
user_id: str,
node_id: str,
@ -36,7 +35,7 @@ class SandboxBashSession:
assets_id: str,
allow_tools: list[tuple[str, str]] | None,
) -> None:
self._workflow_execution_id = workflow_execution_id
self._sandbox = sandbox
self._tenant_id = tenant_id
self._user_id = user_id
self._node_id = node_id
@ -46,25 +45,18 @@ class SandboxBashSession:
self._assets_id = assets_id
self._allow_tools = allow_tools
self._sandbox = None
self._bash_tool = None
self._session_id = None
def __enter__(self) -> SandboxBashSession:
sandbox = SandboxManager.get(self._workflow_execution_id)
if sandbox is None:
raise RuntimeError(f"Sandbox not found for workflow_execution_id={self._workflow_execution_id}")
self._sandbox = sandbox
if self._allow_tools is not None:
if self._node_id is None:
raise ValueError("node_id is required when allow_tools is specified")
tools_path = self._setup_node_tools_directory(sandbox, self._node_id, self._allow_tools)
tools_path = self._setup_node_tools_directory(self._sandbox, self._node_id, self._allow_tools)
else:
tools_path = DIFY_CLI_GLOBAL_TOOLS_PATH
self._bash_tool = SandboxBashTool(sandbox=sandbox, tenant_id=self._tenant_id, tools_path=tools_path)
self._bash_tool = SandboxBashTool(sandbox=self._sandbox, tenant_id=self._tenant_id, tools_path=tools_path)
return self
def _setup_node_tools_directory(

View File

@ -6,7 +6,7 @@ from core.virtual_environment.__base.virtual_environment import VirtualEnvironme
from extensions.ext_storage import storage
from extensions.storage.file_presign_storage import FilePresignStorage
from ..constants import APP_ASSETS_ZIP_PATH
from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH
from .base import SandboxInitializer
logger = logging.getLogger(__name__)
@ -33,7 +33,36 @@ class AppAssetsInitializer(SandboxInitializer):
["sh", "-c", f"unzip {APP_ASSETS_ZIP_PATH} 2>/dev/null || [ $? -eq 1 ]"],
error_message="Failed to unzip assets",
)
.add(["rm", "-f", APP_ASSETS_ZIP_PATH], error_message="Failed to cleanup temp zip file")
.execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True)
)
logger.info(
"App assets initialized for app_id=%s, published_id=%s",
self._app_id,
self._assets_id,
)
class DraftAppAssetsInitializer(SandboxInitializer):
def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None:
self._tenant_id = tenant_id
self._app_id = app_id
self._assets_id = assets_id
def initialize(self, env: VirtualEnvironment) -> None:
zip_key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id)
download_url = FilePresignStorage(storage.storage_runner).get_download_url(zip_key)
(
pipeline(env)
.add(["rm", "-rf", APP_ASSETS_PATH])
.add(["wget", "-q", download_url, "-O", APP_ASSETS_ZIP_PATH], error_message="Failed to download assets zip")
# unzip with silent error and return 1 if the zip is empty
# FIXME(Mairuis): should use a more robust way to check if the zip is empty
.add(
["sh", "-c", f"unzip {APP_ASSETS_ZIP_PATH} 2>/dev/null || [ $? -eq 1 ]"],
error_message="Failed to unzip assets",
)
.execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True)
)

View File

@ -18,10 +18,19 @@ ARCHIVE_DOWNLOAD_TIMEOUT = 60 * 5
ARCHIVE_UPLOAD_TIMEOUT = 60 * 5
def build_tar_exclude_args(patterns: list[str]) -> list[str]:
return [f"--exclude={p}" for p in patterns]
class ArchiveSandboxStorage(SandboxStorage):
def __init__(self, tenant_id: str, sandbox_id: str):
_tenant_id: str
_sandbox_id: str
_exclude_patterns: list[str]
def __init__(self, tenant_id: str, sandbox_id: str, exclude_patterns: list[str] | None = None):
self._tenant_id = tenant_id
self._sandbox_id = sandbox_id
self._exclude_patterns = exclude_patterns or []
@property
def _storage_key(self) -> str:
@ -36,7 +45,7 @@ class ArchiveSandboxStorage(SandboxStorage):
try:
(
pipeline(sandbox)
.add(["wget", download_url, "-O", ARCHIVE_NAME], error_message="Failed to download archive")
.add(["wget", "-q", download_url, "-O", ARCHIVE_NAME], error_message="Failed to download archive")
.add(["tar", "-xzf", ARCHIVE_NAME], error_message="Failed to extract archive")
.add(["rm", ARCHIVE_NAME], error_message="Failed to cleanup archive")
.execute(timeout=ARCHIVE_DOWNLOAD_TIMEOUT, raise_on_error=True)
@ -53,10 +62,22 @@ class ArchiveSandboxStorage(SandboxStorage):
(
pipeline(sandbox)
.add(
["tar", "-czf", ARCHIVE_PATH, "--warning=no-file-changed", "-C", WORKSPACE_DIR, "."],
[
"tar",
"-czf",
ARCHIVE_PATH,
"--warning=no-file-changed",
*build_tar_exclude_args(self._exclude_patterns),
"-C",
WORKSPACE_DIR,
".",
],
error_message="Failed to create archive",
)
.add(["wget", upload_url, "-O", ARCHIVE_PATH], error_message="Failed to upload archive")
.add(
["curl", "-s", "-f", "-X", "PUT", "-T", ARCHIVE_PATH, upload_url],
error_message="Failed to upload archive",
)
.execute(timeout=ARCHIVE_UPLOAD_TIMEOUT, raise_on_error=True)
)
logger.info("Unmounted archive for sandbox %s", self._sandbox_id)

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import Any
from core.sandbox import SandboxManager, sandbox_debug
from core.sandbox.vm import SandboxBuilder
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
from core.virtual_environment.__base.helpers import submit_command, with_connection
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
@ -24,11 +25,18 @@ COMMAND_NODE_TIMEOUT_SECONDS = 60
class CommandNode(Node[CommandNodeData]):
node_type = NodeType.COMMAND
# FIXME(Mairuis): should read sandbox from workflow run context...
def _get_sandbox(self) -> VirtualEnvironment | None:
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
if not workflow_execution_id:
return None
return SandboxManager.get(workflow_execution_id)
sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id)
if sandbox_by_workflow_run_id is not None:
return sandbox_by_workflow_run_id
sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id))
if sandbox_by_draft_id is not None:
return sandbox_by_draft_id
return None
def _render_template(self, template: str) -> str:
parser = VariableTemplateParser(template=template)

View File

@ -51,6 +51,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.sandbox import SandboxBashSession, SandboxManager
from core.sandbox.vm import SandboxBuilder
from core.tools.__base.tool import Tool
from core.tools.signature import sign_upload_file
from core.tools.tool_manager import ToolManager
@ -63,6 +64,7 @@ from core.variables import (
ObjectSegment,
StringSegment,
)
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
from core.workflow.entities.tool_entities import ToolCallResult
@ -172,6 +174,19 @@ class LLMNode(Node[LLMNodeData]):
def version(cls) -> str:
return "1"
# FIXME(Mairuis): should read sandbox from workflow run context...
def _get_sandbox(self) -> VirtualEnvironment | None:
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
if not workflow_execution_id:
return None
sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id)
if sandbox_by_workflow_run_id is not None:
return sandbox_by_workflow_run_id
sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id))
if sandbox_by_draft_id is not None:
return sandbox_by_draft_id
return None
def _run(self) -> Generator:
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
@ -287,13 +302,11 @@ class LLMNode(Node[LLMNodeData]):
structured_output: LLMStructuredOutput | None = None
if self.tool_call_enabled:
workflow_execution_id = variable_pool.system_variables.workflow_execution_id
is_sandbox_runtime = workflow_execution_id is not None and SandboxManager.is_sandbox_runtime(
workflow_execution_id
)
if is_sandbox_runtime:
# FIXME(Mairuis): should read sandbox from workflow run context...
sandbox = self._get_sandbox()
if sandbox:
generator = self._invoke_llm_with_sandbox(
sandbox=sandbox,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
@ -1827,21 +1840,18 @@ class LLMNode(Node[LLMNodeData]):
def _invoke_llm_with_sandbox(
self,
sandbox: VirtualEnvironment,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
variable_pool: VariablePool,
) -> Generator[NodeEventBase, None, LLMGenerationData]:
workflow_execution_id = variable_pool.system_variables.workflow_execution_id
if not workflow_execution_id:
raise LLMNodeError("workflow_execution_id is required for sandbox runtime mode")
allow_tools = self._get_allow_tools_list()
result: LLMGenerationData | None = None
with SandboxBashSession(
workflow_execution_id=workflow_execution_id,
sandbox=sandbox,
tenant_id=self.tenant_id,
user_id=self.user_id,
node_id=self.id,

View File

@ -15,8 +15,9 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.sandbox import SandboxManager
from core.sandbox.constants import APP_ASSETS_PATH
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
from core.sandbox.storage.sandbox_storage import SandboxStorage
from core.sandbox.vm import SandboxBuilder
from core.variables import Variable, VariableBase
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -706,20 +707,24 @@ class WorkflowService:
from core.sandbox import AppAssetsInitializer, DifyCliInitializer
from services.app_asset_service import AppAssetService
assets = AppAssetService.get_or_create_assets(draft_workflow.tenant_id, app_model.id, is_draft=True)
assets = AppAssetService.get_or_create_assets(session, app_model, account.id)
if not assets:
raise ValueError(f"No assets found for tid={draft_workflow.tenant_id}, app_id={app_model.id}")
# FIXME(Mairuis): single step execution
AppAssetService.build_assets(draft_workflow.tenant_id, app_model.id, assets)
sandbox_id = SandboxBuilder.draft_id(account.id)
sandbox_storage = ArchiveSandboxStorage(
draft_workflow.tenant_id, sandbox_id, exclude_patterns=[APP_ASSETS_PATH]
)
sandbox = (
SandboxProviderService.create_sandbox_builder(draft_workflow.tenant_id)
.initializer(DifyCliInitializer(draft_workflow.tenant_id, account.id, app_model.id, assets.id))
.initializer(AppAssetsInitializer(draft_workflow.tenant_id, app_model.id, assets.id))
.storage(ArchiveSandboxStorage(draft_workflow.tenant_id, SandboxStorage.draft_id(account.id)))
.build()
)
sandbox_storage.mount(sandbox)
single_step_execution_id = f"single-step-{uuid.uuid4()}"
SandboxManager.register(single_step_execution_id, sandbox)
@ -742,6 +747,9 @@ class WorkflowService:
start_at=start_at,
node_id=node_id,
)
# FIXME(Mairuis): fidn a better way to handle this
if sandbox is not None:
sandbox_storage.unmount(sandbox)
finally:
if single_step_execution_id:
sandbox = SandboxManager.unregister(single_step_execution_id)

View File

@ -1,324 +0,0 @@
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer
from core.sandbox import SandboxManager
from core.sandbox.storage.sandbox_storage import SandboxStorage
from core.virtual_environment.__base.entities import Arch
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from models.app_asset import AppAssets
class MockMetadata:
def __init__(self, sandbox_id: str = "test-sandbox-id", arch: Arch = Arch.AMD64):
self.id = sandbox_id
self.arch = arch
class MockVirtualEnvironment:
def __init__(self, sandbox_id: str = "test-sandbox-id"):
self.metadata = MockMetadata(sandbox_id=sandbox_id)
self._released = False
def release_environment(self) -> None:
self._released = True
class MockVMBuilder:
_sandbox: VirtualEnvironment
def __init__(self, sandbox: VirtualEnvironment) -> None:
self._sandbox = sandbox
def environments(self, _: object) -> "MockVMBuilder":
return self
def initializer(self, _: object) -> "MockVMBuilder":
return self
def build(self) -> VirtualEnvironment:
return self._sandbox
@pytest.fixture(autouse=True)
def clean_sandbox_manager():
SandboxManager.clear()
yield
SandboxManager.clear()
@pytest.fixture
def mock_sandbox_storage() -> MagicMock:
mock_storage = MagicMock(spec=SandboxStorage)
mock_storage.mount.return_value = False
mock_storage.unmount.return_value = True
return mock_storage
def create_mock_builder(sandbox: Any) -> MockVMBuilder:
return MockVMBuilder(sandbox)
def create_layer(
tenant_id: str = "test-tenant",
app_id: str = "test-app",
workflow_version: str = AppAssets.VERSION_DRAFT,
sandbox_id: str = "test-sandbox",
sandbox_storage: Any = None,
) -> SandboxLayer:
if sandbox_storage is None:
sandbox_storage = MagicMock(spec=SandboxStorage)
sandbox_storage.mount.return_value = False
sandbox_storage.unmount.return_value = True
return SandboxLayer(
tenant_id=tenant_id,
app_id=app_id,
workflow_version=workflow_version,
sandbox_id=sandbox_id,
sandbox_storage=sandbox_storage,
)
class TestSandboxLayer:
def test_init_with_parameters(self, mock_sandbox_storage: MagicMock) -> None:
layer = create_layer(
tenant_id="test-tenant",
app_id="test-app",
sandbox_id="test-sandbox",
sandbox_storage=mock_sandbox_storage,
)
assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage]
assert layer._app_id == "test-app" # pyright: ignore[reportPrivateUsage]
assert layer._sandbox_id == "test-sandbox" # pyright: ignore[reportPrivateUsage]
def test_sandbox_property_raises_when_not_initialized(self, mock_sandbox_storage: MagicMock) -> None:
layer = create_layer(sandbox_storage=mock_sandbox_storage)
with pytest.raises(RuntimeError) as exc_info:
_ = layer.sandbox
assert "Sandbox not found" in str(exc_info.value)
def test_sandbox_property_returns_sandbox_after_initialization(self, mock_sandbox_storage: MagicMock) -> None:
sandbox_id = "test-exec-id"
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
mock_sandbox = MockVirtualEnvironment()
with (
patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
return_value=create_mock_builder(mock_sandbox),
),
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
):
layer.on_graph_start()
assert layer.sandbox is mock_sandbox
def test_on_graph_start_creates_sandbox_and_registers_with_manager(self, mock_sandbox_storage: MagicMock) -> None:
sandbox_id = "test-exec-123"
layer = create_layer(
tenant_id="test-tenant-123",
app_id="test-app-123",
sandbox_id=sandbox_id,
sandbox_storage=mock_sandbox_storage,
)
mock_sandbox = MockVirtualEnvironment()
with (
patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
return_value=create_mock_builder(mock_sandbox),
) as mock_create,
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
):
layer.on_graph_start()
mock_create.assert_called_once_with("test-tenant-123")
assert SandboxManager.get(sandbox_id) is mock_sandbox
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(
self, mock_sandbox_storage: MagicMock
) -> None:
layer = create_layer(sandbox_storage=mock_sandbox_storage)
with (
patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
side_effect=Exception("Sandbox provider not available"),
),
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
):
with pytest.raises(SandboxInitializationError) as exc_info:
layer.on_graph_start()
assert "Failed to initialize sandbox" in str(exc_info.value)
assert "Sandbox provider not available" in str(exc_info.value)
def test_on_event_is_noop(self, mock_sandbox_storage: MagicMock) -> None:
layer = create_layer(sandbox_storage=mock_sandbox_storage)
layer.on_event(GraphRunStartedEvent())
layer.on_event(GraphRunSucceededEvent(outputs={}))
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self, mock_sandbox_storage: MagicMock) -> None:
sandbox_id = "test-exec-456"
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
with (
patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
return_value=create_mock_builder(mock_sandbox),
),
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
):
layer.on_graph_start()
assert SandboxManager.has(sandbox_id)
layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once()
assert not SandboxManager.has(sandbox_id)
def test_on_graph_end_releases_sandbox_even_on_error(self, mock_sandbox_storage: MagicMock) -> None:
sandbox_id = "test-exec-789"
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
with (
patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
return_value=create_mock_builder(mock_sandbox),
),
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
):
layer.on_graph_start()
layer.on_graph_end(error=Exception("Workflow failed"))
mock_sandbox.release_environment.assert_called_once()
assert not SandboxManager.has(sandbox_id)
def test_on_graph_end_handles_release_failure_gracefully(self, mock_sandbox_storage: MagicMock) -> None:
sandbox_id = "test-exec-fail"
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
mock_sandbox.release_environment.side_effect = Exception("Container already removed")
with (
patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
return_value=create_mock_builder(mock_sandbox),
),
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
):
layer.on_graph_start()
layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once()
def test_on_graph_end_noop_when_sandbox_not_registered(self, mock_sandbox_storage: MagicMock) -> None:
layer = create_layer(sandbox_id="nonexistent-sandbox", sandbox_storage=mock_sandbox_storage)
layer.on_graph_end(error=None)
def test_on_graph_end_is_idempotent(self, mock_sandbox_storage: MagicMock) -> None:
sandbox_id = "test-exec-idempotent"
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
with (
patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
return_value=create_mock_builder(mock_sandbox),
),
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
):
layer.on_graph_start()
layer.on_graph_end(error=None)
layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once()
def test_layer_inherits_from_graph_engine_layer(self, mock_sandbox_storage: MagicMock) -> None:
layer = create_layer(sandbox_storage=mock_sandbox_storage)
with pytest.raises(GraphEngineLayerNotInitializedError):
_ = layer.graph_runtime_state
assert layer.command_channel is None
class TestSandboxLayerIntegration:
def test_full_lifecycle_with_mocked_provider(self, mock_sandbox_storage: MagicMock) -> None:
sandbox_id = "integration-test-exec"
layer = create_layer(
tenant_id="integration-tenant",
app_id="integration-app",
sandbox_id=sandbox_id,
sandbox_storage=mock_sandbox_storage,
)
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox")
with (
patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
return_value=create_mock_builder(mock_sandbox),
),
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
):
layer.on_graph_start()
assert layer.sandbox is mock_sandbox
assert SandboxManager.get(sandbox_id) is mock_sandbox
layer.on_graph_end(error=None)
assert not SandboxManager.has(sandbox_id)
mock_sandbox.release_environment.assert_called_once()
def test_lifecycle_with_workflow_error(self, mock_sandbox_storage: MagicMock) -> None:
sandbox_id = "integration-error-test"
layer = create_layer(
tenant_id="error-tenant",
app_id="error-app",
sandbox_id=sandbox_id,
sandbox_storage=mock_sandbox_storage,
)
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
with (
patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
return_value=create_mock_builder(mock_sandbox),
),
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
):
layer.on_graph_start()
assert layer.sandbox.metadata.id is not None
layer.on_graph_end(error=Exception("Workflow execution failed"))
assert not SandboxManager.has(sandbox_id)
mock_sandbox.release_environment.assert_called_once()