mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
feat(sandbox): add SSH agentbox provider for middleware and docker deployments
This commit is contained in:
@ -34,6 +34,10 @@ def _get_sandbox_class(sandbox_type: SandboxType) -> type[VirtualEnvironment]:
|
||||
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||
|
||||
return LocalVirtualEnvironment
|
||||
case SandboxType.SSH:
|
||||
from core.virtual_environment.providers.ssh_sandbox import SSHSandboxEnvironment
|
||||
|
||||
return SSHSandboxEnvironment
|
||||
case _:
|
||||
raise ValueError(f"Unsupported sandbox type: {sandbox_type}")
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ class SandboxType(StrEnum):
|
||||
DOCKER = "docker"
|
||||
E2B = "e2b"
|
||||
LOCAL = "local"
|
||||
SSH = "ssh"
|
||||
|
||||
@classmethod
|
||||
def get_all(cls) -> list[str]:
|
||||
|
||||
437
api/core/virtual_environment/providers/ssh_sandbox.py
Normal file
437
api/core/virtual_environment/providers/ssh_sandbox.py
Normal file
@ -0,0 +1,437 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import shlex
|
||||
import stat
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from io import BytesIO
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.exec import SandboxConfigValidationError, VirtualEnvironmentLaunchFailedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.transport import TransportWriteCloser
|
||||
|
||||
|
||||
class _SSHStdinTransport(TransportWriteCloser):
|
||||
def __init__(self, channel: Any):
|
||||
self._channel = channel
|
||||
self._closed = False
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
if self._closed:
|
||||
raise TransportEOFError("Transport is closed")
|
||||
if not data:
|
||||
return
|
||||
self._channel.sendall(data)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
with contextlib.suppress(Exception):
|
||||
self._channel.shutdown_write()
|
||||
|
||||
|
||||
class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
_DEFAULT_SSH_HOST = "agentbox"
|
||||
_DEFAULT_SSH_PORT = 22
|
||||
_DEFAULT_BASE_WORKING_PATH = "/workspace/sandboxes"
|
||||
|
||||
class OptionsKey(StrEnum):
|
||||
SSH_HOST = "ssh_host"
|
||||
SSH_PORT = "ssh_port"
|
||||
SSH_USERNAME = "ssh_username"
|
||||
SSH_PASSWORD = "ssh_password"
|
||||
BASE_WORKING_PATH = "base_working_path"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
options: Mapping[str, Any],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
self._connections: dict[str, Any] = {}
|
||||
self._commands: dict[str, CommandStatus] = {}
|
||||
self._lock = threading.Lock()
|
||||
super().__init__(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||
return [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_HOST),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_PORT),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.SSH_USERNAME),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=cls.OptionsKey.SSH_PASSWORD),
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.TEXT_INPUT, name=cls.OptionsKey.BASE_WORKING_PATH),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
cls._require_non_empty_option(options, cls.OptionsKey.SSH_USERNAME)
|
||||
cls._require_non_empty_option(options, cls.OptionsKey.SSH_PASSWORD)
|
||||
with cls._create_ssh_client(options):
|
||||
return
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
environment_id = uuid4().hex
|
||||
working_path = self._workspace_path_from_id(environment_id)
|
||||
|
||||
try:
|
||||
with self._client() as client:
|
||||
self._run_command(client, f"mkdir -p {shlex.quote(working_path)}")
|
||||
arch_stdout = self._run_command(client, "uname -m")
|
||||
os_stdout = self._run_command(client, "uname -s")
|
||||
except Exception as e:
|
||||
raise VirtualEnvironmentLaunchFailedError(f"Failed to construct SSH environment: {e}") from e
|
||||
|
||||
return Metadata(
|
||||
id=environment_id,
|
||||
arch=self._parse_arch(arch_stdout.decode("utf-8", errors="replace").strip()),
|
||||
os=self._parse_os(os_stdout.decode("utf-8", errors="replace").strip()),
|
||||
store={"working_path": working_path},
|
||||
)
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
connection_id = uuid4().hex
|
||||
client = self._create_ssh_client(self.options)
|
||||
with self._lock:
|
||||
self._connections[connection_id] = client
|
||||
return ConnectionHandle(id=connection_id)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
with self._lock:
|
||||
client = self._connections.pop(connection_handle.id, None)
|
||||
if client is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
client.close()
|
||||
|
||||
def release_environment(self) -> None:
|
||||
working_path = self.get_working_path()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._client() as client:
|
||||
self._run_command(client, f"rm -rf {shlex.quote(working_path)}")
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, QueueTransportReadCloser, QueueTransportReadCloser]:
|
||||
client = self._get_connection(connection_handle)
|
||||
transport = client.get_transport()
|
||||
if transport is None:
|
||||
raise RuntimeError("SSH transport is not available")
|
||||
|
||||
channel = transport.open_session()
|
||||
channel.set_combine_stderr(False)
|
||||
|
||||
execution_command = self._build_exec_command(command, environments, cwd)
|
||||
channel.exec_command(execution_command)
|
||||
|
||||
pid = uuid4().hex
|
||||
stdin_transport = _SSHStdinTransport(channel)
|
||||
stdout_transport = QueueTransportReadCloser()
|
||||
stderr_transport = QueueTransportReadCloser()
|
||||
|
||||
with self._lock:
|
||||
self._commands[pid] = CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
|
||||
threading.Thread(
|
||||
target=self._consume_channel_output,
|
||||
args=(pid, channel, stdout_transport, stderr_transport),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
return pid, stdin_transport, stdout_transport, stderr_transport
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
with self._lock:
|
||||
status = self._commands.get(pid)
|
||||
if status is None:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||
return status
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
destination_path = self._workspace_path(path)
|
||||
with self._client() as client:
|
||||
sftp = client.open_sftp()
|
||||
try:
|
||||
self._sftp_mkdirs(sftp, str(PurePosixPath(destination_path).parent))
|
||||
with sftp.file(destination_path, "wb") as remote_file:
|
||||
remote_file.write(content.getvalue())
|
||||
finally:
|
||||
sftp.close()
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
source_path = self._workspace_path(path)
|
||||
with self._client() as client:
|
||||
sftp = client.open_sftp()
|
||||
try:
|
||||
with sftp.file(source_path, "rb") as remote_file:
|
||||
return BytesIO(remote_file.read())
|
||||
finally:
|
||||
sftp.close()
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
if limit <= 0:
|
||||
return []
|
||||
|
||||
root_directory = self._workspace_path(directory_path)
|
||||
files: list[FileState] = []
|
||||
|
||||
with self._client() as client:
|
||||
sftp = client.open_sftp()
|
||||
try:
|
||||
pending = [root_directory]
|
||||
while pending and len(files) < limit:
|
||||
current_directory = pending.pop(0)
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
for attr in sftp.listdir_attr(current_directory):
|
||||
current_path = str(PurePosixPath(current_directory) / attr.filename)
|
||||
mode = attr.st_mode
|
||||
if stat.S_ISDIR(mode):
|
||||
pending.append(current_path)
|
||||
continue
|
||||
|
||||
files.append(
|
||||
FileState(
|
||||
path=self._to_relative_workspace_path(current_path),
|
||||
size=attr.st_size,
|
||||
created_at=int(attr.st_mtime),
|
||||
updated_at=int(attr.st_mtime),
|
||||
)
|
||||
)
|
||||
if len(files) >= limit:
|
||||
break
|
||||
finally:
|
||||
sftp.close()
|
||||
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _require_non_empty_option(cls, options: Mapping[str, Any], key: OptionsKey) -> str:
|
||||
value = options.get(key)
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
raise SandboxConfigValidationError(f"Missing required option: {key}")
|
||||
return value.strip()
|
||||
|
||||
@classmethod
|
||||
def _create_ssh_client(cls, options: Mapping[str, Any]) -> Any:
|
||||
import paramiko
|
||||
|
||||
host = options.get(cls.OptionsKey.SSH_HOST, cls._DEFAULT_SSH_HOST)
|
||||
port = options.get(cls.OptionsKey.SSH_PORT, cls._DEFAULT_SSH_PORT)
|
||||
username = cls._require_non_empty_option(options, cls.OptionsKey.SSH_USERNAME)
|
||||
password = cls._require_non_empty_option(options, cls.OptionsKey.SSH_PASSWORD)
|
||||
|
||||
if not isinstance(host, str) or not host.strip():
|
||||
raise SandboxConfigValidationError(f"Invalid option value: {cls.OptionsKey.SSH_HOST}")
|
||||
|
||||
try:
|
||||
port_int = int(port)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise SandboxConfigValidationError(f"Invalid option value: {cls.OptionsKey.SSH_PORT}") from e
|
||||
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
try:
|
||||
client.connect(
|
||||
hostname=host.strip(),
|
||||
port=port_int,
|
||||
username=username,
|
||||
password=password,
|
||||
look_for_keys=False,
|
||||
allow_agent=False,
|
||||
timeout=10,
|
||||
)
|
||||
except Exception as e:
|
||||
with contextlib.suppress(Exception):
|
||||
client.close()
|
||||
raise SandboxConfigValidationError(f"SSH connection failed: {e}") from e
|
||||
|
||||
return client
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _client(self):
|
||||
client = self._create_ssh_client(self.options)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
client.close()
|
||||
|
||||
def _get_connection(self, connection_handle: ConnectionHandle) -> Any:
|
||||
with self._lock:
|
||||
client = self._connections.get(connection_handle.id)
|
||||
if client is None:
|
||||
raise ValueError(f"Connection handle not found: {connection_handle.id}")
|
||||
return client
|
||||
|
||||
def _workspace_path_from_id(self, environment_id: str) -> str:
|
||||
base_path = self.options.get(self.OptionsKey.BASE_WORKING_PATH, self._DEFAULT_BASE_WORKING_PATH)
|
||||
if not isinstance(base_path, str) or not base_path.strip():
|
||||
base_path = self._DEFAULT_BASE_WORKING_PATH
|
||||
return str(PurePosixPath(base_path) / environment_id)
|
||||
|
||||
def get_working_path(self) -> str:
|
||||
working_path = self.metadata.store.get("working_path")
|
||||
if not isinstance(working_path, str) or not working_path:
|
||||
return self._workspace_path_from_id(self.metadata.id)
|
||||
return working_path
|
||||
|
||||
def _workspace_path(self, path: str | None) -> str:
|
||||
if not path:
|
||||
return self.get_working_path()
|
||||
|
||||
normalized = PurePosixPath(path)
|
||||
if normalized.is_absolute():
|
||||
return str(normalized)
|
||||
return str(PurePosixPath(self.get_working_path()) / self._normalize_relative_path(path))
|
||||
|
||||
@staticmethod
|
||||
def _normalize_relative_path(path: str) -> PurePosixPath:
|
||||
parts: list[str] = []
|
||||
for part in PurePosixPath(path).parts:
|
||||
if part in ("", ".", "/"):
|
||||
continue
|
||||
if part == "..":
|
||||
if not parts:
|
||||
raise ValueError("Path escapes the workspace.")
|
||||
parts.pop()
|
||||
continue
|
||||
parts.append(part)
|
||||
return PurePosixPath(*parts)
|
||||
|
||||
def _to_relative_workspace_path(self, path: str) -> str:
|
||||
workspace = PurePosixPath(self.get_working_path())
|
||||
target = PurePosixPath(path)
|
||||
if target.is_relative_to(workspace):
|
||||
return target.relative_to(workspace).as_posix()
|
||||
return target.as_posix()
|
||||
|
||||
def _build_exec_command(
|
||||
self, command: list[str], environments: Mapping[str, str] | None = None, cwd: str | None = None
|
||||
) -> str:
|
||||
working_path = self._workspace_path(cwd)
|
||||
command_body = f"cd {shlex.quote(working_path)} && "
|
||||
|
||||
if environments:
|
||||
env_clause = " ".join(f"{key}={shlex.quote(value)}" for key, value in environments.items())
|
||||
command_body += f"{env_clause} "
|
||||
|
||||
command_body += shlex.join(command)
|
||||
return f"sh -lc {shlex.quote(command_body)}"
|
||||
|
||||
@staticmethod
|
||||
def _run_command(client: Any, command: str) -> bytes:
|
||||
_, stdout, stderr = client.exec_command(command)
|
||||
exit_code = stdout.channel.recv_exit_status()
|
||||
stdout_data = stdout.read()
|
||||
stderr_data = stderr.read()
|
||||
|
||||
if exit_code != 0:
|
||||
stderr_text = stderr_data.decode("utf-8", errors="replace")
|
||||
raise RuntimeError(f"SSH command failed ({exit_code}): {stderr_text}")
|
||||
|
||||
return stdout_data
|
||||
|
||||
def _consume_channel_output(
|
||||
self,
|
||||
pid: str,
|
||||
channel: Any,
|
||||
stdout_transport: QueueTransportReadCloser,
|
||||
stderr_transport: QueueTransportReadCloser,
|
||||
) -> None:
|
||||
stdout_writer = stdout_transport.get_write_handler()
|
||||
stderr_writer = stderr_transport.get_write_handler()
|
||||
exit_code: int | None = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
if channel.recv_ready():
|
||||
stdout_writer.write(channel.recv(4096))
|
||||
if channel.recv_stderr_ready():
|
||||
stderr_writer.write(channel.recv_stderr(4096))
|
||||
|
||||
if channel.exit_status_ready() and not channel.recv_ready() and not channel.recv_stderr_ready():
|
||||
exit_code = int(channel.recv_exit_status())
|
||||
break
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
stdout_transport.close()
|
||||
with contextlib.suppress(Exception):
|
||||
stderr_transport.close()
|
||||
with contextlib.suppress(Exception):
|
||||
channel.close()
|
||||
|
||||
with self._lock:
|
||||
self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
|
||||
@staticmethod
|
||||
def _parse_arch(raw_arch: str) -> Arch:
|
||||
arch = raw_arch.lower()
|
||||
if arch in {"x86_64", "amd64"}:
|
||||
return Arch.AMD64
|
||||
if arch in {"arm64", "aarch64"}:
|
||||
return Arch.ARM64
|
||||
return Arch.AMD64
|
||||
|
||||
@staticmethod
|
||||
def _parse_os(raw_os: str) -> OperatingSystem:
|
||||
system_name = raw_os.lower()
|
||||
if system_name == "darwin":
|
||||
return OperatingSystem.DARWIN
|
||||
return OperatingSystem.LINUX
|
||||
|
||||
@staticmethod
|
||||
def _sftp_mkdirs(sftp: Any, directory: str) -> None:
|
||||
if not directory or directory == "/":
|
||||
return
|
||||
|
||||
path = PurePosixPath(directory)
|
||||
current = PurePosixPath("/") if path.is_absolute() else PurePosixPath()
|
||||
|
||||
for part in path.parts:
|
||||
if part in ("", "/"):
|
||||
continue
|
||||
current = current / part
|
||||
current_path = str(current)
|
||||
try:
|
||||
attrs = sftp.stat(current_path)
|
||||
if not stat.S_ISDIR(attrs.st_mode):
|
||||
raise OSError(f"Path exists but is not a directory: {current_path}")
|
||||
continue
|
||||
except OSError as e:
|
||||
missing = isinstance(e, FileNotFoundError) or getattr(e, "errno", None) == 2
|
||||
missing = missing or "no such file" in str(e).lower()
|
||||
if not missing:
|
||||
raise
|
||||
|
||||
try:
|
||||
sftp.mkdir(current_path)
|
||||
except OSError:
|
||||
# Some SFTP servers report generic "Failure" when directory already exists.
|
||||
attrs = sftp.stat(current_path)
|
||||
if not stat.S_ISDIR(attrs.st_mode):
|
||||
raise OSError(f"Failed to create directory: {current_path}")
|
||||
Reference in New Issue
Block a user