mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
refactor: a lot of optimization and enhancement
This commit is contained in:
@ -23,7 +23,6 @@ if TYPE_CHECKING:
|
|||||||
from .initializer.app_assets_initializer import AppAssetsInitializer
|
from .initializer.app_assets_initializer import AppAssetsInitializer
|
||||||
from .initializer.dify_cli_initializer import DifyCliInitializer
|
from .initializer.dify_cli_initializer import DifyCliInitializer
|
||||||
from .initializer.draft_app_assets_initializer import DraftAppAssetsInitializer
|
from .initializer.draft_app_assets_initializer import DraftAppAssetsInitializer
|
||||||
from .manager import SandboxManager
|
|
||||||
from .sandbox import Sandbox
|
from .sandbox import Sandbox
|
||||||
from .storage import ArchiveSandboxStorage, SandboxStorage
|
from .storage import ArchiveSandboxStorage, SandboxStorage
|
||||||
from .utils.debug import sandbox_debug
|
from .utils.debug import sandbox_debug
|
||||||
@ -47,7 +46,6 @@ __all__ = [
|
|||||||
"SandboxBuilder",
|
"SandboxBuilder",
|
||||||
"SandboxInitializeContext",
|
"SandboxInitializeContext",
|
||||||
"SandboxInitializer",
|
"SandboxInitializer",
|
||||||
"SandboxManager",
|
|
||||||
"SandboxProviderApiEntity",
|
"SandboxProviderApiEntity",
|
||||||
"SandboxStorage",
|
"SandboxStorage",
|
||||||
"SandboxType",
|
"SandboxType",
|
||||||
|
|||||||
@ -131,6 +131,7 @@ class SandboxBuilder:
|
|||||||
environments=self._environments,
|
environments=self._environments,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
)
|
)
|
||||||
|
vm.open_enviroment()
|
||||||
sandbox = Sandbox(
|
sandbox = Sandbox(
|
||||||
vm=vm,
|
vm=vm,
|
||||||
storage=self._storage,
|
storage=self._storage,
|
||||||
@ -174,7 +175,11 @@ class SandboxBuilder:
|
|||||||
|
|
||||||
if sandbox.is_cancelled():
|
if sandbox.is_cancelled():
|
||||||
return
|
return
|
||||||
sandbox.mount()
|
# Storage mount is part of readiness. If restore/mount fails,
|
||||||
|
# the sandbox must surface initialization failure instead of
|
||||||
|
# becoming "ready" with missing files.
|
||||||
|
if not sandbox.mount():
|
||||||
|
raise RuntimeError("Sandbox storage mount failed")
|
||||||
sandbox.mark_ready()
|
sandbox.mark_ready()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -19,12 +19,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DifyCliInitializer(AsyncSandboxInitializer):
|
class DifyCliInitializer(AsyncSandboxInitializer):
|
||||||
_cli_api_session: object | None
|
|
||||||
|
|
||||||
def __init__(self, cli_root: str | Path | None = None) -> None:
|
def __init__(self, cli_root: str | Path | None = None) -> None:
|
||||||
self._locator = DifyCliLocator(root=cli_root)
|
self._locator = DifyCliLocator(root=cli_root)
|
||||||
self._tools: list[object] = []
|
self._tools: list[object] = []
|
||||||
self._cli_api_session = None
|
|
||||||
|
|
||||||
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
|
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
|
||||||
vm = sandbox.vm
|
vm = sandbox.vm
|
||||||
@ -57,7 +54,7 @@ class DifyCliInitializer(AsyncSandboxInitializer):
|
|||||||
logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id)
|
logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._cli_api_session = CliApiSessionManager().create(
|
global_cli_session = CliApiSessionManager().create(
|
||||||
tenant_id=ctx.tenant_id,
|
tenant_id=ctx.tenant_id,
|
||||||
user_id=ctx.user_id,
|
user_id=ctx.user_id,
|
||||||
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())),
|
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())),
|
||||||
@ -67,7 +64,7 @@ class DifyCliInitializer(AsyncSandboxInitializer):
|
|||||||
["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir"
|
["mkdir", "-p", cli.global_tools_path], error_message="Failed to create global tools dir"
|
||||||
).execute(raise_on_error=True)
|
).execute(raise_on_error=True)
|
||||||
|
|
||||||
config = DifyCliConfig.create(self._cli_api_session, ctx.tenant_id, bundle.get_tool_dependencies())
|
config = DifyCliConfig.create(global_cli_session, ctx.tenant_id, bundle.get_tool_dependencies())
|
||||||
config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False)
|
config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False)
|
||||||
config_path = cli.global_config_path
|
config_path = cli.global_config_path
|
||||||
vm.upload_file(config_path, BytesIO(config_json.encode("utf-8")))
|
vm.upload_file(config_path, BytesIO(config_json.encode("utf-8")))
|
||||||
|
|||||||
@ -5,8 +5,6 @@ from pathlib import PurePosixPath
|
|||||||
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
|
from core.sandbox.entities.files import SandboxFileDownloadTicket, SandboxFileNode
|
||||||
from core.sandbox.inspector.archive_source import SandboxFileArchiveSource
|
from core.sandbox.inspector.archive_source import SandboxFileArchiveSource
|
||||||
from core.sandbox.inspector.base import SandboxFileSource
|
from core.sandbox.inspector.base import SandboxFileSource
|
||||||
from core.sandbox.inspector.runtime_source import SandboxFileRuntimeSource
|
|
||||||
from core.sandbox.manager import SandboxManager
|
|
||||||
|
|
||||||
|
|
||||||
class SandboxFileBrowser:
|
class SandboxFileBrowser:
|
||||||
@ -31,14 +29,6 @@ class SandboxFileBrowser:
|
|||||||
return "." if normalized in (".", "") else normalized
|
return "." if normalized in (".", "") else normalized
|
||||||
|
|
||||||
def _backend(self) -> SandboxFileSource:
|
def _backend(self) -> SandboxFileSource:
|
||||||
sandbox = SandboxManager.get(self._sandbox_id)
|
|
||||||
if sandbox is not None:
|
|
||||||
return SandboxFileRuntimeSource(
|
|
||||||
tenant_id=self._tenant_id,
|
|
||||||
app_id=self._app_id,
|
|
||||||
sandbox_id=self._sandbox_id,
|
|
||||||
runtime=sandbox.vm,
|
|
||||||
)
|
|
||||||
return SandboxFileArchiveSource(
|
return SandboxFileArchiveSource(
|
||||||
tenant_id=self._tenant_id,
|
tenant_id=self._tenant_id,
|
||||||
app_id=self._app_id,
|
app_id=self._app_id,
|
||||||
|
|||||||
@ -1,103 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from typing import TYPE_CHECKING, Final
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.sandbox.sandbox import Sandbox
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SandboxManager:
|
|
||||||
"""Registry for active Sandbox instances.
|
|
||||||
|
|
||||||
Stores complete Sandbox objects (not just VirtualEnvironment) to provide
|
|
||||||
access to sandbox metadata like tenant_id, app_id, user_id, assets_id.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_NUM_SHARDS: Final[int] = 1024
|
|
||||||
_SHARD_MASK: Final[int] = _NUM_SHARDS - 1
|
|
||||||
|
|
||||||
_shard_locks: Final[tuple[threading.Lock, ...]] = tuple(threading.Lock() for _ in range(_NUM_SHARDS))
|
|
||||||
_shards: list[dict[str, Sandbox]] = [{} for _ in range(_NUM_SHARDS)]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _shard_index(cls, sandbox_id: str) -> int:
|
|
||||||
return hash(sandbox_id) & cls._SHARD_MASK
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, sandbox_id: str, sandbox: Sandbox) -> None:
|
|
||||||
if not sandbox_id:
|
|
||||||
raise ValueError("sandbox_id cannot be empty")
|
|
||||||
|
|
||||||
shard_index = cls._shard_index(sandbox_id)
|
|
||||||
with cls._shard_locks[shard_index]:
|
|
||||||
shard = cls._shards[shard_index]
|
|
||||||
if sandbox_id in shard:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Sandbox already registered for sandbox_id={sandbox_id}. "
|
|
||||||
"Call unregister() first if you need to replace it."
|
|
||||||
)
|
|
||||||
|
|
||||||
new_shard = dict(shard)
|
|
||||||
new_shard[sandbox_id] = sandbox
|
|
||||||
cls._shards[shard_index] = new_shard
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Registered sandbox: sandbox_id=%s, id=%s, app_id=%s",
|
|
||||||
sandbox_id,
|
|
||||||
sandbox.id,
|
|
||||||
sandbox.app_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls, sandbox_id: str) -> Sandbox | None:
|
|
||||||
shard_index = cls._shard_index(sandbox_id)
|
|
||||||
return cls._shards[shard_index].get(sandbox_id)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def unregister(cls, sandbox_id: str) -> Sandbox | None:
|
|
||||||
shard_index = cls._shard_index(sandbox_id)
|
|
||||||
with cls._shard_locks[shard_index]:
|
|
||||||
shard = cls._shards[shard_index]
|
|
||||||
sandbox = shard.get(sandbox_id)
|
|
||||||
if sandbox is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
new_shard = dict(shard)
|
|
||||||
new_shard.pop(sandbox_id, None)
|
|
||||||
cls._shards[shard_index] = new_shard
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Unregistered sandbox: sandbox_id=%s, id=%s",
|
|
||||||
sandbox_id,
|
|
||||||
sandbox.id,
|
|
||||||
)
|
|
||||||
return sandbox
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def has(cls, sandbox_id: str) -> bool:
|
|
||||||
shard_index = cls._shard_index(sandbox_id)
|
|
||||||
return sandbox_id in cls._shards[shard_index]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_sandbox_runtime(cls, sandbox_id: str) -> bool:
|
|
||||||
return cls.has(sandbox_id)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def clear(cls) -> None:
|
|
||||||
for lock in cls._shard_locks:
|
|
||||||
lock.acquire()
|
|
||||||
try:
|
|
||||||
for i in range(cls._NUM_SHARDS):
|
|
||||||
cls._shards[i] = {}
|
|
||||||
logger.debug("Cleared all registered sandboxes")
|
|
||||||
finally:
|
|
||||||
for lock in reversed(cls._shard_locks):
|
|
||||||
lock.release()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def count(cls) -> int:
|
|
||||||
return sum(len(shard) for shard in cls._shards)
|
|
||||||
@ -2,6 +2,7 @@ import secrets
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.skill.entities import ToolAccessPolicy
|
from core.skill.entities import ToolAccessPolicy
|
||||||
|
|
||||||
from .session import BaseSession, SessionManager
|
from .session import BaseSession, SessionManager
|
||||||
@ -17,7 +18,11 @@ class CliContext(BaseModel):
|
|||||||
|
|
||||||
class CliApiSessionManager(SessionManager[CliApiSession]):
|
class CliApiSessionManager(SessionManager[CliApiSession]):
|
||||||
def __init__(self, ttl: int | None = None):
|
def __init__(self, ttl: int | None = None):
|
||||||
super().__init__(key_prefix="cli_api_session", session_class=CliApiSession, ttl=ttl)
|
super().__init__(
|
||||||
|
key_prefix="cli_api_session",
|
||||||
|
session_class=CliApiSession,
|
||||||
|
ttl=ttl or dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||||
|
)
|
||||||
|
|
||||||
def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession:
|
def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession:
|
||||||
session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json"))
|
session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json"))
|
||||||
|
|||||||
@ -26,6 +26,9 @@ class CommandFuture:
|
|||||||
Lightweight future for command execution.
|
Lightweight future for command execution.
|
||||||
Mirrors concurrent.futures.Future API with 4 essential methods:
|
Mirrors concurrent.futures.Future API with 4 essential methods:
|
||||||
result(), done(), cancel(), cancelled().
|
result(), done(), cancel(), cancelled().
|
||||||
|
|
||||||
|
When a command is cancelled or times out the future now asks the provider
|
||||||
|
to terminate the underlying process/session before marking itself done.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -35,6 +38,7 @@ class CommandFuture:
|
|||||||
stdout_transport: TransportReadCloser,
|
stdout_transport: TransportReadCloser,
|
||||||
stderr_transport: TransportReadCloser,
|
stderr_transport: TransportReadCloser,
|
||||||
poll_status: Callable[[], CommandStatus],
|
poll_status: Callable[[], CommandStatus],
|
||||||
|
terminate_command: Callable[[], bool] | None = None,
|
||||||
poll_interval: float = 0.1,
|
poll_interval: float = 0.1,
|
||||||
):
|
):
|
||||||
self._pid = pid
|
self._pid = pid
|
||||||
@ -42,6 +46,7 @@ class CommandFuture:
|
|||||||
self._stdout_transport = stdout_transport
|
self._stdout_transport = stdout_transport
|
||||||
self._stderr_transport = stderr_transport
|
self._stderr_transport = stderr_transport
|
||||||
self._poll_status = poll_status
|
self._poll_status = poll_status
|
||||||
|
self._terminate_command = terminate_command
|
||||||
self._poll_interval = poll_interval
|
self._poll_interval = poll_interval
|
||||||
|
|
||||||
self._done_event = threading.Event()
|
self._done_event = threading.Event()
|
||||||
@ -49,7 +54,9 @@ class CommandFuture:
|
|||||||
self._result: CommandResult | None = None
|
self._result: CommandResult | None = None
|
||||||
self._exception: BaseException | None = None
|
self._exception: BaseException | None = None
|
||||||
self._cancelled = False
|
self._cancelled = False
|
||||||
|
self._timed_out = False
|
||||||
self._started = False
|
self._started = False
|
||||||
|
self._termination_requested = False
|
||||||
|
|
||||||
def result(self, timeout: float | None = None) -> CommandResult:
|
def result(self, timeout: float | None = None) -> CommandResult:
|
||||||
"""
|
"""
|
||||||
@ -61,15 +68,22 @@ class CommandFuture:
|
|||||||
Raises:
|
Raises:
|
||||||
CommandTimeoutError: If timeout exceeded.
|
CommandTimeoutError: If timeout exceeded.
|
||||||
CommandCancelledError: If command was cancelled.
|
CommandCancelledError: If command was cancelled.
|
||||||
|
|
||||||
|
A timeout is terminal for this future: it triggers best-effort command
|
||||||
|
termination and subsequent ``result()`` calls keep raising timeout.
|
||||||
"""
|
"""
|
||||||
self._ensure_started()
|
self._ensure_started()
|
||||||
|
|
||||||
if not self._done_event.wait(timeout):
|
if not self._done_event.wait(timeout):
|
||||||
|
self._request_stop(timed_out=True)
|
||||||
raise CommandTimeoutError(f"Command timed out after {timeout}s")
|
raise CommandTimeoutError(f"Command timed out after {timeout}s")
|
||||||
|
|
||||||
if self._cancelled:
|
if self._cancelled:
|
||||||
raise CommandCancelledError("Command was cancelled")
|
raise CommandCancelledError("Command was cancelled")
|
||||||
|
|
||||||
|
if self._timed_out:
|
||||||
|
raise CommandTimeoutError("Command timed out")
|
||||||
|
|
||||||
if self._exception is not None:
|
if self._exception is not None:
|
||||||
raise self._exception
|
raise self._exception
|
||||||
|
|
||||||
@ -82,16 +96,10 @@ class CommandFuture:
|
|||||||
|
|
||||||
def cancel(self) -> bool:
|
def cancel(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Attempt to cancel command by closing transports.
|
Attempt to cancel command by terminating it and closing transports.
|
||||||
Returns True if cancelled, False if already completed.
|
Returns True if cancelled, False if already completed.
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
return self._request_stop(cancelled=True)
|
||||||
if self._done_event.is_set():
|
|
||||||
return False
|
|
||||||
self._cancelled = True
|
|
||||||
self._close_transports()
|
|
||||||
self._done_event.set()
|
|
||||||
return True
|
|
||||||
|
|
||||||
def cancelled(self) -> bool:
|
def cancelled(self) -> bool:
|
||||||
return self._cancelled
|
return self._cancelled
|
||||||
@ -103,6 +111,28 @@ class CommandFuture:
|
|||||||
thread = threading.Thread(target=self._execute, daemon=True)
|
thread = threading.Thread(target=self._execute, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
|
def _request_stop(self, *, cancelled: bool = False, timed_out: bool = False) -> bool:
|
||||||
|
should_terminate = False
|
||||||
|
with self._lock:
|
||||||
|
if self._done_event.is_set():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if cancelled:
|
||||||
|
self._cancelled = True
|
||||||
|
if timed_out:
|
||||||
|
self._timed_out = True
|
||||||
|
|
||||||
|
should_terminate = not self._termination_requested
|
||||||
|
if should_terminate:
|
||||||
|
self._termination_requested = True
|
||||||
|
|
||||||
|
self._close_transports()
|
||||||
|
self._done_event.set()
|
||||||
|
|
||||||
|
if should_terminate:
|
||||||
|
self._terminate_running_command()
|
||||||
|
return True
|
||||||
|
|
||||||
def _execute(self) -> None:
|
def _execute(self) -> None:
|
||||||
stdout_buf = bytearray()
|
stdout_buf = bytearray()
|
||||||
stderr_buf = bytearray()
|
stderr_buf = bytearray()
|
||||||
@ -141,7 +171,7 @@ class CommandFuture:
|
|||||||
self._close_transports()
|
self._close_transports()
|
||||||
|
|
||||||
def _wait_for_completion(self) -> int | None:
|
def _wait_for_completion(self) -> int | None:
|
||||||
while not self._cancelled:
|
while not self._cancelled and not self._timed_out:
|
||||||
try:
|
try:
|
||||||
status = self._poll_status()
|
status = self._poll_status()
|
||||||
except NotSupportedOperationError:
|
except NotSupportedOperationError:
|
||||||
@ -167,3 +197,12 @@ class CommandFuture:
|
|||||||
for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport):
|
for transport in (self._stdin_transport, self._stdout_transport, self._stderr_transport):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
transport.close()
|
transport.close()
|
||||||
|
|
||||||
|
def _terminate_running_command(self) -> None:
|
||||||
|
if self._terminate_command is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._terminate_command()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to terminate command for pid %s", self._pid)
|
||||||
|
|||||||
@ -71,6 +71,7 @@ def submit_command(
|
|||||||
stdout_transport=stdout_transport,
|
stdout_transport=stdout_transport,
|
||||||
stderr_transport=stderr_transport,
|
stderr_transport=stderr_transport,
|
||||||
poll_status=partial(env.get_command_status, connection, pid),
|
poll_status=partial(env.get_command_status, connection, pid),
|
||||||
|
terminate_command=partial(env.terminate_command, connection, pid),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -11,8 +11,19 @@ from core.virtual_environment.channel.transport import TransportReadCloser, Tran
|
|||||||
class VirtualEnvironment(ABC):
|
class VirtualEnvironment(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for virtual environment implementations.
|
Base class for virtual environment implementations.
|
||||||
|
|
||||||
|
``VirtualEnvironment`` instances are configured at construction time but do
|
||||||
|
not allocate provider resources until ``open_enviroment()`` is called.
|
||||||
|
This keeps object construction side-effect free and gives callers a chance
|
||||||
|
to own startup error handling explicitly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
user_id: str | None
|
||||||
|
options: Mapping[str, Any]
|
||||||
|
_environments: Mapping[str, str]
|
||||||
|
_metadata: Metadata | None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
@ -21,19 +32,45 @@ class VirtualEnvironment(ABC):
|
|||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the virtual environment with metadata.
|
Initialize the virtual environment configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: The tenant ID associated with this environment (required).
|
tenant_id: The tenant ID associated with this environment (required).
|
||||||
options: Provider-specific configuration options.
|
options: Provider-specific configuration options.
|
||||||
environments: Environment variables to set in the virtual environment.
|
environments: Environment variables to set in the virtual environment.
|
||||||
user_id: The user ID associated with this environment (optional).
|
user_id: The user ID associated with this environment (optional).
|
||||||
|
|
||||||
|
The provider runtime itself is created later by ``open_enviroment()``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.options = options
|
self.options = options
|
||||||
self.metadata = self._construct_environment(options, environments or {})
|
self._environments = dict(environments or {})
|
||||||
|
self._metadata = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> Metadata:
|
||||||
|
"""Provider metadata for a started environment.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the environment has not been started yet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._metadata is None:
|
||||||
|
raise RuntimeError("Virtual environment has not been started")
|
||||||
|
return self._metadata
|
||||||
|
|
||||||
|
def open_enviroment(self) -> Metadata:
|
||||||
|
"""Allocate provider resources and return the resulting metadata.
|
||||||
|
|
||||||
|
Multiple calls are safe and return the existing metadata after the first
|
||||||
|
successful start.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._metadata is None:
|
||||||
|
self._metadata = self._construct_environment(self.options, self._environments)
|
||||||
|
return self._metadata
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||||
@ -131,6 +168,18 @@ class VirtualEnvironment(ABC):
|
|||||||
Multiple calls to `release_environment` with the same `environment_id` is acceptable.
|
Multiple calls to `release_environment` with the same `environment_id` is acceptable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
|
||||||
|
"""Best-effort termination hook for a running command.
|
||||||
|
|
||||||
|
Providers that can map ``pid`` back to a real process/session should
|
||||||
|
override this method and stop the command. The default implementation is
|
||||||
|
a no-op so providers without a termination mechanism remain compatible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ = connection_handle
|
||||||
|
_ = pid
|
||||||
|
return False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_command(
|
def execute_command(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import posixpath
|
import posixpath
|
||||||
import shlex
|
import shlex
|
||||||
import threading
|
import threading
|
||||||
@ -32,6 +33,8 @@ from core.virtual_environment.channel.transport import (
|
|||||||
)
|
)
|
||||||
from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS
|
from core.virtual_environment.constants import COMMAND_EXECUTION_TIMEOUT_SECONDS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@ -132,35 +135,53 @@ class E2BEnvironment(VirtualEnvironment):
|
|||||||
The sandbox lifetime is capped by ``WORKFLOW_MAX_EXECUTION_TIME`` so the
|
The sandbox lifetime is capped by ``WORKFLOW_MAX_EXECUTION_TIME`` so the
|
||||||
provider can rely on E2B's native timeout instead of a background
|
provider can rely on E2B's native timeout instead of a background
|
||||||
keepalive thread that continuously extends the session.
|
keepalive thread that continuously extends the session.
|
||||||
|
|
||||||
|
E2B allocates the remote sandbox before metadata probing completes, so
|
||||||
|
startup failures must best-effort terminate the sandbox before the
|
||||||
|
exception escapes.
|
||||||
"""
|
"""
|
||||||
# Import E2B SDK lazily so it is loaded after gevent monkey-patching.
|
# Import E2B SDK lazily so it is loaded after gevent monkey-patching.
|
||||||
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
|
from e2b_code_interpreter import Sandbox # type: ignore[import-untyped]
|
||||||
|
|
||||||
# TODO: add Dify as the user agent
|
# TODO: add Dify as the user agent
|
||||||
sandbox = Sandbox.create(
|
sandbox = None
|
||||||
template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"),
|
sandbox_id: str | None = None
|
||||||
timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
api_key = options.get(self.OptionsKey.API_KEY, "")
|
||||||
api_key=options.get(self.OptionsKey.API_KEY, ""),
|
try:
|
||||||
api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL),
|
sandbox = Sandbox.create(
|
||||||
envs=dict(environments),
|
template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"),
|
||||||
)
|
timeout=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||||
info = sandbox.get_info(api_key=options.get(self.OptionsKey.API_KEY, ""))
|
api_key=api_key,
|
||||||
system_info = sandbox.commands.run("uname -m -s").stdout.strip()
|
api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL),
|
||||||
system_parts = system_info.split()
|
envs=dict(environments),
|
||||||
if len(system_parts) == 2:
|
)
|
||||||
os_part, arch_part = system_parts
|
info = sandbox.get_info(api_key=api_key)
|
||||||
else:
|
sandbox_id = info.sandbox_id
|
||||||
arch_part = system_parts[0]
|
system_info = sandbox.commands.run("uname -m -s").stdout.strip()
|
||||||
os_part = system_parts[1] if len(system_parts) > 1 else ""
|
system_parts = system_info.split()
|
||||||
|
if len(system_parts) == 2:
|
||||||
|
os_part, arch_part = system_parts
|
||||||
|
else:
|
||||||
|
arch_part = system_parts[0]
|
||||||
|
os_part = system_parts[1] if len(system_parts) > 1 else ""
|
||||||
|
|
||||||
return Metadata(
|
return Metadata(
|
||||||
id=info.sandbox_id,
|
id=info.sandbox_id,
|
||||||
arch=self._convert_architecture(arch_part.strip()),
|
arch=self._convert_architecture(arch_part.strip()),
|
||||||
os=self._convert_operating_system(os_part.strip()),
|
os=self._convert_operating_system(os_part.strip()),
|
||||||
store={
|
store={
|
||||||
self.StoreKey.SANDBOX: sandbox,
|
self.StoreKey.SANDBOX: sandbox,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
except Exception:
|
||||||
|
if sandbox_id is None and sandbox is not None:
|
||||||
|
sandbox_id = getattr(sandbox, "sandbox_id", None)
|
||||||
|
if sandbox_id is not None:
|
||||||
|
try:
|
||||||
|
Sandbox.kill(api_key=api_key, sandbox_id=sandbox_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to cleanup E2B sandbox after startup failure")
|
||||||
|
raise
|
||||||
|
|
||||||
def release_environment(self) -> None:
|
def release_environment(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
@ -246,6 +247,16 @@ class LocalVirtualEnvironment(VirtualEnvironment):
|
|||||||
except ChildProcessError:
|
except ChildProcessError:
|
||||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||||
|
|
||||||
|
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
|
||||||
|
"""Terminate a locally spawned process by PID when cancellation is requested."""
|
||||||
|
|
||||||
|
_ = connection_handle
|
||||||
|
try:
|
||||||
|
os.kill(int(pid), signal.SIGTERM)
|
||||||
|
except ProcessLookupError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _get_os_architecture(self) -> Arch:
|
def _get_os_architecture(self) -> Arch:
|
||||||
"""
|
"""
|
||||||
Get the operating system architecture.
|
Get the operating system architecture.
|
||||||
|
|||||||
@ -76,6 +76,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._connections: dict[str, Any] = {}
|
self._connections: dict[str, Any] = {}
|
||||||
self._commands: dict[str, CommandStatus] = {}
|
self._commands: dict[str, CommandStatus] = {}
|
||||||
|
self._command_channels: dict[str, Any] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
super().__init__(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id)
|
super().__init__(tenant_id=tenant_id, options=options, environments=environments, user_id=user_id)
|
||||||
|
|
||||||
@ -163,6 +164,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
|||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._commands[pid] = CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
self._commands[pid] = CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||||
|
self._command_channels[pid] = channel
|
||||||
|
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self._consume_channel_output,
|
target=self._consume_channel_output,
|
||||||
@ -179,6 +181,23 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
|||||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
def terminate_command(self, connection_handle: ConnectionHandle, pid: str) -> bool:
|
||||||
|
"""Best-effort termination by closing the SSH channel that owns the command."""
|
||||||
|
|
||||||
|
_ = connection_handle
|
||||||
|
with self._lock:
|
||||||
|
channel = self._command_channels.get(pid)
|
||||||
|
if channel is None:
|
||||||
|
return False
|
||||||
|
self._commands[pid] = CommandStatus(
|
||||||
|
status=CommandStatus.Status.COMPLETED,
|
||||||
|
exit_code=self._COMMAND_TIMEOUT_EXIT_CODE,
|
||||||
|
)
|
||||||
|
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
channel.close()
|
||||||
|
return True
|
||||||
|
|
||||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||||
destination_path = self._workspace_path(path)
|
destination_path = self._workspace_path(path)
|
||||||
with self._client() as client:
|
with self._client() as client:
|
||||||
@ -424,6 +443,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
|||||||
channel.close()
|
channel.close()
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
self._command_channels.pop(pid, None)
|
||||||
self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||||
|
|
||||||
def _set_sftp_operation_timeout(self, sftp: Any) -> None:
|
def _set_sftp_operation_timeout(self, sftp: Any) -> None:
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from uuid import uuid4
|
|||||||
|
|
||||||
from core.sandbox.builder import SandboxBuilder
|
from core.sandbox.builder import SandboxBuilder
|
||||||
from core.sandbox.entities.sandbox_type import SandboxType
|
from core.sandbox.entities.sandbox_type import SandboxType
|
||||||
from core.sandbox.manager import SandboxManager
|
|
||||||
from core.sandbox.sandbox import Sandbox
|
from core.sandbox.sandbox import Sandbox
|
||||||
from core.sandbox.storage.noop_storage import NoopSandboxStorage
|
from core.sandbox.storage.noop_storage import NoopSandboxStorage
|
||||||
from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError
|
from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError
|
||||||
@ -100,26 +99,29 @@ class ZipSandbox:
|
|||||||
self._sandbox_id = uuid4().hex
|
self._sandbox_id = uuid4().hex
|
||||||
|
|
||||||
storage = NoopSandboxStorage()
|
storage = NoopSandboxStorage()
|
||||||
self._sandbox = (
|
try:
|
||||||
SandboxBuilder(self._tenant_id, SandboxType(provider_type))
|
self._sandbox = (
|
||||||
.options(provider_options)
|
SandboxBuilder(self._tenant_id, SandboxType(provider_type))
|
||||||
.user(self._user_id)
|
.options(provider_options)
|
||||||
.app(self._app_id)
|
.user(self._user_id)
|
||||||
.storage(storage, assets_id="zip-sandbox")
|
.app(self._app_id)
|
||||||
.build()
|
.storage(storage, assets_id="zip-sandbox")
|
||||||
)
|
.build()
|
||||||
self._sandbox.wait_ready(timeout=60)
|
)
|
||||||
self._vm = self._sandbox.vm
|
self._sandbox.wait_ready(timeout=60)
|
||||||
|
self._vm = self._sandbox.vm
|
||||||
SandboxManager.register(self._sandbox_id, self._sandbox)
|
except Exception:
|
||||||
|
if self._sandbox is not None:
|
||||||
|
self._sandbox.release()
|
||||||
|
self._vm = None
|
||||||
|
self._sandbox = None
|
||||||
|
self._sandbox_id = None
|
||||||
|
raise
|
||||||
|
|
||||||
def _stop(self) -> None:
|
def _stop(self) -> None:
|
||||||
if self._vm is None:
|
if self._vm is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self._sandbox_id:
|
|
||||||
SandboxManager.unregister(self._sandbox_id)
|
|
||||||
|
|
||||||
if self._sandbox is not None:
|
if self._sandbox is not None:
|
||||||
self._sandbox.release()
|
self._sandbox.release()
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import threading
|
import threading
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -18,6 +19,7 @@ def _make_future(
|
|||||||
exit_code: int = 0,
|
exit_code: int = 0,
|
||||||
delay_completion: float = 0,
|
delay_completion: float = 0,
|
||||||
close_streams: bool = True,
|
close_streams: bool = True,
|
||||||
|
terminate_command: Callable[[], bool] | None = None,
|
||||||
) -> CommandFuture:
|
) -> CommandFuture:
|
||||||
stdout_transport = QueueTransportReadCloser()
|
stdout_transport = QueueTransportReadCloser()
|
||||||
stderr_transport = QueueTransportReadCloser()
|
stderr_transport = QueueTransportReadCloser()
|
||||||
@ -48,6 +50,7 @@ def _make_future(
|
|||||||
stdout_transport=stdout_transport,
|
stdout_transport=stdout_transport,
|
||||||
stderr_transport=stderr_transport,
|
stderr_transport=stderr_transport,
|
||||||
poll_status=poll_status,
|
poll_status=poll_status,
|
||||||
|
terminate_command=terminate_command,
|
||||||
poll_interval=0.05,
|
poll_interval=0.05,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -78,6 +81,21 @@ def test_result_raises_timeout_error_when_exceeded():
|
|||||||
future.result(timeout=0.2)
|
future.result(timeout=0.2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeout_requests_command_termination():
|
||||||
|
terminated = threading.Event()
|
||||||
|
|
||||||
|
future = _make_future(
|
||||||
|
delay_completion=10.0,
|
||||||
|
close_streams=False,
|
||||||
|
terminate_command=lambda: terminated.set() or True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(CommandTimeoutError):
|
||||||
|
future.result(timeout=0.2)
|
||||||
|
|
||||||
|
assert terminated.wait(timeout=1.0)
|
||||||
|
|
||||||
|
|
||||||
def test_done_returns_false_while_running():
|
def test_done_returns_false_while_running():
|
||||||
future = _make_future(delay_completion=10.0, close_streams=False)
|
future = _make_future(delay_completion=10.0, close_streams=False)
|
||||||
|
|
||||||
@ -115,6 +133,19 @@ def test_result_raises_cancelled_error_after_cancel():
|
|||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_requests_command_termination():
|
||||||
|
terminated = threading.Event()
|
||||||
|
|
||||||
|
future = _make_future(
|
||||||
|
delay_completion=10.0,
|
||||||
|
close_streams=False,
|
||||||
|
terminate_command=lambda: terminated.set() or True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert future.cancel() is True
|
||||||
|
assert terminated.wait(timeout=1.0)
|
||||||
|
|
||||||
|
|
||||||
def test_nonzero_exit_code_is_returned():
|
def test_nonzero_exit_code_is_returned():
|
||||||
future = _make_future(stdout=b"err", exit_code=42)
|
future = _make_future(stdout=b"err", exit_code=42)
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
from core.virtual_environment.__base.entities import (
|
from core.virtual_environment.__base.entities import (
|
||||||
Arch,
|
Arch,
|
||||||
CommandStatus,
|
CommandStatus,
|
||||||
@ -64,6 +65,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
|||||||
self._establish_count = 0
|
self._establish_count = 0
|
||||||
self._release_count = 0
|
self._release_count = 0
|
||||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||||
|
self.open_enviroment()
|
||||||
|
|
||||||
def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata:
|
def _construct_environment(self, _options: Mapping[str, Any], _environments: Mapping[str, str]) -> Metadata:
|
||||||
return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX)
|
return Metadata(id="fake-id", arch=Arch.AMD64, os=OperatingSystem.LINUX)
|
||||||
@ -111,6 +113,10 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
|||||||
def validate(cls, _options: Mapping[str, Any]) -> None:
|
def validate(cls, _options: Mapping[str, Any]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class TestWithConnection:
|
class TestWithConnection:
|
||||||
def test_connection_established_and_released(self):
|
def test_connection_established_and_released(self):
|
||||||
|
|||||||
@ -26,7 +26,9 @@ def _drain_transport(transport: TransportReadCloser) -> bytes:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
|
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
|
||||||
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
|
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
|
||||||
return LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)})
|
env = LocalVirtualEnvironment(tenant_id="test-tenant", options={"base_working_path": str(tmp_path)})
|
||||||
|
env.open_enviroment()
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):
|
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):
|
||||||
|
|||||||
@ -42,6 +42,7 @@ class FakeVirtualEnvironment(VirtualEnvironment):
|
|||||||
self.last_execute_cwd: str | None = None
|
self.last_execute_cwd: str | None = None
|
||||||
self.released_connections: list[str] = []
|
self.released_connections: list[str] = []
|
||||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||||
|
self.open_enviroment()
|
||||||
|
|
||||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||||
return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX)
|
return Metadata(id="fake", arch=Arch.ARM64, os=OperatingSystem.LINUX)
|
||||||
|
|||||||
Reference in New Issue
Block a user