mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
refactor(virtual_environment): add cwd parameter to execute_command method across all providers for improved command execution context
This commit is contained in:
@ -134,7 +134,11 @@ class VirtualEnvironment(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the virtual environment.
|
||||
@ -142,6 +146,8 @@ class VirtualEnvironment(ABC):
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
command (list[str]): The command to execute as a list of strings.
|
||||
environments (Mapping[str, str] | None): Environment variables for the command.
|
||||
cwd (str | None): Working directory for the command. If None, uses the provider's default.
|
||||
|
||||
Returns:
|
||||
tuple[int, TransportWriteCloser, TransportReadCloser, TransportReadCloser]
|
||||
@ -176,6 +182,7 @@ class VirtualEnvironment(ABC):
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> CommandFuture:
|
||||
"""
|
||||
Execute a command and return a Future for the result.
|
||||
@ -187,6 +194,7 @@ class VirtualEnvironment(ABC):
|
||||
connection_handle: The connection handle.
|
||||
command: Command as list of strings.
|
||||
environments: Environment variables.
|
||||
cwd: Working directory for the command. If None, uses the provider's default.
|
||||
|
||||
Returns:
|
||||
CommandFuture that can be used to get result with timeout or cancel.
|
||||
@ -195,7 +203,7 @@ class VirtualEnvironment(ABC):
|
||||
result = env.run_command(handle, ["ls", "-la"]).result(timeout=30)
|
||||
"""
|
||||
pid, stdin_transport, stdout_transport, stderr_transport = self.execute_command(
|
||||
connection_handle, command, environments
|
||||
connection_handle, command, environments, cwd
|
||||
)
|
||||
|
||||
return CommandFuture(
|
||||
|
||||
@ -185,7 +185,11 @@ class DaytonaEnvironment(VirtualEnvironment):
|
||||
return files
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
|
||||
@ -193,9 +197,11 @@ class DaytonaEnvironment(VirtualEnvironment):
|
||||
stderr_stream = QueueTransportReadCloser()
|
||||
pid = uuid4().hex
|
||||
|
||||
working_dir = cwd or self._working_dir
|
||||
|
||||
thread = threading.Thread(
|
||||
target=self._exec_thread,
|
||||
args=(pid, sandbox, command, environments or {}, stdout_stream, stderr_stream),
|
||||
args=(pid, sandbox, command, environments or {}, working_dir, stdout_stream, stderr_stream),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
@ -236,6 +242,7 @@ class DaytonaEnvironment(VirtualEnvironment):
|
||||
sandbox: Sandbox,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str],
|
||||
cwd: str,
|
||||
stdout_stream: QueueTransportReadCloser,
|
||||
stderr_stream: QueueTransportReadCloser,
|
||||
) -> None:
|
||||
@ -249,6 +256,7 @@ class DaytonaEnvironment(VirtualEnvironment):
|
||||
response = sandbox.process.exec(
|
||||
command=shlex.join(command),
|
||||
env=dict(environments),
|
||||
cwd=cwd,
|
||||
)
|
||||
exit_code = response.exit_code
|
||||
output = response.artifacts.stdout if response.artifacts and response.artifacts.stdout else response.result
|
||||
|
||||
@ -449,13 +449,20 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
return
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
container = self._get_container()
|
||||
container_id = container.id
|
||||
if not isinstance(container_id, str) or not container_id:
|
||||
raise RuntimeError("Docker container ID is not available for exec.")
|
||||
api_client = self.get_docker_api_client(self.get_docker_sock())
|
||||
|
||||
working_dir = cwd or self._working_dir
|
||||
|
||||
exec_info: dict[str, object] = cast(
|
||||
dict[str, object],
|
||||
api_client.exec_create( # pyright: ignore[reportUnknownMemberType] #
|
||||
@ -465,7 +472,7 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
stdout=True,
|
||||
stderr=True,
|
||||
tty=False,
|
||||
workdir=self._working_dir,
|
||||
workdir=working_dir,
|
||||
environment=environments,
|
||||
),
|
||||
)
|
||||
|
||||
@ -200,7 +200,11 @@ class E2BEnvironment(VirtualEnvironment):
|
||||
]
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the E2B virtual environment.
|
||||
@ -212,9 +216,11 @@ class E2BEnvironment(VirtualEnvironment):
|
||||
stdout_stream = QueueTransportReadCloser()
|
||||
stderr_stream = QueueTransportReadCloser()
|
||||
|
||||
working_dir = cwd or self._WORKDIR
|
||||
|
||||
threading.Thread(
|
||||
target=self._cmd_thread,
|
||||
args=(sandbox, command, environments, stdout_stream, stderr_stream),
|
||||
args=(sandbox, command, environments, working_dir, stdout_stream, stderr_stream),
|
||||
).start()
|
||||
|
||||
return (
|
||||
@ -235,10 +241,10 @@ class E2BEnvironment(VirtualEnvironment):
|
||||
sandbox: Sandbox,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None,
|
||||
cwd: str,
|
||||
stdout_stream: QueueTransportReadCloser,
|
||||
stderr_stream: QueueTransportReadCloser,
|
||||
) -> None:
|
||||
""" """
|
||||
stdout_stream_write_handler = stdout_stream.get_write_handler()
|
||||
stderr_stream_write_handler = stderr_stream.get_write_handler()
|
||||
|
||||
@ -246,7 +252,7 @@ class E2BEnvironment(VirtualEnvironment):
|
||||
sandbox.commands.run(
|
||||
cmd=shlex.join(command),
|
||||
envs=dict(environments or {}),
|
||||
# stdin=True,
|
||||
cwd=cwd,
|
||||
on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()),
|
||||
on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()),
|
||||
)
|
||||
|
||||
@ -171,16 +171,13 @@ class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
pass
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the local virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The connection handle.
|
||||
command (list[str]): The command to execute.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
working_path = cwd or self.get_working_path()
|
||||
stdin_read_fd, stdin_write_fd = os.pipe()
|
||||
stdout_read_fd, stdout_write_fd = os.pipe()
|
||||
stderr_read_fd, stderr_write_fd = os.pipe()
|
||||
|
||||
@ -82,26 +82,8 @@ class CommandNode(Node[CommandNodeData]):
|
||||
connection_handle = sandbox.establish_connection()
|
||||
|
||||
try:
|
||||
# TODO: VirtualEnvironment.run_command lacks native cwd support.
|
||||
# Once the interface adds a `cwd` parameter, remove this shell hack
|
||||
# and pass working_directory directly to run_command.
|
||||
if working_directory:
|
||||
check_cmd = ["test", "-d", working_directory]
|
||||
check_future = sandbox.run_command(connection_handle, check_cmd)
|
||||
check_result = check_future.result(timeout=timeout)
|
||||
|
||||
if check_result.exit_code != 0:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Working directory does not exist: {working_directory}",
|
||||
error_type="WorkingDirectoryNotFoundError",
|
||||
)
|
||||
|
||||
command = ["sh", "-c", f"cd {shlex.quote(working_directory)} && {raw_command}"]
|
||||
else:
|
||||
command = shlex.split(raw_command)
|
||||
|
||||
future = sandbox.run_command(connection_handle, command)
|
||||
command = shlex.split(raw_command)
|
||||
future = sandbox.run_command(connection_handle, command, cwd=working_directory)
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
outputs: dict[str, Any] = {
|
||||
|
||||
Reference in New Issue
Block a user