mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
feat(sandbox): enhance sandbox initialization with draft support and asset management
- Introduced DraftAppAssetsInitializer for handling draft assets. - Updated SandboxLayer to conditionally set sandbox ID and storage based on workflow version. - Improved asset initialization logging and error handling. - Refactored ArchiveSandboxStorage to support exclusion patterns during archiving. - Modified command and LLM nodes to retrieve sandbox from workflow context, supporting draft workflows.
This commit is contained in:
@ -1,6 +1,8 @@
|
||||
import logging
|
||||
|
||||
from core.sandbox import AppAssetsInitializer, DifyCliInitializer, SandboxManager
|
||||
from core.sandbox.constants import APP_ASSETS_PATH
|
||||
from core.sandbox.initializer.app_assets_initializer import DraftAppAssetsInitializer
|
||||
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
|
||||
from core.sandbox.vm import SandboxBuilder
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@ -33,12 +35,11 @@ class SandboxLayer(GraphEngineLayer):
|
||||
self._user_id = user_id
|
||||
self._workflow_version = workflow_version
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
self._sandbox_id = (
|
||||
self._workflow_execution_id
|
||||
if self._workflow_version == Workflow.VERSION_DRAFT
|
||||
else SandboxBuilder.draft_id(self._user_id)
|
||||
is_draft = self._workflow_version == Workflow.VERSION_DRAFT
|
||||
self._sandbox_id = SandboxBuilder.draft_id(self._user_id) if is_draft else self._workflow_execution_id
|
||||
self._sandbox_storage = ArchiveSandboxStorage(
|
||||
self._tenant_id, self._sandbox_id, exclude_patterns=[APP_ASSETS_PATH] if is_draft else None
|
||||
)
|
||||
self._sandbox_storage = ArchiveSandboxStorage(self._tenant_id, self._sandbox_id)
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
try:
|
||||
@ -61,9 +62,15 @@ class SandboxLayer(GraphEngineLayer):
|
||||
)
|
||||
AppAssetService.build_assets(self._tenant_id, self._app_id, assets)
|
||||
|
||||
assets_initializer = (
|
||||
DraftAppAssetsInitializer(self._tenant_id, self._app_id, assets.id)
|
||||
if is_draft
|
||||
else AppAssetsInitializer(self._tenant_id, self._app_id, assets.id)
|
||||
)
|
||||
|
||||
builder = (
|
||||
SandboxProviderService.create_sandbox_builder(self._tenant_id)
|
||||
.initializer(AppAssetsInitializer(self._tenant_id, self._app_id, assets.id))
|
||||
.initializer(assets_initializer)
|
||||
.initializer(DifyCliInitializer(self._tenant_id, self._user_id, self._app_id, assets.id))
|
||||
)
|
||||
try:
|
||||
@ -78,12 +85,6 @@ class SandboxLayer(GraphEngineLayer):
|
||||
raise SandboxInitializationError(f"Failed to build sandbox: {e}") from e
|
||||
|
||||
SandboxManager.register(self._sandbox_id, sandbox)
|
||||
logger.info(
|
||||
"Sandbox initialized, workflow_execution_id=%s, sandbox_id=%s, sandbox_arch=%s",
|
||||
self._sandbox_id,
|
||||
sandbox.metadata.id,
|
||||
sandbox.metadata.arch,
|
||||
)
|
||||
|
||||
# mount sandbox files from storage
|
||||
mounted = self._sandbox_storage.mount(sandbox)
|
||||
|
||||
@ -18,7 +18,6 @@ from ..constants import (
|
||||
DIFY_CLI_PATH,
|
||||
DIFY_CLI_TOOLS_ROOT,
|
||||
)
|
||||
from ..manager import SandboxManager
|
||||
from .bash_tool import SandboxBashTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -28,7 +27,7 @@ class SandboxBashSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workflow_execution_id: str,
|
||||
sandbox: VirtualEnvironment,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
@ -36,7 +35,7 @@ class SandboxBashSession:
|
||||
assets_id: str,
|
||||
allow_tools: list[tuple[str, str]] | None,
|
||||
) -> None:
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
self._sandbox = sandbox
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._node_id = node_id
|
||||
@ -46,25 +45,18 @@ class SandboxBashSession:
|
||||
self._assets_id = assets_id
|
||||
self._allow_tools = allow_tools
|
||||
|
||||
self._sandbox = None
|
||||
self._bash_tool = None
|
||||
self._session_id = None
|
||||
|
||||
def __enter__(self) -> SandboxBashSession:
|
||||
sandbox = SandboxManager.get(self._workflow_execution_id)
|
||||
if sandbox is None:
|
||||
raise RuntimeError(f"Sandbox not found for workflow_execution_id={self._workflow_execution_id}")
|
||||
|
||||
self._sandbox = sandbox
|
||||
|
||||
if self._allow_tools is not None:
|
||||
if self._node_id is None:
|
||||
raise ValueError("node_id is required when allow_tools is specified")
|
||||
tools_path = self._setup_node_tools_directory(sandbox, self._node_id, self._allow_tools)
|
||||
tools_path = self._setup_node_tools_directory(self._sandbox, self._node_id, self._allow_tools)
|
||||
else:
|
||||
tools_path = DIFY_CLI_GLOBAL_TOOLS_PATH
|
||||
|
||||
self._bash_tool = SandboxBashTool(sandbox=sandbox, tenant_id=self._tenant_id, tools_path=tools_path)
|
||||
self._bash_tool = SandboxBashTool(sandbox=self._sandbox, tenant_id=self._tenant_id, tools_path=tools_path)
|
||||
return self
|
||||
|
||||
def _setup_node_tools_directory(
|
||||
|
||||
@ -6,7 +6,7 @@ from core.virtual_environment.__base.virtual_environment import VirtualEnvironme
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
from ..constants import APP_ASSETS_ZIP_PATH
|
||||
from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH
|
||||
from .base import SandboxInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -33,7 +33,36 @@ class AppAssetsInitializer(SandboxInitializer):
|
||||
["sh", "-c", f"unzip {APP_ASSETS_ZIP_PATH} 2>/dev/null || [ $? -eq 1 ]"],
|
||||
error_message="Failed to unzip assets",
|
||||
)
|
||||
.add(["rm", "-f", APP_ASSETS_ZIP_PATH], error_message="Failed to cleanup temp zip file")
|
||||
.execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"App assets initialized for app_id=%s, published_id=%s",
|
||||
self._app_id,
|
||||
self._assets_id,
|
||||
)
|
||||
|
||||
|
||||
class DraftAppAssetsInitializer(SandboxInitializer):
|
||||
def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._assets_id = assets_id
|
||||
|
||||
def initialize(self, env: VirtualEnvironment) -> None:
|
||||
zip_key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id)
|
||||
download_url = FilePresignStorage(storage.storage_runner).get_download_url(zip_key)
|
||||
|
||||
(
|
||||
pipeline(env)
|
||||
.add(["rm", "-rf", APP_ASSETS_PATH])
|
||||
.add(["wget", "-q", download_url, "-O", APP_ASSETS_ZIP_PATH], error_message="Failed to download assets zip")
|
||||
# unzip with silent error and return 1 if the zip is empty
|
||||
# FIXME(Mairuis): should use a more robust way to check if the zip is empty
|
||||
.add(
|
||||
["sh", "-c", f"unzip {APP_ASSETS_ZIP_PATH} 2>/dev/null || [ $? -eq 1 ]"],
|
||||
error_message="Failed to unzip assets",
|
||||
)
|
||||
.execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True)
|
||||
)
|
||||
|
||||
|
||||
@ -18,10 +18,19 @@ ARCHIVE_DOWNLOAD_TIMEOUT = 60 * 5
|
||||
ARCHIVE_UPLOAD_TIMEOUT = 60 * 5
|
||||
|
||||
|
||||
def build_tar_exclude_args(patterns: list[str]) -> list[str]:
|
||||
return [f"--exclude={p}" for p in patterns]
|
||||
|
||||
|
||||
class ArchiveSandboxStorage(SandboxStorage):
|
||||
def __init__(self, tenant_id: str, sandbox_id: str):
|
||||
_tenant_id: str
|
||||
_sandbox_id: str
|
||||
_exclude_patterns: list[str]
|
||||
|
||||
def __init__(self, tenant_id: str, sandbox_id: str, exclude_patterns: list[str] | None = None):
|
||||
self._tenant_id = tenant_id
|
||||
self._sandbox_id = sandbox_id
|
||||
self._exclude_patterns = exclude_patterns or []
|
||||
|
||||
@property
|
||||
def _storage_key(self) -> str:
|
||||
@ -36,7 +45,7 @@ class ArchiveSandboxStorage(SandboxStorage):
|
||||
try:
|
||||
(
|
||||
pipeline(sandbox)
|
||||
.add(["wget", download_url, "-O", ARCHIVE_NAME], error_message="Failed to download archive")
|
||||
.add(["wget", "-q", download_url, "-O", ARCHIVE_NAME], error_message="Failed to download archive")
|
||||
.add(["tar", "-xzf", ARCHIVE_NAME], error_message="Failed to extract archive")
|
||||
.add(["rm", ARCHIVE_NAME], error_message="Failed to cleanup archive")
|
||||
.execute(timeout=ARCHIVE_DOWNLOAD_TIMEOUT, raise_on_error=True)
|
||||
@ -53,10 +62,22 @@ class ArchiveSandboxStorage(SandboxStorage):
|
||||
(
|
||||
pipeline(sandbox)
|
||||
.add(
|
||||
["tar", "-czf", ARCHIVE_PATH, "--warning=no-file-changed", "-C", WORKSPACE_DIR, "."],
|
||||
[
|
||||
"tar",
|
||||
"-czf",
|
||||
ARCHIVE_PATH,
|
||||
"--warning=no-file-changed",
|
||||
*build_tar_exclude_args(self._exclude_patterns),
|
||||
"-C",
|
||||
WORKSPACE_DIR,
|
||||
".",
|
||||
],
|
||||
error_message="Failed to create archive",
|
||||
)
|
||||
.add(["wget", upload_url, "-O", ARCHIVE_PATH], error_message="Failed to upload archive")
|
||||
.add(
|
||||
["curl", "-s", "-f", "-X", "PUT", "-T", ARCHIVE_PATH, upload_url],
|
||||
error_message="Failed to upload archive",
|
||||
)
|
||||
.execute(timeout=ARCHIVE_UPLOAD_TIMEOUT, raise_on_error=True)
|
||||
)
|
||||
logger.info("Unmounted archive for sandbox %s", self._sandbox_id)
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.sandbox import SandboxManager, sandbox_debug
|
||||
from core.sandbox.vm import SandboxBuilder
|
||||
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
|
||||
from core.virtual_environment.__base.helpers import submit_command, with_connection
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
@ -24,11 +25,18 @@ COMMAND_NODE_TIMEOUT_SECONDS = 60
|
||||
class CommandNode(Node[CommandNodeData]):
|
||||
node_type = NodeType.COMMAND
|
||||
|
||||
# FIXME(Mairuis): should read sandbox from workflow run context...
|
||||
def _get_sandbox(self) -> VirtualEnvironment | None:
|
||||
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
if not workflow_execution_id:
|
||||
return None
|
||||
return SandboxManager.get(workflow_execution_id)
|
||||
sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id)
|
||||
if sandbox_by_workflow_run_id is not None:
|
||||
return sandbox_by_workflow_run_id
|
||||
sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id))
|
||||
if sandbox_by_draft_id is not None:
|
||||
return sandbox_by_draft_id
|
||||
return None
|
||||
|
||||
def _render_template(self, template: str) -> str:
|
||||
parser = VariableTemplateParser(template=template)
|
||||
|
||||
@ -51,6 +51,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.sandbox import SandboxBashSession, SandboxManager
|
||||
from core.sandbox.vm import SandboxBuilder
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@ -63,6 +64,7 @@ from core.variables import (
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.entities.tool_entities import ToolCallResult
|
||||
@ -172,6 +174,19 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
# FIXME(Mairuis): should read sandbox from workflow run context...
|
||||
def _get_sandbox(self) -> VirtualEnvironment | None:
|
||||
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
if not workflow_execution_id:
|
||||
return None
|
||||
sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id)
|
||||
if sandbox_by_workflow_run_id is not None:
|
||||
return sandbox_by_workflow_run_id
|
||||
sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id))
|
||||
if sandbox_by_draft_id is not None:
|
||||
return sandbox_by_draft_id
|
||||
return None
|
||||
|
||||
def _run(self) -> Generator:
|
||||
node_inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
@ -287,13 +302,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
if self.tool_call_enabled:
|
||||
workflow_execution_id = variable_pool.system_variables.workflow_execution_id
|
||||
is_sandbox_runtime = workflow_execution_id is not None and SandboxManager.is_sandbox_runtime(
|
||||
workflow_execution_id
|
||||
)
|
||||
|
||||
if is_sandbox_runtime:
|
||||
# FIXME(Mairuis): should read sandbox from workflow run context...
|
||||
sandbox = self._get_sandbox()
|
||||
if sandbox:
|
||||
generator = self._invoke_llm_with_sandbox(
|
||||
sandbox=sandbox,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
@ -1827,21 +1840,18 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
def _invoke_llm_with_sandbox(
|
||||
self,
|
||||
sandbox: VirtualEnvironment,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
variable_pool: VariablePool,
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
workflow_execution_id = variable_pool.system_variables.workflow_execution_id
|
||||
if not workflow_execution_id:
|
||||
raise LLMNodeError("workflow_execution_id is required for sandbox runtime mode")
|
||||
|
||||
allow_tools = self._get_allow_tools_list()
|
||||
|
||||
result: LLMGenerationData | None = None
|
||||
|
||||
with SandboxBashSession(
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
sandbox=sandbox,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
node_id=self.id,
|
||||
|
||||
@ -15,8 +15,9 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.file import File
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.sandbox import SandboxManager
|
||||
from core.sandbox.constants import APP_ASSETS_PATH
|
||||
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
|
||||
from core.sandbox.storage.sandbox_storage import SandboxStorage
|
||||
from core.sandbox.vm import SandboxBuilder
|
||||
from core.variables import Variable, VariableBase
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@ -706,20 +707,24 @@ class WorkflowService:
|
||||
from core.sandbox import AppAssetsInitializer, DifyCliInitializer
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
assets = AppAssetService.get_or_create_assets(draft_workflow.tenant_id, app_model.id, is_draft=True)
|
||||
assets = AppAssetService.get_or_create_assets(session, app_model, account.id)
|
||||
if not assets:
|
||||
raise ValueError(f"No assets found for tid={draft_workflow.tenant_id}, app_id={app_model.id}")
|
||||
|
||||
# FIXME(Mairuis): single step execution
|
||||
AppAssetService.build_assets(draft_workflow.tenant_id, app_model.id, assets)
|
||||
sandbox_id = SandboxBuilder.draft_id(account.id)
|
||||
sandbox_storage = ArchiveSandboxStorage(
|
||||
draft_workflow.tenant_id, sandbox_id, exclude_patterns=[APP_ASSETS_PATH]
|
||||
)
|
||||
|
||||
sandbox = (
|
||||
SandboxProviderService.create_sandbox_builder(draft_workflow.tenant_id)
|
||||
.initializer(DifyCliInitializer(draft_workflow.tenant_id, account.id, app_model.id, assets.id))
|
||||
.initializer(AppAssetsInitializer(draft_workflow.tenant_id, app_model.id, assets.id))
|
||||
.storage(ArchiveSandboxStorage(draft_workflow.tenant_id, SandboxStorage.draft_id(account.id)))
|
||||
.build()
|
||||
)
|
||||
sandbox_storage.mount(sandbox)
|
||||
single_step_execution_id = f"single-step-{uuid.uuid4()}"
|
||||
|
||||
SandboxManager.register(single_step_execution_id, sandbox)
|
||||
@ -742,6 +747,9 @@ class WorkflowService:
|
||||
start_at=start_at,
|
||||
node_id=node_id,
|
||||
)
|
||||
# FIXME(Mairuis): fidn a better way to handle this
|
||||
if sandbox is not None:
|
||||
sandbox_storage.unmount(sandbox)
|
||||
finally:
|
||||
if single_step_execution_id:
|
||||
sandbox = SandboxManager.unregister(single_step_execution_id)
|
||||
|
||||
@ -1,324 +0,0 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer
|
||||
from core.sandbox import SandboxManager
|
||||
from core.sandbox.storage.sandbox_storage import SandboxStorage
|
||||
from core.virtual_environment.__base.entities import Arch
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from models.app_asset import AppAssets
|
||||
|
||||
|
||||
class MockMetadata:
|
||||
def __init__(self, sandbox_id: str = "test-sandbox-id", arch: Arch = Arch.AMD64):
|
||||
self.id = sandbox_id
|
||||
self.arch = arch
|
||||
|
||||
|
||||
class MockVirtualEnvironment:
|
||||
def __init__(self, sandbox_id: str = "test-sandbox-id"):
|
||||
self.metadata = MockMetadata(sandbox_id=sandbox_id)
|
||||
self._released = False
|
||||
|
||||
def release_environment(self) -> None:
|
||||
self._released = True
|
||||
|
||||
|
||||
class MockVMBuilder:
|
||||
_sandbox: VirtualEnvironment
|
||||
|
||||
def __init__(self, sandbox: VirtualEnvironment) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
def environments(self, _: object) -> "MockVMBuilder":
|
||||
return self
|
||||
|
||||
def initializer(self, _: object) -> "MockVMBuilder":
|
||||
return self
|
||||
|
||||
def build(self) -> VirtualEnvironment:
|
||||
return self._sandbox
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_sandbox_manager():
|
||||
SandboxManager.clear()
|
||||
yield
|
||||
SandboxManager.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sandbox_storage() -> MagicMock:
|
||||
mock_storage = MagicMock(spec=SandboxStorage)
|
||||
mock_storage.mount.return_value = False
|
||||
mock_storage.unmount.return_value = True
|
||||
return mock_storage
|
||||
|
||||
|
||||
def create_mock_builder(sandbox: Any) -> MockVMBuilder:
|
||||
return MockVMBuilder(sandbox)
|
||||
|
||||
|
||||
def create_layer(
|
||||
tenant_id: str = "test-tenant",
|
||||
app_id: str = "test-app",
|
||||
workflow_version: str = AppAssets.VERSION_DRAFT,
|
||||
sandbox_id: str = "test-sandbox",
|
||||
sandbox_storage: Any = None,
|
||||
) -> SandboxLayer:
|
||||
if sandbox_storage is None:
|
||||
sandbox_storage = MagicMock(spec=SandboxStorage)
|
||||
sandbox_storage.mount.return_value = False
|
||||
sandbox_storage.unmount.return_value = True
|
||||
return SandboxLayer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_version=workflow_version,
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_storage=sandbox_storage,
|
||||
)
|
||||
|
||||
|
||||
class TestSandboxLayer:
|
||||
def test_init_with_parameters(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
layer = create_layer(
|
||||
tenant_id="test-tenant",
|
||||
app_id="test-app",
|
||||
sandbox_id="test-sandbox",
|
||||
sandbox_storage=mock_sandbox_storage,
|
||||
)
|
||||
|
||||
assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._app_id == "test-app" # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._sandbox_id == "test-sandbox" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_sandbox_property_raises_when_not_initialized(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
layer = create_layer(sandbox_storage=mock_sandbox_storage)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
_ = layer.sandbox
|
||||
|
||||
assert "Sandbox not found" in str(exc_info.value)
|
||||
|
||||
def test_sandbox_property_returns_sandbox_after_initialization(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-id"
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MockVirtualEnvironment()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
assert layer.sandbox is mock_sandbox
|
||||
|
||||
def test_on_graph_start_creates_sandbox_and_registers_with_manager(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-123"
|
||||
layer = create_layer(
|
||||
tenant_id="test-tenant-123",
|
||||
app_id="test-app-123",
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_storage=mock_sandbox_storage,
|
||||
)
|
||||
mock_sandbox = MockVirtualEnvironment()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
) as mock_create,
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
mock_create.assert_called_once_with("test-tenant-123")
|
||||
|
||||
assert SandboxManager.get(sandbox_id) is mock_sandbox
|
||||
|
||||
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(
|
||||
self, mock_sandbox_storage: MagicMock
|
||||
) -> None:
|
||||
layer = create_layer(sandbox_storage=mock_sandbox_storage)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
side_effect=Exception("Sandbox provider not available"),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
with pytest.raises(SandboxInitializationError) as exc_info:
|
||||
layer.on_graph_start()
|
||||
|
||||
assert "Failed to initialize sandbox" in str(exc_info.value)
|
||||
assert "Sandbox provider not available" in str(exc_info.value)
|
||||
|
||||
def test_on_event_is_noop(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
layer = create_layer(sandbox_storage=mock_sandbox_storage)
|
||||
|
||||
layer.on_event(GraphRunStartedEvent())
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||
|
||||
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-456"
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
assert SandboxManager.has(sandbox_id)
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
assert not SandboxManager.has(sandbox_id)
|
||||
|
||||
def test_on_graph_end_releases_sandbox_even_on_error(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-789"
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_graph_end(error=Exception("Workflow failed"))
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
assert not SandboxManager.has(sandbox_id)
|
||||
|
||||
def test_on_graph_end_handles_release_failure_gracefully(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-fail"
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
mock_sandbox.release_environment.side_effect = Exception("Container already removed")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
def test_on_graph_end_noop_when_sandbox_not_registered(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
layer = create_layer(sandbox_id="nonexistent-sandbox", sandbox_storage=mock_sandbox_storage)
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
def test_on_graph_end_is_idempotent(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-idempotent"
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
def test_layer_inherits_from_graph_engine_layer(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
layer = create_layer(sandbox_storage=mock_sandbox_storage)
|
||||
|
||||
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||
_ = layer.graph_runtime_state
|
||||
|
||||
assert layer.command_channel is None
|
||||
|
||||
|
||||
class TestSandboxLayerIntegration:
|
||||
def test_full_lifecycle_with_mocked_provider(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "integration-test-exec"
|
||||
layer = create_layer(
|
||||
tenant_id="integration-tenant",
|
||||
app_id="integration-app",
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_storage=mock_sandbox_storage,
|
||||
)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
assert layer.sandbox is mock_sandbox
|
||||
assert SandboxManager.get(sandbox_id) is mock_sandbox
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
assert not SandboxManager.has(sandbox_id)
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
def test_lifecycle_with_workflow_error(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "integration-error-test"
|
||||
layer = create_layer(
|
||||
tenant_id="error-tenant",
|
||||
app_id="error-app",
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_storage=mock_sandbox_storage,
|
||||
)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
assert layer.sandbox.metadata.id is not None
|
||||
|
||||
layer.on_graph_end(error=Exception("Workflow execution failed"))
|
||||
|
||||
assert not SandboxManager.has(sandbox_id)
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
Reference in New Issue
Block a user