feat: introduce TransportEOFError for handling closed transport scenarios and update transport classes to raise it

This commit is contained in:
Yeuoly
2026-01-01 18:46:08 +08:00
parent 180fdffab1
commit 2673fe05a5
8 changed files with 163 additions and 27 deletions

View File

@ -0,0 +1,4 @@
class TransportEOFError(Exception):
"""Exception raised when attempting to read from a closed transport."""
pass

View File

@ -1,5 +1,6 @@
import os
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
@ -18,10 +19,16 @@ class PipeTransport(Transport):
self.w_fd = w_fd
def write(self, data: bytes) -> None:
os.write(self.w_fd, data)
try:
os.write(self.w_fd, data)
except OSError:
raise TransportEOFError("Pipe write error, maybe the read end is closed")
def read(self, n: int) -> bytes:
return os.read(self.r_fd, n)
data = os.read(self.r_fd, n)
if data == b"":
raise TransportEOFError("End of Pipe reached")
return data
def close(self) -> None:
os.close(self.r_fd)
@ -37,7 +44,11 @@ class PipeReadCloser(TransportReadCloser):
self.r_fd = r_fd
def read(self, n: int) -> bytes:
return os.read(self.r_fd, n)
data = os.read(self.r_fd, n)
if data == b"":
raise TransportEOFError("End of Pipe reached")
return data
def close(self) -> None:
os.close(self.r_fd)
@ -52,7 +63,10 @@ class PipeWriteCloser(TransportWriteCloser):
self.w_fd = w_fd
def write(self, data: bytes) -> None:
os.write(self.w_fd, data)
try:
os.write(self.w_fd, data)
except OSError:
raise TransportEOFError("Pipe write error, maybe the read end is closed")
def close(self) -> None:
os.close(self.w_fd)

View File

@ -1,5 +1,6 @@
from queue import Queue
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import TransportReadCloser
@ -39,6 +40,9 @@ class QueueTransportReadCloser(TransportReadCloser):
Initialize the QueueTransportReadCloser with write function.
"""
self.q = Queue[bytes | None]()
self._read_buffer = bytearray()
self._closed = False
self._write_channel_closed = False
def get_write_handler(self) -> WriteHandler:
"""
@ -50,17 +54,47 @@ class QueueTransportReadCloser(TransportReadCloser):
"""
Close the transport by putting a sentinel value in the queue.
"""
if self._write_channel_closed:
raise TransportEOFError("Write channel already closed")
self._write_channel_closed = True
self.q.put(None)
def read(self, n: int) -> bytes:
"""
Read up to n bytes from the queue.
NEVER USE IT IN A MULTI-THREADED CONTEXT WITHOUT PROPER SYNCHRONIZATION.
"""
data = bytearray()
while len(data) < n:
if n <= 0:
return b""
if self._closed:
raise TransportEOFError("Transport is closed")
to_return = self._drain_buffer(n)
while len(to_return) < n and not self._closed:
chunk = self.q.get()
if chunk is None:
break
data.extend(chunk)
self._closed = True
raise TransportEOFError("Transport is closed")
return bytes(data)
self._read_buffer.extend(chunk)
if n - len(to_return) > 0:
# Drain the buffer if we still need more data
to_return += self._drain_buffer(n - len(to_return))
else:
# No more data needed, break
break
if self.q.qsize() == 0:
# If no more data is available, break to return what we have
break
return to_return
def _drain_buffer(self, n: int) -> bytes:
data = bytes(self._read_buffer[:n])
del self._read_buffer[:n]
return data

View File

@ -1,5 +1,6 @@
import socket
from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
@ -12,10 +13,19 @@ class SocketTransport(Transport):
self.sock = sock
def write(self, data: bytes) -> None:
self.sock.write(data)
try:
self.sock.write(data)
except (ConnectionResetError, BrokenPipeError):
raise TransportEOFError("Socket write error, maybe the read end is closed")
def read(self, n: int) -> bytes:
return self.sock.read(n)
try:
data = self.sock.read(n)
if data == b"":
raise TransportEOFError("End of Socket reached")
except (ConnectionResetError, BrokenPipeError):
raise TransportEOFError("Socket connection reset")
return data
def close(self) -> None:
self.sock.close()
@ -30,7 +40,13 @@ class SocketReadCloser(TransportReadCloser):
self.sock = sock
def read(self, n: int) -> bytes:
return self.sock.read(n)
try:
data = self.sock.read(n)
if data == b"":
raise TransportEOFError("End of Socket reached")
return data
except (ConnectionResetError, BrokenPipeError):
raise TransportEOFError("Socket connection reset")
def close(self) -> None:
self.sock.close()
@ -45,7 +61,10 @@ class SocketWriteCloser(TransportWriteCloser):
self.sock = sock
def write(self, data: bytes) -> None:
self.sock.write(data)
try:
self.sock.write(data)
except (ConnectionResetError, BrokenPipeError):
raise TransportEOFError("Socket write error, maybe the read end is closed")
def close(self) -> None:
self.sock.close()

View File

@ -23,6 +23,8 @@ class TransportWriter(Protocol):
def write(self, data: bytes) -> None:
"""
Write data to the transport.
Raises TransportEOFError if the transport is closed.
"""
@ -35,6 +37,8 @@ class TransportReader(Protocol):
def read(self, n: int) -> bytes:
"""
Read up to n bytes from the transport.
Raises TransportEOFError if the end of the transport is reached.
"""