refactor: a lot of optimization and enhancement

This commit is contained in:
Harry
2026-03-12 18:22:57 +08:00
parent 4a64362193
commit f0c6c0159c
17 changed files with 248 additions and 173 deletions

View File

@ -1,4 +1,5 @@
import threading
from collections.abc import Callable
import pytest
@ -18,6 +19,7 @@ def _make_future(
exit_code: int = 0,
delay_completion: float = 0,
close_streams: bool = True,
terminate_command: Callable[[], bool] | None = None,
) -> CommandFuture:
stdout_transport = QueueTransportReadCloser()
stderr_transport = QueueTransportReadCloser()
@ -48,6 +50,7 @@ def _make_future(
stdout_transport=stdout_transport,
stderr_transport=stderr_transport,
poll_status=poll_status,
terminate_command=terminate_command,
poll_interval=0.05,
)
@ -78,6 +81,21 @@ def test_result_raises_timeout_error_when_exceeded():
future.result(timeout=0.2)
def test_timeout_requests_command_termination():
terminated = threading.Event()
future = _make_future(
delay_completion=10.0,
close_streams=False,
terminate_command=lambda: terminated.set() or True,
)
with pytest.raises(CommandTimeoutError):
future.result(timeout=0.2)
assert terminated.wait(timeout=1.0)
def test_done_returns_false_while_running():
future = _make_future(delay_completion=10.0, close_streams=False)
@ -115,6 +133,19 @@ def test_result_raises_cancelled_error_after_cancel():
future.result()
def test_cancel_requests_command_termination():
terminated = threading.Event()
future = _make_future(
delay_completion=10.0,
close_streams=False,
terminate_command=lambda: terminated.set() or True,
)
assert future.cancel() is True
assert terminated.wait(timeout=1.0)
def test_nonzero_exit_code_is_returned():
future = _make_future(stdout=b"err", exit_code=42)

View File

@ -4,6 +4,7 @@ from typing import Any
import pytest
from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.entities import (
Arch,
CommandStatus,
@ -64,6 +65,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
self._establish_count = 0
self._release_count = 0
super().__init__(tenant_id="test-tenant", options={}, environments={})
self.open_enviroment()
def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata:
return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX)
@ -111,6 +113,10 @@ class FakeVirtualEnvironment(VirtualEnvironment):
def validate(cls, _options: Mapping[str, Any]) -> None:
pass
@classmethod
def get_config_schema(cls) -> list[BasicProviderConfig]:
return []
class TestWithConnection:
def test_connection_established_and_released(self):

View File

@ -26,7 +26,9 @@ def _drain_transport(transport: TransportReadCloser) -> bytes:
@pytest.fixture
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
return LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)})
env = LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)})
env.open_enviroment()
return env
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):

View File

@ -42,6 +42,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
self.last_execute_cwd: str | None = None
self.released_connections: list[str] = []
super().__init__(tenant_id="test-tenant", options={}, environments={})
self.open_enviroment()
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX)