refactor(sandbox): extract connection helpers and move run_command to helper module

- Add helpers.py with connection management utilities:
    - with_connection: context manager for connection lifecycle
    - submit_command: execute command and return CommandFuture
    - execute: run command with auto connection, raise on failure
    - try_execute: run command with auto connection, return result

  - Add CommandExecutionError to exec.py for typed error handling
    with access to exit_code, stderr, and full result

  - Remove run_command method from VirtualEnvironment base class
    (now available as submit_command helper)

  - Update all call sites to use new helper functions:
    - sandbox/session.py
    - sandbox/storage/archive_storage.py
    - sandbox/bash/bash_tool.py
    - workflow/nodes/command/node.py

  - Add comprehensive unit tests for helpers with connection reuse
This commit is contained in:
Harry
2026-01-14 23:23:00 +08:00
parent 31427e9c42
commit a0c388f283
19 changed files with 553 additions and 149 deletions

View File

@ -1,17 +1,17 @@
from core.sandbox.bash_tool import SandboxBashTool
from core.sandbox.constants import (
DIFY_CLI_CONFIG_PATH,
DIFY_CLI_PATH,
DIFY_CLI_PATH_PATTERN,
)
from core.sandbox.dify_cli import (
from core.sandbox.bash.bash_tool import SandboxBashTool
from core.sandbox.bash.dify_cli import (
DifyCliBinary,
DifyCliConfig,
DifyCliEnvConfig,
DifyCliLocator,
DifyCliToolConfig,
)
from core.sandbox.initializer import DifyCliInitializer, SandboxInitializer
from core.sandbox.constants import (
DIFY_CLI_CONFIG_PATH,
DIFY_CLI_PATH,
DIFY_CLI_PATH_PATTERN,
)
from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer
from core.sandbox.session import SandboxSession
__all__ = [

View File

@ -0,0 +1,17 @@
from core.sandbox.bash.bash_tool import SandboxBashTool
from core.sandbox.bash.dify_cli import (
DifyCliBinary,
DifyCliConfig,
DifyCliEnvConfig,
DifyCliLocator,
DifyCliToolConfig,
)
__all__ = [
"DifyCliBinary",
"DifyCliConfig",
"DifyCliEnvConfig",
"DifyCliLocator",
"DifyCliToolConfig",
"SandboxBashTool",
]

View File

@ -1,7 +1,7 @@
from collections.abc import Generator
from typing import Any
from core.sandbox.debug import sandbox_debug
from core.sandbox.utils.debug import sandbox_debug
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
@ -13,6 +13,7 @@ from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
from core.virtual_environment.__base.helpers import submit_command, with_connection
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
COMMAND_TIMEOUT_SECONDS = 60
@ -66,31 +67,29 @@ class SandboxBashTool(Tool):
yield self.create_text_message("Error: No command provided")
return
connection_handle = self._sandbox.establish_connection()
try:
cmd_list = ["bash", "-c", command]
with with_connection(self._sandbox) as conn:
cmd_list = ["bash", "-c", command]
sandbox_debug("bash_tool", "cmd_list", cmd_list)
future = self._sandbox.run_command(connection_handle, cmd_list)
timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None
result = future.result(timeout=timeout)
sandbox_debug("bash_tool", "cmd_list", cmd_list)
future = submit_command(self._sandbox, conn, cmd_list)
timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None
result = future.result(timeout=timeout)
stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else ""
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
exit_code = result.exit_code
stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else ""
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
exit_code = result.exit_code
output_parts: list[str] = []
if stdout:
output_parts.append(f"\n{stdout}")
if stderr:
output_parts.append(f"\n{stderr}")
output_parts.append(f"\nCommand exited with code {exit_code}")
output_parts: list[str] = []
if stdout:
output_parts.append(f"\n{stdout}")
if stderr:
output_parts.append(f"\n{stderr}")
output_parts.append(f"\nCommand exited with code {exit_code}")
yield self.create_text_message("\n".join(output_parts))
yield self.create_text_message("\n".join(output_parts))
except TimeoutError:
yield self.create_text_message(f"Error: Command timed out after {COMMAND_TIMEOUT_SECONDS}s")
except Exception as e:
yield self.create_text_message(f"Error: {e!s}")
finally:
self._sandbox.release_connection(connection_handle)

View File

@ -0,0 +1,6 @@
from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer
__all__ = [
"DifyCliInitializer",
"SandboxInitializer",
]

View File

@ -3,8 +3,9 @@ from abc import ABC, abstractmethod
from io import BytesIO
from pathlib import Path
from core.sandbox.bash.dify_cli import DifyCliLocator
from core.sandbox.constants import DIFY_CLI_PATH
from core.sandbox.dify_cli import DifyCliLocator
from core.virtual_environment.__base.helpers import execute
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
logger = logging.getLogger(__name__)
@ -23,14 +24,10 @@ class DifyCliInitializer(SandboxInitializer):
binary = self._locator.resolve(env.metadata.os, env.metadata.arch)
env.upload_file(DIFY_CLI_PATH, BytesIO(binary.path.read_bytes()))
connection_handle = env.establish_connection()
try:
future = env.run_command(connection_handle, ["chmod", "+x", DIFY_CLI_PATH])
result = future.result(timeout=10)
if result.exit_code not in (0, None):
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
raise RuntimeError(f"Failed to mark dify CLI as executable: {stderr}")
logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH)
finally:
env.release_connection(connection_handle)
execute(
env,
["chmod", "+x", DIFY_CLI_PATH],
timeout=10,
error_message="Failed to mark dify CLI as executable",
)
logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH)

View File

@ -5,13 +5,14 @@ import logging
from io import BytesIO
from types import TracebackType
from core.sandbox.bash_tool import SandboxBashTool
from core.sandbox.bash.bash_tool import SandboxBashTool
from core.sandbox.bash.dify_cli import DifyCliConfig
from core.sandbox.constants import DIFY_CLI_CONFIG_PATH, DIFY_CLI_PATH
from core.sandbox.debug import sandbox_debug
from core.sandbox.dify_cli import DifyCliConfig
from core.sandbox.manager import SandboxManager
from core.sandbox.utils.debug import sandbox_debug
from core.session.cli_api import CliApiSessionManager
from core.tools.__base.tool import Tool
from core.virtual_environment.__base.helpers import execute
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
logger = logging.getLogger(__name__)
@ -50,14 +51,12 @@ class SandboxSession:
sandbox_debug("sandbox", "config_json", config_json)
sandbox.upload_file(DIFY_CLI_CONFIG_PATH, BytesIO(config_json.encode("utf-8")))
connection_handle = sandbox.establish_connection()
try:
future = sandbox.run_command(connection_handle, [DIFY_CLI_PATH, "init"])
result = future.result(timeout=30)
if result.is_error:
raise RuntimeError(f"Failed to initialize Dify CLI in sandbox: {result.error_message}")
finally:
sandbox.release_connection(connection_handle)
execute(
sandbox,
[DIFY_CLI_PATH, "init"],
timeout=30,
error_message="Failed to initialize Dify CLI in sandbox",
)
except Exception:
CliApiSessionManager().delete(session.id)

View File

@ -2,6 +2,7 @@ import logging
from io import BytesIO
from core.sandbox.storage.sandbox_storage import SandboxStorage
from core.virtual_environment.__base.helpers import try_execute
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from extensions.ext_storage import Storage
@ -29,38 +30,25 @@ class ArchiveSandboxStorage(SandboxStorage):
archive_data = self._storage.load_once(self._storage_key)
sandbox.upload_file(ARCHIVE_NAME, BytesIO(archive_data))
connection = sandbox.establish_connection()
try:
future = sandbox.run_command(connection, ["tar", "-xzf", ARCHIVE_NAME])
result = future.result(timeout=60)
if result.is_error:
logger.error("Failed to extract archive: %s", result.error_message)
return False
finally:
sandbox.release_connection(connection)
result = try_execute(sandbox, ["tar", "-xzf", ARCHIVE_NAME], timeout=60)
if result.is_error:
logger.error("Failed to extract archive: %s", result.error_message)
return False
connection = sandbox.establish_connection()
try:
sandbox.run_command(connection, ["rm", ARCHIVE_NAME]).result(timeout=10)
finally:
sandbox.release_connection(connection)
try_execute(sandbox, ["rm", ARCHIVE_NAME], timeout=10)
logger.info("Mounted archive for sandbox %s", self._sandbox_id)
return True
def unmount(self, sandbox: VirtualEnvironment) -> bool:
connection = sandbox.establish_connection()
try:
future = sandbox.run_command(
connection,
["tar", "-czf", ARCHIVE_NAME, "-C", WORKSPACE_DIR, "."],
)
result = future.result(timeout=120)
if result.is_error:
logger.error("Failed to create archive: %s", result.error_message)
return False
finally:
sandbox.release_connection(connection)
result = try_execute(
sandbox,
["tar", "-czf", ARCHIVE_NAME, "-C", WORKSPACE_DIR, "."],
timeout=120,
)
if result.is_error:
logger.error("Failed to create archive: %s", result.error_message)
return False
archive_content = sandbox.download_file(ARCHIVE_NAME)
self._storage.save(self._storage_key, archive_content.getvalue())

View File

@ -0,0 +1,2 @@
# Sandbox utilities
# Connection helpers have been moved to core.virtual_environment.helpers

View File

@ -1,3 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.virtual_environment.__base.entities import CommandResult
class ArchNotSupportedError(Exception):
"""Exception raised when the architecture is not supported."""
@ -20,3 +28,19 @@ class SandboxConfigValidationError(ValueError):
"""Exception raised when sandbox configuration validation fails."""
pass
class CommandExecutionError(Exception):
"""Raised when a command execution fails."""
def __init__(self, message: str, result: CommandResult):
super().__init__(message)
self.result = result
@property
def exit_code(self) -> int | None:
return self.result.exit_code
@property
def stderr(self) -> bytes:
return self.result.stderr

View File

@ -0,0 +1,149 @@
from __future__ import annotations
import contextlib
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from functools import partial
from core.virtual_environment.__base.command_future import CommandFuture
from core.virtual_environment.__base.entities import CommandResult, ConnectionHandle
from core.virtual_environment.__base.exec import CommandExecutionError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
@contextmanager
def with_connection(env: VirtualEnvironment) -> Generator[ConnectionHandle, None, None]:
"""Context manager for VirtualEnvironment connection lifecycle.
Automatically establishes and releases connection handles.
Usage:
with with_connection(env) as conn:
future = run_command(env, conn, ["echo", "hello"])
result = future.result(timeout=10)
"""
connection_handle = env.establish_connection()
try:
yield connection_handle
finally:
with contextlib.suppress(Exception):
env.release_connection(connection_handle)
def submit_command(
env: VirtualEnvironment,
connection: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
*,
cwd: str | None = None,
) -> CommandFuture:
"""Execute a command and return a Future for the result.
High-level interface that handles IO draining internally.
For streaming output, use env.execute_command() instead.
Args:
env: The virtual environment to execute the command in.
connection: The connection handle.
command: Command as list of strings.
environments: Environment variables.
cwd: Working directory for the command. If None, uses the provider's default.
Returns:
CommandFuture that can be used to get result with timeout or cancel.
Example:
with with_connection(env) as conn:
result = run_command(env, conn, ["ls", "-la"]).result(timeout=30)
"""
pid, stdin_transport, stdout_transport, stderr_transport = env.execute_command(
connection, command, environments, cwd
)
return CommandFuture(
pid=pid,
stdin_transport=stdin_transport,
stdout_transport=stdout_transport,
stderr_transport=stderr_transport,
poll_status=partial(env.get_command_status, connection, pid),
)
def _execute_with_connection(
env: VirtualEnvironment,
conn: ConnectionHandle,
command: list[str],
timeout: float | None,
cwd: str | None,
) -> CommandResult:
"""Internal helper to execute command with given connection."""
future = submit_command(env, conn, command, cwd=cwd)
return future.result(timeout=timeout)
def execute(
env: VirtualEnvironment,
command: list[str],
*,
timeout: float | None = 30,
cwd: str | None = None,
error_message: str = "Command failed",
connection: ConnectionHandle | None = None,
) -> CommandResult:
"""Execute a command with automatic connection management.
Raises CommandExecutionError if the command fails (non-zero exit code).
Args:
env: The virtual environment to execute the command in.
command: The command to execute as a list of strings.
timeout: Maximum time to wait for the command to complete (seconds).
cwd: Working directory for the command.
error_message: Custom error message prefix for failures.
connection: Optional connection handle to reuse. If None, creates and releases a new connection.
Returns:
CommandResult on success.
Raises:
CommandExecutionError: If the command fails.
"""
if connection is not None:
result = _execute_with_connection(env, connection, command, timeout, cwd)
else:
with with_connection(env) as conn:
result = _execute_with_connection(env, conn, command, timeout, cwd)
if result.is_error:
raise CommandExecutionError(f"{error_message}: {result.error_message}", result)
return result
def try_execute(
env: VirtualEnvironment,
command: list[str],
*,
timeout: float | None = 30,
cwd: str | None = None,
connection: ConnectionHandle | None = None,
) -> CommandResult:
"""Execute a command with automatic connection management.
Does not raise on failure - returns the result for caller to handle.
Args:
env: The virtual environment to execute the command in.
command: The command to execute as a list of strings.
timeout: Maximum time to wait for the command to complete (seconds).
cwd: Working directory for the command.
connection: Optional connection handle to reuse. If None, creates and releases a new connection.
Returns:
CommandResult containing stdout, stderr, and exit_code.
"""
if connection is not None:
return _execute_with_connection(env, connection, command, timeout, cwd)
with with_connection(env) as conn:
return _execute_with_connection(env, conn, command, timeout, cwd)

View File

@ -1,10 +1,8 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from functools import partial
from io import BytesIO
from typing import Any
from core.virtual_environment.__base.command_future import CommandFuture
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
@ -176,40 +174,3 @@ class VirtualEnvironment(ABC):
Returns:
CommandStatus: The status of the command execution.
"""
def run_command(
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> CommandFuture:
"""
Execute a command and return a Future for the result.
High-level interface that handles IO draining internally.
For streaming output, use execute_command() instead.
Args:
connection_handle: The connection handle.
command: Command as list of strings.
environments: Environment variables.
cwd: Working directory for the command. If None, uses the provider's default.
Returns:
CommandFuture that can be used to get result with timeout or cancel.
Example:
result = env.run_command(handle, ["ls", "-la"]).result(timeout=30)
"""
pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command(
connection_handle, command, environments, cwd
)
return CommandFuture(
pid=pid,
stdin_transport=stdin_transport,
stdout_transport=stdout_transport,
stderr_transport=stderr_transport,
poll_status=partial(self.get_command_status, connection_handle, pid),
)

View File

@ -1,12 +1,12 @@
import contextlib
import logging
import shlex
from collections.abc import Mapping, Sequence
from typing import Any
from core.sandbox.debug import sandbox_debug
from core.sandbox.manager import SandboxManager
from core.sandbox.utils.debug import sandbox_debug
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
from core.virtual_environment.__base.helpers import submit_command, with_connection
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -80,42 +80,41 @@ class CommandNode(Node[CommandNodeData]):
)
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
connection_handle = sandbox.establish_connection()
try:
command = shlex.split(raw_command)
with with_connection(sandbox) as conn:
command = shlex.split(raw_command)
sandbox_debug("command_node", "command", command)
sandbox_debug("command_node", "command", command)
future = sandbox.run_command(connection_handle, command, cwd=working_directory)
result = future.result(timeout=timeout)
future = submit_command(sandbox, conn, command, cwd=working_directory)
result = future.result(timeout=timeout)
outputs: dict[str, Any] = {
"stdout": result.stdout.decode("utf-8", errors="replace"),
"stderr": result.stderr.decode("utf-8", errors="replace"),
"exit_code": result.exit_code,
"pid": result.pid,
}
process_data = {"command": command, "working_directory": working_directory}
outputs: dict[str, Any] = {
"stdout": result.stdout.decode("utf-8", errors="replace"),
"stderr": result.stderr.decode("utf-8", errors="replace"),
"exit_code": result.exit_code,
"pid": result.pid,
}
process_data = {"command": command, "working_directory": working_directory}
if result.exit_code not in (None, 0):
stderr_text = result.stderr.decode("utf-8", errors="replace")
error_message = f"{stderr_text}\n\nCommand exited with code {result.exit_code}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs=outputs,
process_data=process_data,
error=error_message,
error_type=CommandExecutionError.__name__,
)
if result.exit_code not in (None, 0):
error_message = (
f"{result.stderr.decode('utf-8', errors='replace')}\n\nCommand exited with code {result.exit_code}"
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
process_data=process_data,
error=error_message,
error_type=CommandExecutionError.__name__,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
process_data=process_data,
)
except CommandTimeoutError:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -135,9 +134,6 @@ class CommandNode(Node[CommandNodeData]):
error=str(e),
error_type=type(e).__name__,
)
finally:
with contextlib.suppress(Exception):
sandbox.release_connection(connection_handle)
@classmethod
def _extract_variable_selector_to_variable_mapping(