feat(sandbox-layer): refactor sandbox management and integrate with SandboxManager

- Simplified the SandboxLayer initialization by removing unused parameters and consolidating sandbox creation logic.
- Integrated SandboxManager for better lifecycle management of sandboxes during workflow execution.
- Updated error handling to ensure proper initialization and cleanup of sandboxes.
- Enhanced CommandNode to retrieve sandboxes from SandboxManager, improving sandbox availability checks.
- Added unit tests to validate the new sandbox management approach and ensure robust error handling.
This commit is contained in:
Harry
2026-01-09 11:08:55 +08:00
parent b09a831d15
commit 0da4d64d38
7 changed files with 481 additions and 270 deletions

View File

@ -1,11 +1,3 @@
"""
Unit tests for the SandboxLayer.
This module tests the SandboxLayer lifecycle management including initialization,
event handling, and cleanup of VirtualEnvironment instances.
"""
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@ -13,7 +5,7 @@ import pytest
from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer
from core.virtual_environment.__base.entities import Arch
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.factory import SandboxFactory, SandboxType
from core.virtual_environment.sandbox_manager import SandboxManager
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
@ -23,16 +15,12 @@ from core.workflow.graph_events.graph import (
class MockMetadata:
"""Mock metadata for testing."""
def __init__(self, sandbox_id: str = "test-sandbox-id", arch: Arch = Arch.AMD64):
self.id = sandbox_id
self.arch = arch
class MockVirtualEnvironment:
"""Mock VirtualEnvironment for testing."""
def __init__(self, sandbox_id: str = "test-sandbox-id"):
self.metadata = MockMetadata(sandbox_id=sandbox_id)
self._released = False
@ -41,33 +29,46 @@ class MockVirtualEnvironment:
self._released = True
class MockSystemVariableView:
def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"):
self._workflow_execution_id = workflow_execution_id
@property
def workflow_execution_id(self) -> str | None:
return self._workflow_execution_id
class MockReadOnlyGraphRuntimeStateWrapper:
def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"):
self._system_variable = MockSystemVariableView(workflow_execution_id)
@property
def system_variable(self) -> MockSystemVariableView:
return self._system_variable
@pytest.fixture(autouse=True)
def clean_sandbox_manager():
SandboxManager.clear()
yield
SandboxManager.clear()
class TestSandboxLayer:
"""Unit tests for SandboxLayer."""
def test_init_with_default_parameters(self):
"""Test SandboxLayer initialization with default parameters."""
layer = SandboxLayer()
assert layer._sandbox_type is None # pyright: ignore[reportPrivateUsage]
assert layer._options == {} # pyright: ignore[reportPrivateUsage]
assert layer._environments == {} # pyright: ignore[reportPrivateUsage]
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
def test_init_with_custom_parameters(self):
"""Test SandboxLayer initialization with custom parameters."""
def test_init_with_parameters(self):
layer = SandboxLayer(
sandbox_type=SandboxType.LOCAL,
tenant_id="test-tenant",
options={"base_working_path": "/tmp/sandbox"},
environments={"PYTHONUNBUFFERED": "1"},
)
assert layer._sandbox_type == SandboxType.LOCAL # pyright: ignore[reportPrivateUsage]
assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage]
assert layer._options == {"base_working_path": "/tmp/sandbox"} # pyright: ignore[reportPrivateUsage]
assert layer._environments == {"PYTHONUNBUFFERED": "1"} # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
def test_sandbox_property_raises_when_not_initialized(self):
"""Test that accessing sandbox property raises error before initialization."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
with pytest.raises(RuntimeError) as exc_info:
_ = layer.sandbox
@ -75,170 +76,213 @@ class TestSandboxLayer:
assert "Sandbox not initialized" in str(exc_info.value)
def test_sandbox_property_returns_sandbox_after_initialization(self):
"""Test that sandbox property returns the sandbox after on_graph_start."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
mock_sandbox = MockVirtualEnvironment()
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id")
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
assert layer.sandbox is mock_sandbox
def test_on_graph_start_creates_sandbox(self):
"""Test that on_graph_start creates a sandbox via factory."""
def test_on_graph_start_creates_sandbox_and_registers_with_manager(self):
layer = SandboxLayer(
sandbox_type=SandboxType.DOCKER,
options={"docker_image": "python:3.11"},
tenant_id="test-tenant-123",
environments={"PATH": "/usr/bin"},
)
mock_sandbox = MockVirtualEnvironment()
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-123")
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox) as mock_create:
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
) as mock_create:
layer.on_graph_start()
mock_create.assert_called_once_with(
tenant_id="default",
sandbox_type=SandboxType.DOCKER,
options={"docker_image": "python:3.11"},
tenant_id="test-tenant-123",
environments={"PATH": "/usr/bin"},
)
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self):
"""Test that on_graph_start raises SandboxInitializationError on factory failure."""
layer = SandboxLayer(sandbox_type=SandboxType.DOCKER)
assert SandboxManager.get("test-exec-123") is mock_sandbox
with patch.object(SandboxFactory, "create", side_effect=Exception("Docker not available")):
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self):
layer = SandboxLayer(tenant_id="test-tenant")
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id")
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
side_effect=Exception("Sandbox provider not available"),
):
with pytest.raises(SandboxInitializationError) as exc_info:
layer.on_graph_start()
assert "Failed to initialize sandbox" in str(exc_info.value)
assert "Docker not available" in str(exc_info.value)
assert "Sandbox provider not available" in str(exc_info.value)
def test_on_graph_start_raises_when_workflow_execution_id_not_set(self):
layer = SandboxLayer(tenant_id="test-tenant")
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id=None)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with pytest.raises(RuntimeError) as exc_info:
layer.on_graph_start()
assert "workflow_execution_id is not set" in str(exc_info.value)
def test_on_event_is_noop(self):
"""Test that on_event does nothing (no-op)."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
# These should not raise any exceptions
layer.on_event(GraphRunStartedEvent())
layer.on_event(GraphRunSucceededEvent(outputs={}))
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
def test_on_graph_end_releases_sandbox(self):
"""Test that on_graph_end releases the sandbox."""
layer = SandboxLayer()
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self):
layer = SandboxLayer(tenant_id="test-tenant")
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
workflow_execution_id = "test-exec-456"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
assert SandboxManager.has(workflow_execution_id)
layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once()
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
assert not SandboxManager.has(workflow_execution_id)
def test_on_graph_end_releases_sandbox_even_on_error(self):
"""Test that on_graph_end releases sandbox even when workflow had an error."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
workflow_execution_id = "test-exec-789"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
layer.on_graph_end(error=Exception("Workflow failed"))
mock_sandbox.release_environment.assert_called_once()
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
assert not SandboxManager.has(workflow_execution_id)
def test_on_graph_end_handles_release_failure_gracefully(self):
"""Test that on_graph_end handles release failures without raising."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
mock_sandbox.release_environment.side_effect = Exception("Container already removed")
workflow_execution_id = "test-exec-fail"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
# Should not raise exception
layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once()
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
def test_on_graph_end_noop_when_sandbox_not_initialized(self):
"""Test that on_graph_end is a no-op when sandbox was never initialized."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
# Should not raise exception
layer.on_graph_end(error=None)
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
def test_on_graph_end_is_idempotent(self):
"""Test that calling on_graph_end multiple times is safe."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
workflow_execution_id = "test-exec-idempotent"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
layer.on_graph_end(error=None)
layer.on_graph_end(error=None) # Second call should be no-op
layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once()
def test_layer_inherits_from_graph_engine_layer(self):
"""Test that SandboxLayer properly inherits from GraphEngineLayer."""
layer = SandboxLayer()
layer = SandboxLayer(tenant_id="test-tenant")
# Should have the graph_runtime_state property from base class
with pytest.raises(GraphEngineLayerNotInitializedError):
_ = layer.graph_runtime_state
# Should have command_channel from base class
assert layer.command_channel is None
class TestSandboxLayerIntegration:
"""Integration tests for SandboxLayer with real LocalVirtualEnvironment."""
def test_full_lifecycle_with_mocked_provider(self):
layer = SandboxLayer(tenant_id="integration-tenant")
workflow_execution_id = "integration-test-exec"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox")
def test_full_lifecycle_with_local_sandbox(self, tmp_path: Path):
"""Test complete lifecycle: init -> start -> end with local sandbox."""
layer = SandboxLayer(
sandbox_type=SandboxType.LOCAL,
options={"base_working_path": str(tmp_path)},
)
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
# Start
layer.on_graph_start()
assert layer._workflow_execution_id == workflow_execution_id # pyright: ignore[reportPrivateUsage]
assert layer.sandbox is mock_sandbox
assert SandboxManager.get(workflow_execution_id) is mock_sandbox
# Verify sandbox is created
assert layer._sandbox is not None # pyright: ignore[reportPrivateUsage]
sandbox_id = layer.sandbox.metadata.id
assert sandbox_id is not None
# End
layer.on_graph_end(error=None)
# Verify sandbox is released
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
assert not SandboxManager.has(workflow_execution_id)
mock_sandbox.release_environment.assert_called_once()
def test_lifecycle_with_workflow_error(self, tmp_path: Path):
"""Test lifecycle when workflow encounters an error."""
layer = SandboxLayer(
sandbox_type=SandboxType.LOCAL,
options={"base_working_path": str(tmp_path)},
)
def test_lifecycle_with_workflow_error(self):
layer = SandboxLayer(tenant_id="error-tenant")
workflow_execution_id = "integration-error-test"
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
mock_sandbox = MagicMock(spec=VirtualEnvironment)
mock_sandbox.metadata = MockMetadata()
with patch(
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
return_value=mock_sandbox,
):
layer.on_graph_start()
layer.on_graph_start()
assert layer.sandbox.metadata.id is not None
# Simulate workflow error
layer.on_graph_end(error=Exception("Workflow execution failed"))
# Sandbox should still be cleaned up
# pyright: ignore[reportPrivateUsage]
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
assert not SandboxManager.has(workflow_execution_id)
mock_sandbox.release_environment.assert_called_once()