feat: future interface for easy way to use VM.execute_command

This commit is contained in:
Harry
2026-01-07 11:57:00 +08:00
parent 888be71639
commit 05c3344554
5 changed files with 367 additions and 132 deletions

View File

@ -1,15 +1,11 @@
import contextlib
import logging
import shlex
import threading
import time
from collections.abc import Mapping, Sequence
from typing import Any
from core.virtual_environment.__base.exec import NotSupportedOperationError
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import TransportReadCloser
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
@ -17,35 +13,22 @@ from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.command.entities import CommandNodeData
from core.workflow.nodes.command.exc import CommandExecutionError, CommandTimeoutError
from core.workflow.nodes.command.exc import CommandExecutionError
logger = logging.getLogger(__name__)
COMMAND_NODE_TIMEOUT_SECONDS = 60
def _drain_transport(transport: TransportReadCloser, buffer: bytearray) -> None:
try:
while True:
buffer.extend(transport.read(4096))
except TransportEOFError:
pass
except Exception:
logger.exception("Failed reading transport")
finally:
with contextlib.suppress(Exception):
transport.close()
class CommandNode(Node[CommandNodeData]):
"""Command Node - execute shell commands in a VirtualEnvironment."""
# FIXME: This is a temporary solution for sandbox injection from SandboxLayer.
# The sandbox is dynamically attached by SandboxLayer.on_node_run_start() before
# node execution and cleared by on_node_run_end(). A cleaner approach would be
# to pass sandbox through GraphRuntimeState or use a proper dependency injection pattern.
sandbox: VirtualEnvironment | None = None
node_type = NodeType.COMMAND
def _render_template(self, template: str) -> str:
parser = VariableTemplateParser(template=template)
selectors = parser.extract_variable_selectors()
@ -59,11 +42,8 @@ class CommandNode(Node[CommandNodeData]):
return parser.format(inputs)
node_type = NodeType.COMMAND
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""Get default config of node."""
return {
"type": "command",
"config": {
@ -91,7 +71,6 @@ class CommandNode(Node[CommandNodeData]):
raw_command = self._render_template(raw_command).strip()
working_directory = working_directory or None
timeout_seconds = COMMAND_NODE_TIMEOUT_SECONDS
if not raw_command:
return NodeRunResult(
@ -105,136 +84,52 @@ class CommandNode(Node[CommandNodeData]):
shell_command = f"cd {shlex.quote(working_directory)} && {raw_command}"
command = ["sh", "-lc", shell_command]
# 0 or negative means no timeout
deadline = None
if timeout_seconds > 0:
deadline = time.monotonic() + timeout_seconds
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
connection_handle = self.sandbox.establish_connection()
pid = ""
stdin_transport = None
stdout_transport = None
stderr_transport = None
threads: list[threading.Thread] = []
stdout_buf = bytearray()
stderr_buf = bytearray()
try:
pid, stdin_transport, stdout_transport, stderr_transport = self.sandbox.execute_command(
connection_handle, command
)
is_combined_stream = stdout_transport is stderr_transport
stdout_thread = threading.Thread(
target=_drain_transport,
args=(stdout_transport, stdout_buf),
daemon=True,
)
threads.append(stdout_thread)
stdout_thread.start()
if not is_combined_stream:
stderr_thread = threading.Thread(
target=_drain_transport,
args=(stderr_transport, stderr_buf),
daemon=True,
)
threads.append(stderr_thread)
stderr_thread.start()
exit_code: int | None = None
while True:
if deadline is not None and time.monotonic() > deadline:
raise CommandTimeoutError(f"Command timed out after {timeout_seconds}s")
try:
status = self.sandbox.get_command_status(connection_handle, pid)
except NotSupportedOperationError:
break
if status.status == status.Status.COMPLETED:
exit_code = status.exit_code
break
time.sleep(0.1)
# Ensure transports are fully drained.
def _join_all() -> bool:
for t in threads:
remaining = None
if deadline is not None:
remaining = max(0.0, deadline - time.monotonic())
t.join(timeout=remaining)
if t.is_alive():
return False
return True
if not _join_all():
raise CommandTimeoutError(f"Command output not drained within {timeout_seconds}s")
stdout_text = stdout_buf.decode("utf-8", errors="replace")
stderr_text = "" if is_combined_stream else stderr_buf.decode("utf-8", errors="replace")
future = self.sandbox.run_command(connection_handle, command)
result = future.result(timeout=timeout)
outputs: dict[str, Any] = {
"stdout": stdout_text,
"stderr": stderr_text,
"exit_code": exit_code,
"pid": pid,
"stdout": result.stdout.decode("utf-8", errors="replace"),
"stderr": result.stderr.decode("utf-8", errors="replace"),
"exit_code": result.exit_code,
"pid": result.pid,
}
process_data = {"command": command, "working_directory": working_directory}
if exit_code not in (None, 0):
if result.exit_code not in (None, 0):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs=outputs,
process_data={"command": command, "working_directory": working_directory},
error=f"Command exited with code {exit_code}",
process_data=process_data,
error=f"Command exited with code {result.exit_code}",
error_type=CommandExecutionError.__name__,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
process_data={"command": command, "working_directory": working_directory},
process_data=process_data,
)
except (CommandExecutionError, CommandTimeoutError) as e:
if isinstance(e, CommandTimeoutError) and stdout_transport is not None:
for transport in (stdout_transport, stderr_transport):
if transport is None:
continue
with contextlib.suppress(Exception):
transport.close()
for t in threads:
t.join(timeout=0.2)
except CommandTimeoutError:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
"stdout": stdout_buf.decode("utf-8", errors="replace"),
"stderr": stderr_buf.decode("utf-8", errors="replace"),
"exit_code": None,
"pid": pid,
},
process_data={"command": command, "working_directory": working_directory},
error=str(e),
error_type=type(e).__name__,
error=f"Command timed out after {COMMAND_NODE_TIMEOUT_SECONDS}s",
error_type=CommandTimeoutError.__name__,
)
except CommandCancelledError:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="Command was cancelled",
error_type=CommandCancelledError.__name__,
)
except Exception as e:
logger.exception("Command node %s failed", self.id)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
"stdout": stdout_buf.decode("utf-8", errors="replace"),
"stderr": stderr_buf.decode("utf-8", errors="replace"),
"exit_code": None,
"pid": pid,
},
process_data={"command": command, "working_directory": working_directory},
error=str(e),
error_type=type(e).__name__,
)
@ -250,8 +145,7 @@ class CommandNode(Node[CommandNodeData]):
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""Extract variable mappings from node data."""
_ = graph_config # Explicitly mark as unused
_ = graph_config
typed_node_data = CommandNodeData.model_validate(node_data)