refactor: a lot of optimization and enhancement

This commit is contained in:
Harry
2026-03-12 18:22:57 +08:00
parent 4a64362193
commit f0c6c0159c
17 changed files with 248 additions and 173 deletions

View File

@ -26,6 +26,9 @@ class CommandFuture:
Lightweight future for command execution.
Mirrors concurrent.futures.Future API with 4 essential methods:
result(), done(), cancel(), cancelled().
When a command is cancelled or times out the future now asks the provider
to terminate the underlying process/session before marking itself done.
"""
def __init__(
@ -35,6 +38,7 @@ class CommandFuture:
stdout_transport: TransportReadCloser,
stderr_transport: TransportReadCloser,
poll_status: Callable[[], CommandStatus],
terminate_command: Callable[[], bool] | None = None,
poll_interval: float = 0.1,
):
self._pid = pid
@ -42,6 +46,7 @@ class CommandFuture:
self._stdout_transport = stdout_transport
self._stderr_transport = stderr_transport
self._poll_status = poll_status
self._terminate_command = terminate_command
self._poll_interval = poll_interval
self._done_event = threading.Event()
@ -49,7 +54,9 @@ class CommandFuture:
self._result: CommandResult | None = None
self._exception: BaseException | None = None
self._cancelled = False
self._timed_out = False
self._started = False
self._termination_requested = False
def result(self, timeout: float | None = None) -> CommandResult:
"""
@ -61,15 +68,22 @@ class CommandFuture:
Raises:
CommandTimeoutError: If timeout exceeded.
CommandCancelledError: If command was cancelled.
A timeout is terminal for this future: it triggers best-effort command
termination and subsequent ``result()`` calls keep raising timeout.
"""
self._ensure_started()
if not self._done_event.wait(timeout):
self._request_stop(timed_out=True)
raise CommandTimeoutError(f"Command timed out after {timeout}s")
if self._cancelled:
raise CommandCancelledError("Command was cancelled")
if self._timed_out:
raise CommandTimeoutError("Command timed out")
if self._exception is not None:
raise self._exception
@ -82,16 +96,10 @@ class CommandFuture:
def cancel(self) -> bool:
"""
Attempt to cancel command by closing transports.
Attempt to cancel command by terminating it and closing transports.
Returns True if cancelled, False if already completed.
"""
with self._lock:
if self._done_event.is_set():
return False
self._cancelled = True
self._close_transports()
self._done_event.set()
return True
return self._request_stop(cancelled=True)
def cancelled(self) -> bool:
return self._cancelled
@ -103,6 +111,28 @@ class CommandFuture:
thread = threading.Thread(target=self._execute, daemon=True)
thread.start()
def _request_stop(self, *, cancelled: bool = False, timed_out: bool = False) -> bool:
should_terminate = False
with self._lock:
if self._done_event.is_set():
return False
if cancelled:
self._cancelled = True
if timed_out:
self._timed_out = True
should_terminate = not self._termination_requested
if should_terminate:
self._termination_requested = True
self._close_transports()
self._done_event.set()
if should_terminate:
self._terminate_running_command()
return True
def _execute(self) -> None:
stdout_buf = bytearray()
stderr_buf = bytearray()
@ -141,7 +171,7 @@ class CommandFuture:
self._close_transports()
def _wait_for_completion(self) -> int | None:
while not self._cancelled:
while not self._cancelled and not self._timed_out:
try:
status = self._poll_status()
except NotSupportedOperationError:
@ -167,3 +197,12 @@ class CommandFuture:
for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport):
with contextlib.suppress(Exception):
transport.close()
def _terminate_running_command(self) -> None:
if self._terminate_command is None:
return
try:
self._terminate_command()
except Exception:
logger.exception("Failed to terminate command for pid %s", self._pid)

View File

@ -71,6 +71,7 @@ def submit_command(
stdout_transport=stdout_transport,
stderr_transport=stderr_transport,
poll_status=partial(env.get_command_status, connection, pid),
terminate_command=partial(env.terminate_command, connection, pid),
)

View File

@ -11,8 +11,19 @@ from core.virtual_environment.channel.transport import TransportReadCloser, Tran
class VirtualEnvironment(ABC):
"""
Base class for virtual environment implementations.
``VirtualEnvironment`` instances are configured at construction time but do
not allocate provider resources until ``open_enviroment()`` is called.
This keeps object construction side-effect free and gives callers a chance
to own startup error handling explicitly.
"""
tenant_id: str
user_id: str | None
options: Mapping[str, Any]
_environments: Mapping[str, str]
_metadata: Metadata | None
def __init__(
self,
tenant_id: str,
@ -21,19 +32,45 @@ class VirtualEnvironment(ABC):
user_id: str | None = None,
) -> None:
"""
Initialize the virtual environment with metadata.
Initialize the virtual environment configuration.
Args:
tenant_id: The tenant ID associated with this environment (required).
options: Provider-specific configuration options.
environments: Environment variables to set in the virtual environment.
user_id: The user ID associated with this environment (optional).
The provider runtime itself is created later by ``open_enviroment()``.
"""
self.tenant_id = tenant_id
self.user_id = user_id
self.options = options
self.metadata = self._construct_environment(options, environments or {})
self._environments = dict(environments or {})
self._metadata = None
@property
def metadata(self) -> Metadata:
"""Provider metadata for a started environment.
Raises:
RuntimeError: If the environment has not been started yet.
"""
if self._metadata is None:
raise RuntimeError("Virtual environment has not been started")
return self._metadata
def open_enviroment(self) -> Metadata:
"""Allocate provider resources and return the resulting metadata.
Multiple calls are safe and return the existing metadata after the first
successful start.
"""
if self._metadata is None:
self._metadata = self._construct_environment(self.options, self._environments)
return self._metadata
@abstractmethod
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
@ -131,6 +168,18 @@ class VirtualEnvironment(ABC):
Multiple calls to `release_environment` with the same `environment_id` is acceptable.
"""
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
"""Best-effort termination hook for a running command.
Providers that can map ``pid`` back to a real process/session should
override this method and stop the command. The default implementation is
a no-op so providers without a termination mechanism remain compatible.
"""
_ = connection_handle
_ = pid
return False
@abstractmethod
def execute_command(
self,