refactor(virtual_environment): add cwd parameter to execute_command method across all providers for improved command execution context

This commit is contained in:
Harry
2026-01-12 14:20:03 +08:00
parent f990f4a8d4
commit 201a18d6ba
8 changed files with 79 additions and 47 deletions

View File

@ -134,7 +134,11 @@ class VirtualEnvironment(ABC):
@abstractmethod
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
"""
Execute a command in the virtual environment.
@ -142,6 +146,8 @@ class VirtualEnvironment(ABC):
Args:
connection_handle (ConnectionHandle): The handle for managing the connection.
command (list[str]): The command to execute as a list of strings.
environments (Mapping[str, str] | None): Environment variables for the command.
cwd (str | None): Working directory for the command. If None, uses the provider's default.
Returns:
tuple[int, TransportWriteCloser, TransportReadCloser, TransportReadCloser]
@ -176,6 +182,7 @@ class VirtualEnvironment(ABC):
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> CommandFuture:
"""
Execute a command and return a Future for the result.
@ -187,6 +194,7 @@ class VirtualEnvironment(ABC):
connection_handle: The connection handle.
command: Command as list of strings.
environments: Environment variables.
cwd: Working directory for the command. If None, uses the provider's default.
Returns:
CommandFuture that can be used to get result with timeout or cancel.
@ -195,7 +203,7 @@ class VirtualEnvironment(ABC):
result = env.run_command(handle, ["ls", "-la"]).result(timeout=30)
"""
pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command(
connection_handle, command, environments
connection_handle, command, environments, cwd
)
return CommandFuture(

View File

@ -185,7 +185,11 @@ class DaytonaEnvironment(VirtualEnvironment):
return files
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
@ -193,9 +197,11 @@ class DaytonaEnvironment(VirtualEnvironment):
stderr_stream = QueueTransportReadCloser()
pid = uuid4().hex
working_dir = cwd or self._working_dir
thread = threading.Thread(
target=self._exec_thread,
args=(pid, sandbox, command, environments or {}, stdout_stream, stderr_stream),
args=(pid, sandbox, command, environments or {}, working_dir, stdout_stream, stderr_stream),
daemon=True,
)
@ -236,6 +242,7 @@ class DaytonaEnvironment(VirtualEnvironment):
sandbox: Sandbox,
command: list[str],
environments: Mapping[str, str],
cwd: str,
stdout_stream: QueueTransportReadCloser,
stderr_stream: QueueTransportReadCloser,
) -> None:
@ -249,6 +256,7 @@ class DaytonaEnvironment(VirtualEnvironment):
response = sandbox.process.exec(
command=shlex.join(command),
env=dict(environments),
cwd=cwd,
)
exit_code = response.exit_code
output = response.artifacts.stdout if response.artifacts and response.artifacts.stdout else response.result

View File

@ -449,13 +449,20 @@ class DockerDaemonEnvironment(VirtualEnvironment):
return
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
container = self._get_container()
container_id = container.id
if not isinstance(container_id, str) or not container_id:
raise RuntimeError("Docker container ID is not available for exec.")
api_client = self.get_docker_api_client(self.get_docker_sock())
working_dir = cwd or self._working_dir
exec_info: dict[str, object] = cast(
dict[str, object],
api_client.exec_create( # pyright: ignore[reportUnknownMemberType] #
@ -465,7 +472,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
stdout=True,
stderr=True,
tty=False,
workdir=self._working_dir,
workdir=working_dir,
environment=environments,
),
)

View File

@ -200,7 +200,11 @@ class E2BEnvironment(VirtualEnvironment):
]
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
"""
Execute a command in the E2B virtual environment.
@ -212,9 +216,11 @@ class E2BEnvironment(VirtualEnvironment):
stdout_stream = QueueTransportReadCloser()
stderr_stream = QueueTransportReadCloser()
working_dir = cwd or self._WORKDIR
threading.Thread(
target=self._cmd_thread,
args=(sandbox, command, environments, stdout_stream, stderr_stream),
args=(sandbox, command, environments, working_dir, stdout_stream, stderr_stream),
).start()
return (
@ -235,10 +241,10 @@ class E2BEnvironment(VirtualEnvironment):
sandbox: Sandbox,
command: list[str],
environments: Mapping[str, str] | None,
cwd: str,
stdout_stream: QueueTransportReadCloser,
stderr_stream: QueueTransportReadCloser,
) -> None:
""" """
stdout_stream_write_handler = stdout_stream.get_write_handler()
stderr_stream_write_handler = stderr_stream.get_write_handler()
@ -246,7 +252,7 @@ class E2BEnvironment(VirtualEnvironment):
sandbox.commands.run(
cmd=shlex.join(command),
envs=dict(environments or {}),
# stdin=True,
cwd=cwd,
on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()),
on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()),
)

View File

@ -171,16 +171,13 @@ class LocalVirtualEnvironment(VirtualEnvironment):
pass
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
self,
connection_handle: ConnectionHandle,
command: list[str],
environments: Mapping[str, str] | None = None,
cwd: str | None = None,
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
"""
Execute a command in the local virtual environment.
Args:
connection_handle (ConnectionHandle): The connection handle.
command (list[str]): The command to execute.
"""
working_path = self.get_working_path()
working_path = cwd or self.get_working_path()
stdin_read_fd, stdin_write_fd = os.pipe()
stdout_read_fd, stdout_write_fd = os.pipe()
stderr_read_fd, stderr_write_fd = os.pipe()

View File

@ -82,26 +82,8 @@ class CommandNode(Node[CommandNodeData]):
connection_handle = sandbox.establish_connection()
try:
# TODO: VirtualEnvironment.run_command lacks native cwd support.
# Once the interface adds a `cwd` parameter, remove this shell hack
# and pass working_directory directly to run_command.
if working_directory:
check_cmd = ["test", "-d", working_directory]
check_future = sandbox.run_command(connection_handle, check_cmd)
check_result = check_future.result(timeout=timeout)
if check_result.exit_code != 0:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Working directory does not exist: {working_directory}",
error_type="WorkingDirectoryNotFoundError",
)
command = ["sh", "-c", f"cd {shlex.quote(working_directory)} && {raw_command}"]
else:
command = shlex.split(raw_command)
future = sandbox.run_command(connection_handle, command)
command = shlex.split(raw_command)
future = sandbox.run_command(connection_handle, command, cwd=working_directory)
result = future.result(timeout=timeout)
outputs: dict[str, Any] = {