From f0c6c0159cebebaeafac1aac9085dffceff3578f Mon Sep 17 00:00:00 2001 From: Harry Date: Thu, 12 Mar 2026 18:22:57 +0800 Subject: [PATCH] refactor: a lot of optimization and enhancement --- api/core/sandbox/__init__.py | 2 - api/core/sandbox/builder.py | 7 +- .../initializer/dify_cli_initializer.py | 7 +- api/core/sandbox/inspector/browser.py | 10 -- api/core/sandbox/manager.py | 103 ------------------ api/core/session/cli_api.py | 7 +- .../__base/command_future.py | 57 ++++++++-- .../virtual_environment/__base/helpers.py | 1 + .../__base/virtual_environment.py | 53 ++++++++- .../providers/e2b_sandbox.py | 67 ++++++++---- .../providers/local_without_isolation.py | 11 ++ .../providers/ssh_sandbox.py | 20 ++++ api/core/zip_sandbox/zip_sandbox.py | 34 +++--- .../__base/test_command_future.py | 31 ++++++ .../core/virtual_environment/test_helpers.py | 6 + .../test_local_without_isolation.py | 4 +- .../nodes/command/test_command_node.py | 1 + 17 files changed, 248 insertions(+), 173 deletions(-) delete mode 100644 api/core/sandbox/manager.py diff --git a/api/core/sandbox/__init__.py b/api/core/sandbox/__init__.py index 608e57546a..b2eded2c87 100644 --- a/api/core/sandbox/__init__.py +++ b/api/core/sandbox/__init__.py @@ -23,7 +23,6 @@ if TYPE_CHECKING: from .initializer.app_assets_initializer import AppAssetsInitializer from .initializer.dify_cli_initializer import DifyCliInitializer from .initializer.draft_app_assets_initializer import DraftAppAssetsInitializer - from .manager import SandboxManager from .sandbox import Sandbox from .storage import ArchiveSandboxStorage, SandboxStorage from .utils.debug import sandbox_debug @@ -47,7 +46,6 @@ __all__ = [ "SandboxBuilder", "SandboxInitializeContext", "SandboxInitializer", - "SandboxManager", "SandboxProviderApiEntity", "SandboxStorage", "SandboxType", diff --git a/api/core/sandbox/builder.py b/api/core/sandbox/builder.py index 04d12b657b..6ec2bd38aa 100644 --- a/api/core/sandbox/builder.py +++ b/api/core/sandbox/builder.py @@ -131,6 +131,7 @@ class SandboxBuilder: environments=self._environments, user_id=self._user_id, ) + vm.open_enviroment() sandbox = Sandbox( vm=vm, storage=self._storage, @@ -174,7 +175,11 @@ class SandboxBuilder: if sandbox.is_cancelled(): return - sandbox.mount() + # Storage mount is part of readiness. If restore/mount fails, + # the sandbox must surface initialization failure instead of + # becoming "ready" with missing files. + if not sandbox.mount(): + raise RuntimeError("Sandbox storage mount failed") sandbox.mark_ready() except Exception as exc: try: diff --git a/api/core/sandbox/initializer/dify_cli_initializer.py b/api/core/sandbox/initializer/dify_cli_initializer.py index 2b01d898fd..49b71f75f0 100644 --- a/api/core/sandbox/initializer/dify_cli_initializer.py +++ b/api/core/sandbox/initializer/dify_cli_initializer.py @@ -19,12 +19,9 @@ logger = logging.getLogger(__name__) class DifyCliInitializer(AsyncSandboxInitializer): - _cli_api_session: object | None - def __init__(self, cli_root: str | Path | None = None) -> None: self._locator = DifyCliLocator(root=cli_root) self._tools: list[object] = [] - self._cli_api_session = None def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: vm = sandbox.vm @@ -57,7 +54,7 @@ class DifyCliInitializer(AsyncSandboxInitializer): logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id) return - self._cli_api_session = CliApiSessionManager().create( + global_cli_session = CliApiSessionManager().create( tenant_id=ctx.tenant_id, user_id=ctx.user_id, context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())), @@ -67,7 +64,7 @@ class DifyCliInitializer(AsyncSandboxInitializer): ["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir" ).execute(raise_on_error=True) - config = DifyCliConfig.create(self._cli_api_session, ctx.tenant_id, bundle.get_tool_dependencies()) + config = DifyCliConfig.create(global_cli_session, ctx.tenant_id, bundle.get_tool_dependencies()) config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False) config_path = cli.global_config_path vm.upload_file(config_path, BytesIO(config_json.encode("utf-8"))) diff --git a/api/core/sandbox/inspector/browser.py b/api/core/sandbox/inspector/browser.py index 0f3bdcb326..d94945606b 100644 --- a/api/core/sandbox/inspector/browser.py +++ b/api/core/sandbox/inspector/browser.py @@ -5,8 +5,6 @@ from pathlib import PurePosixPath from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode from core.sandbox.inspector.archive_source import SandboxFileArchiveSource from core.sandbox.inspector.base import SandboxFileSource -from core.sandbox.inspector.runtime_source import SandboxFileRuntimeSource -from core.sandbox.manager import SandboxManager class SandboxFileBrowser: @@ -31,14 +29,6 @@ class SandboxFileBrowser: return "." if normalized in (".", "") else normalized def _backend(self) -> SandboxFileSource: - sandbox = SandboxManager.get(self._sandbox_id) - if sandbox is not None: - return SandboxFileRuntimeSource( - tenant_id=self._tenant_id, - app_id=self._app_id, - sandbox_id=self._sandbox_id, - runtime=sandbox.vm, - ) return SandboxFileArchiveSource( tenant_id=self._tenant_id, app_id=self._app_id, diff --git a/api/core/sandbox/manager.py b/api/core/sandbox/manager.py deleted file mode 100644 index a948d59da5..0000000000 --- a/api/core/sandbox/manager.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -import logging -import threading -from typing import TYPE_CHECKING, Final - -if TYPE_CHECKING: - from core.sandbox.sandbox import Sandbox - -logger = logging.getLogger(__name__) - - -class SandboxManager: - """Registry for active Sandbox instances. - - Stores complete Sandbox objects (not just VirtualEnvironment) to provide - access to sandbox metadata like tenant_id, app_id, user_id, assets_id. - """ - - _NUM_SHARDS: Final[int] = 1024 - _SHARD_MASK: Final[int] = _NUM_SHARDS - 1 - - _shard_locks: Final[tuple[threading.Lock, ...]] = tuple(threading.Lock() for _ in range(_NUM_SHARDS)) - _shards: list[dict[str, Sandbox]] = [{} for _ in range(_NUM_SHARDS)] - - @classmethod - def _shard_index(cls, sandbox_id: str) -> int: - return hash(sandbox_id) & cls._SHARD_MASK - - @classmethod - def register(cls, sandbox_id: str, sandbox: Sandbox) -> None: - if not sandbox_id: - raise ValueError("sandbox_id cannot be empty") - - shard_index = cls._shard_index(sandbox_id) - with cls._shard_locks[shard_index]: - shard = cls._shards[shard_index] - if sandbox_id in shard: - raise RuntimeError( - f"Sandbox already registered for sandbox_id={sandbox_id}. " - "Call unregister() first if you need to replace it." - ) - - new_shard = dict(shard) - new_shard[sandbox_id] = sandbox - cls._shards[shard_index] = new_shard - - logger.debug( - "Registered sandbox: sandbox_id=%s, id=%s, app_id=%s", - sandbox_id, - sandbox.id, - sandbox.app_id, - ) - - @classmethod - def get(cls, sandbox_id: str) -> Sandbox | None: - shard_index = cls._shard_index(sandbox_id) - return cls._shards[shard_index].get(sandbox_id) - - @classmethod - def unregister(cls, sandbox_id: str) -> Sandbox | None: - shard_index = cls._shard_index(sandbox_id) - with cls._shard_locks[shard_index]: - shard = cls._shards[shard_index] - sandbox = shard.get(sandbox_id) - if sandbox is None: - return None - - new_shard = dict(shard) - new_shard.pop(sandbox_id, None) - cls._shards[shard_index] = new_shard - - logger.debug( - "Unregistered sandbox: sandbox_id=%s, id=%s", - sandbox_id, - sandbox.id, - ) - return sandbox - - @classmethod - def has(cls, sandbox_id: str) -> bool: - shard_index = cls._shard_index(sandbox_id) - return sandbox_id in cls._shards[shard_index] - - @classmethod - def is_sandbox_runtime(cls, sandbox_id: str) -> bool: - return cls.has(sandbox_id) - - @classmethod - def clear(cls) -> None: - for lock in cls._shard_locks: - lock.acquire() - try: - for i in range(cls._NUM_SHARDS): - cls._shards[i] = {} - logger.debug("Cleared all registered sandboxes") - finally: - for lock in reversed(cls._shard_locks): - lock.release() - - @classmethod - def count(cls) -> int: - return sum(len(shard) for shard in cls._shards) diff --git a/api/core/session/cli_api.py b/api/core/session/cli_api.py index 44d13c55ae..a8c42a5ede 100644 --- a/api/core/session/cli_api.py +++ b/api/core/session/cli_api.py @@ -2,6 +2,7 @@ import secrets from pydantic import BaseModel, Field +from configs import dify_config from core.skill.entities import ToolAccessPolicy from .session import BaseSession, SessionManager @@ -17,7 +18,11 @@ class CliContext(BaseModel): class CliApiSessionManager(SessionManager[CliApiSession]): def __init__(self, ttl: int | None = None): - super().__init__(key_prefix="cli_api_session", session_class=CliApiSession, ttl=ttl) + super().__init__( + key_prefix="cli_api_session", + session_class=CliApiSession, + ttl=ttl or dify_config.WORKFLOW_MAX_EXECUTION_TIME, + ) def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession: session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json")) diff --git a/api/core/virtual_environment/__base/command_future.py b/api/core/virtual_environment/__base/command_future.py index 88bef76054..f5363fcf4f 100644 --- a/api/core/virtual_environment/__base/command_future.py +++ b/api/core/virtual_environment/__base/command_future.py @@ -26,6 +26,9 @@ class CommandFuture: Lightweight future for command execution. Mirrors concurrent.futures.Future API with 4 essential methods: result(), done(), cancel(), cancelled(). + + When a command is cancelled or times out the future now asks the provider + to terminate the underlying process/session before marking itself done. """ def __init__( @@ -35,6 +38,7 @@ class CommandFuture: stdout_transport: TransportReadCloser, stderr_transport: TransportReadCloser, poll_status: Callable[[], CommandStatus], + terminate_command: Callable[[], bool] | None = None, poll_interval: float = 0.1, ): self._pid = pid @@ -42,6 +46,7 @@ class CommandFuture: self._stdout_transport = stdout_transport self._stderr_transport = stderr_transport self._poll_status = poll_status + self._terminate_command = terminate_command self._poll_interval = poll_interval self._done_event = threading.Event() @@ -49,7 +54,9 @@ class CommandFuture: self._result: CommandResult | None = None self._exception: BaseException | None = None self._cancelled = False + self._timed_out = False self._started = False + self._termination_requested = False def result(self, timeout: float | None = None) -> CommandResult: """ @@ -61,15 +68,22 @@ class CommandFuture: Raises: CommandTimeoutError: If timeout exceeded. CommandCancelledError: If command was cancelled. + + A timeout is terminal for this future: it triggers best-effort command + termination and subsequent ``result()`` calls keep raising timeout. """ self._ensure_started() if not self._done_event.wait(timeout): + self._request_stop(timed_out=True) raise CommandTimeoutError(f"Command timed out after {timeout}s") if self._cancelled: raise CommandCancelledError("Command was cancelled") + if self._timed_out: + raise CommandTimeoutError("Command timed out") + if self._exception is not None: raise self._exception @@ -82,16 +96,10 @@ class CommandFuture: def cancel(self) -> bool: """ - Attempt to cancel command by closing transports. + Attempt to cancel command by terminating it and closing transports. Returns True if cancelled, False if already completed. """ - with self._lock: - if self._done_event.is_set(): - return False - self._cancelled = True - self._close_transports() - self._done_event.set() - return True + return self._request_stop(cancelled=True) def cancelled(self) -> bool: return self._cancelled @@ -103,6 +111,28 @@ class CommandFuture: thread = threading.Thread(target=self._execute, daemon=True) thread.start() + def _request_stop(self, *, cancelled: bool = False, timed_out: bool = False) -> bool: + should_terminate = False + with self._lock: + if self._done_event.is_set(): + return False + + if cancelled: + self._cancelled = True + if timed_out: + self._timed_out = True + + should_terminate = not self._termination_requested + if should_terminate: + self._termination_requested = True + + self._close_transports() + self._done_event.set() + + if should_terminate: + self._terminate_running_command() + return True + def _execute(self) -> None: stdout_buf = bytearray() stderr_buf = bytearray() @@ -141,7 +171,7 @@ class CommandFuture: self._close_transports() def _wait_for_completion(self) -> int | None: - while not self._cancelled: + while not self._cancelled and not self._timed_out: try: status = self._poll_status() except NotSupportedOperationError: @@ -167,3 +197,12 @@ class CommandFuture: for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport): with contextlib.suppress(Exception): transport.close() + + def _terminate_running_command(self) -> None: + if self._terminate_command is None: + return + + try: + self._terminate_command() + except Exception: + logger.exception("Failed to terminate command for pid %s", self._pid) diff --git a/api/core/virtual_environment/__base/helpers.py b/api/core/virtual_environment/__base/helpers.py index 555dc880ca..e8094f4ba7 100644 --- a/api/core/virtual_environment/__base/helpers.py +++ b/api/core/virtual_environment/__base/helpers.py @@ -71,6 +71,7 @@ def submit_command( stdout_transport=stdout_transport, stderr_transport=stderr_transport, poll_status=partial(env.get_command_status, connection, pid), + terminate_command=partial(env.terminate_command, connection, pid), ) diff --git a/api/core/virtual_environment/__base/virtual_environment.py b/api/core/virtual_environment/__base/virtual_environment.py index 826d3a6ec7..5332f1dd2e 100644 --- a/api/core/virtual_environment/__base/virtual_environment.py +++ b/api/core/virtual_environment/__base/virtual_environment.py @@ -11,8 +11,19 @@ from core.virtual_environment.channel.transport import TransportReadCloser, Tran class VirtualEnvironment(ABC): """ Base class for virtual environment implementations. + + ``VirtualEnvironment`` instances are configured at construction time but do + not allocate provider resources until ``open_enviroment()`` is called. + This keeps object construction side-effect free and gives callers a chance + to own startup error handling explicitly. """ + tenant_id: str + user_id: str | None + options: Mapping[str, Any] + _environments: Mapping[str, str] + _metadata: Metadata | None + def __init__( self, tenant_id: str, @@ -21,19 +32,45 @@ class VirtualEnvironment(ABC): user_id: str | None = None, ) -> None: """ - Initialize the virtual environment with metadata. + Initialize the virtual environment configuration. Args: tenant_id: The tenant ID associated with this environment (required). options: Provider-specific configuration options. environments: Environment variables to set in the virtual environment. user_id: The user ID associated with this environment (optional). + + The provider runtime itself is created later by ``open_enviroment()``. """ self.tenant_id = tenant_id self.user_id = user_id self.options = options - self.metadata = self._construct_environment(options, environments or {}) + self._environments = dict(environments or {}) + self._metadata = None + + @property + def metadata(self) -> Metadata: + """Provider metadata for a started environment. + + Raises: + RuntimeError: If the environment has not been started yet. + """ + + if self._metadata is None: + raise RuntimeError("Virtual environment has not been started") + return self._metadata + + def open_enviroment(self) -> Metadata: + """Allocate provider resources and return the resulting metadata. + + Multiple calls are safe and return the existing metadata after the first + successful start. + """ + + if self._metadata is None: + self._metadata = self._construct_environment(self.options, self._environments) + return self._metadata @abstractmethod def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: @@ -131,6 +168,18 @@ class VirtualEnvironment(ABC): Multiple calls to `release_environment` with the same `environment_id` is acceptable. """ + def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool: + """Best-effort termination hook for a running command. + + Providers that can map ``pid`` back to a real process/session should + override this method and stop the command. The default implementation is + a no-op so providers without a termination mechanism remain compatible. + """ + + _ = connection_handle + _ = pid + return False + @abstractmethod def execute_command( self, diff --git a/api/core/virtual_environment/providers/e2b_sandbox.py b/api/core/virtual_environment/providers/e2b_sandbox.py index 8559aacb32..ab2b5fbe6e 100644 --- a/api/core/virtual_environment/providers/e2b_sandbox.py +++ b/api/core/virtual_environment/providers/e2b_sandbox.py @@ -1,3 +1,4 @@ +import logging import posixpath import shlex import threading @@ -32,6 +33,8 @@ from core.virtual_environment.channel.transport import ( ) from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS +logger = logging.getLogger(__name__) + """ import logging from collections.abc import Mapping @@ -132,35 +135,53 @@ class E2BEnvironment(VirtualEnvironment): The sandbox lifetime is capped by ``WORKFLOW_MAX_EXECUTION_TIME`` so the provider can rely on E2B's native timeout instead of a background keepalive thread that continuously extends the session. + + E2B allocates the remote sandbox before metadata probing completes, so + startup failures must best-effort terminate the sandbox before the + exception escapes. """ # Import E2B SDK lazily so it is loaded after gevent monkey-patching. from e2b_code_interpreter import Sandbox # type: ignore[import-untyped] # TODO: add Dify as the user agent - sandbox = Sandbox.create( - template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"), - timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME, - api_key=options.get(self.OptionsKey.API_KEY, ""), - api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL), - envs=dict(environments), - ) - info = sandbox.get_info(api_key=options.get(self.OptionsKey.API_KEY, "")) - system_info = sandbox.commands.run("uname -m -s").stdout.strip() - system_parts = system_info.split() - if len(system_parts) == 2: - os_part, arch_part = system_parts - else: - arch_part = system_parts[0] - os_part = system_parts[1] if len(system_parts) > 1 else "" + sandbox = None + sandbox_id: str | None = None + api_key = options.get(self.OptionsKey.API_KEY, "") + try: + sandbox = Sandbox.create( + template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"), + timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + api_key=api_key, + api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL), + envs=dict(environments), + ) + info = sandbox.get_info(api_key=api_key) + sandbox_id = info.sandbox_id + system_info = sandbox.commands.run("uname -m -s").stdout.strip() + system_parts = system_info.split() + if len(system_parts) == 2: + os_part, arch_part = system_parts + else: + arch_part = system_parts[0] + os_part = system_parts[1] if len(system_parts) > 1 else "" - return Metadata( - id=info.sandbox_id, - arch=self._convert_architecture(arch_part.strip()), - os=self._convert_operating_system(os_part.strip()), - store={ - self.StoreKey.SANDBOX: sandbox, - }, - ) + return Metadata( + id=info.sandbox_id, + arch=self._convert_architecture(arch_part.strip()), + os=self._convert_operating_system(os_part.strip()), + store={ + self.StoreKey.SANDBOX: sandbox, + }, + ) + except Exception: + if sandbox_id is None and sandbox is not None: + sandbox_id = getattr(sandbox, "sandbox_id", None) + if sandbox_id is not None: + try: + Sandbox.kill(api_key=api_key, sandbox_id=sandbox_id) + except Exception: + logger.exception("Failed to cleanup E2B sandbox after startup failure") + raise def release_environment(self) -> None: """ diff --git a/api/core/virtual_environment/providers/local_without_isolation.py b/api/core/virtual_environment/providers/local_without_isolation.py index 54d3f28ed9..494de05738 100644 --- a/api/core/virtual_environment/providers/local_without_isolation.py +++ b/api/core/virtual_environment/providers/local_without_isolation.py @@ -1,6 +1,7 @@ import os import pathlib import shutil +import signal import subprocess from collections.abc import Mapping, Sequence from functools import cached_property @@ -246,6 +247,16 @@ class LocalVirtualEnvironment(VirtualEnvironment): except ChildProcessError: return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None) + def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool: + """Terminate a locally spawned process by PID when cancellation is requested.""" + + _ = connection_handle + try: + os.kill(int(pid), signal.SIGTERM) + except ProcessLookupError: + return False + return True + def _get_os_architecture(self) -> Arch: """ Get the operating system architecture. diff --git a/api/core/virtual_environment/providers/ssh_sandbox.py b/api/core/virtual_environment/providers/ssh_sandbox.py index f43f218c69..dd2c095509 100644 --- a/api/core/virtual_environment/providers/ssh_sandbox.py +++ b/api/core/virtual_environment/providers/ssh_sandbox.py @@ -76,6 +76,7 @@ class SSHSandboxEnvironment(VirtualEnvironment): ) -> None: self._connections: dict[str, Any] = {} self._commands: dict[str, CommandStatus] = {} + self._command_channels: dict[str, Any] = {} self._lock = threading.Lock() super().__init__(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id) @@ -163,6 +164,7 @@ class SSHSandboxEnvironment(VirtualEnvironment): with self._lock: self._commands[pid] = CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None) + self._command_channels[pid] = channel threading.Thread( target=self._consume_channel_output, @@ -179,6 +181,23 @@ class SSHSandboxEnvironment(VirtualEnvironment): return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None) return status + def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool: + """Best-effort termination by closing the SSH channel that owns the command.""" + + _ = connection_handle + with self._lock: + channel = self._command_channels.get(pid) + if channel is None: + return False + self._commands[pid] = CommandStatus( + status=CommandStatus.Status.COMPLETED, + exit_code=self._COMMAND_TIMEOUT_EXIT_CODE, + ) + + with contextlib.suppress(Exception): + channel.close() + return True + def upload_file(self, path: str, content: BytesIO) -> None: destination_path = self._workspace_path(path) with self._client() as client: @@ -424,6 +443,7 @@ class SSHSandboxEnvironment(VirtualEnvironment): channel.close() with self._lock: + self._command_channels.pop(pid, None) self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code) def _set_sftp_operation_timeout(self, sftp: Any) -> None: diff --git a/api/core/zip_sandbox/zip_sandbox.py b/api/core/zip_sandbox/zip_sandbox.py index 728ef538c6..0b8f60e01b 100644 --- a/api/core/zip_sandbox/zip_sandbox.py +++ b/api/core/zip_sandbox/zip_sandbox.py @@ -12,7 +12,6 @@ from uuid import uuid4 from core.sandbox.builder import SandboxBuilder from core.sandbox.entities.sandbox_type import SandboxType -from core.sandbox.manager import SandboxManager from core.sandbox.sandbox import Sandbox from core.sandbox.storage.noop_storage import NoopSandboxStorage from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError @@ -100,26 +99,29 @@ class ZipSandbox: self._sandbox_id = uuid4().hex storage = NoopSandboxStorage() - self._sandbox = ( - SandboxBuilder(self._tenant_id, SandboxType(provider_type)) - .options(provider_options) - .user(self._user_id) - .app(self._app_id) - .storage(storage, assets_id="zip-sandbox") - .build() - ) - self._sandbox.wait_ready(timeout=60) - self._vm = self._sandbox.vm - - SandboxManager.register(self._sandbox_id, self._sandbox) + try: + self._sandbox = ( + SandboxBuilder(self._tenant_id, SandboxType(provider_type)) + .options(provider_options) + .user(self._user_id) + .app(self._app_id) + .storage(storage, assets_id="zip-sandbox") + .build() + ) + self._sandbox.wait_ready(timeout=60) + self._vm = self._sandbox.vm + except Exception: + if self._sandbox is not None: + self._sandbox.release() + self._vm = None + self._sandbox = None + self._sandbox_id = None + raise def _stop(self) -> None: if self._vm is None: return - if self._sandbox_id: - SandboxManager.unregister(self._sandbox_id) - if self._sandbox is not None: self._sandbox.release() diff --git a/api/tests/unit_tests/core/virtual_environment/__base/test_command_future.py b/api/tests/unit_tests/core/virtual_environment/__base/test_command_future.py index 455a64a9af..78d6cfb664 100644 --- a/api/tests/unit_tests/core/virtual_environment/__base/test_command_future.py +++ b/api/tests/unit_tests/core/virtual_environment/__base/test_command_future.py @@ -1,4 +1,5 @@ import threading +from collections.abc import Callable import pytest @@ -18,6 +19,7 @@ def _make_future( exit_code: int = 0, delay_completion: float = 0, close_streams: bool = True, + terminate_command: Callable[[], bool] | None = None, ) -> CommandFuture: stdout_transport = QueueTransportReadCloser() stderr_transport = QueueTransportReadCloser() @@ -48,6 +50,7 @@ def _make_future( stdout_transport=stdout_transport, stderr_transport=stderr_transport, poll_status=poll_status, + terminate_command=terminate_command, poll_interval=0.05, ) @@ -78,6 +81,21 @@ def test_result_raises_timeout_error_when_exceeded(): future.result(timeout=0.2) +def test_timeout_requests_command_termination(): + terminated = threading.Event() + + future = _make_future( + delay_completion=10.0, + close_streams=False, + terminate_command=lambda: terminated.set() or True, + ) + + with pytest.raises(CommandTimeoutError): + future.result(timeout=0.2) + + assert terminated.wait(timeout=1.0) + + def test_done_returns_false_while_running(): future = _make_future(delay_completion=10.0, close_streams=False) @@ -115,6 +133,19 @@ def test_result_raises_cancelled_error_after_cancel(): future.result() +def test_cancel_requests_command_termination(): + terminated = threading.Event() + + future = _make_future( + delay_completion=10.0, + close_streams=False, + terminate_command=lambda: terminated.set() or True, + ) + + assert future.cancel() is True + assert terminated.wait(timeout=1.0) + + def test_nonzero_exit_code_is_returned(): future = _make_future(stdout=b"err", exit_code=42) diff --git a/api/tests/unit_tests/core/virtual_environment/test_helpers.py b/api/tests/unit_tests/core/virtual_environment/test_helpers.py index 8b91ee1c95..11e5885992 100644 --- a/api/tests/unit_tests/core/virtual_environment/test_helpers.py +++ b/api/tests/unit_tests/core/virtual_environment/test_helpers.py @@ -4,6 +4,7 @@ from typing import Any import pytest +from core.entities.provider_entities import BasicProviderConfig from core.virtual_environment.__base.entities import ( Arch, CommandStatus, @@ -64,6 +65,7 @@ class FakeVirtualEnvironment(VirtualEnvironment): self._establish_count = 0 self._release_count = 0 super().__init__(tenant_id="test-tenant", options={}, environments={}) + self.open_enviroment() def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata: return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX) @@ -111,6 +113,10 @@ class FakeVirtualEnvironment(VirtualEnvironment): def validate(cls, _options: Mapping[str, Any]) -> None: pass + @classmethod + def get_config_schema(cls) -> list[BasicProviderConfig]: + return [] + class TestWithConnection: def test_connection_established_and_released(self): diff --git a/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py b/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py index 52f1c986ee..d0239ce4bf 100644 --- a/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py +++ b/api/tests/unit_tests/core/virtual_environment/test_local_without_isolation.py @@ -26,7 +26,9 @@ def _drain_transport(transport: TransportReadCloser) -> bytes: @pytest.fixture def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment: monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64") - return LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)}) + env = LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)}) + env.open_enviroment() + return env def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment): diff --git a/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py b/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py index ead68c9112..b0115310a6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/command/test_command_node.py @@ -42,6 +42,7 @@ class FakeVirtualEnvironment(VirtualEnvironment): self.last_execute_cwd: str | None = None self.released_connections: list[str] = [] super().__init__(tenant_id="test-tenant", options={}, environments={}) + self.open_enviroment() def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX)