mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
refactor(sandbox): extract connection helpers and move run_command to helper module
- Add helpers.py with connection management utilities:
- with_connection: context manager for connection lifecycle
- submit_command: execute command and return CommandFuture
- execute: run command with auto connection, raise on failure
- try_execute: run command with auto connection, return result
- Add CommandExecutionError to exec.py for typed error handling
with access to exit_code, stderr, and full result
- Remove run_command method from VirtualEnvironment base class
(now available as submit_command helper)
- Update all call sites to use new helper functions:
- sandbox/session.py
- sandbox/storage/archive_storage.py
- sandbox/bash/bash_tool.py
- workflow/nodes/command/node.py
- Add comprehensive unit tests for helpers with connection reuse
This commit is contained in:
@ -1,17 +1,17 @@
|
||||
from core.sandbox.bash_tool import SandboxBashTool
|
||||
from core.sandbox.constants import (
|
||||
DIFY_CLI_CONFIG_PATH,
|
||||
DIFY_CLI_PATH,
|
||||
DIFY_CLI_PATH_PATTERN,
|
||||
)
|
||||
from core.sandbox.dify_cli import (
|
||||
from core.sandbox.bash.bash_tool import SandboxBashTool
|
||||
from core.sandbox.bash.dify_cli import (
|
||||
DifyCliBinary,
|
||||
DifyCliConfig,
|
||||
DifyCliEnvConfig,
|
||||
DifyCliLocator,
|
||||
DifyCliToolConfig,
|
||||
)
|
||||
from core.sandbox.initializer import DifyCliInitializer, SandboxInitializer
|
||||
from core.sandbox.constants import (
|
||||
DIFY_CLI_CONFIG_PATH,
|
||||
DIFY_CLI_PATH,
|
||||
DIFY_CLI_PATH_PATTERN,
|
||||
)
|
||||
from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer
|
||||
from core.sandbox.session import SandboxSession
|
||||
|
||||
__all__ = [
|
||||
|
||||
17
api/core/sandbox/bash/__init__.py
Normal file
17
api/core/sandbox/bash/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from core.sandbox.bash.bash_tool import SandboxBashTool
|
||||
from core.sandbox.bash.dify_cli import (
|
||||
DifyCliBinary,
|
||||
DifyCliConfig,
|
||||
DifyCliEnvConfig,
|
||||
DifyCliLocator,
|
||||
DifyCliToolConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DifyCliBinary",
|
||||
"DifyCliConfig",
|
||||
"DifyCliEnvConfig",
|
||||
"DifyCliLocator",
|
||||
"DifyCliToolConfig",
|
||||
"SandboxBashTool",
|
||||
]
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.sandbox.debug import sandbox_debug
|
||||
from core.sandbox.utils.debug import sandbox_debug
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -13,6 +13,7 @@ from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.virtual_environment.__base.helpers import submit_command, with_connection
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
COMMAND_TIMEOUT_SECONDS = 60
|
||||
@ -66,31 +67,29 @@ class SandboxBashTool(Tool):
|
||||
yield self.create_text_message("Error: No command provided")
|
||||
return
|
||||
|
||||
connection_handle = self._sandbox.establish_connection()
|
||||
try:
|
||||
cmd_list = ["bash", "-c", command]
|
||||
with with_connection(self._sandbox) as conn:
|
||||
cmd_list = ["bash", "-c", command]
|
||||
|
||||
sandbox_debug("bash_tool", "cmd_list", cmd_list)
|
||||
future = self._sandbox.run_command(connection_handle, cmd_list)
|
||||
timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None
|
||||
result = future.result(timeout=timeout)
|
||||
sandbox_debug("bash_tool", "cmd_list", cmd_list)
|
||||
future = submit_command(self._sandbox, conn, cmd_list)
|
||||
timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else ""
|
||||
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
|
||||
exit_code = result.exit_code
|
||||
stdout = result.stdout.decode("utf-8", errors="replace") if result.stdout else ""
|
||||
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
|
||||
exit_code = result.exit_code
|
||||
|
||||
output_parts: list[str] = []
|
||||
if stdout:
|
||||
output_parts.append(f"\n{stdout}")
|
||||
if stderr:
|
||||
output_parts.append(f"\n{stderr}")
|
||||
output_parts.append(f"\nCommand exited with code {exit_code}")
|
||||
output_parts: list[str] = []
|
||||
if stdout:
|
||||
output_parts.append(f"\n{stdout}")
|
||||
if stderr:
|
||||
output_parts.append(f"\n{stderr}")
|
||||
output_parts.append(f"\nCommand exited with code {exit_code}")
|
||||
|
||||
yield self.create_text_message("\n".join(output_parts))
|
||||
yield self.create_text_message("\n".join(output_parts))
|
||||
|
||||
except TimeoutError:
|
||||
yield self.create_text_message(f"Error: Command timed out after {COMMAND_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
yield self.create_text_message(f"Error: {e!s}")
|
||||
finally:
|
||||
self._sandbox.release_connection(connection_handle)
|
||||
6
api/core/sandbox/initializer/__init__.py
Normal file
6
api/core/sandbox/initializer/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from core.sandbox.initializer.initializer import DifyCliInitializer, SandboxInitializer
|
||||
|
||||
__all__ = [
|
||||
"DifyCliInitializer",
|
||||
"SandboxInitializer",
|
||||
]
|
||||
@ -3,8 +3,9 @@ from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from core.sandbox.bash.dify_cli import DifyCliLocator
|
||||
from core.sandbox.constants import DIFY_CLI_PATH
|
||||
from core.sandbox.dify_cli import DifyCliLocator
|
||||
from core.virtual_environment.__base.helpers import execute
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -23,14 +24,10 @@ class DifyCliInitializer(SandboxInitializer):
|
||||
binary = self._locator.resolve(env.metadata.os, env.metadata.arch)
|
||||
env.upload_file(DIFY_CLI_PATH, BytesIO(binary.path.read_bytes()))
|
||||
|
||||
connection_handle = env.establish_connection()
|
||||
try:
|
||||
future = env.run_command(connection_handle, ["chmod", "+x", DIFY_CLI_PATH])
|
||||
result = future.result(timeout=10)
|
||||
if result.exit_code not in (0, None):
|
||||
stderr = result.stderr.decode("utf-8", errors="replace") if result.stderr else ""
|
||||
raise RuntimeError(f"Failed to mark dify CLI as executable: {stderr}")
|
||||
|
||||
logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH)
|
||||
finally:
|
||||
env.release_connection(connection_handle)
|
||||
execute(
|
||||
env,
|
||||
["chmod", "+x", DIFY_CLI_PATH],
|
||||
timeout=10,
|
||||
error_message="Failed to mark dify CLI as executable",
|
||||
)
|
||||
logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH)
|
||||
@ -5,13 +5,14 @@ import logging
|
||||
from io import BytesIO
|
||||
from types import TracebackType
|
||||
|
||||
from core.sandbox.bash_tool import SandboxBashTool
|
||||
from core.sandbox.bash.bash_tool import SandboxBashTool
|
||||
from core.sandbox.bash.dify_cli import DifyCliConfig
|
||||
from core.sandbox.constants import DIFY_CLI_CONFIG_PATH, DIFY_CLI_PATH
|
||||
from core.sandbox.debug import sandbox_debug
|
||||
from core.sandbox.dify_cli import DifyCliConfig
|
||||
from core.sandbox.manager import SandboxManager
|
||||
from core.sandbox.utils.debug import sandbox_debug
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.virtual_environment.__base.helpers import execute
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -50,14 +51,12 @@ class SandboxSession:
|
||||
sandbox_debug("sandbox", "config_json", config_json)
|
||||
sandbox.upload_file(DIFY_CLI_CONFIG_PATH, BytesIO(config_json.encode("utf-8")))
|
||||
|
||||
connection_handle = sandbox.establish_connection()
|
||||
try:
|
||||
future = sandbox.run_command(connection_handle, [DIFY_CLI_PATH, "init"])
|
||||
result = future.result(timeout=30)
|
||||
if result.is_error:
|
||||
raise RuntimeError(f"Failed to initialize Dify CLI in sandbox: {result.error_message}")
|
||||
finally:
|
||||
sandbox.release_connection(connection_handle)
|
||||
execute(
|
||||
sandbox,
|
||||
[DIFY_CLI_PATH, "init"],
|
||||
timeout=30,
|
||||
error_message="Failed to initialize Dify CLI in sandbox",
|
||||
)
|
||||
|
||||
except Exception:
|
||||
CliApiSessionManager().delete(session.id)
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
from io import BytesIO
|
||||
|
||||
from core.sandbox.storage.sandbox_storage import SandboxStorage
|
||||
from core.virtual_environment.__base.helpers import try_execute
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from extensions.ext_storage import Storage
|
||||
|
||||
@ -29,38 +30,25 @@ class ArchiveSandboxStorage(SandboxStorage):
|
||||
archive_data = self._storage.load_once(self._storage_key)
|
||||
sandbox.upload_file(ARCHIVE_NAME, BytesIO(archive_data))
|
||||
|
||||
connection = sandbox.establish_connection()
|
||||
try:
|
||||
future = sandbox.run_command(connection, ["tar", "-xzf", ARCHIVE_NAME])
|
||||
result = future.result(timeout=60)
|
||||
if result.is_error:
|
||||
logger.error("Failed to extract archive: %s", result.error_message)
|
||||
return False
|
||||
finally:
|
||||
sandbox.release_connection(connection)
|
||||
result = try_execute(sandbox, ["tar", "-xzf", ARCHIVE_NAME], timeout=60)
|
||||
if result.is_error:
|
||||
logger.error("Failed to extract archive: %s", result.error_message)
|
||||
return False
|
||||
|
||||
connection = sandbox.establish_connection()
|
||||
try:
|
||||
sandbox.run_command(connection, ["rm", ARCHIVE_NAME]).result(timeout=10)
|
||||
finally:
|
||||
sandbox.release_connection(connection)
|
||||
try_execute(sandbox, ["rm", ARCHIVE_NAME], timeout=10)
|
||||
|
||||
logger.info("Mounted archive for sandbox %s", self._sandbox_id)
|
||||
return True
|
||||
|
||||
def unmount(self, sandbox: VirtualEnvironment) -> bool:
|
||||
connection = sandbox.establish_connection()
|
||||
try:
|
||||
future = sandbox.run_command(
|
||||
connection,
|
||||
["tar", "-czf", ARCHIVE_NAME, "-C", WORKSPACE_DIR, "."],
|
||||
)
|
||||
result = future.result(timeout=120)
|
||||
if result.is_error:
|
||||
logger.error("Failed to create archive: %s", result.error_message)
|
||||
return False
|
||||
finally:
|
||||
sandbox.release_connection(connection)
|
||||
result = try_execute(
|
||||
sandbox,
|
||||
["tar", "-czf", ARCHIVE_NAME, "-C", WORKSPACE_DIR, "."],
|
||||
timeout=120,
|
||||
)
|
||||
if result.is_error:
|
||||
logger.error("Failed to create archive: %s", result.error_message)
|
||||
return False
|
||||
|
||||
archive_content = sandbox.download_file(ARCHIVE_NAME)
|
||||
self._storage.save(self._storage_key, archive_content.getvalue())
|
||||
|
||||
2
api/core/sandbox/utils/__init__.py
Normal file
2
api/core/sandbox/utils/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Sandbox utilities
|
||||
# Connection helpers have been moved to core.virtual_environment.helpers
|
||||
@ -1,3 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.virtual_environment.__base.entities import CommandResult
|
||||
|
||||
|
||||
class ArchNotSupportedError(Exception):
|
||||
"""Exception raised when the architecture is not supported."""
|
||||
|
||||
@ -20,3 +28,19 @@ class SandboxConfigValidationError(ValueError):
|
||||
"""Exception raised when sandbox configuration validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CommandExecutionError(Exception):
|
||||
"""Raised when a command execution fails."""
|
||||
|
||||
def __init__(self, message: str, result: CommandResult):
|
||||
super().__init__(message)
|
||||
self.result = result
|
||||
|
||||
@property
|
||||
def exit_code(self) -> int | None:
|
||||
return self.result.exit_code
|
||||
|
||||
@property
|
||||
def stderr(self) -> bytes:
|
||||
return self.result.stderr
|
||||
|
||||
149
api/core/virtual_environment/__base/helpers.py
Normal file
149
api/core/virtual_environment/__base/helpers.py
Normal file
@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
from core.virtual_environment.__base.command_future import CommandFuture
|
||||
from core.virtual_environment.__base.entities import CommandResult, ConnectionHandle
|
||||
from core.virtual_environment.__base.exec import CommandExecutionError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_connection(env: VirtualEnvironment) -> Generator[ConnectionHandle, None, None]:
|
||||
"""Context manager for VirtualEnvironment connection lifecycle.
|
||||
|
||||
Automatically establishes and releases connection handles.
|
||||
|
||||
Usage:
|
||||
with with_connection(env) as conn:
|
||||
future = run_command(env, conn, ["echo", "hello"])
|
||||
result = future.result(timeout=10)
|
||||
"""
|
||||
connection_handle = env.establish_connection()
|
||||
try:
|
||||
yield connection_handle
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
env.release_connection(connection_handle)
|
||||
|
||||
|
||||
def submit_command(
|
||||
env: VirtualEnvironment,
|
||||
connection: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
*,
|
||||
cwd: str | None = None,
|
||||
) -> CommandFuture:
|
||||
"""Execute a command and return a Future for the result.
|
||||
|
||||
High-level interface that handles IO draining internally.
|
||||
For streaming output, use env.execute_command() instead.
|
||||
|
||||
Args:
|
||||
env: The virtual environment to execute the command in.
|
||||
connection: The connection handle.
|
||||
command: Command as list of strings.
|
||||
environments: Environment variables.
|
||||
cwd: Working directory for the command. If None, uses the provider's default.
|
||||
|
||||
Returns:
|
||||
CommandFuture that can be used to get result with timeout or cancel.
|
||||
|
||||
Example:
|
||||
with with_connection(env) as conn:
|
||||
result = run_command(env, conn, ["ls", "-la"]).result(timeout=30)
|
||||
"""
|
||||
pid, stdin_transport, stdout_transport, stderr_transport = env.execute_command(
|
||||
connection, command, environments, cwd
|
||||
)
|
||||
|
||||
return CommandFuture(
|
||||
pid=pid,
|
||||
stdin_transport=stdin_transport,
|
||||
stdout_transport=stdout_transport,
|
||||
stderr_transport=stderr_transport,
|
||||
poll_status=partial(env.get_command_status, connection, pid),
|
||||
)
|
||||
|
||||
|
||||
def _execute_with_connection(
|
||||
env: VirtualEnvironment,
|
||||
conn: ConnectionHandle,
|
||||
command: list[str],
|
||||
timeout: float | None,
|
||||
cwd: str | None,
|
||||
) -> CommandResult:
|
||||
"""Internal helper to execute command with given connection."""
|
||||
future = submit_command(env, conn, command, cwd=cwd)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
|
||||
def execute(
|
||||
env: VirtualEnvironment,
|
||||
command: list[str],
|
||||
*,
|
||||
timeout: float | None = 30,
|
||||
cwd: str | None = None,
|
||||
error_message: str = "Command failed",
|
||||
connection: ConnectionHandle | None = None,
|
||||
) -> CommandResult:
|
||||
"""Execute a command with automatic connection management.
|
||||
|
||||
Raises CommandExecutionError if the command fails (non-zero exit code).
|
||||
|
||||
Args:
|
||||
env: The virtual environment to execute the command in.
|
||||
command: The command to execute as a list of strings.
|
||||
timeout: Maximum time to wait for the command to complete (seconds).
|
||||
cwd: Working directory for the command.
|
||||
error_message: Custom error message prefix for failures.
|
||||
connection: Optional connection handle to reuse. If None, creates and releases a new connection.
|
||||
|
||||
Returns:
|
||||
CommandResult on success.
|
||||
|
||||
Raises:
|
||||
CommandExecutionError: If the command fails.
|
||||
"""
|
||||
if connection is not None:
|
||||
result = _execute_with_connection(env, connection, command, timeout, cwd)
|
||||
else:
|
||||
with with_connection(env) as conn:
|
||||
result = _execute_with_connection(env, conn, command, timeout, cwd)
|
||||
|
||||
if result.is_error:
|
||||
raise CommandExecutionError(f"{error_message}: {result.error_message}", result)
|
||||
return result
|
||||
|
||||
|
||||
def try_execute(
|
||||
env: VirtualEnvironment,
|
||||
command: list[str],
|
||||
*,
|
||||
timeout: float | None = 30,
|
||||
cwd: str | None = None,
|
||||
connection: ConnectionHandle | None = None,
|
||||
) -> CommandResult:
|
||||
"""Execute a command with automatic connection management.
|
||||
|
||||
Does not raise on failure - returns the result for caller to handle.
|
||||
|
||||
Args:
|
||||
env: The virtual environment to execute the command in.
|
||||
command: The command to execute as a list of strings.
|
||||
timeout: Maximum time to wait for the command to complete (seconds).
|
||||
cwd: Working directory for the command.
|
||||
connection: Optional connection handle to reuse. If None, creates and releases a new connection.
|
||||
|
||||
Returns:
|
||||
CommandResult containing stdout, stderr, and exit_code.
|
||||
"""
|
||||
if connection is not None:
|
||||
return _execute_with_connection(env, connection, command, timeout, cwd)
|
||||
|
||||
with with_connection(env) as conn:
|
||||
return _execute_with_connection(env, conn, command, timeout, cwd)
|
||||
@ -1,10 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.__base.command_future import CommandFuture
|
||||
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
@ -176,40 +174,3 @@ class VirtualEnvironment(ABC):
|
||||
Returns:
|
||||
CommandStatus: The status of the command execution.
|
||||
"""
|
||||
|
||||
def run_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> CommandFuture:
|
||||
"""
|
||||
Execute a command and return a Future for the result.
|
||||
|
||||
High-level interface that handles IO draining internally.
|
||||
For streaming output, use execute_command() instead.
|
||||
|
||||
Args:
|
||||
connection_handle: The connection handle.
|
||||
command: Command as list of strings.
|
||||
environments: Environment variables.
|
||||
cwd: Working directory for the command. If None, uses the provider's default.
|
||||
|
||||
Returns:
|
||||
CommandFuture that can be used to get result with timeout or cancel.
|
||||
|
||||
Example:
|
||||
result = env.run_command(handle, ["ls", "-la"]).result(timeout=30)
|
||||
"""
|
||||
pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command(
|
||||
connection_handle, command, environments, cwd
|
||||
)
|
||||
|
||||
return CommandFuture(
|
||||
pid=pid,
|
||||
stdin_transport=stdin_transport,
|
||||
stdout_transport=stdout_transport,
|
||||
stderr_transport=stderr_transport,
|
||||
poll_status=partial(self.get_command_status, connection_handle, pid),
|
||||
)
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import shlex
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.sandbox.debug import sandbox_debug
|
||||
from core.sandbox.manager import SandboxManager
|
||||
from core.sandbox.utils.debug import sandbox_debug
|
||||
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
|
||||
from core.virtual_environment.__base.helpers import submit_command, with_connection
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
@ -80,42 +80,41 @@ class CommandNode(Node[CommandNodeData]):
|
||||
)
|
||||
|
||||
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
|
||||
connection_handle = sandbox.establish_connection()
|
||||
|
||||
try:
|
||||
command = shlex.split(raw_command)
|
||||
with with_connection(sandbox) as conn:
|
||||
command = shlex.split(raw_command)
|
||||
|
||||
sandbox_debug("command_node", "command", command)
|
||||
sandbox_debug("command_node", "command", command)
|
||||
|
||||
future = sandbox.run_command(connection_handle, command, cwd=working_directory)
|
||||
result = future.result(timeout=timeout)
|
||||
future = submit_command(sandbox, conn, command, cwd=working_directory)
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
outputs: dict[str, Any] = {
|
||||
"stdout": result.stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": result.stderr.decode("utf-8", errors="replace"),
|
||||
"exit_code": result.exit_code,
|
||||
"pid": result.pid,
|
||||
}
|
||||
process_data = {"command": command, "working_directory": working_directory}
|
||||
outputs: dict[str, Any] = {
|
||||
"stdout": result.stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": result.stderr.decode("utf-8", errors="replace"),
|
||||
"exit_code": result.exit_code,
|
||||
"pid": result.pid,
|
||||
}
|
||||
process_data = {"command": command, "working_directory": working_directory}
|
||||
|
||||
if result.exit_code not in (None, 0):
|
||||
stderr_text = result.stderr.decode("utf-8", errors="replace")
|
||||
error_message = f"{stderr_text}\n\nCommand exited with code {result.exit_code}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error_message,
|
||||
error_type=CommandExecutionError.__name__,
|
||||
)
|
||||
|
||||
if result.exit_code not in (None, 0):
|
||||
error_message = (
|
||||
f"{result.stderr.decode('utf-8', errors='replace')}\n\nCommand exited with code {result.exit_code}"
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error_message,
|
||||
error_type=CommandExecutionError.__name__,
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
except CommandTimeoutError:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -135,9 +134,6 @@ class CommandNode(Node[CommandNodeData]):
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
sandbox.release_connection(connection_handle)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
||||
@ -5,7 +5,8 @@ import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from ..core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
|
||||
from .base import Base
|
||||
from .types import LongText, StringUUID
|
||||
|
||||
|
||||
@ -19,9 +19,9 @@ from sqlalchemy.orm import Session
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.sandbox.encryption import create_sandbox_config_encrypter, masked_config
|
||||
from core.sandbox.factory import VMFactory, VMType
|
||||
from core.sandbox.initializer import DifyCliInitializer
|
||||
from core.sandbox.utils.encryption import create_sandbox_config_encrypter, masked_config
|
||||
from core.tools.utils.system_encryption import (
|
||||
decrypt_system_params,
|
||||
)
|
||||
|
||||
264
api/tests/unit_tests/core/virtual_environment/test_helpers.py
Normal file
264
api/tests/unit_tests/core/virtual_environment/test_helpers.py
Normal file
@ -0,0 +1,264 @@
|
||||
from collections.abc import Mapping
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.exec import CommandExecutionError
|
||||
from core.virtual_environment.__base.helpers import execute, try_execute, with_connection
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
|
||||
class MockReadTransport(TransportReadCloser):
|
||||
"""Mock transport that returns data once then raises EOF."""
|
||||
|
||||
def __init__(self, data: bytes):
|
||||
self._data = data
|
||||
self._read = False
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
if self._read:
|
||||
raise TransportEOFError()
|
||||
self._read = True
|
||||
return self._data[:n] if n < len(self._data) else self._data
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MockWriteTransport(TransportWriteCloser):
|
||||
"""Mock transport for stdin (no-op)."""
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
"""Fake virtual environment for testing connection utilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
exit_code: int | None = 0,
|
||||
stdout: bytes = b"",
|
||||
stderr: bytes = b"",
|
||||
):
|
||||
self._exit_code = exit_code
|
||||
self._stdout = stdout
|
||||
self._stderr = stderr
|
||||
self._connection_established = False
|
||||
self._connection_released = False
|
||||
self._establish_count = 0
|
||||
self._release_count = 0
|
||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||
|
||||
def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata:
|
||||
return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX)
|
||||
|
||||
def upload_file(self, _path: str, _content: BytesIO) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def download_file(self, _path: str) -> BytesIO:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_files(self, _directory_path: str, _limit: int) -> list[FileState]:
|
||||
return []
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
self._connection_established = True
|
||||
self._establish_count += 1
|
||||
return ConnectionHandle(id=f"test-conn-{self._establish_count}")
|
||||
|
||||
def release_connection(self, _connection_handle: ConnectionHandle) -> None:
|
||||
self._connection_released = True
|
||||
self._release_count += 1
|
||||
|
||||
def release_environment(self) -> None:
|
||||
pass
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
_connection_handle: ConnectionHandle,
|
||||
_command: list[str],
|
||||
_environments: Mapping[str, str] | None = None,
|
||||
_cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""Return mock transports for testing."""
|
||||
return (
|
||||
"test-pid",
|
||||
MockWriteTransport(),
|
||||
MockReadTransport(self._stdout),
|
||||
MockReadTransport(self._stderr),
|
||||
)
|
||||
|
||||
def get_command_status(self, _connection_handle: ConnectionHandle, _pid: str) -> CommandStatus:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=self._exit_code)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, _options: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class TestWithConnection:
|
||||
def test_connection_established_and_released(self):
|
||||
env = FakeVirtualEnvironment()
|
||||
|
||||
with with_connection(env) as conn:
|
||||
assert env._connection_established is True
|
||||
assert conn.id == "test-conn-1"
|
||||
|
||||
assert env._connection_released is True
|
||||
|
||||
def test_connection_released_on_exception(self):
|
||||
env = FakeVirtualEnvironment()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with with_connection(env):
|
||||
raise ValueError("test error")
|
||||
|
||||
assert env._connection_released is True
|
||||
|
||||
|
||||
class TestExecute:
|
||||
def test_execute_success(self):
|
||||
env = FakeVirtualEnvironment(exit_code=0, stdout=b"hello world")
|
||||
|
||||
result = execute(env, ["echo", "hello"])
|
||||
|
||||
assert result.stdout == b"hello world"
|
||||
assert result.exit_code == 0
|
||||
assert env._connection_released is True
|
||||
|
||||
def test_execute_raises_on_nonzero_exit_code(self):
|
||||
env = FakeVirtualEnvironment(exit_code=1, stderr=b"command not found")
|
||||
|
||||
with pytest.raises(CommandExecutionError, match="Command failed: command not found") as exc_info:
|
||||
execute(env, ["invalid-command"])
|
||||
|
||||
assert exc_info.value.exit_code == 1
|
||||
assert exc_info.value.stderr == b"command not found"
|
||||
assert env._connection_released is True
|
||||
|
||||
def test_execute_with_custom_error_message(self):
|
||||
env = FakeVirtualEnvironment(exit_code=1, stderr=b"error")
|
||||
|
||||
with pytest.raises(CommandExecutionError, match="Custom error: error"):
|
||||
execute(env, ["cmd"], error_message="Custom error")
|
||||
|
||||
def test_execute_releases_connection_on_error(self):
|
||||
env = FakeVirtualEnvironment(exit_code=1, stderr=b"error")
|
||||
|
||||
with pytest.raises(CommandExecutionError):
|
||||
execute(env, ["cmd"])
|
||||
|
||||
assert env._connection_released is True
|
||||
|
||||
|
||||
class TestTryExecute:
|
||||
def test_try_execute_success(self):
|
||||
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
|
||||
|
||||
result = try_execute(env, ["echo", "test"])
|
||||
|
||||
assert result.stdout == b"output"
|
||||
assert result.exit_code == 0
|
||||
assert env._connection_released is True
|
||||
|
||||
def test_try_execute_returns_error_result(self):
|
||||
env = FakeVirtualEnvironment(exit_code=1, stderr=b"error message")
|
||||
|
||||
result = try_execute(env, ["failing-command"])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert result.stderr == b"error message"
|
||||
assert result.is_error is True
|
||||
assert env._connection_released is True
|
||||
|
||||
def test_try_execute_does_not_raise(self):
|
||||
env = FakeVirtualEnvironment(exit_code=127, stderr=b"not found")
|
||||
|
||||
result = try_execute(env, ["nonexistent"])
|
||||
|
||||
assert result.exit_code == 127
|
||||
assert env._connection_released is True
|
||||
|
||||
|
||||
class TestConnectionReuse:
|
||||
def test_execute_with_reused_connection(self):
|
||||
"""Test that execute reuses provided connection without creating new one."""
|
||||
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
|
||||
|
||||
with with_connection(env) as conn:
|
||||
# Execute with reused connection
|
||||
result = execute(env, ["cmd1"], connection=conn)
|
||||
assert result.stdout == b"output"
|
||||
|
||||
# Should have only established one connection (from with_connection)
|
||||
assert env._establish_count == 1
|
||||
assert env._release_count == 0 # Not released yet
|
||||
|
||||
# Now connection should be released
|
||||
assert env._release_count == 1
|
||||
|
||||
def test_execute_without_connection_creates_new(self):
|
||||
"""Test that execute without connection creates and releases its own."""
|
||||
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
|
||||
|
||||
execute(env, ["cmd1"])
|
||||
|
||||
assert env._establish_count == 1
|
||||
assert env._release_count == 1
|
||||
|
||||
def test_multiple_executes_with_same_connection(self):
|
||||
"""Test multiple execute calls reusing the same connection."""
|
||||
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
|
||||
|
||||
with with_connection(env) as conn:
|
||||
execute(env, ["cmd1"], connection=conn)
|
||||
execute(env, ["cmd2"], connection=conn)
|
||||
execute(env, ["cmd3"], connection=conn)
|
||||
|
||||
# Only one connection established
|
||||
assert env._establish_count == 1
|
||||
assert env._release_count == 0
|
||||
|
||||
# Released once at the end
|
||||
assert env._release_count == 1
|
||||
|
||||
def test_try_execute_with_reused_connection(self):
|
||||
"""Test that try_execute reuses provided connection."""
|
||||
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
|
||||
|
||||
with with_connection(env) as conn:
|
||||
result = try_execute(env, ["cmd1"], connection=conn)
|
||||
assert result.stdout == b"output"
|
||||
assert env._establish_count == 1
|
||||
assert env._release_count == 0
|
||||
|
||||
assert env._release_count == 1
|
||||
|
||||
def test_mixed_execute_and_try_execute_reuse(self):
|
||||
"""Test mixing execute and try_execute with same connection."""
|
||||
env = FakeVirtualEnvironment(exit_code=0, stdout=b"output")
|
||||
|
||||
with with_connection(env) as conn:
|
||||
execute(env, ["cmd1"], connection=conn)
|
||||
try_execute(env, ["cmd2"], connection=conn)
|
||||
execute(env, ["cmd3"], connection=conn)
|
||||
|
||||
assert env._establish_count == 1
|
||||
|
||||
assert env._release_count == 1
|
||||
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.virtual_environment.__base.helpers import submit_command
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser
|
||||
from core.virtual_environment.providers import local_without_isolation
|
||||
@ -99,7 +100,7 @@ def test_run_command_returns_output(local_env: LocalVirtualEnvironment):
|
||||
local_env.upload_file("message.txt", BytesIO(b"hello"))
|
||||
connection = local_env.establish_connection()
|
||||
|
||||
result = local_env.run_command(connection, ["/bin/sh", "-c", "cat message.txt"]).result(timeout=10)
|
||||
result = submit_command(local_env, connection, ["/bin/sh", "-c", "cat message.txt"]).result(timeout=10)
|
||||
|
||||
assert result.stdout == b"hello"
|
||||
assert result.stderr == b""
|
||||
@ -109,7 +110,7 @@ def test_run_command_returns_output(local_env: LocalVirtualEnvironment):
|
||||
def test_run_command_captures_stderr(local_env: LocalVirtualEnvironment):
|
||||
connection = local_env.establish_connection()
|
||||
|
||||
result = local_env.run_command(connection, ["/bin/sh", "-c", "echo OUT; echo ERR >&2"]).result(timeout=10)
|
||||
result = submit_command(local_env, connection, ["/bin/sh", "-c", "echo OUT; echo ERR >&2"]).result(timeout=10)
|
||||
|
||||
assert b"OUT" in result.stdout
|
||||
assert b"ERR" in result.stderr
|
||||
|
||||
Reference in New Issue
Block a user