mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
fix(virtual-env): fix Docker stdout/stderr demuxing and exit code parsing
- Add _DockerDemuxer to properly separate stdout/stderr from multiplexed stream - Fix binary header garbage in Docker exec output (tty=False 8-byte header) - Fix LocalVirtualEnvironment.get_command_status() to use os.WEXITSTATUS() - Update tests to use Transport API instead of raw file descriptors
This commit is contained in:
@ -1,31 +1,27 @@
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser
|
||||
from core.virtual_environment.providers import local_without_isolation
|
||||
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||
|
||||
|
||||
def _read_all(fd: int) -> bytes:
|
||||
def _drain_transport(transport: TransportReadCloser) -> bytes:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = os.read(fd, 4096)
|
||||
if not data:
|
||||
break
|
||||
chunks.append(data)
|
||||
try:
|
||||
while True:
|
||||
data = transport.read(4096)
|
||||
if not data:
|
||||
break
|
||||
chunks.append(data)
|
||||
except TransportEOFError:
|
||||
pass
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
def _close_fds(*fds: int) -> None:
|
||||
for fd in fds:
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
|
||||
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
|
||||
@ -35,7 +31,7 @@ def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEn
|
||||
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):
|
||||
working_path = local_env.get_working_path()
|
||||
assert local_env.metadata.id
|
||||
assert os.path.isdir(working_path)
|
||||
assert Path(working_path).is_dir()
|
||||
|
||||
|
||||
def test_upload_download_roundtrip(local_env: LocalVirtualEnvironment):
|
||||
@ -54,7 +50,7 @@ def test_list_files_respects_limit(local_env: LocalVirtualEnvironment):
|
||||
all_files = local_env.list_files("", limit=10)
|
||||
all_paths = {state.path for state in all_files}
|
||||
|
||||
assert os.path.join("dir", "file_a.txt") in all_paths
|
||||
assert "dir/file_a.txt" in all_paths or "dir\\file_a.txt" in all_paths
|
||||
assert "file_b.txt" in all_paths
|
||||
|
||||
limited_files = local_env.list_files("", limit=1)
|
||||
@ -66,16 +62,15 @@ def test_execute_command_uses_working_directory(local_env: LocalVirtualEnvironme
|
||||
connection = local_env.establish_connection()
|
||||
command = ["/bin/sh", "-c", "cat message.txt"]
|
||||
|
||||
pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command)
|
||||
_, stdin_transport, stdout_transport, stderr_transport = local_env.execute_command(connection, command)
|
||||
|
||||
try:
|
||||
os.close(stdin_fd)
|
||||
if hasattr(os, "waitpid"):
|
||||
os.waitpid(pid, 0)
|
||||
stdout = _read_all(stdout_fd)
|
||||
stderr = _read_all(stderr_fd)
|
||||
stdin_transport.close()
|
||||
stdout = _drain_transport(stdout_transport)
|
||||
stderr = _drain_transport(stderr_transport)
|
||||
finally:
|
||||
_close_fds(stdin_fd, stdout_fd, stderr_fd)
|
||||
stdout_transport.close()
|
||||
stderr_transport.close()
|
||||
|
||||
assert stdout == b"hello"
|
||||
assert stderr == b""
|
||||
@ -85,17 +80,37 @@ def test_execute_command_pipes_stdio(local_env: LocalVirtualEnvironment):
|
||||
connection = local_env.establish_connection()
|
||||
command = ["/bin/sh", "-c", "tr a-z A-Z < /dev/stdin; printf ERR >&2"]
|
||||
|
||||
pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command)
|
||||
_, stdin_transport, stdout_transport, stderr_transport = local_env.execute_command(connection, command)
|
||||
|
||||
try:
|
||||
os.write(stdin_fd, b"abc")
|
||||
os.close(stdin_fd)
|
||||
if hasattr(os, "waitpid"):
|
||||
os.waitpid(pid, 0)
|
||||
stdout = _read_all(stdout_fd)
|
||||
stderr = _read_all(stderr_fd)
|
||||
stdin_transport.write(b"abc")
|
||||
stdin_transport.close()
|
||||
stdout = _drain_transport(stdout_transport)
|
||||
stderr = _drain_transport(stderr_transport)
|
||||
finally:
|
||||
_close_fds(stdin_fd, stdout_fd, stderr_fd)
|
||||
stdout_transport.close()
|
||||
stderr_transport.close()
|
||||
|
||||
assert stdout == b"ABC"
|
||||
assert stderr == b"ERR"
|
||||
|
||||
|
||||
def test_run_command_returns_output(local_env: LocalVirtualEnvironment):
|
||||
local_env.upload_file("message.txt", BytesIO(b"hello"))
|
||||
connection = local_env.establish_connection()
|
||||
|
||||
result = local_env.run_command(connection, ["/bin/sh", "-c", "cat message.txt"]).result(timeout=10)
|
||||
|
||||
assert result.stdout == b"hello"
|
||||
assert result.stderr == b""
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_run_command_captures_stderr(local_env: LocalVirtualEnvironment):
|
||||
connection = local_env.establish_connection()
|
||||
|
||||
result = local_env.run_command(connection, ["/bin/sh", "-c", "echo OUT; echo ERR >&2"]).result(timeout=10)
|
||||
|
||||
assert b"OUT" in result.stdout
|
||||
assert b"ERR" in result.stderr
|
||||
assert result.exit_code == 0
|
||||
|
||||
Reference in New Issue
Block a user