mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 18:06:14 +08:00
feat(sandbox): refactor sandbox file handling to include app_id
- Updated API routes to use app_id instead of sandbox_id for file operations, aligning with user-specific sandbox workspaces. - Enhanced SandboxFileService and related classes to accommodate app_id in file listing and download functionalities. - Refactored storage key generation for sandbox archives to include app_id, ensuring proper file organization. - Adjusted frontend contracts and services to reflect the new app_id parameter in API calls.
This commit is contained in:
@ -269,7 +269,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
SandboxService.delete_draft_storage(app_model.tenant_id, current_user.id)
|
||||
SandboxService.delete_draft_storage(app_model.tenant_id, app_model.id, current_user.id)
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
|
||||
@ -51,20 +51,28 @@ sandbox_file_node_model = console_ns.model("SandboxFileNode", SANDBOX_FILE_NODE_
|
||||
sandbox_file_download_ticket_model = console_ns.model("SandboxFileDownloadTicket", SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS)
|
||||
|
||||
|
||||
@console_ns.route("/sandboxes/<string:sandbox_id>/files")
|
||||
@console_ns.route("/apps/<string:app_id>/sandbox/files")
|
||||
class SandboxFilesApi(Resource):
|
||||
"""List sandbox files for the current user.
|
||||
|
||||
The sandbox_id is derived from the current user's ID, as each user has
|
||||
their own sandbox workspace per app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[SandboxFileListQuery.__name__])
|
||||
@console_ns.marshal_list_with(sandbox_file_node_model)
|
||||
def get(self, sandbox_id: str):
|
||||
def get(self, app_id: str):
|
||||
args = SandboxFileListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore[arg-type]
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
sandbox_id = account.id
|
||||
return [
|
||||
e.__dict__
|
||||
for e in SandboxFileService.list_files(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
sandbox_id=sandbox_id,
|
||||
path=args.path,
|
||||
recursive=args.recursive,
|
||||
@ -72,15 +80,24 @@ class SandboxFilesApi(Resource):
|
||||
]
|
||||
|
||||
|
||||
@console_ns.route("/sandboxes/<string:sandbox_id>/files/download")
|
||||
@console_ns.route("/apps/<string:app_id>/sandbox/files/download")
|
||||
class SandboxFileDownloadApi(Resource):
|
||||
"""Download a sandbox file for the current user.
|
||||
|
||||
The sandbox_id is derived from the current user's ID, as each user has
|
||||
their own sandbox workspace per app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[SandboxFileDownloadRequest.__name__])
|
||||
@console_ns.marshal_with(sandbox_file_download_ticket_model)
|
||||
def post(self, sandbox_id: str):
|
||||
def post(self, app_id: str):
|
||||
payload = SandboxFileDownloadRequest.model_validate(console_ns.payload or {})
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
res = SandboxFileService.download_file(tenant_id=tenant_id, sandbox_id=sandbox_id, path=payload.path)
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
sandbox_id = account.id
|
||||
res = SandboxFileService.download_file(
|
||||
tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id, path=payload.path
|
||||
)
|
||||
return res.__dict__
|
||||
|
||||
@ -68,7 +68,9 @@ print(json.dumps(entries))
|
||||
"""Get a pre-signed download URL for the sandbox archive."""
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
storage_key = f"sandbox_archives/{self._tenant_id}/{self._sandbox_id}.tar.gz"
|
||||
storage_key = SandboxFilePaths.archive(self._tenant_id, self._app_id, self._sandbox_id)
|
||||
if not storage.exists(storage_key):
|
||||
raise ValueError("Sandbox archive not found")
|
||||
presign_storage = FilePresignStorage(storage.storage_runner)
|
||||
return presign_storage.get_download_url(storage_key, self._EXPORT_EXPIRES_IN_SECONDS)
|
||||
|
||||
@ -76,11 +78,11 @@ print(json.dumps(entries))
|
||||
"""Create a ZipSandbox instance for archive operations."""
|
||||
from core.zip_sandbox import ZipSandbox
|
||||
|
||||
return ZipSandbox(tenant_id=self._tenant_id, user_id="system", app_id="sandbox-archive-browser")
|
||||
return ZipSandbox(tenant_id=self._tenant_id, user_id="system", app_id=self._app_id)
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if the sandbox archive exists in storage."""
|
||||
storage_key = f"sandbox_archives/{self._tenant_id}/{self._sandbox_id}.tar.gz"
|
||||
storage_key = SandboxFilePaths.archive(self._tenant_id, self._app_id, self._sandbox_id)
|
||||
return storage.exists(storage_key)
|
||||
|
||||
def list_files(self, *, path: str, recursive: bool) -> list[SandboxFileNode]:
|
||||
@ -195,6 +197,7 @@ raise SystemExit(2)
|
||||
file_data = zs.read_file(target_path)
|
||||
export_key = SandboxFilePaths.export(
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._sandbox_id,
|
||||
export_id,
|
||||
os.path.basename(path) or "file",
|
||||
@ -206,6 +209,7 @@ raise SystemExit(2)
|
||||
tar_data = zs.read_file(tar_file.path)
|
||||
export_key = SandboxFilePaths.export(
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._sandbox_id,
|
||||
export_id,
|
||||
f"{export_name}.tar.gz",
|
||||
|
||||
@ -10,8 +10,9 @@ class SandboxFileSource(abc.ABC):
|
||||
_UPLOAD_TIMEOUT_SECONDS = 60 * 10
|
||||
_EXPORT_EXPIRES_IN_SECONDS = 60 * 10
|
||||
|
||||
def __init__(self, *, tenant_id: str, sandbox_id: str):
|
||||
def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str):
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._sandbox_id = sandbox_id
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@ -10,8 +10,9 @@ from core.sandbox.manager import SandboxManager
|
||||
|
||||
|
||||
class SandboxFileBrowser:
|
||||
def __init__(self, *, tenant_id: str, sandbox_id: str):
|
||||
def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str):
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._sandbox_id = sandbox_id
|
||||
|
||||
@staticmethod
|
||||
@ -30,10 +31,19 @@ class SandboxFileBrowser:
|
||||
return "." if normalized in (".", "") else normalized
|
||||
|
||||
def _backend(self) -> SandboxFileSource:
|
||||
runtime = SandboxManager.get(self._sandbox_id)
|
||||
if runtime is not None:
|
||||
return SandboxFileRuntimeSource(tenant_id=self._tenant_id, sandbox_id=self._sandbox_id, runtime=runtime)
|
||||
return SandboxFileArchiveSource(tenant_id=self._tenant_id, sandbox_id=self._sandbox_id)
|
||||
sandbox = SandboxManager.get(self._sandbox_id)
|
||||
if sandbox is not None:
|
||||
return SandboxFileRuntimeSource(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
sandbox_id=self._sandbox_id,
|
||||
runtime=sandbox.vm,
|
||||
)
|
||||
return SandboxFileArchiveSource(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
sandbox_id=self._sandbox_id,
|
||||
)
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if the sandbox source exists and is available."""
|
||||
|
||||
@ -16,8 +16,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxFileRuntimeSource(SandboxFileSource):
|
||||
def __init__(self, *, tenant_id: str, sandbox_id: str, runtime: VirtualEnvironment):
|
||||
super().__init__(tenant_id=tenant_id, sandbox_id=sandbox_id)
|
||||
def __init__(self, *, tenant_id: str, app_id: str, sandbox_id: str, runtime: VirtualEnvironment):
|
||||
super().__init__(tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id)
|
||||
self._runtime = runtime
|
||||
|
||||
def exists(self) -> bool:
|
||||
@ -122,6 +122,7 @@ print(json.dumps(entries))
|
||||
export_id = uuid4().hex
|
||||
export_key = SandboxFilePaths.export(
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._sandbox_id,
|
||||
export_id,
|
||||
filename,
|
||||
|
||||
@ -2,81 +2,89 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Final
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
if TYPE_CHECKING:
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxManager:
|
||||
"""Registry for active Sandbox instances.
|
||||
|
||||
Stores complete Sandbox objects (not just VirtualEnvironment) to provide
|
||||
access to sandbox metadata like tenant_id, app_id, user_id, assets_id.
|
||||
"""
|
||||
|
||||
_NUM_SHARDS: Final[int] = 1024
|
||||
_SHARD_MASK: Final[int] = _NUM_SHARDS - 1
|
||||
|
||||
_shard_locks: Final[tuple[threading.Lock, ...]] = tuple(threading.Lock() for _ in range(_NUM_SHARDS))
|
||||
_shards: list[dict[str, VirtualEnvironment]] = [{} for _ in range(_NUM_SHARDS)]
|
||||
_shards: list[dict[str, Sandbox]] = [{} for _ in range(_NUM_SHARDS)]
|
||||
|
||||
@classmethod
|
||||
def _shard_index(cls, workflow_execution_id: str) -> int:
|
||||
return hash(workflow_execution_id) & cls._SHARD_MASK
|
||||
def _shard_index(cls, sandbox_id: str) -> int:
|
||||
return hash(sandbox_id) & cls._SHARD_MASK
|
||||
|
||||
@classmethod
|
||||
def register(cls, workflow_execution_id: str, sandbox: VirtualEnvironment) -> None:
|
||||
if not workflow_execution_id:
|
||||
raise ValueError("workflow_execution_id cannot be empty")
|
||||
def register(cls, sandbox_id: str, sandbox: Sandbox) -> None:
|
||||
if not sandbox_id:
|
||||
raise ValueError("sandbox_id cannot be empty")
|
||||
|
||||
shard_index = cls._shard_index(workflow_execution_id)
|
||||
shard_index = cls._shard_index(sandbox_id)
|
||||
with cls._shard_locks[shard_index]:
|
||||
shard = cls._shards[shard_index]
|
||||
if workflow_execution_id in shard:
|
||||
if sandbox_id in shard:
|
||||
raise RuntimeError(
|
||||
f"Sandbox already registered for workflow_execution_id={workflow_execution_id}. "
|
||||
f"Sandbox already registered for sandbox_id={sandbox_id}. "
|
||||
"Call unregister() first if you need to replace it."
|
||||
)
|
||||
|
||||
new_shard = dict(shard)
|
||||
new_shard[workflow_execution_id] = sandbox
|
||||
new_shard[sandbox_id] = sandbox
|
||||
cls._shards[shard_index] = new_shard
|
||||
|
||||
logger.debug(
|
||||
"Registered sandbox for workflow_execution_id=%s, sandbox_id=%s",
|
||||
workflow_execution_id,
|
||||
sandbox.metadata.id,
|
||||
"Registered sandbox: sandbox_id=%s, vm_id=%s, app_id=%s",
|
||||
sandbox_id,
|
||||
sandbox.vm.metadata.id,
|
||||
sandbox.app_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get(cls, workflow_execution_id: str) -> VirtualEnvironment | None:
|
||||
shard_index = cls._shard_index(workflow_execution_id)
|
||||
return cls._shards[shard_index].get(workflow_execution_id)
|
||||
def get(cls, sandbox_id: str) -> Sandbox | None:
|
||||
shard_index = cls._shard_index(sandbox_id)
|
||||
return cls._shards[shard_index].get(sandbox_id)
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, workflow_execution_id: str) -> VirtualEnvironment | None:
|
||||
shard_index = cls._shard_index(workflow_execution_id)
|
||||
def unregister(cls, sandbox_id: str) -> Sandbox | None:
|
||||
shard_index = cls._shard_index(sandbox_id)
|
||||
with cls._shard_locks[shard_index]:
|
||||
shard = cls._shards[shard_index]
|
||||
sandbox = shard.get(workflow_execution_id)
|
||||
sandbox = shard.get(sandbox_id)
|
||||
if sandbox is None:
|
||||
return None
|
||||
|
||||
new_shard = dict(shard)
|
||||
new_shard.pop(workflow_execution_id, None)
|
||||
new_shard.pop(sandbox_id, None)
|
||||
cls._shards[shard_index] = new_shard
|
||||
|
||||
logger.debug(
|
||||
"Unregistered sandbox for workflow_execution_id=%s, sandbox_id=%s",
|
||||
workflow_execution_id,
|
||||
sandbox.metadata.id,
|
||||
"Unregistered sandbox: sandbox_id=%s, vm_id=%s",
|
||||
sandbox_id,
|
||||
sandbox.vm.metadata.id,
|
||||
)
|
||||
return sandbox
|
||||
|
||||
@classmethod
|
||||
def has(cls, workflow_execution_id: str) -> bool:
|
||||
shard_index = cls._shard_index(workflow_execution_id)
|
||||
return workflow_execution_id in cls._shards[shard_index]
|
||||
def has(cls, sandbox_id: str) -> bool:
|
||||
shard_index = cls._shard_index(sandbox_id)
|
||||
return sandbox_id in cls._shards[shard_index]
|
||||
|
||||
@classmethod
|
||||
def is_sandbox_runtime(cls, workflow_execution_id: str) -> bool:
|
||||
return cls.has(workflow_execution_id)
|
||||
def is_sandbox_runtime(cls, sandbox_id: str) -> bool:
|
||||
return cls.has(sandbox_id)
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
|
||||
@ -11,6 +11,7 @@ from extensions.storage.base_storage import BaseStorage
|
||||
from extensions.storage.cached_presign_storage import CachedPresignStorage
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
from .sandbox_file_storage import SandboxFilePaths
|
||||
from .sandbox_storage import SandboxStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -24,13 +25,14 @@ class ArchiveSandboxStorage(SandboxStorage):
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
sandbox_id: str,
|
||||
storage: BaseStorage,
|
||||
exclude_patterns: list[str] | None = None,
|
||||
):
|
||||
self._sandbox_id = sandbox_id
|
||||
self._exclude_patterns = exclude_patterns or []
|
||||
self._storage_key = f"sandbox_archives/{tenant_id}/{sandbox_id}.tar.gz"
|
||||
self._storage_key = SandboxFilePaths.archive(tenant_id, app_id, sandbox_id)
|
||||
self._storage = CachedPresignStorage(
|
||||
storage=FilePresignStorage(storage),
|
||||
cache_key_prefix="sandbox_archives",
|
||||
|
||||
@ -1,18 +1,21 @@
|
||||
"""Sandbox file storage key generation.
|
||||
|
||||
Provides SandboxFilePaths facade for generating storage keys for sandbox file exports.
|
||||
Provides SandboxFilePaths facade for generating storage keys for sandbox files.
|
||||
Storage instances are obtained via SandboxFileService.get_storage().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
_BASE = "sandbox_files"
|
||||
|
||||
|
||||
class SandboxFilePaths:
|
||||
"""Facade for generating sandbox file export storage keys."""
|
||||
"""Facade for generating sandbox file storage keys."""
|
||||
|
||||
@staticmethod
|
||||
def export(tenant_id: str, sandbox_id: str, export_id: str, filename: str) -> str:
|
||||
"""sandbox_files/{tenant}/{sandbox}/{export_id}/{filename}"""
|
||||
return f"{_BASE}/{tenant_id}/{sandbox_id}/{export_id}/{filename}"
|
||||
def export(tenant_id: str, app_id: str, sandbox_id: str, export_id: str, filename: str) -> str:
|
||||
"""sandbox_files/{tenant}/{app}/{sandbox}/{export_id}/{filename}"""
|
||||
return f"sandbox_files/{tenant_id}/{app_id}/{sandbox_id}/{export_id}/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def archive(tenant_id: str, app_id: str, sandbox_id: str) -> str:
|
||||
"""sandbox_archives/{tenant}/{app}/{sandbox}.tar.gz"""
|
||||
return f"sandbox_archives/{tenant_id}/{app_id}/{sandbox_id}.tar.gz"
|
||||
|
||||
@ -132,7 +132,7 @@ class ZipSandbox:
|
||||
self._sandbox.wait_ready(timeout=60)
|
||||
self._vm = self._sandbox.vm
|
||||
|
||||
SandboxManager.register(self._sandbox_id, self._vm)
|
||||
SandboxManager.register(self._sandbox_id, self._sandbox)
|
||||
|
||||
def _stop(self) -> None:
|
||||
if self._vm is None:
|
||||
|
||||
@ -21,9 +21,9 @@ class SandboxFileService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def exists(cls, *, tenant_id: str, sandbox_id: str) -> bool:
|
||||
def exists(cls, *, tenant_id: str, app_id: str, sandbox_id: str) -> bool:
|
||||
"""Check if the sandbox source exists and is available."""
|
||||
browser = SandboxFileBrowser(tenant_id=tenant_id, sandbox_id=sandbox_id)
|
||||
browser = SandboxFileBrowser(tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id)
|
||||
return browser.exists()
|
||||
|
||||
@classmethod
|
||||
@ -31,18 +31,19 @@ class SandboxFileService:
|
||||
cls,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
sandbox_id: str,
|
||||
path: str | None = None,
|
||||
recursive: bool = False,
|
||||
) -> list[SandboxFileNode]:
|
||||
browser = SandboxFileBrowser(tenant_id=tenant_id, sandbox_id=sandbox_id)
|
||||
browser = SandboxFileBrowser(tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id)
|
||||
if not browser.exists():
|
||||
return []
|
||||
return browser.list_files(path=path, recursive=recursive)
|
||||
|
||||
@classmethod
|
||||
def download_file(cls, *, tenant_id: str, sandbox_id: str, path: str) -> SandboxFileDownloadTicket:
|
||||
browser = SandboxFileBrowser(tenant_id=tenant_id, sandbox_id=sandbox_id)
|
||||
def download_file(cls, *, tenant_id: str, app_id: str, sandbox_id: str, path: str) -> SandboxFileDownloadTicket:
|
||||
browser = SandboxFileBrowser(tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id)
|
||||
if not browser.exists():
|
||||
raise ValueError("Sandbox source not found")
|
||||
return browser.download_file(path=path)
|
||||
|
||||
@ -31,7 +31,7 @@ class SandboxService:
|
||||
if not assets:
|
||||
raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}")
|
||||
|
||||
archive_storage = ArchiveSandboxStorage(tenant_id, workflow_execution_id, storage.storage_runner)
|
||||
archive_storage = ArchiveSandboxStorage(tenant_id, app_id, workflow_execution_id, storage.storage_runner)
|
||||
sandbox = (
|
||||
SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type))
|
||||
.options(sandbox_provider.config)
|
||||
@ -49,8 +49,10 @@ class SandboxService:
|
||||
return sandbox
|
||||
|
||||
@classmethod
|
||||
def delete_draft_storage(cls, tenant_id: str, user_id: str) -> None:
|
||||
archive_storage = ArchiveSandboxStorage(tenant_id, SandboxBuilder.draft_id(user_id), storage.storage_runner)
|
||||
def delete_draft_storage(cls, tenant_id: str, app_id: str, user_id: str) -> None:
|
||||
archive_storage = ArchiveSandboxStorage(
|
||||
tenant_id, app_id, SandboxBuilder.draft_id(user_id), storage.storage_runner
|
||||
)
|
||||
archive_storage.delete()
|
||||
|
||||
@classmethod
|
||||
@ -65,10 +67,12 @@ class SandboxService:
|
||||
if not assets:
|
||||
raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}")
|
||||
|
||||
SandboxService.delete_draft_storage(tenant_id, app_id, user_id)
|
||||
|
||||
AppAssetPackageService.build_assets(tenant_id, app_id, assets)
|
||||
sandbox_id = SandboxBuilder.draft_id(user_id)
|
||||
archive_storage = ArchiveSandboxStorage(
|
||||
tenant_id, sandbox_id, storage.storage_runner, exclude_patterns=[AppAssets.PATH]
|
||||
tenant_id, app_id, sandbox_id, storage.storage_runner, exclude_patterns=[AppAssets.PATH]
|
||||
)
|
||||
|
||||
sandbox = (
|
||||
@ -102,7 +106,7 @@ class SandboxService:
|
||||
AppAssetPackageService.build_assets(tenant_id, app_id, assets)
|
||||
sandbox_id = SandboxBuilder.draft_id(user_id)
|
||||
archive_storage = ArchiveSandboxStorage(
|
||||
tenant_id, sandbox_id, storage.storage_runner, exclude_patterns=[AppAssets.PATH]
|
||||
tenant_id, app_id, sandbox_id, storage.storage_runner, exclude_patterns=[AppAssets.PATH]
|
||||
)
|
||||
|
||||
sandbox = (
|
||||
|
||||
@ -2,10 +2,9 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.sandbox import SandboxManager
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
@ -24,7 +23,9 @@ from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
class FakeSandbox(VirtualEnvironment):
|
||||
class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
"""Fake VirtualEnvironment for testing CommandNode execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -98,17 +99,39 @@ class FakeSandbox(VirtualEnvironment):
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config_schema(cls) -> list[BasicProviderConfig]:
|
||||
return []
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_sandbox_manager():
|
||||
SandboxManager.clear()
|
||||
yield
|
||||
SandboxManager.clear()
|
||||
|
||||
def _make_mock_sandbox(vm: VirtualEnvironment) -> MagicMock:
|
||||
"""Create a mock Sandbox wrapping a VirtualEnvironment for testing."""
|
||||
sandbox = MagicMock()
|
||||
sandbox.vm = vm
|
||||
sandbox.tenant_id = "test-tenant"
|
||||
sandbox.app_id = "test-app"
|
||||
sandbox.user_id = "test-user"
|
||||
sandbox.assets_id = "test-assets"
|
||||
sandbox.wait_ready = MagicMock() # No-op for tests
|
||||
return sandbox
|
||||
|
||||
|
||||
def _make_node(
|
||||
*, command: str, working_directory: str = "", workflow_execution_id: str = "test-workflow-exec-id"
|
||||
*,
|
||||
command: str,
|
||||
working_directory: str = "",
|
||||
workflow_execution_id: str = "test-workflow-exec-id",
|
||||
vm: FakeVirtualEnvironment | None = None,
|
||||
) -> CommandNode:
|
||||
"""Create a CommandNode for testing.
|
||||
|
||||
Args:
|
||||
command: The shell command to execute.
|
||||
working_directory: Optional working directory for command execution.
|
||||
workflow_execution_id: Identifier for the workflow execution.
|
||||
vm: Optional FakeVirtualEnvironment. If provided, a mock Sandbox
|
||||
wrapping this VM will be set on the runtime state.
|
||||
"""
|
||||
system_variables = SystemVariable(workflow_execution_id=workflow_execution_id)
|
||||
variable_pool = VariablePool(system_variables=system_variables, user_inputs={})
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
@ -123,6 +146,10 @@ def _make_node(
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
if vm is not None:
|
||||
sandbox = _make_mock_sandbox(vm)
|
||||
runtime_state.set_sandbox(sandbox)
|
||||
|
||||
return CommandNode(
|
||||
id="node-instance",
|
||||
config={
|
||||
@ -139,17 +166,14 @@ def _make_node(
|
||||
|
||||
|
||||
def test_command_node_success_executes_in_sandbox():
|
||||
workflow_execution_id = "test-exec-success"
|
||||
vm = FakeVirtualEnvironment(stdout=b"ok\n", stderr=b"")
|
||||
node = _make_node(
|
||||
command="echo {{#pre_node_id.number#}}",
|
||||
working_directory="dir-{{#pre_node_id.number#}}",
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
vm=vm,
|
||||
)
|
||||
node.graph_runtime_state.variable_pool.add(("pre_node_id", "number"), 42)
|
||||
|
||||
sandbox = FakeSandbox(stdout=b"ok\n", stderr=b"")
|
||||
SandboxManager.register(workflow_execution_id, sandbox)
|
||||
|
||||
result = node._run() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
@ -157,20 +181,19 @@ def test_command_node_success_executes_in_sandbox():
|
||||
assert result.outputs["stderr"] == ""
|
||||
assert result.outputs["exit_code"] == 0
|
||||
|
||||
assert sandbox.last_execute_command is not None
|
||||
assert sandbox.last_execute_command == ["echo", "42"]
|
||||
assert sandbox.last_execute_cwd == "dir-42"
|
||||
assert vm.last_execute_command is not None
|
||||
# CommandNode wraps commands in bash -c
|
||||
assert vm.last_execute_command == ["bash", "-c", "echo 42"]
|
||||
assert vm.last_execute_cwd == "dir-42"
|
||||
|
||||
|
||||
def test_command_node_nonzero_exit_code_returns_failed_result():
|
||||
workflow_execution_id = "test-exec-nonzero"
|
||||
node = _make_node(command="false", workflow_execution_id=workflow_execution_id)
|
||||
sandbox = FakeSandbox(
|
||||
vm = FakeVirtualEnvironment(
|
||||
stdout=b"out",
|
||||
stderr=b"err",
|
||||
statuses=[CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=2)],
|
||||
)
|
||||
SandboxManager.register(workflow_execution_id, sandbox)
|
||||
node = _make_node(command="false", vm=vm)
|
||||
|
||||
result = node._run() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@ -184,15 +207,13 @@ def test_command_node_timeout_returns_failed_result_and_closes_transports(monkey
|
||||
|
||||
monkeypatch.setattr(command_node_module, "COMMAND_NODE_TIMEOUT_SECONDS", 1)
|
||||
|
||||
workflow_execution_id = "test-exec-timeout"
|
||||
node = _make_node(command="sleep 10", workflow_execution_id=workflow_execution_id)
|
||||
sandbox = FakeSandbox(
|
||||
vm = FakeVirtualEnvironment(
|
||||
stdout=b"",
|
||||
stderr=b"",
|
||||
statuses=[CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)] * 1000,
|
||||
close_streams=False,
|
||||
)
|
||||
SandboxManager.register(workflow_execution_id, sandbox)
|
||||
node = _make_node(command="sleep 10", vm=vm)
|
||||
|
||||
result = node._run() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@ -202,8 +223,7 @@ def test_command_node_timeout_returns_failed_result_and_closes_transports(monkey
|
||||
|
||||
|
||||
def test_command_node_no_sandbox_returns_failed():
|
||||
workflow_execution_id = "test-exec-no-sandbox"
|
||||
node = _make_node(command="echo hello", workflow_execution_id=workflow_execution_id)
|
||||
node = _make_node(command="echo hello")
|
||||
|
||||
result = node._run() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@ -9,22 +9,22 @@ import { base } from '../base'
|
||||
|
||||
export const listFilesContract = base
|
||||
.route({
|
||||
path: '/sandboxes/{sandboxId}/files',
|
||||
path: '/apps/{appId}/sandbox/files',
|
||||
method: 'GET',
|
||||
})
|
||||
.input(type<{
|
||||
params: { sandboxId: string }
|
||||
params: { appId: string }
|
||||
query?: SandboxFileListQuery
|
||||
}>())
|
||||
.output(type<SandboxFileNode[]>())
|
||||
|
||||
export const downloadFileContract = base
|
||||
.route({
|
||||
path: '/sandboxes/{sandboxId}/files/download',
|
||||
path: '/apps/{appId}/sandbox/files/download',
|
||||
method: 'POST',
|
||||
})
|
||||
.input(type<{
|
||||
params: { sandboxId: string }
|
||||
params: { appId: string }
|
||||
body: SandboxFileDownloadRequest
|
||||
}>())
|
||||
.output(type<SandboxFileDownloadTicket>())
|
||||
|
||||
@ -15,7 +15,7 @@ type UseGetSandboxFilesOptions = {
|
||||
}
|
||||
|
||||
export function useGetSandboxFiles(
|
||||
sandboxId: string | undefined,
|
||||
appId: string | undefined,
|
||||
options?: UseGetSandboxFilesOptions,
|
||||
) {
|
||||
const query: SandboxFileListQuery = {
|
||||
@ -25,38 +25,38 @@ export function useGetSandboxFiles(
|
||||
|
||||
return useQuery({
|
||||
queryKey: consoleQuery.sandboxFile.listFiles.queryKey({
|
||||
input: { params: { sandboxId: sandboxId! }, query },
|
||||
input: { params: { appId: appId! }, query },
|
||||
}),
|
||||
queryFn: () => consoleClient.sandboxFile.listFiles({
|
||||
params: { sandboxId: sandboxId! },
|
||||
params: { appId: appId! },
|
||||
query,
|
||||
}),
|
||||
enabled: !!sandboxId && (options?.enabled ?? true),
|
||||
enabled: !!appId && (options?.enabled ?? true),
|
||||
refetchInterval: options?.refetchInterval,
|
||||
})
|
||||
}
|
||||
|
||||
export function useSandboxFileDownloadUrl(
|
||||
sandboxId: string | undefined,
|
||||
appId: string | undefined,
|
||||
path: string | undefined,
|
||||
) {
|
||||
return useQuery({
|
||||
queryKey: ['sandboxFileDownloadUrl', sandboxId, path],
|
||||
queryKey: ['sandboxFileDownloadUrl', appId, path],
|
||||
queryFn: () => consoleClient.sandboxFile.downloadFile({
|
||||
params: { sandboxId: sandboxId! },
|
||||
params: { appId: appId! },
|
||||
body: { path: path! },
|
||||
}),
|
||||
enabled: !!sandboxId && !!path,
|
||||
enabled: !!appId && !!path,
|
||||
})
|
||||
}
|
||||
|
||||
export function useDownloadSandboxFile(sandboxId: string | undefined) {
|
||||
export function useDownloadSandboxFile(appId: string | undefined) {
|
||||
return useMutation({
|
||||
mutationFn: (path: string) => {
|
||||
if (!sandboxId)
|
||||
throw new Error('sandboxId is required')
|
||||
if (!appId)
|
||||
throw new Error('appId is required')
|
||||
return consoleClient.sandboxFile.downloadFile({
|
||||
params: { sandboxId },
|
||||
params: { appId },
|
||||
body: { path },
|
||||
})
|
||||
},
|
||||
@ -103,10 +103,10 @@ function buildTreeFromFlatList(nodes: SandboxFileNode[]): SandboxFileTreeNode[]
|
||||
}
|
||||
|
||||
export function useSandboxFilesTree(
|
||||
sandboxId: string | undefined,
|
||||
appId: string | undefined,
|
||||
options?: UseGetSandboxFilesOptions,
|
||||
) {
|
||||
const { data, isLoading, error, refetch } = useGetSandboxFiles(sandboxId, {
|
||||
const { data, isLoading, error, refetch } = useGetSandboxFiles(appId, {
|
||||
...options,
|
||||
recursive: true,
|
||||
})
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
/**
|
||||
* Sandbox file node from API (flat format)
|
||||
* Returned by GET /sandboxes/{sandbox_id}/files
|
||||
* Returned by GET /apps/{app_id}/sandbox/files
|
||||
*/
|
||||
export type SandboxFileNode = {
|
||||
/** Relative path (POSIX format), e.g. "folder/file.txt" */
|
||||
@ -23,7 +23,7 @@ export type SandboxFileNode = {
|
||||
}
|
||||
|
||||
/**
|
||||
* Download ticket returned by POST /sandboxes/{sandbox_id}/files/download
|
||||
* Download ticket returned by POST /apps/{app_id}/sandbox/files/download
|
||||
*/
|
||||
export type SandboxFileDownloadTicket = {
|
||||
/** Signed download URL */
|
||||
|
||||
Reference in New Issue
Block a user