mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 07:58:02 +08:00
refactor: a lot of optimization and enhancement
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
@ -18,6 +19,7 @@ def _make_future(
|
||||
exit_code: int = 0,
|
||||
delay_completion: float = 0,
|
||||
close_streams: bool = True,
|
||||
terminate_command: Callable[[], bool] | None = None,
|
||||
) -> CommandFuture:
|
||||
stdout_transport = QueueTransportReadCloser()
|
||||
stderr_transport = QueueTransportReadCloser()
|
||||
@ -48,6 +50,7 @@ def _make_future(
|
||||
stdout_transport=stdout_transport,
|
||||
stderr_transport=stderr_transport,
|
||||
poll_status=poll_status,
|
||||
terminate_command=terminate_command,
|
||||
poll_interval=0.05,
|
||||
)
|
||||
|
||||
@ -78,6 +81,21 @@ def test_result_raises_timeout_error_when_exceeded():
|
||||
future.result(timeout=0.2)
|
||||
|
||||
|
||||
def test_timeout_requests_command_termination():
|
||||
terminated = threading.Event()
|
||||
|
||||
future = _make_future(
|
||||
delay_completion=10.0,
|
||||
close_streams=False,
|
||||
terminate_command=lambda: terminated.set() or True,
|
||||
)
|
||||
|
||||
with pytest.raises(CommandTimeoutError):
|
||||
future.result(timeout=0.2)
|
||||
|
||||
assert terminated.wait(timeout=1.0)
|
||||
|
||||
|
||||
def test_done_returns_false_while_running():
|
||||
future = _make_future(delay_completion=10.0, close_streams=False)
|
||||
|
||||
@ -115,6 +133,19 @@ def test_result_raises_cancelled_error_after_cancel():
|
||||
future.result()
|
||||
|
||||
|
||||
def test_cancel_requests_command_termination():
|
||||
terminated = threading.Event()
|
||||
|
||||
future = _make_future(
|
||||
delay_completion=10.0,
|
||||
close_streams=False,
|
||||
terminate_command=lambda: terminated.set() or True,
|
||||
)
|
||||
|
||||
assert future.cancel() is True
|
||||
assert terminated.wait(timeout=1.0)
|
||||
|
||||
|
||||
def test_nonzero_exit_code_is_returned():
|
||||
future = _make_future(stdout=b"err", exit_code=42)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
@ -64,6 +65,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
self._establish_count = 0
|
||||
self._release_count = 0
|
||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||
self.open_enviroment()
|
||||
|
||||
def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata:
|
||||
return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX)
|
||||
@ -111,6 +113,10 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
def validate(cls, _options: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||
return []
|
||||
|
||||
|
||||
class TestWithConnection:
|
||||
def test_connection_established_and_released(self):
|
||||
|
||||
@ -26,7 +26,9 @@ def _drain_transport(transport: TransportReadCloser) -> bytes:
|
||||
@pytest.fixture
|
||||
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
|
||||
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
|
||||
return LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)})
|
||||
env = LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)})
|
||||
env.open_enviroment()
|
||||
return env
|
||||
|
||||
|
||||
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):
|
||||
|
||||
@ -42,6 +42,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
self.last_execute_cwd: str | None = None
|
||||
self.released_connections: list[str] = []
|
||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||
self.open_enviroment()
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX)
|
||||
|
||||
Reference in New Issue
Block a user