refactor: a lot of optimization and enhancement

This commit is contained in:
Harry
2026-03-12 18:22:57 +08:00
parent 4a64362193
commit f0c6c0159c
17 changed files with 248 additions and 173 deletions

View File

@ -23,7 +23,6 @@ if TYPE_CHECKING:
from .initializer.app_assets_initializer import AppAssetsInitializer from .initializer.app_assets_initializer import AppAssetsInitializer
from .initializer.dify_cli_initializer import DifyCliInitializer from .initializer.dify_cli_initializer import DifyCliInitializer
from .initializer.draft_app_assets_initializer import DraftAppAssetsInitializer from .initializer.draft_app_assets_initializer import DraftAppAssetsInitializer
from .manager import SandboxManager
from .sandbox import Sandbox from .sandbox import Sandbox
from .storage import ArchiveSandboxStorage, SandboxStorage from .storage import ArchiveSandboxStorage, SandboxStorage
from .utils.debug import sandbox_debug from .utils.debug import sandbox_debug
@ -47,7 +46,6 @@ __all__ = [
"SandboxBuilder", "SandboxBuilder",
"SandboxInitializeContext", "SandboxInitializeContext",
"SandboxInitializer", "SandboxInitializer",
"SandboxManager",
"SandboxProviderApiEntity", "SandboxProviderApiEntity",
"SandboxStorage", "SandboxStorage",
"SandboxType", "SandboxType",

View File

@ -131,6 +131,7 @@ class SandboxBuilder:
environments=self._environments, environments=self._environments,
user_id=self._user_id, user_id=self._user_id,
) )
vm.open_enviroment()
sandbox = Sandbox( sandbox = Sandbox(
vm=vm, vm=vm,
storage=self._storage, storage=self._storage,
@ -174,7 +175,11 @@ class SandboxBuilder:
if sandbox.is_cancelled(): if sandbox.is_cancelled():
return return
sandbox.mount() # Storage mount is part of readiness. If restore/mount fails,
# the sandbox must surface initialization failure instead of
# becoming "ready" with missing files.
if not sandbox.mount():
raise RuntimeError("Sandbox storage mount failed")
sandbox.mark_ready() sandbox.mark_ready()
except Exception as exc: except Exception as exc:
try: try:

View File

@ -19,12 +19,9 @@ logger = logging.getLogger(__name__)
class DifyCliInitializer(AsyncSandboxInitializer): class DifyCliInitializer(AsyncSandboxInitializer):
_cli_api_session: object | None
def __init__(self, cli_root: str | Path | None = None) -> None: def __init__(self, cli_root: str | Path | None = None) -> None:
self._locator = DifyCliLocator(root=cli_root) self._locator = DifyCliLocator(root=cli_root)
self._tools: list[object] = [] self._tools: list[object] = []
self._cli_api_session = None
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
vm = sandbox.vm vm = sandbox.vm
@ -57,7 +54,7 @@ class DifyCliInitializer(AsyncSandboxInitializer):
logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id) logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id)
return return
self._cli_api_session = CliApiSessionManager().create( global_cli_session = CliApiSessionManager().create(
tenant_id=ctx.tenant_id, tenant_id=ctx.tenant_id,
user_id=ctx.user_id, user_id=ctx.user_id,
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())), context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())),
@ -67,7 +64,7 @@ class DifyCliInitializer(AsyncSandboxInitializer):
["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir" ["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir"
).execute(raise_on_error=True) ).execute(raise_on_error=True)
config = DifyCliConfig.create(self._cli_api_session, ctx.tenant_id, bundle.get_tool_dependencies()) config = DifyCliConfig.create(global_cli_session, ctx.tenant_id, bundle.get_tool_dependencies())
config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False) config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False)
config_path = cli.global_config_path config_path = cli.global_config_path
vm.upload_file(config_path, BytesIO(config_json.encode("utf-8"))) vm.upload_file(config_path, BytesIO(config_json.encode("utf-8")))

View File

@ -5,8 +5,6 @@ from pathlib import PurePosixPath
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
from core.sandbox.inspector.archive_source import SandboxFileArchiveSource from core.sandbox.inspector.archive_source import SandboxFileArchiveSource
from core.sandbox.inspector.base import SandboxFileSource from core.sandbox.inspector.base import SandboxFileSource
from core.sandbox.inspector.runtime_source import SandboxFileRuntimeSource
from core.sandbox.manager import SandboxManager
class SandboxFileBrowser: class SandboxFileBrowser:
@ -31,14 +29,6 @@ class SandboxFileBrowser:
return "." if normalized in (".", "") else normalized return "." if normalized in (".", "") else normalized
def _backend(self) -> SandboxFileSource: def _backend(self) -> SandboxFileSource:
sandbox = SandboxManager.get(self._sandbox_id)
if sandbox is not None:
return SandboxFileRuntimeSource(
tenant_id=self._tenant_id,
app_id=self._app_id,
sandbox_id=self._sandbox_id,
runtime=sandbox.vm,
)
return SandboxFileArchiveSource( return SandboxFileArchiveSource(
tenant_id=self._tenant_id, tenant_id=self._tenant_id,
app_id=self._app_id, app_id=self._app_id,

View File

@ -1,103 +0,0 @@
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING, Final
if TYPE_CHECKING:
from core.sandbox.sandbox import Sandbox
logger = logging.getLogger(__name__)
class SandboxManager:
"""Registry for active Sandbox instances.
Stores complete Sandbox objects (not just VirtualEnvironment) to provide
access to sandbox metadata like tenant_id, app_id, user_id, assets_id.
"""
_NUM_SHARDS: Final[int] = 1024
_SHARD_MASK: Final[int] = _NUM_SHARDS - 1
_shard_locks: Final[tuple[threading.Lock, ...]] = tuple(threading.Lock() for _ in range(_NUM_SHARDS))
_shards: list[dict[str, Sandbox]] = [{} for _ in range(_NUM_SHARDS)]
@classmethod
def _shard_index(cls, sandbox_id: str) -> int:
return hash(sandbox_id) & cls._SHARD_MASK
@classmethod
def register(cls, sandbox_id: str, sandbox: Sandbox) -> None:
if not sandbox_id:
raise ValueError("sandbox_id cannot be empty")
shard_index = cls._shard_index(sandbox_id)
with cls._shard_locks[shard_index]:
shard = cls._shards[shard_index]
if sandbox_id in shard:
raise RuntimeError(
f"Sandbox already registered for sandbox_id={sandbox_id}. "
"Call unregister() first if you need to replace it."
)
new_shard = dict(shard)
new_shard[sandbox_id] = sandbox
cls._shards[shard_index] = new_shard
logger.debug(
"Registered sandbox: sandbox_id=%s, id=%s, app_id=%s",
sandbox_id,
sandbox.id,
sandbox.app_id,
)
@classmethod
def get(cls, sandbox_id: str) -> Sandbox | None:
shard_index = cls._shard_index(sandbox_id)
return cls._shards[shard_index].get(sandbox_id)
@classmethod
def unregister(cls, sandbox_id: str) -> Sandbox | None:
shard_index = cls._shard_index(sandbox_id)
with cls._shard_locks[shard_index]:
shard = cls._shards[shard_index]
sandbox = shard.get(sandbox_id)
if sandbox is None:
return None
new_shard = dict(shard)
new_shard.pop(sandbox_id, None)
cls._shards[shard_index] = new_shard
logger.debug(
"Unregistered sandbox: sandbox_id=%s, id=%s",
sandbox_id,
sandbox.id,
)
return sandbox
@classmethod
def has(cls, sandbox_id: str) -> bool:
shard_index = cls._shard_index(sandbox_id)
return sandbox_id in cls._shards[shard_index]
@classmethod
def is_sandbox_runtime(cls, sandbox_id: str) -> bool:
return cls.has(sandbox_id)
@classmethod
def clear(cls) -> None:
for lock in cls._shard_locks:
lock.acquire()
try:
for i in range(cls._NUM_SHARDS):
cls._shards[i] = {}
logger.debug("Cleared all registered sandboxes")
finally:
for lock in reversed(cls._shard_locks):
lock.release()
@classmethod
def count(cls) -> int:
return sum(len(shard) for shard in cls._shards)

View File

@ -2,6 +2,7 @@ import secrets
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from configs import dify_config
from core.skill.entities import ToolAccessPolicy from core.skill.entities import ToolAccessPolicy
from .session import BaseSession, SessionManager from .session import BaseSession, SessionManager
@ -17,7 +18,11 @@ class CliContext(BaseModel):
class CliApiSessionManager(SessionManager[CliApiSession]): class CliApiSessionManager(SessionManager[CliApiSession]):
def __init__(self, ttl: int | None = None): def __init__(self, ttl: int | None = None):
super().__init__(key_prefix="cli_api_session", session_class=CliApiSession, ttl=ttl) super().__init__(
key_prefix="cli_api_session",
session_class=CliApiSession,
ttl=ttl or dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession: def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession:
session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json")) session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json"))

View File

@ -26,6 +26,9 @@ class CommandFuture:
Lightweight future for command execution. Lightweight future for command execution.
Mirrors concurrent.futures.Future API with 4 essential methods: Mirrors concurrent.futures.Future API with 4 essential methods:
result(), done(), cancel(), cancelled(). result(), done(), cancel(), cancelled().
When a command is cancelled or times out the future now asks the provider
to terminate the underlying process/session before marking itself done.
""" """
def __init__( def __init__(
@ -35,6 +38,7 @@ class CommandFuture:
stdout_transport: TransportReadCloser, stdout_transport: TransportReadCloser,
stderr_transport: TransportReadCloser, stderr_transport: TransportReadCloser,
poll_status: Callable[[], CommandStatus], poll_status: Callable[[], CommandStatus],
terminate_command: Callable[[], bool] | None = None,
poll_interval: float = 0.1, poll_interval: float = 0.1,
): ):
self._pid = pid self._pid = pid
@ -42,6 +46,7 @@ class CommandFuture:
self._stdout_transport = stdout_transport self._stdout_transport = stdout_transport
self._stderr_transport = stderr_transport self._stderr_transport = stderr_transport
self._poll_status = poll_status self._poll_status = poll_status
self._terminate_command = terminate_command
self._poll_interval = poll_interval self._poll_interval = poll_interval
self._done_event = threading.Event() self._done_event = threading.Event()
@ -49,7 +54,9 @@ class CommandFuture:
self._result: CommandResult | None = None self._result: CommandResult | None = None
self._exception: BaseException | None = None self._exception: BaseException | None = None
self._cancelled = False self._cancelled = False
self._timed_out = False
self._started = False self._started = False
self._termination_requested = False
def result(self, timeout: float | None = None) -> CommandResult: def result(self, timeout: float | None = None) -> CommandResult:
""" """
@ -61,15 +68,22 @@ class CommandFuture:
Raises: Raises:
CommandTimeoutError: If timeout exceeded. CommandTimeoutError: If timeout exceeded.
CommandCancelledError: If command was cancelled. CommandCancelledError: If command was cancelled.
A timeout is terminal for this future: it triggers best-effort command
termination and subsequent ``result()`` calls keep raising timeout.
""" """
self._ensure_started() self._ensure_started()
if not self._done_event.wait(timeout): if not self._done_event.wait(timeout):
self._request_stop(timed_out=True)
raise CommandTimeoutError(f"Command timed out after {timeout}s") raise CommandTimeoutError(f"Command timed out after {timeout}s")
if self._cancelled: if self._cancelled:
raise CommandCancelledError("Command was cancelled") raise CommandCancelledError("Command was cancelled")
if self._timed_out:
raise CommandTimeoutError("Command timed out")
if self._exception is not None: if self._exception is not None:
raise self._exception raise self._exception
@ -82,16 +96,10 @@ class CommandFuture:
def cancel(self) -> bool: def cancel(self) -> bool:
""" """
Attempt to cancel command by closing transports. Attempt to cancel command by terminating it and closing transports.
Returns True if cancelled, False if already completed. Returns True if cancelled, False if already completed.
""" """
with self._lock: return self._request_stop(cancelled=True)
if self._done_event.is_set():
return False
self._cancelled = True
self._close_transports()
self._done_event.set()
return True
def cancelled(self) -> bool: def cancelled(self) -> bool:
return self._cancelled return self._cancelled
@ -103,6 +111,28 @@ class CommandFuture:
thread = threading.Thread(target=self._execute, daemon=True) thread = threading.Thread(target=self._execute, daemon=True)
thread.start() thread.start()
def _request_stop(self, *, cancelled: bool = False, timed_out: bool = False) -> bool:
should_terminate = False
with self._lock:
if self._done_event.is_set():
return False
if cancelled:
self._cancelled = True
if timed_out:
self._timed_out = True
should_terminate = not self._termination_requested
if should_terminate:
self._termination_requested = True
self._close_transports()
self._done_event.set()
if should_terminate:
self._terminate_running_command()
return True
def _execute(self) -> None: def _execute(self) -> None:
stdout_buf = bytearray() stdout_buf = bytearray()
stderr_buf = bytearray() stderr_buf = bytearray()
@ -141,7 +171,7 @@ class CommandFuture:
self._close_transports() self._close_transports()
def _wait_for_completion(self) -> int | None: def _wait_for_completion(self) -> int | None:
while not self._cancelled: while not self._cancelled and not self._timed_out:
try: try:
status = self._poll_status() status = self._poll_status()
except NotSupportedOperationError: except NotSupportedOperationError:
@ -167,3 +197,12 @@ class CommandFuture:
for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport): for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
transport.close() transport.close()
def _terminate_running_command(self) -> None:
if self._terminate_command is None:
return
try:
self._terminate_command()
except Exception:
logger.exception("Failed to terminate command for pid %s", self._pid)

View File

@ -71,6 +71,7 @@ def submit_command(
stdout_transport=stdout_transport, stdout_transport=stdout_transport,
stderr_transport=stderr_transport, stderr_transport=stderr_transport,
poll_status=partial(env.get_command_status, connection, pid), poll_status=partial(env.get_command_status, connection, pid),
terminate_command=partial(env.terminate_command, connection, pid),
) )

View File

@ -11,8 +11,19 @@ from core.virtual_environment.channel.transport import TransportReadCloser, Tran
class VirtualEnvironment(ABC): class VirtualEnvironment(ABC):
""" """
Base class for virtual environment implementations. Base class for virtual environment implementations.
``VirtualEnvironment`` instances are configured at construction time but do
not allocate provider resources until ``open_enviroment()`` is called.
This keeps object construction side-effect free and gives callers a chance
to own startup error handling explicitly.
""" """
tenant_id: str
user_id: str | None
options: Mapping[str, Any]
_environments: Mapping[str, str]
_metadata: Metadata | None
def __init__( def __init__(
self, self,
tenant_id: str, tenant_id: str,
@ -21,19 +32,45 @@ class VirtualEnvironment(ABC):
user_id: str | None = None, user_id: str | None = None,
) -> None: ) -> None:
""" """
Initialize the virtual environment with metadata. Initialize the virtual environment configuration.
Args: Args:
tenant_id: The tenant ID associated with this environment (required). tenant_id: The tenant ID associated with this environment (required).
options: Provider-specific configuration options. options: Provider-specific configuration options.
environments: Environment variables to set in the virtual environment. environments: Environment variables to set in the virtual environment.
user_id: The user ID associated with this environment (optional). user_id: The user ID associated with this environment (optional).
The provider runtime itself is created later by ``open_enviroment()``.
""" """
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.user_id = user_id self.user_id = user_id
self.options = options self.options = options
self.metadata = self._construct_environment(options, environments or {}) self._environments = dict(environments or {})
self._metadata = None
@property
def metadata(self) -> Metadata:
"""Provider metadata for a started environment.
Raises:
RuntimeError: If the environment has not been started yet.
"""
if self._metadata is None:
raise RuntimeError("Virtual environment has not been started")
return self._metadata
def open_enviroment(self) -> Metadata:
"""Allocate provider resources and return the resulting metadata.
Multiple calls are safe and return the existing metadata after the first
successful start.
"""
if self._metadata is None:
self._metadata = self._construct_environment(self.options, self._environments)
return self._metadata
@abstractmethod @abstractmethod
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
@ -131,6 +168,18 @@ class VirtualEnvironment(ABC):
Multiple calls to `release_environment` with the same `environment_id` is acceptable. Multiple calls to `release_environment` with the same `environment_id` is acceptable.
""" """
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
"""Best-effort termination hook for a running command.
Providers that can map ``pid`` back to a real process/session should
override this method and stop the command. The default implementation is
a no-op so providers without a termination mechanism remain compatible.
"""
_ = connection_handle
_ = pid
return False
@abstractmethod @abstractmethod
def execute_command( def execute_command(
self, self,

View File

@ -1,3 +1,4 @@
import logging
import posixpath import posixpath
import shlex import shlex
import threading import threading
@ -32,6 +33,8 @@ from core.virtual_environment.channel.transport import (
) )
from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS
logger = logging.getLogger(__name__)
""" """
import logging import logging
from collections.abc import Mapping from collections.abc import Mapping
@ -132,35 +135,53 @@ class E2BEnvironment(VirtualEnvironment):
The sandbox lifetime is capped by ``WORKFLOW_MAX_EXECUTION_TIME`` so the The sandbox lifetime is capped by ``WORKFLOW_MAX_EXECUTION_TIME`` so the
provider can rely on E2B's native timeout instead of a background provider can rely on E2B's native timeout instead of a background
keepalive thread that continuously extends the session. keepalive thread that continuously extends the session.
E2B allocates the remote sandbox before metadata probing completes, so
startup failures must best-effort terminate the sandbox before the
exception escapes.
""" """
# Import E2B SDK lazily so it is loaded after gevent monkey-patching. # Import E2B SDK lazily so it is loaded after gevent monkey-patching.
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped] from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
# TODO: add Dify as the user agent # TODO: add Dify as the user agent
sandbox = Sandbox.create( sandbox = None
template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"), sandbox_id: str | None = None
timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME, api_key = options.get(self.OptionsKey.API_KEY, "")
api_key=options.get(self.OptionsKey.API_KEY, ""), try:
api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL), sandbox = Sandbox.create(
envs=dict(environments), template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"),
) timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
info = sandbox.get_info(api_key=options.get(self.OptionsKey.API_KEY, "")) api_key=api_key,
system_info = sandbox.commands.run("uname -m -s").stdout.strip() api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL),
system_parts = system_info.split() envs=dict(environments),
if len(system_parts) == 2: )
os_part, arch_part = system_parts info = sandbox.get_info(api_key=api_key)
else: sandbox_id = info.sandbox_id
arch_part = system_parts[0] system_info = sandbox.commands.run("uname -m -s").stdout.strip()
os_part = system_parts[1] if len(system_parts) > 1 else "" system_parts = system_info.split()
if len(system_parts) == 2:
os_part, arch_part = system_parts
else:
arch_part = system_parts[0]
os_part = system_parts[1] if len(system_parts) > 1 else ""
return Metadata( return Metadata(
id=info.sandbox_id, id=info.sandbox_id,
arch=self._convert_architecture(arch_part.strip()), arch=self._convert_architecture(arch_part.strip()),
os=self._convert_operating_system(os_part.strip()), os=self._convert_operating_system(os_part.strip()),
store={ store={
self.StoreKey.SANDBOX: sandbox, self.StoreKey.SANDBOX: sandbox,
}, },
) )
except Exception:
if sandbox_id is None and sandbox is not None:
sandbox_id = getattr(sandbox, "sandbox_id", None)
if sandbox_id is not None:
try:
Sandbox.kill(api_key=api_key, sandbox_id=sandbox_id)
except Exception:
logger.exception("Failed to cleanup E2B sandbox after startup failure")
raise
def release_environment(self) -> None: def release_environment(self) -> None:
""" """

View File

@ -1,6 +1,7 @@
import os import os
import pathlib import pathlib
import shutil import shutil
import signal
import subprocess import subprocess
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from functools import cached_property from functools import cached_property
@ -246,6 +247,16 @@ class LocalVirtualEnvironment(VirtualEnvironment):
except ChildProcessError: except ChildProcessError:
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None) return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
"""Terminate a locally spawned process by PID when cancellation is requested."""
_ = connection_handle
try:
os.kill(int(pid), signal.SIGTERM)
except ProcessLookupError:
return False
return True
def _get_os_architecture(self) -> Arch: def _get_os_architecture(self) -> Arch:
""" """
Get the operating system architecture. Get the operating system architecture.

View File

@ -76,6 +76,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
) -> None: ) -> None:
self._connections: dict[str, Any] = {} self._connections: dict[str, Any] = {}
self._commands: dict[str, CommandStatus] = {} self._commands: dict[str, CommandStatus] = {}
self._command_channels: dict[str, Any] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
super().__init__(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id) super().__init__(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id)
@ -163,6 +164,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
with self._lock: with self._lock:
self._commands[pid] = CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None) self._commands[pid] = CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
self._command_channels[pid] = channel
threading.Thread( threading.Thread(
target=self._consume_channel_output, target=self._consume_channel_output,
@ -179,6 +181,23 @@ class SSHSandboxEnvironment(VirtualEnvironment):
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None) return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
return status return status
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
"""Best-effort termination by closing the SSH channel that owns the command."""
_ = connection_handle
with self._lock:
channel = self._command_channels.get(pid)
if channel is None:
return False
self._commands[pid] = CommandStatus(
status=CommandStatus.Status.COMPLETED,
exit_code=self._COMMAND_TIMEOUT_EXIT_CODE,
)
with contextlib.suppress(Exception):
channel.close()
return True
def upload_file(self, path: str, content: BytesIO) -> None: def upload_file(self, path: str, content: BytesIO) -> None:
destination_path = self._workspace_path(path) destination_path = self._workspace_path(path)
with self._client() as client: with self._client() as client:
@ -424,6 +443,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
channel.close() channel.close()
with self._lock: with self._lock:
self._command_channels.pop(pid, None)
self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code) self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
def _set_sftp_operation_timeout(self, sftp: Any) -> None: def _set_sftp_operation_timeout(self, sftp: Any) -> None:

View File

@ -12,7 +12,6 @@ from uuid import uuid4
from core.sandbox.builder import SandboxBuilder from core.sandbox.builder import SandboxBuilder
from core.sandbox.entities.sandbox_type import SandboxType from core.sandbox.entities.sandbox_type import SandboxType
from core.sandbox.manager import SandboxManager
from core.sandbox.sandbox import Sandbox from core.sandbox.sandbox import Sandbox
from core.sandbox.storage.noop_storage import NoopSandboxStorage from core.sandbox.storage.noop_storage import NoopSandboxStorage
from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError
@ -100,26 +99,29 @@ class ZipSandbox:
self._sandbox_id = uuid4().hex self._sandbox_id = uuid4().hex
storage = NoopSandboxStorage() storage = NoopSandboxStorage()
self._sandbox = ( try:
SandboxBuilder(self._tenant_id, SandboxType(provider_type)) self._sandbox = (
.options(provider_options) SandboxBuilder(self._tenant_id, SandboxType(provider_type))
.user(self._user_id) .options(provider_options)
.app(self._app_id) .user(self._user_id)
.storage(storage, assets_id="zip-sandbox") .app(self._app_id)
.build() .storage(storage, assets_id="zip-sandbox")
) .build()
self._sandbox.wait_ready(timeout=60) )
self._vm = self._sandbox.vm self._sandbox.wait_ready(timeout=60)
self._vm = self._sandbox.vm
SandboxManager.register(self._sandbox_id, self._sandbox) except Exception:
if self._sandbox is not None:
self._sandbox.release()
self._vm = None
self._sandbox = None
self._sandbox_id = None
raise
def _stop(self) -> None: def _stop(self) -> None:
if self._vm is None: if self._vm is None:
return return
if self._sandbox_id:
SandboxManager.unregister(self._sandbox_id)
if self._sandbox is not None: if self._sandbox is not None:
self._sandbox.release() self._sandbox.release()

View File

@ -1,4 +1,5 @@
import threading import threading
from collections.abc import Callable
import pytest import pytest
@ -18,6 +19,7 @@ def _make_future(
exit_code: int = 0, exit_code: int = 0,
delay_completion: float = 0, delay_completion: float = 0,
close_streams: bool = True, close_streams: bool = True,
terminate_command: Callable[[], bool] | None = None,
) -> CommandFuture: ) -> CommandFuture:
stdout_transport = QueueTransportReadCloser() stdout_transport = QueueTransportReadCloser()
stderr_transport = QueueTransportReadCloser() stderr_transport = QueueTransportReadCloser()
@ -48,6 +50,7 @@ def _make_future(
stdout_transport=stdout_transport, stdout_transport=stdout_transport,
stderr_transport=stderr_transport, stderr_transport=stderr_transport,
poll_status=poll_status, poll_status=poll_status,
terminate_command=terminate_command,
poll_interval=0.05, poll_interval=0.05,
) )
@ -78,6 +81,21 @@ def test_result_raises_timeout_error_when_exceeded():
future.result(timeout=0.2) future.result(timeout=0.2)
def test_timeout_requests_command_termination():
terminated = threading.Event()
future = _make_future(
delay_completion=10.0,
close_streams=False,
terminate_command=lambda: terminated.set() or True,
)
with pytest.raises(CommandTimeoutError):
future.result(timeout=0.2)
assert terminated.wait(timeout=1.0)
def test_done_returns_false_while_running(): def test_done_returns_false_while_running():
future = _make_future(delay_completion=10.0, close_streams=False) future = _make_future(delay_completion=10.0, close_streams=False)
@ -115,6 +133,19 @@ def test_result_raises_cancelled_error_after_cancel():
future.result() future.result()
def test_cancel_requests_command_termination():
terminated = threading.Event()
future = _make_future(
delay_completion=10.0,
close_streams=False,
terminate_command=lambda: terminated.set() or True,
)
assert future.cancel() is True
assert terminated.wait(timeout=1.0)
def test_nonzero_exit_code_is_returned(): def test_nonzero_exit_code_is_returned():
future = _make_future(stdout=b"err", exit_code=42) future = _make_future(stdout=b"err", exit_code=42)

View File

@ -4,6 +4,7 @@ from typing import Any
import pytest import pytest
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import ( from core.virtual_environment.__base.entities import (
Arch, Arch,
CommandStatus, CommandStatus,
@ -64,6 +65,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
self._establish_count = 0 self._establish_count = 0
self._release_count = 0 self._release_count = 0
super().__init__(tenant_id="test-tenant", options={}, environments={}) super().__init__(tenant_id="test-tenant", options={}, environments={})
self.open_enviroment()
def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata: def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata:
return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX) return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX)
@ -111,6 +113,10 @@ class FakeVirtualEnvironment(VirtualEnvironment):
def validate(cls, _options: Mapping[str, Any]) -> None: def validate(cls, _options: Mapping[str, Any]) -> None:
pass pass
@classmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
return []
class TestWithConnection: class TestWithConnection:
def test_connection_established_and_released(self): def test_connection_established_and_released(self):

View File

@ -26,7 +26,9 @@ def _drain_transport(transport: TransportReadCloser) -> bytes:
@pytest.fixture @pytest.fixture
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment: def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64") monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
return LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)}) env = LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)})
env.open_enviroment()
return env
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment): def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):

View File

@ -42,6 +42,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
self.last_execute_cwd: str | None = None self.last_execute_cwd: str | None = None
self.released_connections: list[str] = [] self.released_connections: list[str] = []
super().__init__(tenant_id="test-tenant", options={}, environments={}) super().__init__(tenant_id="test-tenant", options={}, environments={})
self.open_enviroment()
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata: def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX) return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX)