mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 19:27:40 +08:00
refactor: consolidate sandbox management and initialization
- Moved sandbox-related classes and functions into a dedicated module for better organization. - Updated the sandbox initialization process to streamline asset management and environment setup. - Removed deprecated constants and refactored related code to utilize new sandbox entities. - Enhanced the workflow context to support sandbox integration, allowing for improved state management during execution. - Adjusted various components to utilize the new sandbox structure, ensuring compatibility across the application.
This commit is contained in:
@ -23,7 +23,7 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.sandbox.vm import SandboxBuilder, SandboxType
|
||||
from core.sandbox import SandboxBuilder, SandboxType
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
|
||||
@ -30,6 +30,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.sandbox import Sandbox, SandboxManager
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
@ -43,6 +44,7 @@ from models import Account, App, Conversation, EndUser, Message, Workflow, Workf
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow_features import WorkflowFeatures
|
||||
from services.conversation_service import ConversationService
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVarLoader,
|
||||
WorkflowDraftVariableService,
|
||||
@ -514,19 +516,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
sandbox: Sandbox | None = None
|
||||
graph_engine_layers: tuple = ()
|
||||
if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
if application_generate_entity.workflow_run_id is None:
|
||||
raise ValueError("workflow_run_id is required when sandbox is enabled")
|
||||
graph_engine_layers = (
|
||||
SandboxLayer(
|
||||
sandbox_provider = SandboxProviderService.get_sandbox_provider(
|
||||
application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
sandbox = SandboxManager.create_draft(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
sandbox_provider=sandbox_provider,
|
||||
)
|
||||
else:
|
||||
if application_generate_entity.workflow_run_id is None:
|
||||
raise ValueError("workflow_run_id is required when sandbox is enabled")
|
||||
sandbox = SandboxManager.create(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
workflow_version=workflow.version,
|
||||
workflow_execution_id=application_generate_entity.workflow_run_id,
|
||||
),
|
||||
)
|
||||
sandbox_provider=sandbox_provider,
|
||||
)
|
||||
graph_engine_layers = (SandboxLayer(sandbox=sandbox),)
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
is_external_api_call = application_generate_entity.invoke_from in {
|
||||
@ -559,6 +572,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
sandbox=sandbox,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -24,6 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.sandbox import Sandbox
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
@ -66,6 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
sandbox: Sandbox | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -82,6 +84,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._app = app
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._sandbox = sandbox
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
@ -156,6 +159,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
# init graph
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||
|
||||
if self._sandbox:
|
||||
graph_runtime_state.set_sandbox(self._sandbox)
|
||||
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@ -29,6 +29,7 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.sandbox import Sandbox, SandboxManager
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
@ -40,6 +41,7 @@ from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow_features import WorkflowFeatures
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
@ -490,16 +492,29 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
if workflow is None:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
sandbox: Sandbox | None = None
|
||||
if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
graph_engine_layers = (
|
||||
*graph_engine_layers,
|
||||
SandboxLayer(
|
||||
sandbox_provider = SandboxProviderService.get_sandbox_provider(
|
||||
application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
sandbox = SandboxManager.create_draft(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
sandbox_provider=sandbox_provider,
|
||||
)
|
||||
else:
|
||||
sandbox = SandboxManager.create(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
workflow_version=workflow.version,
|
||||
workflow_execution_id=application_generate_entity.workflow_execution_id,
|
||||
),
|
||||
sandbox_provider=sandbox_provider,
|
||||
)
|
||||
graph_engine_layers = (
|
||||
*graph_engine_layers,
|
||||
SandboxLayer(sandbox=sandbox),
|
||||
)
|
||||
|
||||
# Determine system_user_id based on invocation source
|
||||
@ -526,6 +541,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
sandbox=sandbox,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -7,6 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.sandbox import Sandbox
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@ -42,6 +43,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
sandbox: Sandbox | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@ -55,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self._root_node_id = root_node_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._sandbox = sandbox
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
@ -99,6 +102,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
if self._sandbox:
|
||||
graph_runtime_state.set_sandbox(self._sandbox)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
|
||||
@ -1,122 +1,22 @@
|
||||
import logging
|
||||
|
||||
from core.sandbox import AppAssetsInitializer, DifyCliInitializer, SandboxManager
|
||||
from core.sandbox.constants import APP_ASSETS_PATH
|
||||
from core.sandbox.initializer.app_assets_initializer import DraftAppAssetsInitializer
|
||||
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
|
||||
from core.sandbox.vm import SandboxBuilder
|
||||
from core.sandbox import Sandbox
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from models.workflow import Workflow
|
||||
from services.app_asset_service import AppAssetService
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxInitializationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SandboxLayer(GraphEngineLayer):
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
workflow_version: str,
|
||||
workflow_execution_id: str,
|
||||
) -> None:
|
||||
def __init__(self, sandbox: Sandbox) -> None:
|
||||
super().__init__()
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._user_id = user_id
|
||||
self._workflow_version = workflow_version
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
is_draft = self._workflow_version == Workflow.VERSION_DRAFT
|
||||
self._sandbox_id = SandboxBuilder.draft_id(self._user_id) if is_draft else self._workflow_execution_id
|
||||
self._sandbox_storage = ArchiveSandboxStorage(
|
||||
self._tenant_id, self._sandbox_id, exclude_patterns=[APP_ASSETS_PATH] if is_draft else None
|
||||
)
|
||||
self._sandbox = sandbox
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
try:
|
||||
is_draft = self._workflow_version == Workflow.VERSION_DRAFT
|
||||
assets = AppAssetService.get_assets(self._tenant_id, self._app_id, self._user_id, is_draft=is_draft)
|
||||
if not assets:
|
||||
raise ValueError(
|
||||
f"No assets found for tid={self._tenant_id}, app_id={self._app_id}, wf={self._workflow_version}"
|
||||
)
|
||||
|
||||
self._assets_id = assets.id
|
||||
|
||||
if is_draft:
|
||||
logger.info(
|
||||
"Building draft assets for tenant_id=%s, app_id=%s, workflow_version=%s, assets_id=%s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._workflow_version,
|
||||
assets.id,
|
||||
)
|
||||
AppAssetService.build_assets(self._tenant_id, self._app_id, assets)
|
||||
|
||||
assets_initializer = (
|
||||
DraftAppAssetsInitializer(self._tenant_id, self._app_id, assets.id)
|
||||
if is_draft
|
||||
else AppAssetsInitializer(self._tenant_id, self._app_id, assets.id)
|
||||
)
|
||||
|
||||
builder = (
|
||||
SandboxProviderService.create_sandbox_builder(self._tenant_id)
|
||||
.initializer(assets_initializer)
|
||||
.initializer(DifyCliInitializer(self._tenant_id, self._user_id, self._app_id, assets.id))
|
||||
)
|
||||
try:
|
||||
sandbox = builder.build()
|
||||
logger.info(
|
||||
"Sandbox initialized, workflow_execution_id=%s, sandbox_id=%s, sandbox_arch=%s",
|
||||
self._sandbox_id,
|
||||
sandbox.metadata.id,
|
||||
sandbox.metadata.arch,
|
||||
)
|
||||
except Exception as e:
|
||||
raise SandboxInitializationError(f"Failed to build sandbox: {e}") from e
|
||||
|
||||
SandboxManager.register(self._sandbox_id, sandbox)
|
||||
|
||||
# mount sandbox files from storage
|
||||
mounted = self._sandbox_storage.mount(sandbox)
|
||||
logger.info("Sandbox files mount status: %s", mounted)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to initialize sandbox")
|
||||
raise SandboxInitializationError(f"Failed to initialize sandbox: {e}") from e
|
||||
|
||||
def on_node_run_start(self, node: Node) -> None:
|
||||
# FIXME(Mairuis): should read from workflow run context...
|
||||
node.assets_id = self._assets_id
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
# TODO: handle graph run paused event
|
||||
if not isinstance(event, GraphRunPausedEvent):
|
||||
return
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
sandbox = SandboxManager.unregister(self._sandbox_id)
|
||||
if sandbox is None:
|
||||
logger.debug("No sandbox to release for sandbox_id=%s", self._sandbox_id)
|
||||
return
|
||||
|
||||
try:
|
||||
self._sandbox_storage.unmount(sandbox)
|
||||
logger.info("Sandbox files persisted, sandbox_id=%s", self._sandbox_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to persist sandbox files, sandbox_id=%s", self._sandbox_id)
|
||||
|
||||
try:
|
||||
sandbox.release_environment()
|
||||
logger.info("Sandbox released, sandbox_id=%s", self._sandbox_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to release sandbox, sandbox_id=%s", self._sandbox_id)
|
||||
self._sandbox.release()
|
||||
|
||||
@ -6,44 +6,32 @@ from .bash.dify_cli import (
|
||||
DifyCliToolConfig,
|
||||
)
|
||||
from .bash.session import SandboxBashSession
|
||||
from .constants import (
|
||||
APP_ASSETS_PATH,
|
||||
APP_ASSETS_ZIP_PATH,
|
||||
DIFY_CLI_CONFIG_FILENAME,
|
||||
DIFY_CLI_GLOBAL_TOOLS_PATH,
|
||||
DIFY_CLI_PATH,
|
||||
DIFY_CLI_PATH_PATTERN,
|
||||
DIFY_CLI_ROOT,
|
||||
DIFY_CLI_TOOLS_ROOT,
|
||||
)
|
||||
from .builder import SandboxBuilder, VMConfig
|
||||
from .entities import AppAssets, DifyCli, SandboxProviderApiEntity, SandboxType
|
||||
from .initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer
|
||||
from .manager import SandboxManager
|
||||
from .sandbox import Sandbox
|
||||
from .storage import ArchiveSandboxStorage, SandboxStorage
|
||||
from .utils.debug import sandbox_debug
|
||||
from .utils.encryption import create_sandbox_config_encrypter, masked_config
|
||||
from .vm import SandboxBuilder, SandboxType, VMConfig
|
||||
|
||||
__all__ = [
|
||||
"APP_ASSETS_PATH",
|
||||
"APP_ASSETS_ZIP_PATH",
|
||||
"DIFY_CLI_CONFIG_FILENAME",
|
||||
"DIFY_CLI_GLOBAL_TOOLS_PATH",
|
||||
"DIFY_CLI_PATH",
|
||||
"DIFY_CLI_PATH_PATTERN",
|
||||
"DIFY_CLI_ROOT",
|
||||
"DIFY_CLI_TOOLS_ROOT",
|
||||
"AppAssets",
|
||||
"AppAssetsInitializer",
|
||||
"ArchiveSandboxStorage",
|
||||
"DifyCli",
|
||||
"DifyCliBinary",
|
||||
"DifyCliConfig",
|
||||
"DifyCliEnvConfig",
|
||||
"DifyCliInitializer",
|
||||
"DifyCliLocator",
|
||||
"DifyCliToolConfig",
|
||||
"Sandbox",
|
||||
"SandboxBashSession",
|
||||
"SandboxBuilder",
|
||||
"SandboxInitializer",
|
||||
"SandboxManager",
|
||||
"SandboxProviderApiEntity",
|
||||
"SandboxStorage",
|
||||
"SandboxType",
|
||||
"VMConfig",
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.sandbox.constants import DIFY_CLI_CONFIG_FILENAME
|
||||
from core.sandbox.entities import DifyCli
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -79,7 +79,7 @@ class SandboxBashTool(Tool):
|
||||
if self._tools_path:
|
||||
environments = {
|
||||
"PATH": f"{self._tools_path}:/usr/local/bin:/usr/bin:/bin",
|
||||
"DIFY_CLI_CONFIG": self._tools_path + f"/{DIFY_CLI_CONFIG_FILENAME}",
|
||||
"DIFY_CLI_CONFIG": self._tools_path + f"/{DifyCli.CONFIG_FILENAME}",
|
||||
}
|
||||
future = submit_command(
|
||||
self._sandbox,
|
||||
|
||||
@ -14,7 +14,7 @@ from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.virtual_environment.__base.entities import Arch, OperatingSystem
|
||||
|
||||
from ..constants import DIFY_CLI_PATH_PATTERN
|
||||
from ..entities import DifyCli
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
@ -44,7 +44,7 @@ class DifyCliLocator:
|
||||
self._root = api_root / "bin"
|
||||
|
||||
def resolve(self, operating_system: OperatingSystem, arch: Arch) -> DifyCliBinary:
|
||||
filename = DIFY_CLI_PATH_PATTERN.format(os=operating_system.value, arch=arch.value)
|
||||
filename = DifyCli.PATH_PATTERN.format(os=operating_system.value, arch=arch.value)
|
||||
candidate = self._root / filename
|
||||
if not candidate.is_file():
|
||||
raise FileNotFoundError(
|
||||
|
||||
@ -5,19 +5,14 @@ import logging
|
||||
from io import BytesIO
|
||||
from types import TracebackType
|
||||
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.session.cli_api import CliApiSession, CliApiSessionManager
|
||||
from core.skill.entities.tool_artifact import ToolArtifact
|
||||
from core.skill.skill_manager import SkillManager
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
from ..bash.dify_cli import DifyCliConfig
|
||||
from ..constants import (
|
||||
DIFY_CLI_CONFIG_FILENAME,
|
||||
DIFY_CLI_GLOBAL_TOOLS_PATH,
|
||||
DIFY_CLI_PATH,
|
||||
DIFY_CLI_TOOLS_ROOT,
|
||||
)
|
||||
from ..entities import DifyCli
|
||||
from .bash_tool import SandboxBashTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -27,46 +22,46 @@ class SandboxBashSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sandbox: VirtualEnvironment,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
sandbox: Sandbox,
|
||||
node_id: str,
|
||||
app_id: str,
|
||||
assets_id: str,
|
||||
allow_tools: list[tuple[str, str]] | None,
|
||||
) -> None:
|
||||
self._sandbox = sandbox
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._node_id = node_id
|
||||
self._app_id = app_id
|
||||
|
||||
# FIXME(Mairuis): should read from workflow run context...
|
||||
self._assets_id = assets_id
|
||||
self._allow_tools = allow_tools
|
||||
|
||||
self._bash_tool = None
|
||||
self._session_id = None
|
||||
self._bash_tool: SandboxBashTool | None = None
|
||||
self._cli_api_session: CliApiSession | None = None
|
||||
self._tenant_id = sandbox.tenant_id
|
||||
self._user_id = sandbox.user_id
|
||||
self._app_id = sandbox.app_id
|
||||
self._assets_id = sandbox.assets_id
|
||||
|
||||
def __enter__(self) -> SandboxBashSession:
|
||||
self._cli_api_session = CliApiSessionManager().create(
|
||||
tenant_id=self._tenant_id,
|
||||
user_id=self._user_id,
|
||||
)
|
||||
if self._allow_tools is not None:
|
||||
if self._node_id is None:
|
||||
raise ValueError("node_id is required when allow_tools is specified")
|
||||
tools_path = self._setup_node_tools_directory(self._sandbox, self._node_id, self._allow_tools)
|
||||
tools_path = self._setup_node_tools_directory(self._node_id, self._allow_tools, self._cli_api_session)
|
||||
else:
|
||||
tools_path = DIFY_CLI_GLOBAL_TOOLS_PATH
|
||||
tools_path = DifyCli.GLOBAL_TOOLS_PATH
|
||||
|
||||
self._bash_tool = SandboxBashTool(sandbox=self._sandbox, tenant_id=self._tenant_id, tools_path=tools_path)
|
||||
self._bash_tool = SandboxBashTool(
|
||||
sandbox=self._sandbox.vm,
|
||||
tenant_id=self._tenant_id,
|
||||
tools_path=tools_path,
|
||||
)
|
||||
return self
|
||||
|
||||
def _setup_node_tools_directory(
|
||||
self,
|
||||
sandbox: VirtualEnvironment,
|
||||
node_id: str,
|
||||
allow_tools: list[tuple[str, str]],
|
||||
cli_api_session: CliApiSession,
|
||||
) -> str | None:
|
||||
artifact: ToolArtifact | None = SkillManager.load_tool_artifact(
|
||||
self._tenant_id,
|
||||
self._sandbox.tenant_id,
|
||||
self._app_id,
|
||||
self._assets_id,
|
||||
)
|
||||
@ -80,26 +75,26 @@ class SandboxBashSession:
|
||||
logger.info("No tools found in artifact for assets_id=%s", self._assets_id)
|
||||
return None
|
||||
|
||||
self._cli_api_session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id)
|
||||
node_tools_path = f"{DIFY_CLI_TOOLS_ROOT}/{node_id}"
|
||||
node_tools_path = f"{DifyCli.TOOLS_ROOT}/{node_id}"
|
||||
|
||||
vm = self._sandbox.vm
|
||||
(
|
||||
pipeline(sandbox)
|
||||
.add(["mkdir", "-p", DIFY_CLI_GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir")
|
||||
pipeline(vm)
|
||||
.add(["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir")
|
||||
.add(["mkdir", "-p", node_tools_path], error_message="Failed to create node tools dir")
|
||||
.execute(raise_on_error=True)
|
||||
)
|
||||
|
||||
config_json = json.dumps(
|
||||
DifyCliConfig.create(
|
||||
session=self._cli_api_session, tenant_id=self._tenant_id, artifact=artifact
|
||||
).model_dump(mode="json"),
|
||||
DifyCliConfig.create(session=cli_api_session, tenant_id=self._tenant_id, artifact=artifact).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
sandbox.upload_file(f"{node_tools_path}/{DIFY_CLI_CONFIG_FILENAME}", BytesIO(config_json.encode("utf-8")))
|
||||
vm.upload_file(f"{node_tools_path}/{DifyCli.CONFIG_FILENAME}", BytesIO(config_json.encode("utf-8")))
|
||||
|
||||
pipeline(sandbox, cwd=node_tools_path).add(
|
||||
[DIFY_CLI_PATH, "init"], error_message="Failed to initialize Dify CLI"
|
||||
pipeline(vm, cwd=node_tools_path).add(
|
||||
[DifyCli.PATH, "init"], error_message="Failed to initialize Dify CLI"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
logger.info(
|
||||
@ -114,7 +109,10 @@ class SandboxBashSession:
|
||||
tb: TracebackType | None,
|
||||
) -> bool:
|
||||
try:
|
||||
self.cleanup()
|
||||
if self._session_id is not None:
|
||||
CliApiSessionManager().delete(self._session_id)
|
||||
logger.debug("Cleaned up SandboxSession session_id=%s", self._session_id)
|
||||
self._session_id = None
|
||||
except Exception:
|
||||
logger.exception("Failed to cleanup SandboxSession")
|
||||
return False
|
||||
@ -124,11 +122,3 @@ class SandboxBashSession:
|
||||
if self._bash_tool is None:
|
||||
raise RuntimeError("SandboxSession is not initialized")
|
||||
return self._bash_tool
|
||||
|
||||
def cleanup(self) -> None:
|
||||
if self._session_id is None:
|
||||
return
|
||||
|
||||
CliApiSessionManager().delete(self._session_id)
|
||||
logger.debug("Cleaned up SandboxSession session_id=%s", self._session_id)
|
||||
self._session_id = None
|
||||
|
||||
@ -1,41 +1,17 @@
|
||||
"""
|
||||
Facade module for virtual machine providers.
|
||||
|
||||
Provides unified interfaces to access different VM provider implementations
|
||||
(E2B, Docker, Local) through VMType, VMBuilder, and VMConfig.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
from .entities.sandbox_type import SandboxType
|
||||
from .initializer import SandboxInitializer
|
||||
from .sandbox import Sandbox
|
||||
|
||||
|
||||
class SandboxType(StrEnum):
|
||||
"""
|
||||
Sandbox types.
|
||||
"""
|
||||
|
||||
DOCKER = "docker"
|
||||
E2B = "e2b"
|
||||
LOCAL = "local"
|
||||
|
||||
@classmethod
|
||||
def get_all(cls) -> list[str]:
|
||||
"""
|
||||
Get all available sandbox types.
|
||||
"""
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return [p.value for p in cls]
|
||||
else:
|
||||
return [p.value for p in cls if p != SandboxType.LOCAL]
|
||||
if TYPE_CHECKING:
|
||||
from .storage.sandbox_storage import SandboxStorage
|
||||
|
||||
|
||||
def _get_sandbox_class(sandbox_type: SandboxType) -> type[VirtualEnvironment]:
|
||||
@ -57,18 +33,35 @@ def _get_sandbox_class(sandbox_type: SandboxType) -> type[VirtualEnvironment]:
|
||||
|
||||
|
||||
class SandboxBuilder:
|
||||
_tenant_id: str
|
||||
_sandbox_type: SandboxType
|
||||
_user_id: str | None
|
||||
_app_id: str | None
|
||||
_options: dict[str, Any]
|
||||
_environments: dict[str, str]
|
||||
_initializers: list[SandboxInitializer]
|
||||
_storage: SandboxStorage | None
|
||||
_assets_id: str | None
|
||||
|
||||
def __init__(self, tenant_id: str, sandbox_type: SandboxType) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._sandbox_type = sandbox_type
|
||||
self._user_id: str | None = None
|
||||
self._options: dict[str, Any] = {}
|
||||
self._environments: dict[str, str] = {}
|
||||
self._initializers: list[SandboxInitializer] = []
|
||||
self._user_id = None
|
||||
self._app_id = None
|
||||
self._options = {}
|
||||
self._environments = {}
|
||||
self._initializers = []
|
||||
self._storage = None
|
||||
self._assets_id = None
|
||||
|
||||
def user(self, user_id: str) -> SandboxBuilder:
|
||||
self._user_id = user_id
|
||||
return self
|
||||
|
||||
def app(self, app_id: str) -> SandboxBuilder:
|
||||
self._app_id = app_id
|
||||
return self
|
||||
|
||||
def options(self, options: Mapping[str, Any]) -> SandboxBuilder:
|
||||
self._options = dict(options)
|
||||
return self
|
||||
@ -85,7 +78,21 @@ class SandboxBuilder:
|
||||
self._initializers.extend(initializers)
|
||||
return self
|
||||
|
||||
def build(self) -> VirtualEnvironment:
|
||||
def storage(self, storage: SandboxStorage, assets_id: str) -> SandboxBuilder:
|
||||
self._storage = storage
|
||||
self._assets_id = assets_id
|
||||
return self
|
||||
|
||||
def build(self) -> Sandbox:
|
||||
if self._storage is None:
|
||||
raise ValueError("storage is required, call .storage() before .build()")
|
||||
if self._assets_id is None:
|
||||
raise ValueError("assets_id is required, call .storage() before .build()")
|
||||
if self._user_id is None:
|
||||
raise ValueError("user_id is required, call .user() before .build()")
|
||||
if self._app_id is None:
|
||||
raise ValueError("app_id is required, call .app() before .build()")
|
||||
|
||||
vm_class = _get_sandbox_class(self._sandbox_type)
|
||||
vm = vm_class(
|
||||
tenant_id=self._tenant_id,
|
||||
@ -95,7 +102,17 @@ class SandboxBuilder:
|
||||
)
|
||||
for init in self._initializers:
|
||||
init.initialize(vm)
|
||||
return vm
|
||||
|
||||
sandbox = Sandbox(
|
||||
vm=vm,
|
||||
storage=self._storage,
|
||||
tenant_id=self._tenant_id,
|
||||
user_id=self._user_id,
|
||||
app_id=self._app_id,
|
||||
assets_id=self._assets_id,
|
||||
)
|
||||
sandbox.mount()
|
||||
return sandbox
|
||||
|
||||
@staticmethod
|
||||
def validate(vm_type: SandboxType, options: Mapping[str, Any]) -> None:
|
||||
@ -1,16 +0,0 @@
|
||||
from typing import Final
|
||||
|
||||
# Dify CLI (absolute path - hidden in /tmp, not in sandbox workdir)
|
||||
DIFY_CLI_ROOT: Final[str] = "/tmp/.dify"
|
||||
DIFY_CLI_PATH: Final[str] = "/tmp/.dify/bin/dify"
|
||||
|
||||
DIFY_CLI_PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}"
|
||||
|
||||
DIFY_CLI_CONFIG_FILENAME: Final[str] = ".dify_cli.json"
|
||||
|
||||
DIFY_CLI_TOOLS_ROOT: Final[str] = "/tmp/.dify/tools"
|
||||
DIFY_CLI_GLOBAL_TOOLS_PATH: Final[str] = "/tmp/.dify/tools/global"
|
||||
|
||||
# App Assets (relative path - stays in sandbox workdir)
|
||||
APP_ASSETS_PATH: Final[str] = "skills"
|
||||
APP_ASSETS_ZIP_PATH: Final[str] = "/tmp/assets.zip"
|
||||
@ -1,3 +1,10 @@
|
||||
from .config import AppAssets, DifyCli
|
||||
from .providers import SandboxProviderApiEntity
|
||||
from .sandbox_type import SandboxType
|
||||
|
||||
__all__ = ["SandboxProviderApiEntity"]
|
||||
__all__ = [
|
||||
"AppAssets",
|
||||
"DifyCli",
|
||||
"SandboxProviderApiEntity",
|
||||
"SandboxType",
|
||||
]
|
||||
|
||||
19
api/core/sandbox/entities/config.py
Normal file
19
api/core/sandbox/entities/config.py
Normal file
@ -0,0 +1,19 @@
|
||||
from typing import Final
|
||||
|
||||
|
||||
class DifyCli:
|
||||
"""Dify CLI constants (absolute path - hidden in /tmp, not in sandbox workdir)"""
|
||||
|
||||
ROOT: Final[str] = "/tmp/.dify"
|
||||
PATH: Final[str] = "/tmp/.dify/bin/dify"
|
||||
PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}"
|
||||
CONFIG_FILENAME: Final[str] = ".dify_cli.json"
|
||||
TOOLS_ROOT: Final[str] = "/tmp/.dify/tools"
|
||||
GLOBAL_TOOLS_PATH: Final[str] = "/tmp/.dify/tools/global"
|
||||
|
||||
|
||||
class AppAssets:
|
||||
"""App Assets constants (relative path - stays in sandbox workdir)"""
|
||||
|
||||
PATH: Final[str] = "skills"
|
||||
ZIP_PATH: Final[str] = "/tmp/assets.zip"
|
||||
16
api/core/sandbox/entities/sandbox_type.py
Normal file
16
api/core/sandbox/entities/sandbox_type.py
Normal file
@ -0,0 +1,16 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class SandboxType(StrEnum):
|
||||
DOCKER = "docker"
|
||||
E2B = "e2b"
|
||||
LOCAL = "local"
|
||||
|
||||
@classmethod
|
||||
def get_all(cls) -> list[str]:
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return [p.value for p in cls]
|
||||
else:
|
||||
return [p.value for p in cls if p != SandboxType.LOCAL]
|
||||
@ -6,7 +6,7 @@ from core.virtual_environment.__base.virtual_environment import VirtualEnvironme
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH
|
||||
from ..entities import AppAssets
|
||||
from .base import SandboxInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -26,11 +26,11 @@ class AppAssetsInitializer(SandboxInitializer):
|
||||
|
||||
(
|
||||
pipeline(env)
|
||||
.add(["wget", "-q", download_url, "-O", APP_ASSETS_ZIP_PATH], error_message="Failed to download assets zip")
|
||||
.add(["wget", "-q", download_url, "-O", AppAssets.ZIP_PATH], error_message="Failed to download assets zip")
|
||||
# unzip with silent error and return 1 if the zip is empty
|
||||
# FIXME(Mairuis): should use a more robust way to check if the zip is empty
|
||||
.add(
|
||||
["sh", "-c", f"unzip {APP_ASSETS_ZIP_PATH} -d {APP_ASSETS_PATH} 2>/dev/null || [ $? -eq 1 ]"],
|
||||
["sh", "-c", f"unzip {AppAssets.ZIP_PATH} -d {AppAssets.PATH} 2>/dev/null || [ $? -eq 1 ]"],
|
||||
error_message="Failed to unzip assets",
|
||||
)
|
||||
.execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True)
|
||||
@ -55,12 +55,12 @@ class DraftAppAssetsInitializer(SandboxInitializer):
|
||||
|
||||
(
|
||||
pipeline(env)
|
||||
.add(["rm", "-rf", APP_ASSETS_PATH])
|
||||
.add(["wget", "-q", download_url, "-O", APP_ASSETS_ZIP_PATH], error_message="Failed to download assets zip")
|
||||
.add(["rm", "-rf", AppAssets.PATH])
|
||||
.add(["wget", "-q", download_url, "-O", AppAssets.ZIP_PATH], error_message="Failed to download assets zip")
|
||||
# unzip with silent error and return 1 if the zip is empty
|
||||
# FIXME(Mairuis): should use a more robust way to check if the zip is empty
|
||||
.add(
|
||||
["sh", "-c", f"unzip {APP_ASSETS_ZIP_PATH} -d {APP_ASSETS_PATH} 2>/dev/null || [ $? -eq 1 ]"],
|
||||
["sh", "-c", f"unzip {AppAssets.ZIP_PATH} -d {AppAssets.PATH} 2>/dev/null || [ $? -eq 1 ]"],
|
||||
error_message="Failed to unzip assets",
|
||||
)
|
||||
.execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True)
|
||||
|
||||
@ -11,12 +11,7 @@ from core.virtual_environment.__base.helpers import pipeline
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
from ..bash.dify_cli import DifyCliConfig, DifyCliLocator
|
||||
from ..constants import (
|
||||
DIFY_CLI_CONFIG_FILENAME,
|
||||
DIFY_CLI_GLOBAL_TOOLS_PATH,
|
||||
DIFY_CLI_PATH,
|
||||
DIFY_CLI_ROOT,
|
||||
)
|
||||
from ..entities import DifyCli
|
||||
from .base import SandboxInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -44,10 +39,10 @@ class DifyCliInitializer(SandboxInitializer):
|
||||
binary = self._locator.resolve(env.metadata.os, env.metadata.arch)
|
||||
|
||||
pipeline(env).add(
|
||||
["mkdir", "-p", f"{DIFY_CLI_ROOT}/bin"], error_message="Failed to create dify CLI directory"
|
||||
["mkdir", "-p", f"{DifyCli.ROOT}/bin"], error_message="Failed to create dify CLI directory"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
env.upload_file(DIFY_CLI_PATH, BytesIO(binary.path.read_bytes()))
|
||||
env.upload_file(DifyCli.PATH, BytesIO(binary.path.read_bytes()))
|
||||
|
||||
# Use 'cp' with mode preservation workaround: copy file to itself to claim ownership,
|
||||
# then use 'install' to set executable permission
|
||||
@ -55,14 +50,14 @@ class DifyCliInitializer(SandboxInitializer):
|
||||
[
|
||||
"sh",
|
||||
"-c",
|
||||
f"cat '{DIFY_CLI_PATH}' > '{DIFY_CLI_PATH}.tmp' && "
|
||||
f"mv '{DIFY_CLI_PATH}.tmp' '{DIFY_CLI_PATH}' && "
|
||||
f"chmod +x '{DIFY_CLI_PATH}'",
|
||||
f"cat '{DifyCli.PATH}' > '{DifyCli.PATH}.tmp' && "
|
||||
f"mv '{DifyCli.PATH}.tmp' '{DifyCli.PATH}' && "
|
||||
f"chmod +x '{DifyCli.PATH}'",
|
||||
],
|
||||
error_message="Failed to mark dify CLI as executable",
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH)
|
||||
logger.info("Dify CLI uploaded to sandbox, path=%s", DifyCli.PATH)
|
||||
|
||||
artifact = SkillManager.load_tool_artifact(self._tenant_id, self._app_id, self._assets_id)
|
||||
if artifact is None or not artifact.references:
|
||||
@ -73,16 +68,16 @@ class DifyCliInitializer(SandboxInitializer):
|
||||
self._cli_api_session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id)
|
||||
|
||||
pipeline(env).add(
|
||||
["mkdir", "-p", DIFY_CLI_GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir"
|
||||
["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
config = DifyCliConfig.create(self._cli_api_session, self._tenant_id, artifact)
|
||||
config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False)
|
||||
config_path = f"{DIFY_CLI_GLOBAL_TOOLS_PATH}/{DIFY_CLI_CONFIG_FILENAME}"
|
||||
config_path = f"{DifyCli.GLOBAL_TOOLS_PATH}/{DifyCli.CONFIG_FILENAME}"
|
||||
env.upload_file(config_path, BytesIO(config_json.encode("utf-8")))
|
||||
|
||||
pipeline(env, cwd=DIFY_CLI_GLOBAL_TOOLS_PATH).add(
|
||||
[DIFY_CLI_PATH, "init"], error_message="Failed to initialize Dify CLI"
|
||||
pipeline(env, cwd=DifyCli.GLOBAL_TOOLS_PATH).add(
|
||||
[DifyCli.PATH, "init"], error_message="Failed to initialize Dify CLI"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
logger.info("Global tools initialized, path=%s, tool_count=%d", DIFY_CLI_GLOBAL_TOOLS_PATH, len(self._tools))
|
||||
logger.info("Global tools initialized, path=%s, tool_count=%d", DifyCli.GLOBAL_TOOLS_PATH, len(self._tools))
|
||||
|
||||
@ -4,23 +4,20 @@ import logging
|
||||
import threading
|
||||
from typing import Final
|
||||
|
||||
from core.sandbox.builder import SandboxBuilder
|
||||
from core.sandbox.entities import AppAssets, SandboxType
|
||||
from core.sandbox.entities.providers import SandboxProviderEntity
|
||||
from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer, DraftAppAssetsInitializer
|
||||
from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxManager:
|
||||
"""Process-local registry for workflow sandboxes.
|
||||
|
||||
Stores `VirtualEnvironment` references keyed by `workflow_execution_id`.
|
||||
|
||||
Concurrency: the registry is split into hash shards and each shard is updated with
|
||||
copy-on-write under a shard lock. Reads are lock-free (snapshot dict) to reduce
|
||||
contention in hot paths like `get()`.
|
||||
"""
|
||||
|
||||
# FIXME:(sandbox) Prefer a workflow-level context on GraphRuntimeState to store workflow-scoped shared objects.
|
||||
|
||||
_NUM_SHARDS: Final[int] = 1024
|
||||
_SHARD_MASK: Final[int] = _NUM_SHARDS - 1
|
||||
|
||||
@ -104,3 +101,91 @@ class SandboxManager:
|
||||
@classmethod
|
||||
def count(cls) -> int:
|
||||
return sum(len(shard) for shard in cls._shards)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
workflow_execution_id: str,
|
||||
sandbox_provider: SandboxProviderEntity,
|
||||
) -> Sandbox:
|
||||
assets = AppAssetService.get_assets(tenant_id, app_id, user_id, is_draft=False)
|
||||
if not assets:
|
||||
raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}")
|
||||
|
||||
storage = ArchiveSandboxStorage(tenant_id, workflow_execution_id)
|
||||
sandbox = (
|
||||
SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type))
|
||||
.options(sandbox_provider.config)
|
||||
.user(user_id)
|
||||
.app(app_id)
|
||||
.initializer(AppAssetsInitializer(tenant_id, app_id, assets.id))
|
||||
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
|
||||
.storage(storage, assets.id)
|
||||
.build()
|
||||
)
|
||||
|
||||
logger.info("Sandbox created: id=%s, assets=%s", sandbox.vm.metadata.id, sandbox.assets_id)
|
||||
return sandbox
|
||||
|
||||
@classmethod
|
||||
def create_draft(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
sandbox_provider: SandboxProviderEntity,
|
||||
) -> Sandbox:
|
||||
assets = AppAssetService.get_assets(tenant_id, app_id, user_id, is_draft=True)
|
||||
if not assets:
|
||||
raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}")
|
||||
|
||||
AppAssetService.build_assets(tenant_id, app_id, assets)
|
||||
sandbox_id = SandboxBuilder.draft_id(user_id)
|
||||
storage = ArchiveSandboxStorage(tenant_id, sandbox_id, exclude_patterns=[AppAssets.PATH])
|
||||
|
||||
sandbox = (
|
||||
SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type))
|
||||
.options(sandbox_provider.config)
|
||||
.user(user_id)
|
||||
.app(app_id)
|
||||
.initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id))
|
||||
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
|
||||
.storage(storage, assets.id)
|
||||
.build()
|
||||
)
|
||||
|
||||
logger.info("Draft sandbox created: id=%s, assets=%s", sandbox.vm.metadata.id, sandbox.assets_id)
|
||||
return sandbox
|
||||
|
||||
@classmethod
|
||||
def create_for_single_step(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
sandbox_provider: SandboxProviderEntity,
|
||||
) -> Sandbox:
|
||||
assets = AppAssetService.get_assets(tenant_id, app_id, user_id, is_draft=True)
|
||||
if not assets:
|
||||
raise ValueError(f"No assets found for tid={tenant_id}, app_id={app_id}")
|
||||
|
||||
AppAssetService.build_assets(tenant_id, app_id, assets)
|
||||
sandbox_id = SandboxBuilder.draft_id(user_id)
|
||||
storage = ArchiveSandboxStorage(tenant_id, sandbox_id, exclude_patterns=[AppAssets.PATH])
|
||||
|
||||
sandbox = (
|
||||
SandboxBuilder(tenant_id, SandboxType(sandbox_provider.provider_type))
|
||||
.options(sandbox_provider.config)
|
||||
.user(user_id)
|
||||
.app(app_id)
|
||||
.initializer(AppAssetsInitializer(tenant_id, app_id, assets.id))
|
||||
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
|
||||
.storage(storage, assets.id)
|
||||
.build()
|
||||
)
|
||||
|
||||
logger.info("Single-step sandbox created: id=%s, assets=%s", sandbox.vm.metadata.id, sandbox.assets_id)
|
||||
return sandbox
|
||||
|
||||
73
api/core/sandbox/sandbox.py
Normal file
73
api/core/sandbox/sandbox.py
Normal file
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.sandbox.storage.sandbox_storage import SandboxStorage
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sandbox:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vm: VirtualEnvironment,
|
||||
storage: SandboxStorage,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
app_id: str,
|
||||
assets_id: str,
|
||||
) -> None:
|
||||
self._vm = vm
|
||||
self._storage = storage
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._app_id = app_id
|
||||
self._assets_id = assets_id
|
||||
|
||||
@property
|
||||
def vm(self) -> VirtualEnvironment:
|
||||
return self._vm
|
||||
|
||||
@property
|
||||
def storage(self) -> SandboxStorage:
|
||||
return self._storage
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self._tenant_id
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str:
|
||||
return self._app_id
|
||||
|
||||
@property
|
||||
def assets_id(self) -> str:
|
||||
return self._assets_id
|
||||
|
||||
def mount(self) -> bool:
|
||||
return self._storage.mount(self._vm)
|
||||
|
||||
def unmount(self) -> bool:
|
||||
return self._storage.unmount(self._vm)
|
||||
|
||||
def release(self) -> None:
|
||||
sandbox_id = self._vm.metadata.id
|
||||
try:
|
||||
self._storage.unmount(self._vm)
|
||||
logger.info("Sandbox storage unmounted: sandbox_id=%s", sandbox_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to unmount sandbox storage: sandbox_id=%s", sandbox_id)
|
||||
|
||||
try:
|
||||
self._vm.release_environment()
|
||||
logger.info("Sandbox released: sandbox_id=%s", sandbox_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to release sandbox: sandbox_id=%s", sandbox_id)
|
||||
@ -17,7 +17,6 @@ from core.workflow.context.execution_context import (
|
||||
register_context_capturer,
|
||||
reset_context_provider,
|
||||
)
|
||||
from core.workflow.context.models import SandboxContext
|
||||
|
||||
__all__ = [
|
||||
"AppContext",
|
||||
@ -25,7 +24,6 @@ __all__ = [
|
||||
"ExecutionContext",
|
||||
"IExecutionContext",
|
||||
"NullAppContext",
|
||||
"SandboxContext",
|
||||
"capture_current_context",
|
||||
"read_context",
|
||||
"register_context",
|
||||
|
||||
@ -1,13 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
|
||||
|
||||
class SandboxContext(BaseModel):
|
||||
"""Typed context for sandbox integration. All fields optional by design."""
|
||||
|
||||
sandbox_url: AnyHttpUrl | None = None
|
||||
sandbox_token: str | None = None # optional, if later needed for auth
|
||||
|
||||
|
||||
__all__ = ["SandboxContext"]
|
||||
__all__: list[str] = []
|
||||
|
||||
@ -2,11 +2,9 @@ import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.sandbox import SandboxManager, sandbox_debug
|
||||
from core.sandbox.vm import SandboxBuilder
|
||||
from core.sandbox import sandbox_debug
|
||||
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
|
||||
from core.virtual_environment.__base.helpers import submit_command, with_connection
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
@ -24,19 +22,6 @@ COMMAND_NODE_TIMEOUT_SECONDS = 60
|
||||
class CommandNode(Node[CommandNodeData]):
|
||||
node_type = NodeType.COMMAND
|
||||
|
||||
# FIXME(Mairuis): should read sandbox from workflow run context...
|
||||
def _get_sandbox(self) -> VirtualEnvironment | None:
|
||||
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
if not workflow_execution_id:
|
||||
return None
|
||||
sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id)
|
||||
if sandbox_by_workflow_run_id is not None:
|
||||
return sandbox_by_workflow_run_id
|
||||
sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id))
|
||||
if sandbox_by_draft_id is not None:
|
||||
return sandbox_by_draft_id
|
||||
return None
|
||||
|
||||
def _render_template(self, template: str) -> str:
|
||||
parser = VariableTemplateParser(template=template)
|
||||
selectors = parser.extract_variable_selectors()
|
||||
@ -65,7 +50,7 @@ class CommandNode(Node[CommandNodeData]):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
sandbox = self._get_sandbox()
|
||||
sandbox = self.graph_runtime_state.sandbox
|
||||
if sandbox is None:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -88,12 +73,12 @@ class CommandNode(Node[CommandNodeData]):
|
||||
timeout = COMMAND_NODE_TIMEOUT_SECONDS if COMMAND_NODE_TIMEOUT_SECONDS > 0 else None
|
||||
|
||||
try:
|
||||
with with_connection(sandbox) as conn:
|
||||
with with_connection(sandbox.vm) as conn:
|
||||
command = ["bash", "-c", raw_command]
|
||||
|
||||
sandbox_debug("command_node", "command", command)
|
||||
|
||||
future = submit_command(sandbox, conn, command, cwd=working_directory)
|
||||
future = submit_command(sandbox.vm, conn, command, cwd=working_directory)
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
outputs: dict[str, Any] = {
|
||||
|
||||
@ -50,8 +50,8 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.sandbox import SandboxBashSession, SandboxManager
|
||||
from core.sandbox.vm import SandboxBuilder
|
||||
from core.sandbox import Sandbox
|
||||
from core.sandbox.bash.session import SandboxBashSession
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@ -64,7 +64,6 @@ from core.variables import (
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.entities.tool_entities import ToolCallResult
|
||||
@ -174,19 +173,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
# FIXME(Mairuis): should read sandbox from workflow run context...
|
||||
def _get_sandbox(self) -> VirtualEnvironment | None:
|
||||
workflow_execution_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
if not workflow_execution_id:
|
||||
return None
|
||||
sandbox_by_workflow_run_id = SandboxManager.get(workflow_execution_id)
|
||||
if sandbox_by_workflow_run_id is not None:
|
||||
return sandbox_by_workflow_run_id
|
||||
sandbox_by_draft_id = SandboxManager.get(SandboxBuilder.draft_id(self.user_id))
|
||||
if sandbox_by_draft_id is not None:
|
||||
return sandbox_by_draft_id
|
||||
return None
|
||||
|
||||
def _run(self) -> Generator:
|
||||
node_inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
@ -301,8 +287,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
generation_data: LLMGenerationData | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
# FIXME(Mairuis): should read sandbox from workflow run context...
|
||||
sandbox = self._get_sandbox()
|
||||
sandbox = self.graph_runtime_state.sandbox
|
||||
if sandbox:
|
||||
generator = self._invoke_llm_with_sandbox(
|
||||
sandbox=sandbox,
|
||||
@ -1839,7 +1824,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
def _invoke_llm_with_sandbox(
|
||||
self,
|
||||
sandbox: VirtualEnvironment,
|
||||
sandbox: Sandbox,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
@ -1849,23 +1834,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
result: LLMGenerationData | None = None
|
||||
|
||||
with SandboxBashSession(
|
||||
sandbox=sandbox,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
node_id=self.id,
|
||||
app_id=self.app_id,
|
||||
# FIXME(Mairuis): should read from workflow run context...
|
||||
assets_id=getattr(self, "assets_id", ""),
|
||||
allow_tools=allow_tools,
|
||||
) as sandbox_session:
|
||||
with SandboxBashSession(sandbox=sandbox, node_id=self.id, allow_tools=allow_tools) as session:
|
||||
prompt_files = self._extract_prompt_files(variable_pool)
|
||||
model_features = self._get_model_features(model_instance)
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=[sandbox_session.bash_tool],
|
||||
tools=[session.bash_tool],
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 100,
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
|
||||
@ -11,6 +11,7 @@ from typing import Any, Protocol
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
@ -171,6 +172,8 @@ class GraphRuntimeState:
|
||||
self._paused_nodes: set[str] = set()
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
self._sandbox: Sandbox | None = None
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
@ -294,6 +297,16 @@ class GraphRuntimeState:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sandbox context (workflow-scoped)
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def sandbox(self) -> Sandbox | None:
|
||||
return self._sandbox
|
||||
|
||||
def set_sandbox(self, sandbox: Sandbox) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@ -78,6 +78,10 @@ class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""Get a single output value (returns a copy)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def sandbox(self) -> Any:
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the runtime state into a JSON snapshot (read-only)."""
|
||||
...
|
||||
|
||||
@ -82,6 +82,10 @@ class ReadOnlyGraphRuntimeStateWrapper:
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
return self._state.get_output(key, default)
|
||||
|
||||
@property
|
||||
def sandbox(self) -> Any:
|
||||
return self._state.sandbox
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the underlying runtime state for external persistence."""
|
||||
return self._state.dumps()
|
||||
|
||||
@ -8,6 +8,7 @@ from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.sandbox import Sandbox
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@ -128,6 +129,7 @@ class WorkflowEntry:
|
||||
user_inputs: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
sandbox: Sandbox | None = None,
|
||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
@ -156,6 +158,9 @@ class WorkflowEntry:
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
if sandbox is not None:
|
||||
graph_runtime_state.set_sandbox(sandbox)
|
||||
|
||||
# init workflow run state
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
|
||||
@ -6,8 +6,14 @@ from typing import Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.sandbox import SandboxBuilder, SandboxType, VMConfig, create_sandbox_config_encrypter, masked_config
|
||||
from core.sandbox.entities import SandboxProviderApiEntity
|
||||
from core.sandbox import (
|
||||
SandboxBuilder,
|
||||
SandboxProviderApiEntity,
|
||||
SandboxType,
|
||||
VMConfig,
|
||||
create_sandbox_config_encrypter,
|
||||
masked_config,
|
||||
)
|
||||
from core.sandbox.entities.providers import SandboxProviderEntity
|
||||
from core.tools.utils.system_encryption import decrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
@ -206,7 +212,6 @@ class SandboxProviderService:
|
||||
raise ValueError(f"No system default provider configured for tenant {tenant_id}")
|
||||
|
||||
@classmethod
|
||||
def create_sandbox_builder(cls, tenant_id: str) -> SandboxBuilder:
|
||||
def get_sandbox_provider(cls, tenant_id: str) -> SandboxProviderEntity:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
active_config = cls.get_active_sandbox_config(session, tenant_id)
|
||||
return SandboxBuilder(tenant_id, SandboxType(active_config.provider_type)).options(active_config.config)
|
||||
return cls.get_active_sandbox_config(session, tenant_id)
|
||||
|
||||
@ -14,10 +14,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.file import File
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.sandbox import SandboxManager
|
||||
from core.sandbox.constants import APP_ASSETS_PATH
|
||||
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
|
||||
from core.sandbox.vm import SandboxBuilder
|
||||
from core.sandbox.manager import SandboxManager
|
||||
from core.variables import Variable, VariableBase
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@ -702,34 +699,15 @@ class WorkflowService:
|
||||
enclosing_node_id = None
|
||||
|
||||
sandbox = None
|
||||
single_step_execution_id: str | None = None
|
||||
if draft_workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
from core.sandbox import AppAssetsInitializer, DifyCliInitializer
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
assets = AppAssetService.get_or_create_assets(session, app_model, account.id)
|
||||
if not assets:
|
||||
raise ValueError(f"No assets found for tid={draft_workflow.tenant_id}, app_id={app_model.id}")
|
||||
|
||||
# FIXME(Mairuis): single step execution
|
||||
AppAssetService.build_assets(draft_workflow.tenant_id, app_model.id, assets)
|
||||
sandbox_id = SandboxBuilder.draft_id(account.id)
|
||||
sandbox_storage = ArchiveSandboxStorage(
|
||||
draft_workflow.tenant_id, sandbox_id, exclude_patterns=[APP_ASSETS_PATH]
|
||||
sandbox_provider = SandboxProviderService.get_sandbox_provider(draft_workflow.tenant_id)
|
||||
sandbox = SandboxManager.create_for_single_step(
|
||||
tenant_id=draft_workflow.tenant_id,
|
||||
app_id=app_model.id,
|
||||
user_id=account.id,
|
||||
sandbox_provider=sandbox_provider,
|
||||
)
|
||||
|
||||
sandbox = (
|
||||
SandboxProviderService.create_sandbox_builder(draft_workflow.tenant_id)
|
||||
.initializer(DifyCliInitializer(draft_workflow.tenant_id, account.id, app_model.id, assets.id))
|
||||
.initializer(AppAssetsInitializer(draft_workflow.tenant_id, app_model.id, assets.id))
|
||||
.build()
|
||||
)
|
||||
sandbox_storage.mount(sandbox)
|
||||
single_step_execution_id = f"single-step-{uuid.uuid4()}"
|
||||
|
||||
SandboxManager.register(single_step_execution_id, sandbox)
|
||||
variable_pool.system_variables.workflow_execution_id = single_step_execution_id
|
||||
|
||||
try:
|
||||
node, generator = WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
@ -738,6 +716,7 @@ class WorkflowService:
|
||||
user_id=account.id,
|
||||
variable_pool=variable_pool,
|
||||
variable_loader=variable_loader,
|
||||
sandbox=sandbox,
|
||||
)
|
||||
|
||||
# Run draft workflow node
|
||||
@ -747,17 +726,9 @@ class WorkflowService:
|
||||
start_at=start_at,
|
||||
node_id=node_id,
|
||||
)
|
||||
# FIXME(Mairuis): fidn a better way to handle this
|
||||
if sandbox is not None:
|
||||
sandbox_storage.unmount(sandbox)
|
||||
finally:
|
||||
if single_step_execution_id:
|
||||
sandbox = SandboxManager.unregister(single_step_execution_id)
|
||||
if sandbox:
|
||||
try:
|
||||
sandbox.release_environment()
|
||||
except Exception:
|
||||
logger.exception("Failed to release sandbox")
|
||||
if sandbox is not None:
|
||||
sandbox.release()
|
||||
|
||||
# Set workflow_id on the NodeExecution
|
||||
node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
@ -1,164 +0,0 @@
|
||||
import threading
|
||||
from collections.abc import Mapping
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.sandbox import SandboxManager
|
||||
from core.virtual_environment.__base.entities import (
|
||||
Arch,
|
||||
CommandStatus,
|
||||
ConnectionHandle,
|
||||
FileState,
|
||||
Metadata,
|
||||
OperatingSystem,
|
||||
)
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
class FakeVirtualEnvironment(VirtualEnvironment):
|
||||
def __init__(self, sandbox_id: str = "fake-id"):
|
||||
self._sandbox_id = sandbox_id
|
||||
super().__init__(tenant_id="test-tenant", options={}, environments={})
|
||||
|
||||
def _construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
return Metadata(id=self._sandbox_id, arch=Arch.AMD64, os=OperatingSystem.LINUX)
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> list[FileState]:
|
||||
return []
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
return ConnectionHandle(id="conn")
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
pass
|
||||
|
||||
def release_environment(self) -> None:
|
||||
pass
|
||||
|
||||
def execute_command(
|
||||
self,
|
||||
connection_handle: ConnectionHandle,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> tuple[str, Any, Any, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=0)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, options: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_sandbox_manager():
|
||||
SandboxManager.clear()
|
||||
yield
|
||||
SandboxManager.clear()
|
||||
|
||||
|
||||
class TestSandboxManager:
|
||||
def test_register_and_get(self):
|
||||
sandbox = FakeVirtualEnvironment("sandbox-1")
|
||||
|
||||
SandboxManager.register("exec-1", sandbox)
|
||||
result = SandboxManager.get("exec-1")
|
||||
|
||||
assert result is sandbox
|
||||
|
||||
def test_get_returns_none_for_unknown_id(self):
|
||||
result = SandboxManager.get("unknown-id")
|
||||
assert result is None
|
||||
|
||||
def test_register_raises_on_empty_workflow_execution_id(self):
|
||||
sandbox = FakeVirtualEnvironment()
|
||||
|
||||
with pytest.raises(ValueError, match="workflow_execution_id cannot be empty"):
|
||||
SandboxManager.register("", sandbox)
|
||||
|
||||
def test_register_raises_on_duplicate(self):
|
||||
sandbox1 = FakeVirtualEnvironment("sandbox-1")
|
||||
sandbox2 = FakeVirtualEnvironment("sandbox-2")
|
||||
|
||||
SandboxManager.register("exec-dup", sandbox1)
|
||||
|
||||
with pytest.raises(RuntimeError, match="already registered"):
|
||||
SandboxManager.register("exec-dup", sandbox2)
|
||||
|
||||
def test_unregister_returns_sandbox(self):
|
||||
sandbox = FakeVirtualEnvironment("sandbox-to-remove")
|
||||
SandboxManager.register("exec-remove", sandbox)
|
||||
|
||||
result = SandboxManager.unregister("exec-remove")
|
||||
|
||||
assert result is sandbox
|
||||
assert SandboxManager.get("exec-remove") is None
|
||||
|
||||
def test_unregister_returns_none_for_unknown(self):
|
||||
result = SandboxManager.unregister("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_has_returns_true_when_registered(self):
|
||||
sandbox = FakeVirtualEnvironment()
|
||||
SandboxManager.register("exec-has", sandbox)
|
||||
|
||||
assert SandboxManager.has("exec-has") is True
|
||||
|
||||
def test_has_returns_false_when_not_registered(self):
|
||||
assert SandboxManager.has("exec-no") is False
|
||||
|
||||
def test_clear_removes_all_sandboxes(self):
|
||||
sandbox1 = FakeVirtualEnvironment("s1")
|
||||
sandbox2 = FakeVirtualEnvironment("s2")
|
||||
SandboxManager.register("exec-1", sandbox1)
|
||||
SandboxManager.register("exec-2", sandbox2)
|
||||
|
||||
SandboxManager.clear()
|
||||
|
||||
assert SandboxManager.count() == 0
|
||||
assert SandboxManager.get("exec-1") is None
|
||||
assert SandboxManager.get("exec-2") is None
|
||||
|
||||
def test_count_returns_number_of_sandboxes(self):
|
||||
assert SandboxManager.count() == 0
|
||||
|
||||
SandboxManager.register("e1", FakeVirtualEnvironment("s1"))
|
||||
assert SandboxManager.count() == 1
|
||||
|
||||
SandboxManager.register("e2", FakeVirtualEnvironment("s2"))
|
||||
assert SandboxManager.count() == 2
|
||||
|
||||
SandboxManager.unregister("e1")
|
||||
assert SandboxManager.count() == 1
|
||||
|
||||
def test_thread_safety(self):
|
||||
results: list[bool] = []
|
||||
errors: list[Exception] = []
|
||||
|
||||
def register_sandbox(exec_id: str):
|
||||
try:
|
||||
sandbox = FakeVirtualEnvironment(f"sandbox-{exec_id}")
|
||||
SandboxManager.register(exec_id, sandbox)
|
||||
results.append(True)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=register_sandbox, args=(f"exec-{i}",)) for i in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0
|
||||
assert len(results) == 10
|
||||
assert SandboxManager.count() == 10
|
||||
Reference in New Issue
Block a user