mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
feat(sandbox-layer): refactor sandbox management and integrate with SandboxManager
- Simplified the SandboxLayer initialization by removing unused parameters and consolidating sandbox creation logic. - Integrated SandboxManager for better lifecycle management of sandboxes during workflow execution. - Updated error handling to ensure proper initialization and cleanup of sandboxes. - Enhanced CommandNode to retrieve sandboxes from SandboxManager, improving sandbox availability checks. - Added unit tests to validate the new sandbox management approach and ensure robust error handling.
This commit is contained in:
@ -1,11 +1,3 @@
|
||||
"""
|
||||
Unit tests for the SandboxLayer.
|
||||
|
||||
This module tests the SandboxLayer lifecycle management including initialization,
|
||||
event handling, and cleanup of VirtualEnvironment instances.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -13,7 +5,7 @@ import pytest
|
||||
from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer
|
||||
from core.virtual_environment.__base.entities import Arch
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.factory import SandboxFactory, SandboxType
|
||||
from core.virtual_environment.sandbox_manager import SandboxManager
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
@ -23,16 +15,12 @@ from core.workflow.graph_events.graph import (
|
||||
|
||||
|
||||
class MockMetadata:
|
||||
"""Mock metadata for testing."""
|
||||
|
||||
def __init__(self, sandbox_id: str = "test-sandbox-id", arch: Arch = Arch.AMD64):
|
||||
self.id = sandbox_id
|
||||
self.arch = arch
|
||||
|
||||
|
||||
class MockVirtualEnvironment:
|
||||
"""Mock VirtualEnvironment for testing."""
|
||||
|
||||
def __init__(self, sandbox_id: str = "test-sandbox-id"):
|
||||
self.metadata = MockMetadata(sandbox_id=sandbox_id)
|
||||
self._released = False
|
||||
@ -41,33 +29,46 @@ class MockVirtualEnvironment:
|
||||
self._released = True
|
||||
|
||||
|
||||
class MockSystemVariableView:
|
||||
def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"):
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._workflow_execution_id
|
||||
|
||||
|
||||
class MockReadOnlyGraphRuntimeStateWrapper:
|
||||
def __init__(self, workflow_execution_id: str | None = "test-workflow-exec-id"):
|
||||
self._system_variable = MockSystemVariableView(workflow_execution_id)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> MockSystemVariableView:
|
||||
return self._system_variable
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_sandbox_manager():
|
||||
SandboxManager.clear()
|
||||
yield
|
||||
SandboxManager.clear()
|
||||
|
||||
|
||||
class TestSandboxLayer:
|
||||
"""Unit tests for SandboxLayer."""
|
||||
|
||||
def test_init_with_default_parameters(self):
|
||||
"""Test SandboxLayer initialization with default parameters."""
|
||||
layer = SandboxLayer()
|
||||
|
||||
assert layer._sandbox_type is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._options == {} # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._environments == {} # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_init_with_custom_parameters(self):
|
||||
"""Test SandboxLayer initialization with custom parameters."""
|
||||
def test_init_with_parameters(self):
|
||||
layer = SandboxLayer(
|
||||
sandbox_type=SandboxType.LOCAL,
|
||||
tenant_id="test-tenant",
|
||||
options={"base_working_path": "/tmp/sandbox"},
|
||||
environments={"PYTHONUNBUFFERED": "1"},
|
||||
)
|
||||
|
||||
assert layer._sandbox_type == SandboxType.LOCAL # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._options == {"base_working_path": "/tmp/sandbox"} # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._environments == {"PYTHONUNBUFFERED": "1"} # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_sandbox_property_raises_when_not_initialized(self):
|
||||
"""Test that accessing sandbox property raises error before initialization."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
_ = layer.sandbox
|
||||
@ -75,170 +76,213 @@ class TestSandboxLayer:
|
||||
assert "Sandbox not initialized" in str(exc_info.value)
|
||||
|
||||
def test_sandbox_property_returns_sandbox_after_initialization(self):
|
||||
"""Test that sandbox property returns the sandbox after on_graph_start."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_sandbox = MockVirtualEnvironment()
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id")
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
assert layer.sandbox is mock_sandbox
|
||||
|
||||
def test_on_graph_start_creates_sandbox(self):
|
||||
"""Test that on_graph_start creates a sandbox via factory."""
|
||||
def test_on_graph_start_creates_sandbox_and_registers_with_manager(self):
|
||||
layer = SandboxLayer(
|
||||
sandbox_type=SandboxType.DOCKER,
|
||||
options={"docker_image": "python:3.11"},
|
||||
tenant_id="test-tenant-123",
|
||||
environments={"PATH": "/usr/bin"},
|
||||
)
|
||||
mock_sandbox = MockVirtualEnvironment()
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-123")
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox) as mock_create:
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
) as mock_create:
|
||||
layer.on_graph_start()
|
||||
|
||||
mock_create.assert_called_once_with(
|
||||
tenant_id="default",
|
||||
sandbox_type=SandboxType.DOCKER,
|
||||
options={"docker_image": "python:3.11"},
|
||||
tenant_id="test-tenant-123",
|
||||
environments={"PATH": "/usr/bin"},
|
||||
)
|
||||
|
||||
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self):
|
||||
"""Test that on_graph_start raises SandboxInitializationError on factory failure."""
|
||||
layer = SandboxLayer(sandbox_type=SandboxType.DOCKER)
|
||||
assert SandboxManager.get("test-exec-123") is mock_sandbox
|
||||
|
||||
with patch.object(SandboxFactory, "create", side_effect=Exception("Docker not available")):
|
||||
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper("test-exec-id")
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
side_effect=Exception("Sandbox provider not available"),
|
||||
):
|
||||
with pytest.raises(SandboxInitializationError) as exc_info:
|
||||
layer.on_graph_start()
|
||||
|
||||
assert "Failed to initialize sandbox" in str(exc_info.value)
|
||||
assert "Docker not available" in str(exc_info.value)
|
||||
assert "Sandbox provider not available" in str(exc_info.value)
|
||||
|
||||
def test_on_graph_start_raises_when_workflow_execution_id_not_set(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id=None)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
layer.on_graph_start()
|
||||
|
||||
assert "workflow_execution_id is not set" in str(exc_info.value)
|
||||
|
||||
def test_on_event_is_noop(self):
|
||||
"""Test that on_event does nothing (no-op)."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
|
||||
# These should not raise any exceptions
|
||||
layer.on_event(GraphRunStartedEvent())
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||
|
||||
def test_on_graph_end_releases_sandbox(self):
|
||||
"""Test that on_graph_end releases the sandbox."""
|
||||
layer = SandboxLayer()
|
||||
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
workflow_execution_id = "test-exec-456"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
assert SandboxManager.has(workflow_execution_id)
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
assert not SandboxManager.has(workflow_execution_id)
|
||||
|
||||
def test_on_graph_end_releases_sandbox_even_on_error(self):
|
||||
"""Test that on_graph_end releases sandbox even when workflow had an error."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
workflow_execution_id = "test-exec-789"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_graph_end(error=Exception("Workflow failed"))
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
assert not SandboxManager.has(workflow_execution_id)
|
||||
|
||||
def test_on_graph_end_handles_release_failure_gracefully(self):
|
||||
"""Test that on_graph_end handles release failures without raising."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
mock_sandbox.release_environment.side_effect = Exception("Container already removed")
|
||||
workflow_execution_id = "test-exec-fail"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
# Should not raise exception
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_on_graph_end_noop_when_sandbox_not_initialized(self):
|
||||
"""Test that on_graph_end is a no-op when sandbox was never initialized."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
|
||||
# Should not raise exception
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_on_graph_end_is_idempotent(self):
|
||||
"""Test that calling on_graph_end multiple times is safe."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
workflow_execution_id = "test-exec-idempotent"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
|
||||
with patch.object(SandboxFactory, "create", return_value=mock_sandbox):
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
layer.on_graph_end(error=None) # Second call should be no-op
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
def test_layer_inherits_from_graph_engine_layer(self):
|
||||
"""Test that SandboxLayer properly inherits from GraphEngineLayer."""
|
||||
layer = SandboxLayer()
|
||||
layer = SandboxLayer(tenant_id="test-tenant")
|
||||
|
||||
# Should have the graph_runtime_state property from base class
|
||||
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||
_ = layer.graph_runtime_state
|
||||
|
||||
# Should have command_channel from base class
|
||||
assert layer.command_channel is None
|
||||
|
||||
|
||||
class TestSandboxLayerIntegration:
|
||||
"""Integration tests for SandboxLayer with real LocalVirtualEnvironment."""
|
||||
def test_full_lifecycle_with_mocked_provider(self):
|
||||
layer = SandboxLayer(tenant_id="integration-tenant")
|
||||
workflow_execution_id = "integration-test-exec"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox")
|
||||
|
||||
def test_full_lifecycle_with_local_sandbox(self, tmp_path: Path):
|
||||
"""Test complete lifecycle: init -> start -> end with local sandbox."""
|
||||
layer = SandboxLayer(
|
||||
sandbox_type=SandboxType.LOCAL,
|
||||
options={"base_working_path": str(tmp_path)},
|
||||
)
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
# Start
|
||||
layer.on_graph_start()
|
||||
assert layer._workflow_execution_id == workflow_execution_id # pyright: ignore[reportPrivateUsage]
|
||||
assert layer.sandbox is mock_sandbox
|
||||
assert SandboxManager.get(workflow_execution_id) is mock_sandbox
|
||||
|
||||
# Verify sandbox is created
|
||||
assert layer._sandbox is not None # pyright: ignore[reportPrivateUsage]
|
||||
sandbox_id = layer.sandbox.metadata.id
|
||||
assert sandbox_id is not None
|
||||
|
||||
# End
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
# Verify sandbox is released
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
assert not SandboxManager.has(workflow_execution_id)
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
def test_lifecycle_with_workflow_error(self, tmp_path: Path):
|
||||
"""Test lifecycle when workflow encounters an error."""
|
||||
layer = SandboxLayer(
|
||||
sandbox_type=SandboxType.LOCAL,
|
||||
options={"base_working_path": str(tmp_path)},
|
||||
)
|
||||
def test_lifecycle_with_workflow_error(self):
|
||||
layer = SandboxLayer(tenant_id="error-tenant")
|
||||
workflow_execution_id = "integration-error-test"
|
||||
mock_runtime_state = MockReadOnlyGraphRuntimeStateWrapper(workflow_execution_id)
|
||||
layer._graph_runtime_state = mock_runtime_state # type: ignore[assignment]
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox",
|
||||
return_value=mock_sandbox,
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_graph_start()
|
||||
assert layer.sandbox.metadata.id is not None
|
||||
|
||||
# Simulate workflow error
|
||||
layer.on_graph_end(error=Exception("Workflow execution failed"))
|
||||
|
||||
# Sandbox should still be cleaned up
|
||||
# pyright: ignore[reportPrivateUsage]
|
||||
assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._workflow_execution_id is None # pyright: ignore[reportPrivateUsage]
|
||||
assert not SandboxManager.has(workflow_execution_id)
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user