refactor(sandbox): extract connection helpers and move run_command to helper module

- Add helpers.py with connection management utilities:
    - with_connection: context manager for connection lifecycle
    - submit_command: execute command and return CommandFuture
    - execute: run command with auto connection, raise on failure
    - try_execute: run command with auto connection, return result

  - Add CommandExecutionError to exec.py for typed error handling
    with access to exit_code, stderr, and full result

  - Remove run_command method from VirtualEnvironment base class
    (now available as submit_command helper)

  - Update all call sites to use new helper functions:
    - sandbox/session.py
    - sandbox/storage/archive_storage.py
    - sandbox/bash/bash_tool.py
    - workflow/nodes/command/node.py

  - Add comprehensive unit tests for helpers with connection reuse
This commit is contained in:
Harry
2026-01-14 23:23:00 +08:00
parent 31427e9c42
commit a0c388f283
19 changed files with 553 additions and 149 deletions

View File

@ -1,17 +1,17 @@
from core.sandbox.bash_tool import SandboxBashTool
from core.sandbox.constants import (
DIFY_CLI_CONFIG_PATH,
DIFY_CLI_PATH,
DIFY_CLI_PATH_PATTERN,
)
from core.sandbox.dify_cli import (
from core.sandbox.bash.bash_tool import SandboxBashTool
from core.sandbox.bash.dify_cli import (
DifyCliBinary,
DifyCliConfig,
DifyCliEnvConfig,
DifyCliLocator,
DifyCliToolConfig,
)
from core.sandbox.initializer import DifyCliInitializer, SandboxInitializer
from core.sandbox.constants import (
DIFY_CLI_CONFIG_PATH,
DIFY_CLI_PATH,
DIFY_CLI_PATH_PATTERN,
)
from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer
from core.sandbox.session import SandboxSession
__all__ = [

View File

@ -0,0 +1,17 @@
from core.sandbox.bash.bash_tool import SandboxBashTool
from core.sandbox.bash.dify_cli import (
DifyCliBinary,
DifyCliConfig,
DifyCliEnvConfig,
DifyCliLocator,
DifyCliToolConfig,
)
__all__ = [
"DifyCliBinary",
"DifyCliConfig",
"DifyCliEnvConfig",
"DifyCliLocator",
"DifyCliToolConfig",
"SandboxBashTool",
]

View File

@ -1,7 +1,7 @@
from collections.abc import Generator
from typing import Any
from core.sandbox.debug import sandbox_debug
from core.sandbox.utils.debug import sandbox_debug
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
@ -13,6 +13,7 @@ from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
from core.virtual_environment.__base.helpers import submit_command, with_connection
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
COMMAND_TIMEOUT_SECONDS = 60
@ -66,31 +67,29 @@ class SandboxBashTool(Tool):
yield self.create_text_message("Error: No command provided")
return
connection_handle = self._sandbox.establish_connection()
try:
cmd_list = ["bash", "-c", command]
with with_connection(self._sandbox) as conn:
cmd_list = ["bash", "-c", command]
sandbox_debug("bash_tool", "cmd_list", cmd_list)
future = self._sandbox.run_command(connection_handle, cmd_list)
timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None
result = future.result(timeout=timeout)
sandbox_debug("bash_tool", "cmd_list", cmd_list)
future = submit_command(self._sandbox, conn, cmd_list)
timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None
result = future.result(timeout=timeout)
stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else ""
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
exit_code = result.exit_code
stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else ""
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
exit_code = result.exit_code
output_parts: list[str] = []
if stdout:
output_parts.append(f"\n{stdout}")
if stderr:
output_parts.append(f"\n{stderr}")
output_parts.append(f"\nCommand exited with code {exit_code}")
output_parts: list[str] = []
if stdout:
output_parts.append(f"\n{stdout}")
if stderr:
output_parts.append(f"\n{stderr}")
output_parts.append(f"\nCommand exited with code {exit_code}")
yield self.create_text_message("\n".join(output_parts))
yield self.create_text_message("\n".join(output_parts))
except TimeoutError:
yield self.create_text_message(f"Error: Command timed out after {COMMAND_TIMEOUT_SECONDS}s")
except Exception as e:
yield self.create_text_message(f"Error: {e!s}")
finally:
self._sandbox.release_connection(connection_handle)

View File

@ -0,0 +1,6 @@
from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer
__all__ = [
"DifyCliInitializer",
"SandboxInitializer",
]

View File

@ -3,8 +3,9 @@ from abc import ABC, abstractmethod
from io import BytesIO
from pathlib import Path
from core.sandbox.bash.dify_cli import DifyCliLocator
from core.sandbox.constants import DIFY_CLI_PATH
from core.sandbox.dify_cli import DifyCliLocator
from core.virtual_environment.__base.helpers import execute
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
logger = logging.getLogger(__name__)
@ -23,14 +24,10 @@ class DifyCliInitializer(SandboxInitializer):
binary = self._locator.resolve(env.metadata.os, env.metadata.arch)
env.upload_file(DIFY_CLI_PATH, BytesIO(binary.path.read_bytes()))
connection_handle = env.establish_connection()
try:
future = env.run_command(connection_handle, ["chmod", "+x", DIFY_CLI_PATH])
result = future.result(timeout=10)
if result.exit_code not in (0, None):
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
raise RuntimeError(f"Failed to mark dify CLI as executable: {stderr}")
logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH)
finally:
env.release_connection(connection_handle)
execute(
env,
["chmod", "+x", DIFY_CLI_PATH],
timeout=10,
error_message="Failed to mark dify CLI as executable",
)
logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH)

View File

@ -5,13 +5,14 @@ import logging
from io import BytesIO
from types import TracebackType
from core.sandbox.bash_tool import SandboxBashTool
from core.sandbox.bash.bash_tool import SandboxBashTool
from core.sandbox.bash.dify_cli import DifyCliConfig
from core.sandbox.constants import DIFY_CLI_CONFIG_PATH, DIFY_CLI_PATH
from core.sandbox.debug import sandbox_debug
from core.sandbox.dify_cli import DifyCliConfig
from core.sandbox.manager import SandboxManager
from core.sandbox.utils.debug import sandbox_debug
from core.session.cli_api import CliApiSessionManager
from core.tools.__base.tool import Tool
from core.virtual_environment.__base.helpers import execute
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
logger = logging.getLogger(__name__)
@ -50,14 +51,12 @@ class SandboxSession:
sandbox_debug("sandbox", "config_json", config_json)
sandbox.upload_file(DIFY_CLI_CONFIG_PATH, BytesIO(config_json.encode("utf-8")))
connection_handle = sandbox.establish_connection()
try:
future = sandbox.run_command(connection_handle, [DIFY_CLI_PATH, "init"])
result = future.result(timeout=30)
if result.is_error:
raise RuntimeError(f"Failed to initialize Dify CLI in sandbox: {result.error_message}")
finally:
sandbox.release_connection(connection_handle)
execute(
sandbox,
[DIFY_CLI_PATH, "init"],
timeout=30,
error_message="Failed to initialize Dify CLI in sandbox",
)
except Exception:
CliApiSessionManager().delete(session.id)

View File

@ -2,6 +2,7 @@ import logging
from io import BytesIO
from core.sandbox.storage.sandbox_storage import SandboxStorage
from core.virtual_environment.__base.helpers import try_execute
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from extensions.ext_storage import Storage
@ -29,38 +30,25 @@ class ArchiveSandboxStorage(SandboxStorage):
archive_data = self._storage.load_once(self._storage_key)
sandbox.upload_file(ARCHIVE_NAME, BytesIO(archive_data))
connection = sandbox.establish_connection()
try:
future = sandbox.run_command(connection, ["tar", "-xzf", ARCHIVE_NAME])
result = future.result(timeout=60)
if result.is_error:
logger.error("Failed to extract archive: %s", result.error_message)
return False
finally:
sandbox.release_connection(connection)
result = try_execute(sandbox, ["tar", "-xzf", ARCHIVE_NAME], timeout=60)
if result.is_error:
logger.error("Failed to extract archive: %s", result.error_message)
return False
connection = sandbox.establish_connection()
try:
sandbox.run_command(connection, ["rm", ARCHIVE_NAME]).result(timeout=10)
finally:
sandbox.release_connection(connection)
try_execute(sandbox, ["rm", ARCHIVE_NAME], timeout=10)
logger.info("Mounted archive for sandbox %s", self._sandbox_id)
return True
def unmount(self, sandbox: VirtualEnvironment) -> bool:
connection = sandbox.establish_connection()
try:
future = sandbox.run_command(
connection,
["tar", "-czf", ARCHIVE_NAME, "-C", WORKSPACE_DIR, "."],
)
result = future.result(timeout=120)
if result.is_error:
logger.error("Failed to create archive: %s", result.error_message)
return False
finally:
sandbox.release_connection(connection)
result = try_execute(
sandbox,
["tar", "-czf", ARCHIVE_NAME, "-C", WORKSPACE_DIR, "."],
timeout=120,
)
if result.is_error:
logger.error("Failed to create archive: %s", result.error_message)
return False
archive_content = sandbox.download_file(ARCHIVE_NAME)
self._storage.save(self._storage_key, archive_content.getvalue())

View File

@ -0,0 +1,2 @@
# Sandbox utilities
# Connection helpers have been moved to core.virtual_environment.helpers

View File

@ -1,3 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.virtual_environment.__base.entities import CommandResult
class ArchNotSupportedError(Exception):
"""Exception raised when the architecture is not supported."""
@ -20,3 +28,19 @@ class SandboxConfigValidationError(ValueError):
"""Exception raised when sandbox configuration validation fails."""
pass
class CommandExecutionError(Exception):
"""Raised when a command execution fails."""
def __init__(self, message: str, result: CommandResult):
super().__init__(message)
self.result = result
@property
def exit_code(self) -> int | None:
return self.result.exit_code
@property
def stderr(self) -> bytes:
return self.result.stderr

View File

@ -0,0 +1,149 @@
from __future__ import annotations
import contextlib
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from functools import partial
from core.virtual_environment.__base.command_future import CommandFuture
from core.virtual_environment.__base.entities import CommandResult, ConnectionHandle
from core.virtual_environment.__base.exec import CommandExecutionError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
@contextmanager
def with_connection(env: VirtualEnvironment) -> Generator[ConnectionHandle, None, None]:
"""Context manager for VirtualEnvironment connection lifecycle.
Automatically establishes and releases connection handles.
Usage:
with with_connection(env) as conn:
future = run_command(env, conn, ["echo", "hello"])
result = future.result(timeout=10)
"""
connection_handle = env.establish_connection()
try:
yield connection_handle
finally:
with contextlib.suppress(Exception):
env.release_connection(connection_handle)
def submit_command(
env: VirtualEnvironment,
connection: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
*,
cwd: str | None = None,
) -> CommandFuture:
"""Execute a command and return a Future for the result.
High-level interface that handles IO draining internally.
For streaming output, use env.execute_command() instead.
Args:
env: The virtual environment to execute the command in.
connection: The connection handle.
command: Command as list of strings.
environments: Environment variables.
cwd: Working directory for the command. If None, uses the provider's default.
Returns:
CommandFuture that can be used to get result with timeout or cancel.
Example:
with with_connection(env) as conn:
result = run_command(env, conn, ["ls", "-la"]).result(timeout=30)
"""
pid, stdin_transport, stdout_transport, stderr_transport = env.execute_command(
connection, command, environments, cwd
)
return CommandFuture(
pid=pid,
stdin_transport=stdin_transport,
stdout_transport=stdout_transport,
stderr_transport=stderr_transport,
poll_status=partial(env.get_command_status, connection, pid),
)
def _execute_with_connection(
env: VirtualEnvironment,
conn: ConnectionHandle,
command: list[str],
timeout: float | None,
cwd: str | None,
) -> CommandResult:
"""Internal helper to execute command with given connection."""
future = submit_command(env, conn, command, cwd=cwd)
return future.result(timeout=timeout)
def execute(
env: VirtualEnvironment,
command: list[str],
*,
timeout: float | None = 30,
cwd: str | None = None,
error_message: str = "Command failed",
connection: ConnectionHandle | None = None,
) -> CommandResult:
"""Execute a command with automatic connection management.
Raises CommandExecutionError if the command fails (non-zero exit code).
Args:
env: The virtual environment to execute the command in.
command: The command to execute as a list of strings.
timeout: Maximum time to wait for the command to complete (seconds).
cwd: Working directory for the command.
error_message: Custom error message prefix for failures.
connection: Optional connection handle to reuse. If None, creates and releases a new connection.
Returns:
CommandResult on success.
Raises:
CommandExecutionError: If the command fails.
"""
if connection is not None:
result = _execute_with_connection(env, connection, command, timeout, cwd)
else:
with with_connection(env) as conn:
result = _execute_with_connection(env, conn, command, timeout, cwd)
if result.is_error:
raise CommandExecutionError(f"{error_message}: {result.error_message}", result)
return result
def try_execute(
env: VirtualEnvironment,
command: list[str],
*,
timeout: float | None = 30,
cwd: str | None = None,
connection: ConnectionHandle | None = None,
) -> CommandResult:
"""Execute a command with automatic connection management.
Does not raise on failure - returns the result for caller to handle.
Args:
env: The virtual environment to execute the command in.
command: The command to execute as a list of strings.
timeout: Maximum time to wait for the command to complete (seconds).
cwd: Working directory for the command.
connection: Optional connection handle to reuse. If None, creates and releases a new connection.
Returns:
CommandResult containing stdout, stderr, and exit_code.
"""
if connection is not None:
return _execute_with_connection(env, connection, command, timeout, cwd)
with with_connection(env) as conn:
return _execute_with_connection(env, conn, command, timeout, cwd)

View File

@ -1,10 +1,8 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from functools import partial
from io import BytesIO
from typing import Any
from core.virtual_environment.__base.command_future import CommandFuture
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
@ -176,40 +174,3 @@ class VirtualEnvironment(ABC):
Returns:
CommandStatus: The status of the command execution.
"""
def run_command(
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> CommandFuture:
"""
Execute a command and return a Future for the result.
High-level interface that handles IO draining internally.
For streaming output, use execute_command() instead.
Args:
connection_handle: The connection handle.
command: Command as list of strings.
environments: Environment variables.
cwd: Working directory for the command. If None, uses the provider's default.
Returns:
CommandFuture that can be used to get result with timeout or cancel.
Example:
result = env.run_command(handle, ["ls", "-la"]).result(timeout=30)
"""
pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command(
connection_handle, command, environments, cwd
)
return CommandFuture(
pid=pid,
stdin_transport=stdin_transport,
stdout_transport=stdout_transport,
stderr_transport=stderr_transport,
poll_status=partial(self.get_command_status, connection_handle, pid),
)

View File

@ -1,12 +1,12 @@
import contextlib
import logging
import shlex
from collections.abc import Mapping, Sequence
from typing import Any
from core.sandbox.debug import sandbox_debug
from core.sandbox.manager import SandboxManager
from core.sandbox.utils.debug import sandbox_debug
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
from core.virtual_environment.__base.helpers import submit_command, with_connection
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -80,42 +80,41 @@ class CommandNode(Node[CommandNodeData]):
)
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
connection_handle = sandbox.establish_connection()
try:
command = shlex.split(raw_command)
with with_connection(sandbox) as conn:
command = shlex.split(raw_command)
sandbox_debug("command_node", "command", command)
sandbox_debug("command_node", "command", command)
future = sandbox.run_command(connection_handle, command, cwd=working_directory)
result = future.result(timeout=timeout)
future = submit_command(sandbox, conn, command, cwd=working_directory)
result = future.result(timeout=timeout)
outputs: dict[str, Any] = {
"stdout": result.stdout.decode("utf-8", errors="replace"),
"stderr": result.stderr.decode("utf-8", errors="replace"),
"exit_code": result.exit_code,
"pid": result.pid,
}
process_data = {"command": command, "working_directory": working_directory}
outputs: dict[str, Any] = {
"stdout": result.stdout.decode("utf-8", errors="replace"),
"stderr": result.stderr.decode("utf-8", errors="replace"),
"exit_code": result.exit_code,
"pid": result.pid,
}
process_data = {"command": command, "working_directory": working_directory}
if result.exit_code not in (None, 0):
stderr_text = result.stderr.decode("utf-8", errors="replace")
error_message = f"{stderr_text}\n\nCommand exited with code {result.exit_code}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs=outputs,
process_data=process_data,
error=error_message,
error_type=CommandExecutionError.__name__,
)
if result.exit_code not in (None, 0):
error_message = (
f"{result.stderr.decode('utf-8', errors='replace')}\n\nCommand exited with code {result.exit_code}"
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
process_data=process_data,
error=error_message,
error_type=CommandExecutionError.__name__,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
process_data=process_data,
)
except CommandTimeoutError:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -135,9 +134,6 @@ class CommandNode(Node[CommandNodeData]):
error=str(e),
error_type=type(e).__name__,
)
finally:
with contextlib.suppress(Exception):
sandbox.release_connection(connection_handle)
@classmethod
def _extract_variable_selector_to_variable_mapping(

View File

@ -5,7 +5,8 @@ import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
from ..core.app.entities.app_asset_entities import AppAssetFileTree
from core.app.entities.app_asset_entities import AppAssetFileTree
from .base import Base
from .types import LongText, StringUUID

View File

@ -19,9 +19,9 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE
from core.entities.provider_entities import BasicProviderConfig
from core.sandbox.encryption import create_sandbox_config_encrypter, masked_config
from core.sandbox.factory import VMFactory, VMType
from core.sandbox.initializer import DifyCliInitializer
from core.sandbox.utils.encryption import create_sandbox_config_encrypter, masked_config
from core.tools.utils.system_encryption import (
decrypt_system_params,
)

View File

@ -0,0 +1,264 @@
from collections.abc import Mapping
from io import BytesIO
from typing import Any
import pytest
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
ConnectionHandle,
FileState,
Metadata,
OperatingSystem,
)
from core.virtual_environment.__base.exec import CommandExecutionError
from core.virtual_environment.__base.helpers import execute, try_execute, with_connection
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
class MockReadTransport(TransportReadCloser):
"""Mock transport that returns data once then raises EOF."""
def __init__(self, data: bytes):
self._data = data
self._read = False
def read(self, n: int) -> bytes:
if self._read:
raise TransportEOFError()
self._read = True
return self._data[:n] if n < len(self._data) else self._data
def close(self) -> None:
pass
class MockWriteTransport(TransportWriteCloser):
"""Mock transport for stdin (no-op)."""
def write(self, data: bytes) -> None:
pass
def close(self) -> None:
pass
class FakeVirtualEnvironment(VirtualEnvironment):
"""Fake virtual environment for testing connection utilities."""
def __init__(
self,
*,
exit_code: int | None = 0,
stdout: bytes = b"",
stderr: bytes = b"",
):
self._exit_code = exit_code
self._stdout = stdout
self._stderr = stderr
self._connection_established = False
self._connection_released = False
self._establish_count = 0
self._release_count = 0
super().__init__(tenant_id="test-tenant", options={}, environments={})
def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata:
return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX)
def upload_file(self, _path: str, _content: BytesIO) -> None:
raise NotImplementedError
def download_file(self, _path: str) -> BytesIO:
raise NotImplementedError
def list_files(self, _directory_path: str, _limit: int) -> list[FileState]:
return []
def establish_connection(self) -> ConnectionHandle:
self._connection_established = True
self._establish_count += 1
return ConnectionHandle(id=f"test-conn-{self._establish_count}")
def release_connection(self, _connection_handle: ConnectionHandle) -> None:
self._connection_released = True
self._release_count += 1
def release_environment(self) -> None:
pass
def execute_command(
self,
_connection_handle: ConnectionHandle,
_command: list[str],
_environments: Mapping[str, str] | None = None,
_cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
"""Return mock transports for testing."""
return (
"test-pid",
MockWriteTransport(),
MockReadTransport(self._stdout),
MockReadTransport(self._stderr),
)
def get_command_status(self, _connection_handle: ConnectionHandle, _pid: str) -> CommandStatus:
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=self._exit_code)
@classmethod
def validate(cls, _options: Mapping[str, Any]) -> None:
pass
class TestWithConnection:
def test_connection_established_and_released(self):
env = FakeVirtualEnvironment()
with with_connection(env) as conn:
assert env._connection_established is True
assert conn.id == "test-conn-1"
assert env._connection_released is True
def test_connection_released_on_exception(self):
env = FakeVirtualEnvironment()
with pytest.raises(ValueError):
with with_connection(env):
raise ValueError("test error")
assert env._connection_released is True
class TestExecute:
def test_execute_success(self):
env = FakeVirtualEnvironment(exit_code=0, stdout=b"hello world")
result = execute(env, ["echo", "hello"])
assert result.stdout == b"hello world"
assert result.exit_code == 0
assert env._connection_released is True
def test_execute_raises_on_nonzero_exit_code(self):
env = FakeVirtualEnvironment(exit_code=1, stderr=b"command not found")
with pytest.raises(CommandExecutionError, match="Command failed: command not found") as exc_info:
execute(env, ["invalid-command"])
assert exc_info.value.exit_code == 1
assert exc_info.value.stderr == b"command not found"
assert env._connection_released is True
def test_execute_with_custom_error_message(self):
env = FakeVirtualEnvironment(exit_code=1, stderr=b"error")
with pytest.raises(CommandExecutionError, match="Custom error: error"):
execute(env, ["cmd"], error_message="Custom error")
def test_execute_releases_connection_on_error(self):
env = FakeVirtualEnvironment(exit_code=1, stderr=b"error")
with pytest.raises(CommandExecutionError):
execute(env, ["cmd"])
assert env._connection_released is True
class TestTryExecute:
def test_try_execute_success(self):
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
result = try_execute(env, ["echo", "test"])
assert result.stdout == b"output"
assert result.exit_code == 0
assert env._connection_released is True
def test_try_execute_returns_error_result(self):
env = FakeVirtualEnvironment(exit_code=1, stderr=b"error message")
result = try_execute(env, ["failing-command"])
assert result.exit_code == 1
assert result.stderr == b"error message"
assert result.is_error is True
assert env._connection_released is True
def test_try_execute_does_not_raise(self):
env = FakeVirtualEnvironment(exit_code=127, stderr=b"not found")
result = try_execute(env, ["nonexistent"])
assert result.exit_code == 127
assert env._connection_released is True
class TestConnectionReuse:
def test_execute_with_reused_connection(self):
"""Test that execute reuses provided connection without creating new one."""
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
with with_connection(env) as conn:
# Execute with reused connection
result = execute(env, ["cmd1"], connection=conn)
assert result.stdout == b"output"
# Should have only established one connection (from with_connection)
assert env._establish_count == 1
assert env._release_count == 0 # Not released yet
# Now connection should be released
assert env._release_count == 1
def test_execute_without_connection_creates_new(self):
"""Test that execute without connection creates and releases its own."""
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
execute(env, ["cmd1"])
assert env._establish_count == 1
assert env._release_count == 1
def test_multiple_executes_with_same_connection(self):
"""Test multiple execute calls reusing the same connection."""
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
with with_connection(env) as conn:
execute(env, ["cmd1"], connection=conn)
execute(env, ["cmd2"], connection=conn)
execute(env, ["cmd3"], connection=conn)
# Only one connection established
assert env._establish_count == 1
assert env._release_count == 0
# Released once at the end
assert env._release_count == 1
def test_try_execute_with_reused_connection(self):
"""Test that try_execute reuses provided connection."""
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
with with_connection(env) as conn:
result = try_execute(env, ["cmd1"], connection=conn)
assert result.stdout == b"output"
assert env._establish_count == 1
assert env._release_count == 0
assert env._release_count == 1
def test_mixed_execute_and_try_execute_reuse(self):
"""Test mixing execute and try_execute with same connection."""
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
with with_connection(env) as conn:
execute(env, ["cmd1"], connection=conn)
try_execute(env, ["cmd2"], connection=conn)
execute(env, ["cmd3"], connection=conn)
assert env._establish_count == 1
assert env._release_count == 1

View File

@ -3,6 +3,7 @@ from pathlib import Path
import pytest
from core.virtual_environment.__base.helpers import submit_command
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import TransportReadCloser
from core.virtual_environment.providers import local_without_isolation
@ -99,7 +100,7 @@ def test_run_command_returns_output(local_env: LocalVirtualEnvironment):
local_env.upload_file("message.txt", BytesIO(b"hello"))
connection = local_env.establish_connection()
result = local_env.run_command(connection, ["/bin/sh", "-c", "cat message.txt"]).result(timeout=10)
result = submit_command(local_env, connection, ["/bin/sh", "-c", "cat message.txt"]).result(timeout=10)
assert result.stdout == b"hello"
assert result.stderr == b""
@ -109,7 +110,7 @@ def test_run_command_returns_output(local_env: LocalVirtualEnvironment):
def test_run_command_captures_stderr(local_env: LocalVirtualEnvironment):
connection = local_env.establish_connection()
result = local_env.run_command(connection, ["/bin/sh", "-c", "echo OUT; echo ERR >&2"]).result(timeout=10)
result = submit_command(local_env, connection, ["/bin/sh", "-c", "echo OUT; echo ERR >&2"]).result(timeout=10)
assert b"OUT" in result.stdout
assert b"ERR" in result.stderr