feat: add tenant_id support to Sandbox and VirtualEnvironment initialization

This commit is contained in:
Yeuoly
2026-01-08 16:19:29 +08:00
parent 94dbda503f
commit b09a831d15
8 changed files with 66 additions and 28 deletions

View File

@ -113,7 +113,10 @@ class SandboxLayer(GraphEngineLayer):
# Fallback to explicit configuration (backward compatibility) # Fallback to explicit configuration (backward compatibility)
sandbox_type = self._sandbox_type or SandboxType.DOCKER sandbox_type = self._sandbox_type or SandboxType.DOCKER
logger.info("Initializing sandbox, sandbox_type=%s", sandbox_type) logger.info("Initializing sandbox, sandbox_type=%s", sandbox_type)
# Use a placeholder tenant_id for backward compatibility when tenant_id is not provided
effective_tenant_id = self._tenant_id or "default"
self._sandbox = SandboxFactory.create( self._sandbox = SandboxFactory.create(
tenant_id=effective_tenant_id,
sandbox_type=sandbox_type, sandbox_type=sandbox_type,
options=self._options, options=self._options,
environments=self._environments, environments=self._environments,

View File

@ -14,11 +14,25 @@ class VirtualEnvironment(ABC):
Base class for virtual environment implementations. Base class for virtual environment implementations.
""" """
def __init__(self, options: Mapping[str, Any], environments: Mapping[str, str] | None = None) -> None: def __init__(
self,
tenant_id: str,
options: Mapping[str, Any],
environments: Mapping[str, str] | None = None,
user_id: str | None = None,
) -> None:
""" """
Initialize the virtual environment with metadata. Initialize the virtual environment with metadata.
Args:
tenant_id: The tenant ID associated with this environment (required).
options: Provider-specific configuration options.
environments: Environment variables to set in the virtual environment.
user_id: The user ID associated with this environment (optional).
""" """
self.tenant_id = tenant_id
self.user_id = user_id
self.options = options self.options = options
self.metadata = self._construct_environment(options, environments or {}) self.metadata = self._construct_environment(options, environments or {})

View File

@ -3,7 +3,8 @@ Sandbox factory for creating VirtualEnvironment instances.
Example: Example:
sandbox = SandboxFactory.create( sandbox = SandboxFactory.create(
SandboxType.DOCKER, tenant_id="tenant-uuid",
sandbox_type=SandboxType.DOCKER,
options={"docker_image": "python:3.11-slim"}, options={"docker_image": "python:3.11-slim"},
environments={"PATH": "/usr/local/bin"}, environments={"PATH": "/usr/local/bin"},
) )
@ -34,17 +35,21 @@ class SandboxFactory:
@classmethod @classmethod
def create( def create(
cls, cls,
tenant_id: str,
sandbox_type: SandboxType, sandbox_type: SandboxType,
options: Mapping[str, Any] | None = None, options: Mapping[str, Any] | None = None,
environments: Mapping[str, str] | None = None, environments: Mapping[str, str] | None = None,
user_id: str | None = None,
) -> VirtualEnvironment: ) -> VirtualEnvironment:
""" """
Create a VirtualEnvironment instance based on the specified type. Create a VirtualEnvironment instance based on the specified type.
Args: Args:
tenant_id: Tenant ID associated with the sandbox (required)
sandbox_type: Type of sandbox to create sandbox_type: Type of sandbox to create
options: Sandbox-specific configuration options options: Sandbox-specific configuration options
environments: Environment variables to set in the sandbox environments: Environment variables to set in the sandbox
user_id: User ID associated with the sandbox (optional)
Returns: Returns:
Configured VirtualEnvironment instance Configured VirtualEnvironment instance
@ -56,7 +61,7 @@ class SandboxFactory:
environments = environments or {} environments = environments or {}
sandbox_class = cls._get_sandbox_class(sandbox_type) sandbox_class = cls._get_sandbox_class(sandbox_type)
return sandbox_class(options=options, environments=environments) return sandbox_class(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id)
@classmethod @classmethod
def _get_sandbox_class(cls, sandbox_type: SandboxType) -> type[VirtualEnvironment]: def _get_sandbox_class(cls, sandbox_type: SandboxType) -> type[VirtualEnvironment]:

View File

@ -362,6 +362,7 @@ class SandboxProviderService:
config = decrypt_system_oauth_params(system_default.encrypted_config) config = decrypt_system_oauth_params(system_default.encrypted_config)
return SandboxFactory.create( return SandboxFactory.create(
tenant_id=tenant_id,
sandbox_type=SandboxType(provider_type), sandbox_type=SandboxType(provider_type),
options=dict(config) if config else {}, options=dict(config) if config else {},
environments=environments or {}, environments=environments or {},

View File

@ -48,10 +48,10 @@ class TestSandboxLayer:
"""Test SandboxLayer initialization with default parameters.""" """Test SandboxLayer initialization with default parameters."""
layer = SandboxLayer() layer = SandboxLayer()
assert layer._sandbox_type == SandboxType.DOCKER assert layer._sandbox_type is None # pyright: ignore[reportPrivateUsage]
assert layer._options == {} assert layer._options == {} # pyright: ignore[reportPrivateUsage]
assert layer._environments == {} assert layer._environments == {} # pyright: ignore[reportPrivateUsage]
assert layer._sandbox is None assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
def test_init_with_custom_parameters(self): def test_init_with_custom_parameters(self):
"""Test SandboxLayer initialization with custom parameters.""" """Test SandboxLayer initialization with custom parameters."""
@ -61,9 +61,9 @@ class TestSandboxLayer:
environments={"PYTHONUNBUFFERED": "1"}, environments={"PYTHONUNBUFFERED": "1"},
) )
assert layer._sandbox_type == SandboxType.LOCAL assert layer._sandbox_type == SandboxType.LOCAL # pyright: ignore[reportPrivateUsage]
assert layer._options == {"base_working_path": "/tmp/sandbox"} assert layer._options == {"base_working_path": "/tmp/sandbox"} # pyright: ignore[reportPrivateUsage]
assert layer._environments == {"PYTHONUNBUFFERED": "1"} assert layer._environments == {"PYTHONUNBUFFERED": "1"} # pyright: ignore[reportPrivateUsage]
def test_sandbox_property_raises_when_not_initialized(self): def test_sandbox_property_raises_when_not_initialized(self):
"""Test that accessing sandbox property raises error before initialization.""" """Test that accessing sandbox property raises error before initialization."""
@ -97,6 +97,7 @@ class TestSandboxLayer:
layer.on_graph_start() layer.on_graph_start()
mock_create.assert_called_once_with( mock_create.assert_called_once_with(
tenant_id="default",
sandbox_type=SandboxType.DOCKER, sandbox_type=SandboxType.DOCKER,
options={"docker_image": "python:3.11"}, options={"docker_image": "python:3.11"},
environments={"PATH": "/usr/bin"}, environments={"PATH": "/usr/bin"},
@ -110,7 +111,7 @@ class TestSandboxLayer:
with pytest.raises(SandboxInitializationError) as exc_info: with pytest.raises(SandboxInitializationError) as exc_info:
layer.on_graph_start() layer.on_graph_start()
assert "Failed to initialize docker sandbox" in str(exc_info.value) assert "Failed to initialize sandbox" in str(exc_info.value)
assert "Docker not available" in str(exc_info.value) assert "Docker not available" in str(exc_info.value)
def test_on_event_is_noop(self): def test_on_event_is_noop(self):
@ -134,7 +135,7 @@ class TestSandboxLayer:
layer.on_graph_end(error=None) layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once() mock_sandbox.release_environment.assert_called_once()
assert layer._sandbox is None assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
def test_on_graph_end_releases_sandbox_even_on_error(self): def test_on_graph_end_releases_sandbox_even_on_error(self):
"""Test that on_graph_end releases sandbox even when workflow had an error.""" """Test that on_graph_end releases sandbox even when workflow had an error."""
@ -148,7 +149,7 @@ class TestSandboxLayer:
layer.on_graph_end(error=Exception("Workflow failed")) layer.on_graph_end(error=Exception("Workflow failed"))
mock_sandbox.release_environment.assert_called_once() mock_sandbox.release_environment.assert_called_once()
assert layer._sandbox is None assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
def test_on_graph_end_handles_release_failure_gracefully(self): def test_on_graph_end_handles_release_failure_gracefully(self):
"""Test that on_graph_end handles release failures without raising.""" """Test that on_graph_end handles release failures without raising."""
@ -164,7 +165,7 @@ class TestSandboxLayer:
layer.on_graph_end(error=None) layer.on_graph_end(error=None)
mock_sandbox.release_environment.assert_called_once() mock_sandbox.release_environment.assert_called_once()
assert layer._sandbox is None assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
def test_on_graph_end_noop_when_sandbox_not_initialized(self): def test_on_graph_end_noop_when_sandbox_not_initialized(self):
"""Test that on_graph_end is a no-op when sandbox was never initialized.""" """Test that on_graph_end is a no-op when sandbox was never initialized."""
@ -173,7 +174,7 @@ class TestSandboxLayer:
# Should not raise exception # Should not raise exception
layer.on_graph_end(error=None) layer.on_graph_end(error=None)
assert layer._sandbox is None assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
def test_on_graph_end_is_idempotent(self): def test_on_graph_end_is_idempotent(self):
"""Test that calling on_graph_end multiple times is safe.""" """Test that calling on_graph_end multiple times is safe."""
@ -215,7 +216,7 @@ class TestSandboxLayerIntegration:
layer.on_graph_start() layer.on_graph_start()
# Verify sandbox is created # Verify sandbox is created
assert layer._sandbox is not None assert layer._sandbox is not None # pyright: ignore[reportPrivateUsage]
sandbox_id = layer.sandbox.metadata.id sandbox_id = layer.sandbox.metadata.id
assert sandbox_id is not None assert sandbox_id is not None
@ -223,7 +224,7 @@ class TestSandboxLayerIntegration:
layer.on_graph_end(error=None) layer.on_graph_end(error=None)
# Verify sandbox is released # Verify sandbox is released
assert layer._sandbox is None assert layer._sandbox is None # pyright: ignore[reportPrivateUsage]
def test_lifecycle_with_workflow_error(self, tmp_path: Path): def test_lifecycle_with_workflow_error(self, tmp_path: Path):
"""Test lifecycle when workflow encounters an error.""" """Test lifecycle when workflow encounters an error."""

View File

@ -40,14 +40,17 @@ class TestSandboxFactory:
with patch.object(SandboxFactory, "_get_sandbox_class", return_value=mock_sandbox_class): with patch.object(SandboxFactory, "_get_sandbox_class", return_value=mock_sandbox_class):
result = SandboxFactory.create( result = SandboxFactory.create(
tenant_id="test-tenant",
sandbox_type=SandboxType.DOCKER, sandbox_type=SandboxType.DOCKER,
options={"docker_image": "python:3.11-slim"}, options={"docker_image": "python:3.11-slim"},
environments={"PYTHONUNBUFFERED": "1"}, environments={"PYTHONUNBUFFERED": "1"},
) )
mock_sandbox_class.assert_called_once_with( mock_sandbox_class.assert_called_once_with(
tenant_id="test-tenant",
options={"docker_image": "python:3.11-slim"}, options={"docker_image": "python:3.11-slim"},
environments={"PYTHONUNBUFFERED": "1"}, environments={"PYTHONUNBUFFERED": "1"},
user_id=None,
) )
assert result is mock_sandbox_instance assert result is mock_sandbox_instance
@ -57,9 +60,13 @@ class TestSandboxFactory:
mock_sandbox_class = MagicMock(return_value=mock_sandbox_instance) mock_sandbox_class = MagicMock(return_value=mock_sandbox_instance)
with patch.object(SandboxFactory, "_get_sandbox_class", return_value=mock_sandbox_class): with patch.object(SandboxFactory, "_get_sandbox_class", return_value=mock_sandbox_class):
SandboxFactory.create(sandbox_type=SandboxType.DOCKER, options=None, environments=None) SandboxFactory.create(
tenant_id="test-tenant", sandbox_type=SandboxType.DOCKER, options=None, environments=None
)
mock_sandbox_class.assert_called_once_with(options={}, environments={}) mock_sandbox_class.assert_called_once_with(
tenant_id="test-tenant", options={}, environments={}, user_id=None
)
def test_create_with_default_parameters(self): def test_create_with_default_parameters(self):
"""Test sandbox creation with default parameters.""" """Test sandbox creation with default parameters."""
@ -67,9 +74,11 @@ class TestSandboxFactory:
mock_sandbox_class = MagicMock(return_value=mock_sandbox_instance) mock_sandbox_class = MagicMock(return_value=mock_sandbox_instance)
with patch.object(SandboxFactory, "_get_sandbox_class", return_value=mock_sandbox_class): with patch.object(SandboxFactory, "_get_sandbox_class", return_value=mock_sandbox_class):
result = SandboxFactory.create(sandbox_type=SandboxType.DOCKER) result = SandboxFactory.create(tenant_id="test-tenant", sandbox_type=SandboxType.DOCKER)
mock_sandbox_class.assert_called_once_with(options={}, environments={}) mock_sandbox_class.assert_called_once_with(
tenant_id="test-tenant", options={}, environments={}, user_id=None
)
assert result is mock_sandbox_instance assert result is mock_sandbox_instance
def test_get_sandbox_class_docker_returns_correct_class(self): def test_get_sandbox_class_docker_returns_correct_class(self):
@ -81,7 +90,7 @@ class TestSandboxFactory:
"core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment", "core.virtual_environment.providers.docker_daemon_sandbox.DockerDaemonEnvironment",
return_value=mock_instance, return_value=mock_instance,
) as mock_docker_class: ) as mock_docker_class:
SandboxFactory.create(sandbox_type=SandboxType.DOCKER) SandboxFactory.create(tenant_id="test-tenant", sandbox_type=SandboxType.DOCKER)
mock_docker_class.assert_called_once() mock_docker_class.assert_called_once()
def test_get_sandbox_class_local_returns_correct_class(self): def test_get_sandbox_class_local_returns_correct_class(self):
@ -92,7 +101,7 @@ class TestSandboxFactory:
"core.virtual_environment.providers.local_without_isolation.LocalVirtualEnvironment", "core.virtual_environment.providers.local_without_isolation.LocalVirtualEnvironment",
return_value=mock_instance, return_value=mock_instance,
) as mock_local_class: ) as mock_local_class:
SandboxFactory.create(sandbox_type=SandboxType.LOCAL) SandboxFactory.create(tenant_id="test-tenant", sandbox_type=SandboxType.LOCAL)
mock_local_class.assert_called_once() mock_local_class.assert_called_once()
def test_get_sandbox_class_e2b_returns_correct_class(self): def test_get_sandbox_class_e2b_returns_correct_class(self):
@ -103,13 +112,13 @@ class TestSandboxFactory:
"core.virtual_environment.providers.e2b_sandbox.E2BEnvironment", "core.virtual_environment.providers.e2b_sandbox.E2BEnvironment",
return_value=mock_instance, return_value=mock_instance,
) as mock_e2b_class: ) as mock_e2b_class:
SandboxFactory.create(sandbox_type=SandboxType.E2B) SandboxFactory.create(tenant_id="test-tenant", sandbox_type=SandboxType.E2B)
mock_e2b_class.assert_called_once() mock_e2b_class.assert_called_once()
def test_create_with_unsupported_type_raises_value_error(self): def test_create_with_unsupported_type_raises_value_error(self):
"""Test that unsupported sandbox type raises ValueError.""" """Test that unsupported sandbox type raises ValueError."""
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
SandboxFactory.create(sandbox_type="unsupported_type") # type: ignore[arg-type] SandboxFactory.create(tenant_id="test-tenant", sandbox_type="unsupported_type") # type: ignore[arg-type]
assert "Unsupported sandbox type: unsupported_type" in str(exc_info.value) assert "Unsupported sandbox type: unsupported_type" in str(exc_info.value)
@ -120,7 +129,7 @@ class TestSandboxFactory:
with patch.object(SandboxFactory, "_get_sandbox_class", return_value=mock_sandbox_class): with patch.object(SandboxFactory, "_get_sandbox_class", return_value=mock_sandbox_class):
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
SandboxFactory.create(sandbox_type=SandboxType.DOCKER) SandboxFactory.create(tenant_id="test-tenant", sandbox_type=SandboxType.DOCKER)
assert "Docker daemon not available" in str(exc_info.value) assert "Docker daemon not available" in str(exc_info.value)
@ -131,6 +140,7 @@ class TestSandboxFactoryIntegration:
def test_create_local_sandbox_integration(self, tmp_path: Path): def test_create_local_sandbox_integration(self, tmp_path: Path):
"""Test creating a real local sandbox.""" """Test creating a real local sandbox."""
sandbox = SandboxFactory.create( sandbox = SandboxFactory.create(
tenant_id="test-tenant",
sandbox_type=SandboxType.LOCAL, sandbox_type=SandboxType.LOCAL,
options={"base_working_path": str(tmp_path)}, options={"base_working_path": str(tmp_path)},
environments={}, environments={},

View File

@ -25,7 +25,7 @@ def _drain_transport(transport: TransportReadCloser) -> bytes:
@pytest.fixture @pytest.fixture
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment: def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64") monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
return LocalVirtualEnvironment({"base_working_path": str(tmp_path)}) return LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)})
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment): def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):

View File

@ -28,7 +28,7 @@ class FakeSandbox(VirtualEnvironment):
self._close_streams = close_streams self._close_streams = close_streams
self.last_execute_command: list[str] | None = None self.last_execute_command: list[str] | None = None
self.released_connections: list[str] = [] self.released_connections: list[str] = []
super().__init__(options={}, environments={}) super().__init__(tenant_id="test-tenant", options={}, environments={})
def _construct_environment(self, options, environments): # type: ignore[override] def _construct_environment(self, options, environments): # type: ignore[override]
return Metadata(id="fake", arch=Arch.ARM64) return Metadata(id="fake", arch=Arch.ARM64)
@ -75,6 +75,10 @@ class FakeSandbox(VirtualEnvironment):
return self._statuses.pop(0) return self._statuses.pop(0)
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=0) return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=0)
@classmethod
def validate(cls, options: Any) -> None:
pass
def _make_node(*, command: str, working_directory: str = "") -> CommandNode: def _make_node(*, command: str, working_directory: str = "") -> CommandNode:
variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={}) variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={})