mirror of
https://github.com/langgenius/dify.git
synced 2026-03-17 04:47:50 +08:00
refactor: a lot of optimization and enhancement
This commit is contained in:
@ -23,7 +23,6 @@ if TYPE_CHECKING:
|
||||
from .initializer.app_assets_initializer import AppAssetsInitializer
|
||||
from .initializer.dify_cli_initializer import DifyCliInitializer
|
||||
from .initializer.draft_app_assets_initializer import DraftAppAssetsInitializer
|
||||
from .manager import SandboxManager
|
||||
from .sandbox import Sandbox
|
||||
from .storage import ArchiveSandboxStorage, SandboxStorage
|
||||
from .utils.debug import sandbox_debug
|
||||
@ -47,7 +46,6 @@ __all__ = [
|
||||
"SandboxBuilder",
|
||||
"SandboxInitializeContext",
|
||||
"SandboxInitializer",
|
||||
"SandboxManager",
|
||||
"SandboxProviderApiEntity",
|
||||
"SandboxStorage",
|
||||
"SandboxType",
|
||||
|
||||
@ -131,6 +131,7 @@ class SandboxBuilder:
|
||||
environments=self._environments,
|
||||
user_id=self._user_id,
|
||||
)
|
||||
vm.open_enviroment()
|
||||
sandbox = Sandbox(
|
||||
vm=vm,
|
||||
storage=self._storage,
|
||||
@ -174,7 +175,11 @@ class SandboxBuilder:
|
||||
|
||||
if sandbox.is_cancelled():
|
||||
return
|
||||
sandbox.mount()
|
||||
# Storage mount is part of readiness. If restore/mount fails,
|
||||
# the sandbox must surface initialization failure instead of
|
||||
# becoming "ready" with missing files.
|
||||
if not sandbox.mount():
|
||||
raise RuntimeError("Sandbox storage mount failed")
|
||||
sandbox.mark_ready()
|
||||
except Exception as exc:
|
||||
try:
|
||||
|
||||
@ -19,12 +19,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyCliInitializer(AsyncSandboxInitializer):
|
||||
_cli_api_session: object | None
|
||||
|
||||
def __init__(self, cli_root: str | Path | None = None) -> None:
|
||||
self._locator = DifyCliLocator(root=cli_root)
|
||||
self._tools: list[object] = []
|
||||
self._cli_api_session = None
|
||||
|
||||
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
|
||||
vm = sandbox.vm
|
||||
@ -57,7 +54,7 @@ class DifyCliInitializer(AsyncSandboxInitializer):
|
||||
logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id)
|
||||
return
|
||||
|
||||
self._cli_api_session = CliApiSessionManager().create(
|
||||
global_cli_session = CliApiSessionManager().create(
|
||||
tenant_id=ctx.tenant_id,
|
||||
user_id=ctx.user_id,
|
||||
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())),
|
||||
@ -67,7 +64,7 @@ class DifyCliInitializer(AsyncSandboxInitializer):
|
||||
["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
config = DifyCliConfig.create(self._cli_api_session, ctx.tenant_id, bundle.get_tool_dependencies())
|
||||
config = DifyCliConfig.create(global_cli_session, ctx.tenant_id, bundle.get_tool_dependencies())
|
||||
config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False)
|
||||
config_path = cli.global_config_path
|
||||
vm.upload_file(config_path, BytesIO(config_json.encode("utf-8")))
|
||||
|
||||
@ -5,8 +5,6 @@ from pathlib import PurePosixPath
|
||||
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
|
||||
from core.sandbox.inspector.archive_source import SandboxFileArchiveSource
|
||||
from core.sandbox.inspector.base import SandboxFileSource
|
||||
from core.sandbox.inspector.runtime_source import SandboxFileRuntimeSource
|
||||
from core.sandbox.manager import SandboxManager
|
||||
|
||||
|
||||
class SandboxFileBrowser:
|
||||
@ -31,14 +29,6 @@ class SandboxFileBrowser:
|
||||
return "." if normalized in (".", "") else normalized
|
||||
|
||||
def _backend(self) -> SandboxFileSource:
|
||||
sandbox = SandboxManager.get(self._sandbox_id)
|
||||
if sandbox is not None:
|
||||
return SandboxFileRuntimeSource(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
sandbox_id=self._sandbox_id,
|
||||
runtime=sandbox.vm,
|
||||
)
|
||||
return SandboxFileArchiveSource(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
|
||||
@ -1,103 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxManager:
|
||||
"""Registry for active Sandbox instances.
|
||||
|
||||
Stores complete Sandbox objects (not just VirtualEnvironment) to provide
|
||||
access to sandbox metadata like tenant_id, app_id, user_id, assets_id.
|
||||
"""
|
||||
|
||||
_NUM_SHARDS: Final[int] = 1024
|
||||
_SHARD_MASK: Final[int] = _NUM_SHARDS - 1
|
||||
|
||||
_shard_locks: Final[tuple[threading.Lock, ...]] = tuple(threading.Lock() for _ in range(_NUM_SHARDS))
|
||||
_shards: list[dict[str, Sandbox]] = [{} for _ in range(_NUM_SHARDS)]
|
||||
|
||||
@classmethod
|
||||
def _shard_index(cls, sandbox_id: str) -> int:
|
||||
return hash(sandbox_id) & cls._SHARD_MASK
|
||||
|
||||
@classmethod
|
||||
def register(cls, sandbox_id: str, sandbox: Sandbox) -> None:
|
||||
if not sandbox_id:
|
||||
raise ValueError("sandbox_id cannot be empty")
|
||||
|
||||
shard_index = cls._shard_index(sandbox_id)
|
||||
with cls._shard_locks[shard_index]:
|
||||
shard = cls._shards[shard_index]
|
||||
if sandbox_id in shard:
|
||||
raise RuntimeError(
|
||||
f"Sandbox already registered for sandbox_id={sandbox_id}. "
|
||||
"Call unregister() first if you need to replace it."
|
||||
)
|
||||
|
||||
new_shard = dict(shard)
|
||||
new_shard[sandbox_id] = sandbox
|
||||
cls._shards[shard_index] = new_shard
|
||||
|
||||
logger.debug(
|
||||
"Registered sandbox: sandbox_id=%s, id=%s, app_id=%s",
|
||||
sandbox_id,
|
||||
sandbox.id,
|
||||
sandbox.app_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get(cls, sandbox_id: str) -> Sandbox | None:
|
||||
shard_index = cls._shard_index(sandbox_id)
|
||||
return cls._shards[shard_index].get(sandbox_id)
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, sandbox_id: str) -> Sandbox | None:
|
||||
shard_index = cls._shard_index(sandbox_id)
|
||||
with cls._shard_locks[shard_index]:
|
||||
shard = cls._shards[shard_index]
|
||||
sandbox = shard.get(sandbox_id)
|
||||
if sandbox is None:
|
||||
return None
|
||||
|
||||
new_shard = dict(shard)
|
||||
new_shard.pop(sandbox_id, None)
|
||||
cls._shards[shard_index] = new_shard
|
||||
|
||||
logger.debug(
|
||||
"Unregistered sandbox: sandbox_id=%s, id=%s",
|
||||
sandbox_id,
|
||||
sandbox.id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
@classmethod
|
||||
def has(cls, sandbox_id: str) -> bool:
|
||||
shard_index = cls._shard_index(sandbox_id)
|
||||
return sandbox_id in cls._shards[shard_index]
|
||||
|
||||
@classmethod
|
||||
def is_sandbox_runtime(cls, sandbox_id: str) -> bool:
|
||||
return cls.has(sandbox_id)
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
for lock in cls._shard_locks:
|
||||
lock.acquire()
|
||||
try:
|
||||
for i in range(cls._NUM_SHARDS):
|
||||
cls._shards[i] = {}
|
||||
logger.debug("Cleared all registered sandboxes")
|
||||
finally:
|
||||
for lock in reversed(cls._shard_locks):
|
||||
lock.release()
|
||||
|
||||
@classmethod
|
||||
def count(cls) -> int:
|
||||
return sum(len(shard) for shard in cls._shards)
|
||||
@ -2,6 +2,7 @@ import secrets
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from core.skill.entities import ToolAccessPolicy
|
||||
|
||||
from .session import BaseSession, SessionManager
|
||||
@ -17,7 +18,11 @@ class CliContext(BaseModel):
|
||||
|
||||
class CliApiSessionManager(SessionManager[CliApiSession]):
|
||||
def __init__(self, ttl: int | None = None):
|
||||
super().__init__(key_prefix="cli_api_session", session_class=CliApiSession, ttl=ttl)
|
||||
super().__init__(
|
||||
key_prefix="cli_api_session",
|
||||
session_class=CliApiSession,
|
||||
ttl=ttl or dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
)
|
||||
|
||||
def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession:
|
||||
session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json"))
|
||||
|
||||
@ -26,6 +26,9 @@ class CommandFuture:
|
||||
Lightweight future for command execution.
|
||||
Mirrors concurrent.futures.Future API with 4 essential methods:
|
||||
result(), done(), cancel(), cancelled().
|
||||
|
||||
When a command is cancelled or times out the future now asks the provider
|
||||
to terminate the underlying process/session before marking itself done.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -35,6 +38,7 @@ class CommandFuture:
|
||||
stdout_transport: TransportReadCloser,
|
||||
stderr_transport: TransportReadCloser,
|
||||
poll_status: Callable[[], CommandStatus],
|
||||
terminate_command: Callable[[], bool] | None = None,
|
||||
poll_interval: float = 0.1,
|
||||
):
|
||||
self._pid = pid
|
||||
@ -42,6 +46,7 @@ class CommandFuture:
|
||||
self._stdout_transport = stdout_transport
|
||||
self._stderr_transport = stderr_transport
|
||||
self._poll_status = poll_status
|
||||
self._terminate_command = terminate_command
|
||||
self._poll_interval = poll_interval
|
||||
|
||||
self._done_event = threading.Event()
|
||||
@ -49,7 +54,9 @@ class CommandFuture:
|
||||
self._result: CommandResult | None = None
|
||||
self._exception: BaseException | None = None
|
||||
self._cancelled = False
|
||||
self._timed_out = False
|
||||
self._started = False
|
||||
self._termination_requested = False
|
||||
|
||||
def result(self, timeout: float | None = None) -> CommandResult:
|
||||
"""
|
||||
@ -61,15 +68,22 @@ class CommandFuture:
|
||||
Raises:
|
||||
CommandTimeoutError: If timeout exceeded.
|
||||
CommandCancelledError: If command was cancelled.
|
||||
|
||||
A timeout is terminal for this future: it triggers best-effort command
|
||||
termination and subsequent ``result()`` calls keep raising timeout.
|
||||
"""
|
||||
self._ensure_started()
|
||||
|
||||
if not self._done_event.wait(timeout):
|
||||
self._request_stop(timed_out=True)
|
||||
raise CommandTimeoutError(f"Command timed out after {timeout}s")
|
||||
|
||||
if self._cancelled:
|
||||
raise CommandCancelledError("Command was cancelled")
|
||||
|
||||
if self._timed_out:
|
||||
raise CommandTimeoutError("Command timed out")
|
||||
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
@ -82,16 +96,10 @@ class CommandFuture:
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""
|
||||
Attempt to cancel command by closing transports.
|
||||
Attempt to cancel command by terminating it and closing transports.
|
||||
Returns True if cancelled, False if already completed.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._done_event.is_set():
|
||||
return False
|
||||
self._cancelled = True
|
||||
self._close_transports()
|
||||
self._done_event.set()
|
||||
return True
|
||||
return self._request_stop(cancelled=True)
|
||||
|
||||
def cancelled(self) -> bool:
|
||||
return self._cancelled
|
||||
@ -103,6 +111,28 @@ class CommandFuture:
|
||||
thread = threading.Thread(target=self._execute, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _request_stop(self, *, cancelled: bool = False, timed_out: bool = False) -> bool:
|
||||
should_terminate = False
|
||||
with self._lock:
|
||||
if self._done_event.is_set():
|
||||
return False
|
||||
|
||||
if cancelled:
|
||||
self._cancelled = True
|
||||
if timed_out:
|
||||
self._timed_out = True
|
||||
|
||||
should_terminate = not self._termination_requested
|
||||
if should_terminate:
|
||||
self._termination_requested = True
|
||||
|
||||
self._close_transports()
|
||||
self._done_event.set()
|
||||
|
||||
if should_terminate:
|
||||
self._terminate_running_command()
|
||||
return True
|
||||
|
||||
def _execute(self) -> None:
|
||||
stdout_buf = bytearray()
|
||||
stderr_buf = bytearray()
|
||||
@ -141,7 +171,7 @@ class CommandFuture:
|
||||
self._close_transports()
|
||||
|
||||
def _wait_for_completion(self) -> int | None:
|
||||
while not self._cancelled:
|
||||
while not self._cancelled and not self._timed_out:
|
||||
try:
|
||||
status = self._poll_status()
|
||||
except NotSupportedOperationError:
|
||||
@ -167,3 +197,12 @@ class CommandFuture:
|
||||
for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport):
|
||||
with contextlib.suppress(Exception):
|
||||
transport.close()
|
||||
|
||||
def _terminate_running_command(self) -> None:
|
||||
if self._terminate_command is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self._terminate_command()
|
||||
except Exception:
|
||||
logger.exception("Failed to terminate command for pid %s", self._pid)
|
||||
|
||||
@ -71,6 +71,7 @@ def submit_command(
|
||||
stdout_transport=stdout_transport,
|
||||
stderr_transport=stderr_transport,
|
||||
poll_status=partial(env.get_command_status, connection, pid),
|
||||
terminate_command=partial(env.terminate_command, connection, pid),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -11,8 +11,19 @@ from core.virtual_environment.channel.transport import TransportReadCloser, Tran
|
||||
class VirtualEnvironment(ABC):
|
||||
"""
|
||||
Base class for virtual environment implementations.
|
||||
|
||||
``VirtualEnvironment`` instances are configured at construction time but do
|
||||
not allocate provider resources until ``open_enviroment()`` is called.
|
||||
This keeps object construction side-effect free and gives callers a chance
|
||||
to own startup error handling explicitly.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None
|
||||
options: Mapping[str, Any]
|
||||
_environments: Mapping[str, str]
|
||||
_metadata: Metadata | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -21,19 +32,45 @@ class VirtualEnvironment(ABC):
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the virtual environment with metadata.
|
||||
Initialize the virtual environment configuration.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID associated with this environment (required).
|
||||
options: Provider-specific configuration options.
|
||||
environments: Environment variables to set in the virtual environment.
|
||||
user_id: The user ID associated with this environment (optional).
|
||||
|
||||
The provider runtime itself is created later by ``open_enviroment()``.
|
||||
"""
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.options = options
|
||||
self.metadata = self._construct_environment(options, environments or {})
|
||||
self._environments = dict(environments or {})
|
||||
self._metadata = None
|
||||
|
||||
@property
|
||||
def metadata(self) -> Metadata:
|
||||
"""Provider metadata for a started environment.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the environment has not been started yet.
|
||||
"""
|
||||
|
||||
if self._metadata is None:
|
||||
raise RuntimeError("Virtual environment has not been started")
|
||||
return self._metadata
|
||||
|
||||
def open_enviroment(self) -> Metadata:
|
||||
"""Allocate provider resources and return the resulting metadata.
|
||||
|
||||
Multiple calls are safe and return the existing metadata after the first
|
||||
successful start.
|
||||
"""
|
||||
|
||||
if self._metadata is None:
|
||||
self._metadata = self._construct_environment(self.options, self._environments)
|
||||
return self._metadata
|
||||
|
||||
@abstractmethod
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
@ -131,6 +168,18 @@ class VirtualEnvironment(ABC):
|
||||
Multiple calls to `release_environment` with the same `environment_id` is acceptable.
|
||||
"""
|
||||
|
||||
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
|
||||
"""Best-effort termination hook for a running command.
|
||||
|
||||
Providers that can map ``pid`` back to a real process/session should
|
||||
override this method and stop the command. The default implementation is
|
||||
a no-op so providers without a termination mechanism remain compatible.
|
||||
"""
|
||||
|
||||
_ = connection_handle
|
||||
_ = pid
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def execute_command(
|
||||
self,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import posixpath
|
||||
import shlex
|
||||
import threading
|
||||
@ -32,6 +33,8 @@ from core.virtual_environment.channel.transport import (
|
||||
)
|
||||
from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
@ -132,35 +135,53 @@ class E2BEnvironment(VirtualEnvironment):
|
||||
The sandbox lifetime is capped by ``WORKFLOW_MAX_EXECUTION_TIME`` so the
|
||||
provider can rely on E2B's native timeout instead of a background
|
||||
keepalive thread that continuously extends the session.
|
||||
|
||||
E2B allocates the remote sandbox before metadata probing completes, so
|
||||
startup failures must best-effort terminate the sandbox before the
|
||||
exception escapes.
|
||||
"""
|
||||
# Import E2B SDK lazily so it is loaded after gevent monkey-patching.
|
||||
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
|
||||
|
||||
# TODO: add Dify as the user agent
|
||||
sandbox = Sandbox.create(
|
||||
template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"),
|
||||
timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
api_key=options.get(self.OptionsKey.API_KEY, ""),
|
||||
api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL),
|
||||
envs=dict(environments),
|
||||
)
|
||||
info = sandbox.get_info(api_key=options.get(self.OptionsKey.API_KEY, ""))
|
||||
system_info = sandbox.commands.run("uname -m -s").stdout.strip()
|
||||
system_parts = system_info.split()
|
||||
if len(system_parts) == 2:
|
||||
os_part, arch_part = system_parts
|
||||
else:
|
||||
arch_part = system_parts[0]
|
||||
os_part = system_parts[1] if len(system_parts) > 1 else ""
|
||||
sandbox = None
|
||||
sandbox_id: str | None = None
|
||||
api_key = options.get(self.OptionsKey.API_KEY, "")
|
||||
try:
|
||||
sandbox = Sandbox.create(
|
||||
template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"),
|
||||
timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
api_key=api_key,
|
||||
api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL),
|
||||
envs=dict(environments),
|
||||
)
|
||||
info = sandbox.get_info(api_key=api_key)
|
||||
sandbox_id = info.sandbox_id
|
||||
system_info = sandbox.commands.run("uname -m -s").stdout.strip()
|
||||
system_parts = system_info.split()
|
||||
if len(system_parts) == 2:
|
||||
os_part, arch_part = system_parts
|
||||
else:
|
||||
arch_part = system_parts[0]
|
||||
os_part = system_parts[1] if len(system_parts) > 1 else ""
|
||||
|
||||
return Metadata(
|
||||
id=info.sandbox_id,
|
||||
arch=self._convert_architecture(arch_part.strip()),
|
||||
os=self._convert_operating_system(os_part.strip()),
|
||||
store={
|
||||
self.StoreKey.SANDBOX: sandbox,
|
||||
},
|
||||
)
|
||||
return Metadata(
|
||||
id=info.sandbox_id,
|
||||
arch=self._convert_architecture(arch_part.strip()),
|
||||
os=self._convert_operating_system(os_part.strip()),
|
||||
store={
|
||||
self.StoreKey.SANDBOX: sandbox,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
if sandbox_id is None and sandbox is not None:
|
||||
sandbox_id = getattr(sandbox, "sandbox_id", None)
|
||||
if sandbox_id is not None:
|
||||
try:
|
||||
Sandbox.kill(api_key=api_key, sandbox_id=sandbox_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to cleanup E2B sandbox after startup failure")
|
||||
raise
|
||||
|
||||
def release_environment(self) -> None:
|
||||
"""
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cached_property
|
||||
@ -246,6 +247,16 @@ class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
except ChildProcessError:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||
|
||||
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
|
||||
"""Terminate a locally spawned process by PID when cancellation is requested."""
|
||||
|
||||
_ = connection_handle
|
||||
try:
|
||||
os.kill(int(pid), signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_os_architecture(self) -> Arch:
|
||||
"""
|
||||
Get the operating system architecture.
|
||||
|
||||
@ -76,6 +76,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
) -> None:
|
||||
self._connections: dict[str, Any] = {}
|
||||
self._commands: dict[str, CommandStatus] = {}
|
||||
self._command_channels: dict[str, Any] = {}
|
||||
self._lock = threading.Lock()
|
||||
super().__init__(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id)
|
||||
|
||||
@ -163,6 +164,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
|
||||
with self._lock:
|
||||
self._commands[pid] = CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
self._command_channels[pid] = channel
|
||||
|
||||
threading.Thread(
|
||||
target=self._consume_channel_output,
|
||||
@ -179,6 +181,23 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||
return status
|
||||
|
||||
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
|
||||
"""Best-effort termination by closing the SSH channel that owns the command."""
|
||||
|
||||
_ = connection_handle
|
||||
with self._lock:
|
||||
channel = self._command_channels.get(pid)
|
||||
if channel is None:
|
||||
return False
|
||||
self._commands[pid] = CommandStatus(
|
||||
status=CommandStatus.Status.COMPLETED,
|
||||
exit_code=self._COMMAND_TIMEOUT_EXIT_CODE,
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
channel.close()
|
||||
return True
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
destination_path = self._workspace_path(path)
|
||||
with self._client() as client:
|
||||
@ -424,6 +443,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
channel.close()
|
||||
|
||||
with self._lock:
|
||||
self._command_channels.pop(pid, None)
|
||||
self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
|
||||
def _set_sftp_operation_timeout(self, sftp: Any) -> None:
|
||||
|
||||
@ -12,7 +12,6 @@ from uuid import uuid4
|
||||
|
||||
from core.sandbox.builder import SandboxBuilder
|
||||
from core.sandbox.entities.sandbox_type import SandboxType
|
||||
from core.sandbox.manager import SandboxManager
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.sandbox.storage.noop_storage import NoopSandboxStorage
|
||||
from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError
|
||||
@ -100,26 +99,29 @@ class ZipSandbox:
|
||||
self._sandbox_id = uuid4().hex
|
||||
|
||||
storage = NoopSandboxStorage()
|
||||
self._sandbox = (
|
||||
SandboxBuilder(self._tenant_id, SandboxType(provider_type))
|
||||
.options(provider_options)
|
||||
.user(self._user_id)
|
||||
.app(self._app_id)
|
||||
.storage(storage, assets_id="zip-sandbox")
|
||||
.build()
|
||||
)
|
||||
self._sandbox.wait_ready(timeout=60)
|
||||
self._vm = self._sandbox.vm
|
||||
|
||||
SandboxManager.register(self._sandbox_id, self._sandbox)
|
||||
try:
|
||||
self._sandbox = (
|
||||
SandboxBuilder(self._tenant_id, SandboxType(provider_type))
|
||||
.options(provider_options)
|
||||
.user(self._user_id)
|
||||
.app(self._app_id)
|
||||
.storage(storage, assets_id="zip-sandbox")
|
||||
.build()
|
||||
)
|
||||
self._sandbox.wait_ready(timeout=60)
|
||||
self._vm = self._sandbox.vm
|
||||
except Exception:
|
||||
if self._sandbox is not None:
|
||||
self._sandbox.release()
|
||||
self._vm = None
|
||||
self._sandbox = None
|
||||
self._sandbox_id = None
|
||||
raise
|
||||
|
||||
def _stop(self) -> None:
|
||||
if self._vm is None:
|
||||
return
|
||||
|
||||
if self._sandbox_id:
|
||||
SandboxManager.unregister(self._sandbox_id)
|
||||
|
||||
if self._sandbox is not None:
|
||||
self._sandbox.release()
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
@ -18,6 +19,7 @@ def _make_future(
|
||||
exit_code: int = 0,
|
||||
delay_completion: float = 0,
|
||||
close_streams: bool = True,
|
||||
terminate_command: Callable[[], bool] | None = None,
|
||||
) -> CommandFuture:
|
||||
stdout_transport = QueueTransportReadCloser()
|
||||
stderr_transport = QueueTransportReadCloser()
|
||||
@ -48,6 +50,7 @@ def _make_future(
|
||||
stdout_transport=stdout_transport,
|
||||
stderr_transport=stderr_transport,
|
||||
poll_status=poll_status,
|
||||
terminate_command=terminate_command,
|
||||
poll_interval=0.05,
|
||||
)
|
||||
|
||||
@ -78,6 +81,21 @@ def test_result_raises_timeout_error_when_exceeded():
|
||||
future.result(timeout=0.2)
|
||||
|
||||
|
||||
def test_timeout_requests_command_termination():
|
||||
terminated = threading.Event()
|
||||
|
||||
future = _make_future(
|
||||
delay_completion=10.0,
|
||||
close_streams=False,
|
||||
terminate_command=lambda: terminated.set() or True,
|
||||
)
|
||||
|
||||
with pytest.raises(CommandTimeoutError):
|
||||
future.result(timeout=0.2)
|
||||
|
||||
assert terminated.wait(timeout=1.0)
|
||||
|
||||
|
||||
def test_done_returns_false_while_running():
|
||||
future = _make_future(delay_completion=10.0, close_streams=False)
|
||||
|
||||
@ -115,6 +133,19 @@ def test_result_raises_cancelled_error_after_cancel():
|
||||
future.result()
|
||||
|
||||
|
||||
def test_cancel_requests_command_termination():
|
||||
terminated = threading.Event()
|
||||
|
||||
future = _make_future(
|
||||
delay_completion=10.0,
|
||||
close_streams=False,
|
||||
terminate_command=lambda: terminated.set() or True,
|
||||
)
|
||||
|
||||
assert future.cancel() is True
|
||||
assert terminated.wait(timeout=1.0)
|
||||
|
||||
|
||||
def test_nonzero_exit_code_is_returned():
|
||||
future = _make_future(stdout=b"err", exit_code=42)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
@ -64,6 +65,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
self._establish_count = 0
|
||||
self._release_count = 0
|
||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||
self.open_enviroment()
|
||||
|
||||
def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata:
|
||||
return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX)
|
||||
@ -111,6 +113,10 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
def validate(cls, _options: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||
return []
|
||||
|
||||
|
||||
class TestWithConnection:
|
||||
def test_connection_established_and_released(self):
|
||||
|
||||
@ -26,7 +26,9 @@ def _drain_transport(transport: TransportReadCloser) -> bytes:
|
||||
@pytest.fixture
|
||||
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
|
||||
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
|
||||
return LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)})
|
||||
env = LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)})
|
||||
env.open_enviroment()
|
||||
return env
|
||||
|
||||
|
||||
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):
|
||||
|
||||
@ -42,6 +42,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
self.last_execute_cwd: str | None = None
|
||||
self.released_connections: list[str] = []
|
||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||
self.open_enviroment()
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX)
|
||||
|
||||
Reference in New Issue
Block a user