mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
feat(sandbox): skill initialize & draft run
This commit is contained in:
@ -523,6 +523,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
SandboxLayer(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_version=workflow.version,
|
||||
sandbox_id=application_generate_entity.workflow_run_id,
|
||||
sandbox_storage=ArchiveSandboxStorage(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
|
||||
@ -497,6 +497,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
SandboxLayer(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_version=workflow.version,
|
||||
sandbox_id=application_generate_entity.workflow_execution_id,
|
||||
sandbox_storage=ArchiveSandboxStorage(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
import logging
|
||||
|
||||
from core.sandbox import SandboxManager
|
||||
from core.sandbox import AppAssetsInitializer, DifyCliInitializer, SandboxManager
|
||||
from core.sandbox.storage.sandbox_storage import SandboxStorage
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from models.workflow import Workflow
|
||||
from services.app_asset_service import AppAssetService
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,10 +18,18 @@ class SandboxInitializationError(Exception):
|
||||
|
||||
|
||||
class SandboxLayer(GraphEngineLayer):
|
||||
def __init__(self, tenant_id: str, app_id: str, sandbox_id: str, sandbox_storage: SandboxStorage) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_version: str,
|
||||
sandbox_id: str,
|
||||
sandbox_storage: SandboxStorage,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._workflow_version = workflow_version
|
||||
self._sandbox_id = sandbox_id
|
||||
self._sandbox_storage = sandbox_storage
|
||||
|
||||
@ -31,16 +42,34 @@ class SandboxLayer(GraphEngineLayer):
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
try:
|
||||
# Initialize sandbox
|
||||
from core.sandbox import AppAssetsInitializer, DifyCliInitializer
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
is_draft = self._workflow_version == Workflow.VERSION_DRAFT
|
||||
assets = AppAssetService.get_assets(self._tenant_id, self._app_id, is_draft=is_draft)
|
||||
if not assets:
|
||||
raise ValueError(
|
||||
f"No assets found for tid={self._tenant_id}, app_id={self._app_id}, wf={self._workflow_version}"
|
||||
)
|
||||
if is_draft:
|
||||
logger.info(
|
||||
"Building draft assets for tenant_id=%s, app_id=%s, workflow_version=%s, assets_id=%s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._workflow_version,
|
||||
assets.id,
|
||||
)
|
||||
AppAssetService.build_assets(self._tenant_id, self._app_id, assets)
|
||||
|
||||
logger.info("Initializing sandbox for tenant_id=%s, app_id=%s", self._tenant_id, self._app_id)
|
||||
logger.info(
|
||||
"Initializing sandbox for tenant_id=%s, app_id=%s, workflow_version=%s, assets_id=%s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._workflow_version,
|
||||
assets.id,
|
||||
)
|
||||
|
||||
builder = (
|
||||
SandboxProviderService.create_sandbox_builder(self._tenant_id)
|
||||
.initializer(DifyCliInitializer())
|
||||
.initializer(AppAssetsInitializer(self._tenant_id, self._app_id))
|
||||
.initializer(AppAssetsInitializer(self._tenant_id, self._app_id, assets.id))
|
||||
.initializer(DifyCliInitializer(self._tenant_id, self._app_id, assets.id))
|
||||
)
|
||||
sandbox = builder.build()
|
||||
|
||||
@ -65,10 +94,6 @@ class SandboxLayer(GraphEngineLayer):
|
||||
return
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
if self._sandbox_id is None:
|
||||
logger.debug("No workflow_execution_id set, nothing to release")
|
||||
return
|
||||
|
||||
sandbox = SandboxManager.unregister(self._sandbox_id)
|
||||
if sandbox is None:
|
||||
logger.debug("No sandbox to release for sandbox_id=%s", self._sandbox_id)
|
||||
|
||||
@ -61,7 +61,6 @@ class SkillMetadata(BaseModel):
|
||||
class SkillAsset(AssetItem):
|
||||
storage_key: str
|
||||
metadata: SkillMetadata
|
||||
content: str
|
||||
tool_references: list[ToolReference] = field(default_factory=list)
|
||||
file_references: list[FileReference] = field(default_factory=list)
|
||||
|
||||
|
||||
@ -1,34 +1,20 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app_assets.entities import AssetItem
|
||||
from core.app_assets.paths import AssetPaths
|
||||
|
||||
from .base import AssetItemParser, FileAssetParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_storage import Storage
|
||||
|
||||
|
||||
class AssetParser:
|
||||
_tree: AppAssetFileTree
|
||||
_tenant_id: str
|
||||
_app_id: str
|
||||
_storage: "Storage"
|
||||
_parsers: dict[str, AssetItemParser]
|
||||
_default_parser: AssetItemParser
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tree: AppAssetFileTree,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
storage: "Storage",
|
||||
) -> None:
|
||||
self._tree = tree
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._storage = storage
|
||||
self._parsers = {}
|
||||
self._default_parser = FileAssetParser()
|
||||
|
||||
@ -41,11 +27,10 @@ class AssetParser:
|
||||
for node in self._tree.walk_files():
|
||||
path = self._tree.get_path(node.id).lstrip("/")
|
||||
storage_key = AssetPaths.draft_file(self._tenant_id, self._app_id, node.id)
|
||||
raw_bytes = self._storage.load_once(storage_key)
|
||||
extension = node.extension or ""
|
||||
|
||||
parser = self._parsers.get(extension, self._default_parser)
|
||||
asset = parser.parse(node.id, path, node.name, extension, storage_key, raw_bytes)
|
||||
asset = parser.parse(node.id, path, node.name, extension, storage_key)
|
||||
assets.append(asset)
|
||||
|
||||
return assets
|
||||
|
||||
@ -12,7 +12,6 @@ class AssetItemParser(ABC):
|
||||
file_name: str,
|
||||
extension: str,
|
||||
storage_key: str,
|
||||
raw_bytes: bytes,
|
||||
) -> AssetItem:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -25,7 +24,6 @@ class FileAssetParser(AssetItemParser):
|
||||
file_name: str,
|
||||
extension: str,
|
||||
storage_key: str,
|
||||
raw_bytes: bytes,
|
||||
) -> FileAsset:
|
||||
return FileAsset(
|
||||
node_id=node_id,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from core.app_assets.entities import (
|
||||
FileReference,
|
||||
@ -9,36 +10,26 @@ from core.app_assets.entities import (
|
||||
ToolReference,
|
||||
)
|
||||
from core.app_assets.paths import AssetPaths
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from .base import AssetItemParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_storage import Storage
|
||||
|
||||
TOOL_REFERENCE_PATTERN = re.compile(r"§\[tool\]\.\[([^\]]+)\]\.\[([^\]]+)\]\.\[([^\]]+)\]§")
|
||||
FILE_REFERENCE_PATTERN = re.compile(r"§\[file\]\.\[([^\]]+)\]\.\[([^\]]+)\]§")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillAssetParser(AssetItemParser):
|
||||
_tenant_id: str
|
||||
_app_id: str
|
||||
_publish_id: str
|
||||
_storage: "Storage"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
publish_id: str,
|
||||
storage: "Storage",
|
||||
assets_id: str,
|
||||
) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._publish_id = publish_id
|
||||
self._storage = storage
|
||||
|
||||
def _get_resolved_key(self, node_id: str) -> str:
|
||||
return AssetPaths.published_resolved_file(self._tenant_id, self._app_id, self._publish_id, node_id)
|
||||
self._assets_id = assets_id
|
||||
|
||||
def parse(
|
||||
self,
|
||||
@ -47,12 +38,40 @@ class SkillAssetParser(AssetItemParser):
|
||||
file_name: str,
|
||||
extension: str,
|
||||
storage_key: str,
|
||||
raw_bytes: bytes,
|
||||
) -> SkillAsset:
|
||||
try:
|
||||
data = json.loads(raw_bytes.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
raise ValueError(f"Invalid skill document JSON for {node_id}: {e}") from e
|
||||
return self._parse_skill_asset(node_id, path, file_name, extension, storage_key)
|
||||
except Exception:
|
||||
logger.exception("Failed to parse skill asset %s: %s", node_id)
|
||||
# handle as plain text
|
||||
return SkillAsset(
|
||||
node_id=node_id,
|
||||
path=path,
|
||||
file_name=file_name,
|
||||
extension=extension,
|
||||
storage_key=storage_key,
|
||||
metadata=SkillMetadata(),
|
||||
tool_references=[],
|
||||
file_references=[],
|
||||
)
|
||||
|
||||
def _parse_skill_asset(
|
||||
self, node_id: str, path: str, file_name: str, extension: str, storage_key: str
|
||||
) -> SkillAsset:
|
||||
try:
|
||||
data = json.loads(storage.load_once(storage_key))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
# handle as plain text
|
||||
return SkillAsset(
|
||||
node_id=node_id,
|
||||
path=path,
|
||||
file_name=file_name,
|
||||
extension=extension,
|
||||
storage_key=storage_key,
|
||||
metadata=SkillMetadata(),
|
||||
tool_references=[],
|
||||
file_references=[],
|
||||
)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Skill document {node_id} must be a JSON object")
|
||||
@ -66,30 +85,12 @@ class SkillAssetParser(AssetItemParser):
|
||||
|
||||
metadata = SkillMetadata.model_validate(metadata_raw)
|
||||
|
||||
tool_references: list[ToolReference] = []
|
||||
for match in TOOL_REFERENCE_PATTERN.finditer(content):
|
||||
tool_references.append(
|
||||
ToolReference(
|
||||
provider=match.group(1),
|
||||
tool_name=match.group(2),
|
||||
uuid=match.group(3),
|
||||
raw=match.group(0),
|
||||
)
|
||||
)
|
||||
|
||||
file_references: list[FileReference] = []
|
||||
for match in FILE_REFERENCE_PATTERN.finditer(content):
|
||||
file_references.append(
|
||||
FileReference(
|
||||
source=match.group(1),
|
||||
uuid=match.group(2),
|
||||
raw=match.group(0),
|
||||
)
|
||||
)
|
||||
tool_references: list[ToolReference] = self._parse_tool_references(content)
|
||||
file_references: list[FileReference] = self._parse_file_references(content)
|
||||
|
||||
resolved_content = self._resolve_content(content, tool_references, file_references)
|
||||
resolved_key = self._get_resolved_key(node_id)
|
||||
self._storage.save(resolved_key, resolved_content.encode("utf-8"))
|
||||
resolved_key = AssetPaths.build_resolved_file(self._tenant_id, self._app_id, self._assets_id, node_id)
|
||||
storage.save(resolved_key, resolved_content.encode("utf-8"))
|
||||
|
||||
return SkillAsset(
|
||||
node_id=node_id,
|
||||
@ -98,7 +99,6 @@ class SkillAssetParser(AssetItemParser):
|
||||
extension=extension,
|
||||
storage_key=resolved_key,
|
||||
metadata=metadata,
|
||||
content=resolved_content,
|
||||
tool_references=tool_references,
|
||||
file_references=file_references,
|
||||
)
|
||||
@ -110,7 +110,7 @@ class SkillAssetParser(AssetItemParser):
|
||||
file_references: list[FileReference],
|
||||
) -> str:
|
||||
for ref in tool_references:
|
||||
replacement = f"{ref.provider}/{ref.tool_name}"
|
||||
replacement = f"{ref.tool_name}"
|
||||
content = content.replace(ref.raw, replacement)
|
||||
|
||||
for ref in file_references:
|
||||
@ -118,3 +118,29 @@ class SkillAssetParser(AssetItemParser):
|
||||
content = content.replace(ref.raw, replacement)
|
||||
|
||||
return content
|
||||
|
||||
def _parse_tool_references(self, content: str) -> list[ToolReference]:
|
||||
tool_references: list[ToolReference] = []
|
||||
for match in TOOL_REFERENCE_PATTERN.finditer(content):
|
||||
tool_references.append(
|
||||
ToolReference(
|
||||
provider=match.group(1),
|
||||
tool_name=match.group(2),
|
||||
uuid=match.group(3),
|
||||
raw=match.group(0),
|
||||
)
|
||||
)
|
||||
|
||||
return tool_references
|
||||
|
||||
def _parse_file_references(self, content: str) -> list[FileReference]:
|
||||
file_references: list[FileReference] = []
|
||||
for match in FILE_REFERENCE_PATTERN.finditer(content):
|
||||
file_references.append(
|
||||
FileReference(
|
||||
source=match.group(1),
|
||||
uuid=match.group(2),
|
||||
raw=match.group(0),
|
||||
)
|
||||
)
|
||||
return file_references
|
||||
|
||||
@ -6,13 +6,13 @@ class AssetPaths:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/draft/{node_id}"
|
||||
|
||||
@staticmethod
|
||||
def published_zip(tenant_id: str, app_id: str, publish_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/published/{publish_id}.zip"
|
||||
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/build/{assets_id}.zip"
|
||||
|
||||
@staticmethod
|
||||
def published_resolved_file(tenant_id: str, app_id: str, publish_id: str, node_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/published/{publish_id}/resolved/{node_id}"
|
||||
def build_resolved_file(tenant_id: str, app_id: str, assets_id: str, node_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/build/{assets_id}/resolved/{node_id}"
|
||||
|
||||
@staticmethod
|
||||
def published_tool_manifest(tenant_id: str, app_id: str, publish_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/published/{publish_id}/tools.json"
|
||||
def build_tool_manifest(tenant_id: str, app_id: str, assets_id: str) -> str:
|
||||
return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/build/{assets_id}/tools.json"
|
||||
|
||||
@ -8,10 +8,12 @@ from .bash.dify_cli import (
|
||||
from .constants import (
|
||||
APP_ASSETS_PATH,
|
||||
APP_ASSETS_ZIP_PATH,
|
||||
DIFY_CLI_CONFIG_PATH,
|
||||
DIFY_CLI_CONFIG_FILENAME,
|
||||
DIFY_CLI_GLOBAL_TOOLS_PATH,
|
||||
DIFY_CLI_PATH,
|
||||
DIFY_CLI_PATH_PATTERN,
|
||||
DIFY_CLI_ROOT,
|
||||
DIFY_CLI_TOOLS_ROOT,
|
||||
)
|
||||
from .initializer import AppAssetsInitializer, DifyCliInitializer, SandboxInitializer
|
||||
from .manager import SandboxManager
|
||||
@ -24,10 +26,12 @@ from .vm import SandboxBuilder, SandboxType, VMConfig
|
||||
__all__ = [
|
||||
"APP_ASSETS_PATH",
|
||||
"APP_ASSETS_ZIP_PATH",
|
||||
"DIFY_CLI_CONFIG_PATH",
|
||||
"DIFY_CLI_CONFIG_FILENAME",
|
||||
"DIFY_CLI_GLOBAL_TOOLS_PATH",
|
||||
"DIFY_CLI_PATH",
|
||||
"DIFY_CLI_PATH_PATTERN",
|
||||
"DIFY_CLI_ROOT",
|
||||
"DIFY_CLI_TOOLS_ROOT",
|
||||
"AppAssetsInitializer",
|
||||
"ArchiveSandboxStorage",
|
||||
"DifyCliBinary",
|
||||
|
||||
@ -21,8 +21,9 @@ COMMAND_TIMEOUT_SECONDS = 60
|
||||
|
||||
|
||||
class SandboxBashTool(Tool):
|
||||
def __init__(self, sandbox: VirtualEnvironment, tenant_id: str):
|
||||
def __init__(self, sandbox: VirtualEnvironment, tenant_id: str, tools_path: str) -> None:
|
||||
self._sandbox = sandbox
|
||||
self._tools_path = tools_path
|
||||
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
@ -71,9 +72,10 @@ class SandboxBashTool(Tool):
|
||||
try:
|
||||
with with_connection(self._sandbox) as conn:
|
||||
cmd_list = ["bash", "-c", command]
|
||||
env_vars = {"PATH": f"{self._tools_path}:/usr/local/bin:/usr/bin:/bin"}
|
||||
|
||||
sandbox_debug("bash_tool", "cmd_list", cmd_list)
|
||||
future = submit_command(self._sandbox, conn, cmd_list)
|
||||
future = submit_command(self._sandbox, conn, cmd_list, environments=env_vars)
|
||||
timeout = COMMAND_TIMEOUT_SECONDS if COMMAND_TIMEOUT_SECONDS > 0 else None
|
||||
result = future.result(timeout=timeout)
|
||||
|
||||
|
||||
@ -6,8 +6,11 @@ DIFY_CLI_PATH: Final[str] = "/tmp/.dify/bin/dify"
|
||||
|
||||
DIFY_CLI_PATH_PATTERN: Final[str] = "dify-cli-{os}-{arch}"
|
||||
|
||||
DIFY_CLI_CONFIG_PATH: Final[str] = "/tmp/.dify/.dify_cli.json"
|
||||
DIFY_CLI_CONFIG_FILENAME: Final[str] = ".dify_cli.json"
|
||||
|
||||
DIFY_CLI_TOOLS_ROOT: Final[str] = "/tmp/.dify/tools"
|
||||
DIFY_CLI_GLOBAL_TOOLS_PATH: Final[str] = "/tmp/.dify/tools/global"
|
||||
|
||||
# App Assets (relative path - stays in sandbox workdir)
|
||||
APP_ASSETS_PATH: Final[str] = "assets"
|
||||
APP_ASSETS_ZIP_PATH: Final[str] = "/tmp/.dify/tmp/assets.zip"
|
||||
APP_ASSETS_ZIP_PATH: Final[str] = "/tmp/assets.zip"
|
||||
|
||||
@ -1,33 +1,25 @@
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app_assets.paths import AssetPaths
|
||||
from core.virtual_environment.__base.helpers import execute, with_connection
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.app_asset import AppAssets
|
||||
|
||||
from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH, DIFY_CLI_ROOT
|
||||
from ..constants import APP_ASSETS_PATH, APP_ASSETS_ZIP_PATH
|
||||
from .base import SandboxInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppAssetsInitializer(SandboxInitializer):
|
||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
||||
def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._assets_id = assets_id
|
||||
|
||||
def initialize(self, env: VirtualEnvironment) -> None:
|
||||
published = self._get_latest_published()
|
||||
if not published:
|
||||
logger.debug("No published assets for app_id=%s, skipping", self._app_id)
|
||||
return
|
||||
|
||||
zip_key = AssetPaths.published_zip(self._tenant_id, self._app_id, published.id)
|
||||
zip_key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id)
|
||||
try:
|
||||
zip_data = storage.load_once(zip_key)
|
||||
except Exception:
|
||||
@ -42,18 +34,6 @@ class AppAssetsInitializer(SandboxInitializer):
|
||||
env.upload_file(APP_ASSETS_ZIP_PATH, BytesIO(zip_data))
|
||||
|
||||
with with_connection(env) as conn:
|
||||
execute(
|
||||
env,
|
||||
["mkdir", "-p", f"{DIFY_CLI_ROOT}/tmp"],
|
||||
connection=conn,
|
||||
error_message="Failed to create temp directory",
|
||||
)
|
||||
execute(
|
||||
env,
|
||||
["mkdir", "-p", APP_ASSETS_PATH],
|
||||
connection=conn,
|
||||
error_message="Failed to create assets directory",
|
||||
)
|
||||
execute(
|
||||
env,
|
||||
["unzip", "-o", APP_ASSETS_ZIP_PATH, "-d", APP_ASSETS_PATH],
|
||||
@ -71,18 +51,5 @@ class AppAssetsInitializer(SandboxInitializer):
|
||||
logger.info(
|
||||
"App assets initialized for app_id=%s, published_id=%s",
|
||||
self._app_id,
|
||||
published.id,
|
||||
self._assets_id,
|
||||
)
|
||||
|
||||
def _get_latest_published(self) -> AppAssets | None:
|
||||
with Session(db.engine) as session:
|
||||
return (
|
||||
session.query(AppAssets)
|
||||
.filter(
|
||||
AppAssets.tenant_id == self._tenant_id,
|
||||
AppAssets.app_id == self._app_id,
|
||||
AppAssets.version != AppAssets.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(AppAssets.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -1,37 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from core.virtual_environment.__base.helpers import execute
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app_assets.entities import ToolType
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
from core.skill.entities import ToolManifest
|
||||
from core.skill.skill_manager import SkillManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.virtual_environment.__base.helpers import execute, with_connection
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
from ..bash.dify_cli import DifyCliLocator
|
||||
from ..constants import DIFY_CLI_PATH, DIFY_CLI_ROOT
|
||||
from ..bash.dify_cli import DifyCliConfig, DifyCliLocator
|
||||
from ..constants import (
|
||||
DIFY_CLI_CONFIG_FILENAME,
|
||||
DIFY_CLI_GLOBAL_TOOLS_PATH,
|
||||
DIFY_CLI_PATH,
|
||||
DIFY_CLI_ROOT,
|
||||
)
|
||||
from .base import SandboxInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyCliInitializer(SandboxInitializer):
|
||||
def __init__(self, cli_root: str | Path | None = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
assets_id: str,
|
||||
cli_root: str | Path | None = None,
|
||||
) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._assets_id = assets_id
|
||||
self._locator = DifyCliLocator(root=cli_root)
|
||||
|
||||
self._tools = []
|
||||
self._cli_api_session = None
|
||||
|
||||
def initialize(self, env: VirtualEnvironment) -> None:
|
||||
binary = self._locator.resolve(env.metadata.os, env.metadata.arch)
|
||||
|
||||
execute(
|
||||
env,
|
||||
["mkdir", "-p", f"{DIFY_CLI_ROOT}/bin"],
|
||||
timeout=10,
|
||||
error_message="Failed to create dify CLI directory",
|
||||
)
|
||||
with with_connection(env) as conn:
|
||||
execute(
|
||||
env,
|
||||
["mkdir", "-p", f"{DIFY_CLI_ROOT}/bin"],
|
||||
connection=conn,
|
||||
timeout=10,
|
||||
error_message="Failed to create dify CLI directory",
|
||||
)
|
||||
|
||||
env.upload_file(DIFY_CLI_PATH, BytesIO(binary.path.read_bytes()))
|
||||
env.upload_file(DIFY_CLI_PATH, BytesIO(binary.path.read_bytes()))
|
||||
|
||||
execute(
|
||||
env,
|
||||
["chmod", "+x", DIFY_CLI_PATH],
|
||||
timeout=10,
|
||||
error_message="Failed to mark dify CLI as executable",
|
||||
)
|
||||
logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH)
|
||||
execute(
|
||||
env,
|
||||
["chmod", "+x", DIFY_CLI_PATH],
|
||||
connection=conn,
|
||||
timeout=10,
|
||||
error_message="Failed to mark dify CLI as executable",
|
||||
)
|
||||
|
||||
logger.info("Dify CLI uploaded to sandbox, path=%s", DIFY_CLI_PATH)
|
||||
|
||||
manifest = SkillManager.load_tool_manifest(
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._assets_id,
|
||||
)
|
||||
|
||||
if manifest is None or not manifest.tools:
|
||||
logger.info("No tools found in manifest for assets_id=%s", self._assets_id)
|
||||
return
|
||||
|
||||
self._tools = self._resolve_tools_from_manifest(manifest)
|
||||
self._cli_api_session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id="system")
|
||||
|
||||
execute(
|
||||
env,
|
||||
["mkdir", "-p", DIFY_CLI_GLOBAL_TOOLS_PATH],
|
||||
connection=conn,
|
||||
timeout=10,
|
||||
error_message="Failed to create global tools directory",
|
||||
)
|
||||
|
||||
config_json = json.dumps(
|
||||
DifyCliConfig.create(self._cli_api_session, self._tools).model_dump(mode="json"), ensure_ascii=False
|
||||
)
|
||||
env.upload_file(
|
||||
f"{DIFY_CLI_GLOBAL_TOOLS_PATH}/{DIFY_CLI_CONFIG_FILENAME}", BytesIO(config_json.encode("utf-8"))
|
||||
)
|
||||
|
||||
execute(
|
||||
env,
|
||||
[DIFY_CLI_PATH, "init"],
|
||||
connection=conn,
|
||||
timeout=30,
|
||||
cwd=DIFY_CLI_GLOBAL_TOOLS_PATH,
|
||||
error_message="Failed to initialize Dify CLI",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Global tools initialized, path=%s, tool_count=%d",
|
||||
DIFY_CLI_GLOBAL_TOOLS_PATH,
|
||||
len(self._tools),
|
||||
)
|
||||
|
||||
def _resolve_tools_from_manifest(self, manifest: ToolManifest) -> list[Tool]:
|
||||
tools: list[Tool] = []
|
||||
|
||||
for entry in manifest.tools.values():
|
||||
if entry.provider is None or entry.tool_name is None:
|
||||
logger.warning("Skipping tool entry with missing provider or tool_name: %s", entry.uuid)
|
||||
continue
|
||||
|
||||
try:
|
||||
provider_type = self._convert_tool_type(entry.type)
|
||||
tool = ToolManager.get_tool_runtime(
|
||||
tenant_id=self._tenant_id,
|
||||
provider_type=provider_type,
|
||||
provider_id=entry.provider,
|
||||
tool_name=entry.tool_name,
|
||||
invoke_from=InvokeFrom.AGENT,
|
||||
credential_id=entry.credential_id,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to resolve tool %s/%s: %s", entry.provider, entry.tool_name, e)
|
||||
continue
|
||||
|
||||
return tools
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_type(tool_type: ToolType) -> ToolProviderType:
|
||||
match tool_type:
|
||||
case ToolType.BUILTIN:
|
||||
return ToolProviderType.BUILT_IN
|
||||
case ToolType.MCP:
|
||||
return ToolProviderType.MCP
|
||||
case _:
|
||||
raise ValueError(f"Unsupported tool type: {tool_type}")
|
||||
|
||||
@ -1,79 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
from core.virtual_environment.__base.helpers import execute
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
from .bash.dify_cli import DifyCliConfig
|
||||
from .constants import DIFY_CLI_CONFIG_PATH, DIFY_CLI_PATH
|
||||
from .constants import (
|
||||
DIFY_CLI_GLOBAL_TOOLS_PATH,
|
||||
)
|
||||
from .manager import SandboxManager
|
||||
from .utils.debug import sandbox_debug
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
from .bash.bash_tool import SandboxBashTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxSession:
|
||||
_workflow_execution_id: str
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
_node_id: str | None
|
||||
_allow_tools: list[str] | None
|
||||
|
||||
_sandbox: VirtualEnvironment | None
|
||||
_bash_tool: SandboxBashTool | None
|
||||
_session_id: str | None
|
||||
_tools_path: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workflow_execution_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
tools: list[Tool],
|
||||
node_id: str | None = None,
|
||||
allow_tools: list[str] | None = None,
|
||||
) -> None:
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._tools = tools
|
||||
self._node_id = node_id
|
||||
self._allow_tools = allow_tools
|
||||
|
||||
self._sandbox: VirtualEnvironment | None = None
|
||||
self._bash_tool: SandboxBashTool | None = None
|
||||
self._session_id: str | None = None
|
||||
self._sandbox = None
|
||||
self._bash_tool = None
|
||||
self._session_id = None
|
||||
self._tools_path = DIFY_CLI_GLOBAL_TOOLS_PATH
|
||||
|
||||
def __enter__(self) -> SandboxSession:
|
||||
sandbox = SandboxManager.get(self._workflow_execution_id)
|
||||
if sandbox is None:
|
||||
raise RuntimeError(f"Sandbox not found for workflow_execution_id={self._workflow_execution_id}")
|
||||
|
||||
session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id)
|
||||
self._session_id = session.id
|
||||
self._sandbox = sandbox
|
||||
|
||||
try:
|
||||
config = DifyCliConfig.create(session, self._tools)
|
||||
config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False)
|
||||
|
||||
sandbox_debug("sandbox", "config_json", config_json)
|
||||
sandbox.upload_file(DIFY_CLI_CONFIG_PATH, BytesIO(config_json.encode("utf-8")))
|
||||
|
||||
execute(
|
||||
sandbox,
|
||||
[DIFY_CLI_PATH, "init"],
|
||||
timeout=30,
|
||||
error_message="Failed to initialize Dify CLI in sandbox",
|
||||
)
|
||||
|
||||
except Exception:
|
||||
CliApiSessionManager().delete(session.id)
|
||||
self._session_id = None
|
||||
raise
|
||||
if self._allow_tools is not None:
|
||||
# TODO: Implement node tools directory setup
|
||||
if self._node_id is None:
|
||||
raise ValueError("node_id is required when allow_tools is specified")
|
||||
# self._tools_path = self._setup_node_tools_directory(sandbox, self._node_id, self._allow_tools)
|
||||
else:
|
||||
self._tools_path = DIFY_CLI_GLOBAL_TOOLS_PATH
|
||||
|
||||
from .bash.bash_tool import SandboxBashTool
|
||||
|
||||
self._sandbox = sandbox
|
||||
self._bash_tool = SandboxBashTool(sandbox=sandbox, tenant_id=self._tenant_id)
|
||||
self._bash_tool = SandboxBashTool(sandbox=sandbox, tenant_id=self._tenant_id, tools_path=self._tools_path)
|
||||
return self
|
||||
|
||||
def _setup_node_tools_directory(
|
||||
self,
|
||||
sandbox: VirtualEnvironment,
|
||||
node_id: str,
|
||||
allow_tools: list[str],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_name_from_config(tool_config: dict) -> str:
|
||||
identity = tool_config.get("identity", {})
|
||||
provider = identity.get("provider", "")
|
||||
name = identity.get("name", "")
|
||||
return f"{provider}__{name}"
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
|
||||
@ -22,15 +22,28 @@ class SkillManager:
|
||||
def save_tool_manifest(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
publish_id: str,
|
||||
assets_id: str,
|
||||
manifest: ToolManifest,
|
||||
) -> None:
|
||||
if not manifest.tools:
|
||||
return
|
||||
|
||||
key = AssetPaths.published_tool_manifest(tenant_id, app_id, publish_id)
|
||||
key = AssetPaths.build_tool_manifest(tenant_id, app_id, assets_id)
|
||||
storage.save(key, manifest.model_dump_json(indent=2).encode("utf-8"))
|
||||
|
||||
@staticmethod
|
||||
def load_tool_manifest(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
assets_id: str,
|
||||
) -> ToolManifest | None:
|
||||
key = AssetPaths.build_tool_manifest(tenant_id, app_id, assets_id)
|
||||
try:
|
||||
data = storage.load_once(key)
|
||||
return ToolManifest.model_validate_json(data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _collect_asset_manifest(asset: SkillAsset) -> ToolManifest:
|
||||
tools: dict[str, ToolManifestEntry] = {}
|
||||
|
||||
@ -13,7 +13,7 @@ from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.patterns import StrategyFactory
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
@ -1580,34 +1580,17 @@ class LLMNode(Node[LLMNodeData]):
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
return result
|
||||
|
||||
def _prepare_sandbox_tools(self) -> list[Tool]:
|
||||
"""Prepare sandbox tools."""
|
||||
tool_instances = []
|
||||
def _get_allow_tools_list(self) -> list[str] | None:
|
||||
if not self._node_data.tools:
|
||||
return None
|
||||
|
||||
for tool in self._node_data.tools or []:
|
||||
try:
|
||||
# Get tool runtime from ToolManager
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
tool_name=tool.tool_name,
|
||||
provider_id=tool.provider_name,
|
||||
provider_type=tool.type,
|
||||
invoke_from=InvokeFrom.AGENT,
|
||||
credential_id=tool.credential_id,
|
||||
)
|
||||
allow_tools = []
|
||||
for tool in self._node_data.tools:
|
||||
if tool.enabled:
|
||||
tool_name = f"{tool.tool_name}"
|
||||
allow_tools.append(tool_name)
|
||||
|
||||
# Apply custom description from extra field if available
|
||||
if tool.extra.get("description") and tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
tool.extra.get("description") or tool_runtime.entity.description.llm
|
||||
)
|
||||
|
||||
tool_instances.append(tool_runtime)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load tool %s: %s", tool, str(e))
|
||||
continue
|
||||
|
||||
return tool_instances
|
||||
return allow_tools or None
|
||||
|
||||
def _invoke_llm_with_sandbox(
|
||||
self,
|
||||
@ -1620,7 +1603,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if not workflow_execution_id:
|
||||
raise LLMNodeError("workflow_execution_id is required for sandbox runtime mode")
|
||||
|
||||
configured_tools = self._prepare_sandbox_tools()
|
||||
allow_tools = self._get_allow_tools_list()
|
||||
|
||||
result: LLMGenerationData | None = None
|
||||
|
||||
@ -1628,7 +1611,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
tools=configured_tools,
|
||||
node_id=self.id,
|
||||
allow_tools=allow_tools,
|
||||
) as sandbox_session:
|
||||
prompt_files = self._extract_prompt_files(variable_pool)
|
||||
model_features = self._get_model_features(model_instance)
|
||||
|
||||
@ -61,6 +61,27 @@ class AppAssetService:
|
||||
session.commit()
|
||||
return assets
|
||||
|
||||
@staticmethod
|
||||
def get_assets(tenant_id: str, app_id: str, *, is_draft: bool) -> AppAssets | None:
|
||||
with Session(db.engine) as session:
|
||||
if is_draft:
|
||||
stmt = session.query(AppAssets).filter(
|
||||
AppAssets.tenant_id == tenant_id,
|
||||
AppAssets.app_id == app_id,
|
||||
AppAssets.version == AppAssets.VERSION_DRAFT,
|
||||
)
|
||||
else:
|
||||
stmt = (
|
||||
session.query(AppAssets)
|
||||
.filter(
|
||||
AppAssets.tenant_id == tenant_id,
|
||||
AppAssets.app_id == app_id,
|
||||
AppAssets.version != AppAssets.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(AppAssets.created_at.desc())
|
||||
)
|
||||
return stmt.first()
|
||||
|
||||
@staticmethod
|
||||
def get_asset_tree(app_model: App, account_id: str) -> AppAssetFileTree:
|
||||
with Session(db.engine) as session:
|
||||
@ -284,10 +305,10 @@ class AppAssetService:
|
||||
session.add(published)
|
||||
session.flush()
|
||||
|
||||
parser = AssetParser(tree, tenant_id, app_id, storage)
|
||||
parser = AssetParser(tree, tenant_id, app_id)
|
||||
parser.register(
|
||||
"md",
|
||||
SkillAssetParser(tenant_id, app_id, publish_id, storage),
|
||||
SkillAssetParser(tenant_id, app_id, publish_id),
|
||||
)
|
||||
|
||||
assets = parser.parse()
|
||||
@ -306,13 +327,40 @@ class AppAssetService:
|
||||
packager = ZipPackager(storage)
|
||||
|
||||
zip_bytes = packager.package(assets)
|
||||
zip_key = AssetPaths.published_zip(tenant_id, app_id, publish_id)
|
||||
zip_key = AssetPaths.build_zip(tenant_id, app_id, publish_id)
|
||||
storage.save(zip_key, zip_bytes)
|
||||
|
||||
session.commit()
|
||||
|
||||
return published
|
||||
|
||||
@staticmethod
|
||||
def build_assets(tenant_id: str, app_id: str, assets: AppAssets) -> None:
|
||||
tree = assets.asset_tree
|
||||
|
||||
parser = AssetParser(tree, tenant_id, app_id)
|
||||
parser.register(
|
||||
"md",
|
||||
SkillAssetParser(tenant_id, app_id, assets.id),
|
||||
)
|
||||
|
||||
parsed_assets = parser.parse()
|
||||
manifest = SkillManager.generate_tool_manifest(
|
||||
assets=[asset for asset in parsed_assets if isinstance(asset, SkillAsset)]
|
||||
)
|
||||
|
||||
SkillManager.save_tool_manifest(
|
||||
tenant_id,
|
||||
app_id,
|
||||
assets.id,
|
||||
manifest,
|
||||
)
|
||||
|
||||
packager = ZipPackager(storage)
|
||||
zip_bytes = packager.package(parsed_assets)
|
||||
zip_key = AssetPaths.build_zip(tenant_id, app_id, assets.id)
|
||||
storage.save(zip_key, zip_bytes)
|
||||
|
||||
@staticmethod
|
||||
def get_file_download_url(
|
||||
app_model: App,
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.layers.sandbox_layer import SandboxInitializationError, SandboxLayer
|
||||
from core.sandbox import SandboxManager
|
||||
from core.sandbox.storage.sandbox_storage import SandboxStorage
|
||||
from core.virtual_environment.__base.entities import Arch
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
|
||||
@ -12,6 +14,7 @@ from core.workflow.graph_events.graph import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from models.app_asset import AppAssets
|
||||
|
||||
|
||||
class MockMetadata:
|
||||
@ -30,16 +33,18 @@ class MockVirtualEnvironment:
|
||||
|
||||
|
||||
class MockVMBuilder:
|
||||
def __init__(self, sandbox: VirtualEnvironment):
|
||||
_sandbox: VirtualEnvironment
|
||||
|
||||
def __init__(self, sandbox: VirtualEnvironment) -> None:
|
||||
self._sandbox = sandbox
|
||||
|
||||
def environments(self, _):
|
||||
def environments(self, _: object) -> "MockVMBuilder":
|
||||
return self
|
||||
|
||||
def initializer(self, _):
|
||||
def initializer(self, _: object) -> "MockVMBuilder":
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
def build(self) -> VirtualEnvironment:
|
||||
return self._sandbox
|
||||
|
||||
|
||||
@ -51,68 +56,107 @@ def clean_sandbox_manager():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_archive_storage():
|
||||
with patch("core.app.layers.sandbox_layer.ArchiveSandboxStorage") as mock_class:
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.mount.return_value = False
|
||||
mock_instance.unmount.return_value = True
|
||||
mock_class.return_value = mock_instance
|
||||
yield mock_instance
|
||||
def mock_sandbox_storage() -> MagicMock:
|
||||
mock_storage = MagicMock(spec=SandboxStorage)
|
||||
mock_storage.mount.return_value = False
|
||||
mock_storage.unmount.return_value = True
|
||||
return mock_storage
|
||||
|
||||
|
||||
def create_mock_builder(sandbox):
|
||||
def create_mock_builder(sandbox: Any) -> MockVMBuilder:
|
||||
return MockVMBuilder(sandbox)
|
||||
|
||||
|
||||
def create_layer(
|
||||
tenant_id: str = "test-tenant",
|
||||
app_id: str = "test-app",
|
||||
workflow_version: str = AppAssets.VERSION_DRAFT,
|
||||
sandbox_id: str = "test-sandbox",
|
||||
sandbox_storage: Any = None,
|
||||
) -> SandboxLayer:
|
||||
if sandbox_storage is None:
|
||||
sandbox_storage = MagicMock(spec=SandboxStorage)
|
||||
sandbox_storage.mount.return_value = False
|
||||
sandbox_storage.unmount.return_value = True
|
||||
return SandboxLayer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_version=workflow_version,
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_storage=sandbox_storage,
|
||||
)
|
||||
|
||||
|
||||
class TestSandboxLayer:
|
||||
def test_init_with_parameters(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox")
|
||||
def test_init_with_parameters(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
layer = create_layer(
|
||||
tenant_id="test-tenant",
|
||||
app_id="test-app",
|
||||
sandbox_id="test-sandbox",
|
||||
sandbox_storage=mock_sandbox_storage,
|
||||
)
|
||||
|
||||
assert layer._tenant_id == "test-tenant" # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._app_id == "test-app" # pyright: ignore[reportPrivateUsage]
|
||||
assert layer._sandbox_id == "test-sandbox" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_sandbox_property_raises_when_not_initialized(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox")
|
||||
def test_sandbox_property_raises_when_not_initialized(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
layer = create_layer(sandbox_storage=mock_sandbox_storage)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
_ = layer.sandbox
|
||||
|
||||
assert "Sandbox not found" in str(exc_info.value)
|
||||
|
||||
def test_sandbox_property_returns_sandbox_after_initialization(self, mock_archive_storage):
|
||||
def test_sandbox_property_returns_sandbox_after_initialization(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-id"
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id)
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MockVirtualEnvironment()
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
assert layer.sandbox is mock_sandbox
|
||||
|
||||
def test_on_graph_start_creates_sandbox_and_registers_with_manager(self, mock_archive_storage):
|
||||
def test_on_graph_start_creates_sandbox_and_registers_with_manager(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-123"
|
||||
layer = SandboxLayer(tenant_id="test-tenant-123", app_id="test-app-123", sandbox_id=sandbox_id)
|
||||
layer = create_layer(
|
||||
tenant_id="test-tenant-123",
|
||||
app_id="test-app-123",
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_storage=mock_sandbox_storage,
|
||||
)
|
||||
mock_sandbox = MockVirtualEnvironment()
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
) as mock_create:
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
) as mock_create,
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
mock_create.assert_called_once_with("test-tenant-123")
|
||||
|
||||
assert SandboxManager.get(sandbox_id) is mock_sandbox
|
||||
|
||||
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox")
|
||||
def test_on_graph_start_raises_sandbox_initialization_error_on_failure(
|
||||
self, mock_sandbox_storage: MagicMock
|
||||
) -> None:
|
||||
layer = create_layer(sandbox_storage=mock_sandbox_storage)
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
side_effect=Exception("Sandbox provider not available"),
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
side_effect=Exception("Sandbox provider not available"),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
with pytest.raises(SandboxInitializationError) as exc_info:
|
||||
layer.on_graph_start()
|
||||
@ -120,22 +164,27 @@ class TestSandboxLayer:
|
||||
assert "Failed to initialize sandbox" in str(exc_info.value)
|
||||
assert "Sandbox provider not available" in str(exc_info.value)
|
||||
|
||||
def test_on_event_is_noop(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox")
|
||||
def test_on_event_is_noop(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
layer = create_layer(sandbox_storage=mock_sandbox_storage)
|
||||
|
||||
layer.on_event(GraphRunStartedEvent())
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||
|
||||
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self, mock_archive_storage):
|
||||
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(
|
||||
self, mock_sandbox_storage: MagicMock
|
||||
) -> None:
|
||||
sandbox_id = "test-exec-456"
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id)
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
@ -146,15 +195,18 @@ class TestSandboxLayer:
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
assert not SandboxManager.has(sandbox_id)
|
||||
|
||||
def test_on_graph_end_releases_sandbox_even_on_error(self, mock_archive_storage):
|
||||
def test_on_graph_end_releases_sandbox_even_on_error(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-789"
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id)
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
@ -163,16 +215,19 @@ class TestSandboxLayer:
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
assert not SandboxManager.has(sandbox_id)
|
||||
|
||||
def test_on_graph_end_handles_release_failure_gracefully(self, mock_archive_storage):
|
||||
def test_on_graph_end_handles_release_failure_gracefully(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-fail"
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id)
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
mock_sandbox.release_environment.side_effect = Exception("Container already removed")
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
@ -180,20 +235,23 @@ class TestSandboxLayer:
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
def test_on_graph_end_noop_when_sandbox_not_registered(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="nonexistent-sandbox")
|
||||
def test_on_graph_end_noop_when_sandbox_not_registered(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
layer = create_layer(sandbox_id="nonexistent-sandbox", sandbox_storage=mock_sandbox_storage)
|
||||
|
||||
layer.on_graph_end(error=None)
|
||||
|
||||
def test_on_graph_end_is_idempotent(self, mock_archive_storage):
|
||||
def test_on_graph_end_is_idempotent(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-idempotent"
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id=sandbox_id)
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
@ -202,8 +260,8 @@ class TestSandboxLayer:
|
||||
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
def test_layer_inherits_from_graph_engine_layer(self):
|
||||
layer = SandboxLayer(tenant_id="test-tenant", app_id="test-app", sandbox_id="test-sandbox")
|
||||
def test_layer_inherits_from_graph_engine_layer(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
layer = create_layer(sandbox_storage=mock_sandbox_storage)
|
||||
|
||||
with pytest.raises(GraphEngineLayerNotInitializedError):
|
||||
_ = layer.graph_runtime_state
|
||||
@ -212,15 +270,23 @@ class TestSandboxLayer:
|
||||
|
||||
|
||||
class TestSandboxLayerIntegration:
|
||||
def test_full_lifecycle_with_mocked_provider(self, mock_archive_storage):
|
||||
def test_full_lifecycle_with_mocked_provider(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "integration-test-exec"
|
||||
layer = SandboxLayer(tenant_id="integration-tenant", app_id="integration-app", sandbox_id=sandbox_id)
|
||||
layer = create_layer(
|
||||
tenant_id="integration-tenant",
|
||||
app_id="integration-app",
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_storage=mock_sandbox_storage,
|
||||
)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata(sandbox_id="integration-sandbox")
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
@ -232,15 +298,23 @@ class TestSandboxLayerIntegration:
|
||||
assert not SandboxManager.has(sandbox_id)
|
||||
mock_sandbox.release_environment.assert_called_once()
|
||||
|
||||
def test_lifecycle_with_workflow_error(self, mock_archive_storage):
|
||||
def test_lifecycle_with_workflow_error(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "integration-error-test"
|
||||
layer = SandboxLayer(tenant_id="error-tenant", app_id="error-app", sandbox_id=sandbox_id)
|
||||
layer = create_layer(
|
||||
tenant_id="error-tenant",
|
||||
app_id="error-app",
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_storage=mock_sandbox_storage,
|
||||
)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
mock_sandbox.metadata = MockMetadata()
|
||||
|
||||
with patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
with (
|
||||
patch(
|
||||
"services.sandbox.sandbox_provider_service.SandboxProviderService.create_sandbox_builder",
|
||||
return_value=create_mock_builder(mock_sandbox),
|
||||
),
|
||||
patch("services.app_asset_service.AppAssetService.get_assets", return_value=None),
|
||||
):
|
||||
layer.on_graph_start()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user