diff --git a/api/core/app_assets/builder/skill_builder.py b/api/core/app_assets/builder/skill_builder.py index a39e5891c2..fd2a2fb946 100644 --- a/api/core/app_assets/builder/skill_builder.py +++ b/api/core/app_assets/builder/skill_builder.py @@ -5,10 +5,10 @@ each into a ``SkillDocument``, assembles a ``SkillBundle`` (with transitive tool/file dependency resolution), and returns ``AssetItem`` objects whose *content* field carries the resolved bytes in-process. -No S3 writes happen here — the only persistence is the ``SkillBundle`` -saved via ``SkillManager`` (S3 + Redis cache invalidation) so that -downstream consumers (``SkillInitializer``, ``DifyCliInitializer``) can -load it later. +The assembled ``SkillBundle`` is persisted via ``SkillManager`` +(S3 + Redis) **and** retained on the ``bundle`` property so that +callers (e.g. ``DraftAppAssetsInitializer``) can pass it directly to +``sandbox.attrs`` without a redundant Redis/S3 round-trip. """ import json @@ -29,10 +29,17 @@ logger = logging.getLogger(__name__) class SkillBuilder: _nodes: list[tuple[AppAssetNode, str]] _accessor: CachedContentAccessor + _bundle: SkillBundle | None def __init__(self, accessor: CachedContentAccessor) -> None: self._nodes = [] self._accessor = accessor + self._bundle = None + + @property + def bundle(self) -> SkillBundle | None: + """The ``SkillBundle`` produced by the last ``build()`` call, or *None*.""" + return self._bundle def accept(self, node: AppAssetNode) -> bool: return node.extension == "md" @@ -44,9 +51,9 @@ class SkillBuilder: from core.skill.skill_manager import SkillManager if not self._nodes: - SkillManager.save_bundle( - ctx.tenant_id, ctx.app_id, ctx.build_id, SkillBundle(assets_id=ctx.build_id, asset_tree=tree) - ) + bundle = SkillBundle(assets_id=ctx.build_id, asset_tree=tree) + SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle) + self._bundle = bundle return [] # Batch-load all skill draft content in one DB query (with S3 fallback on miss). @@ -69,6 +76,7 @@ class SkillBuilder: bundle = SkillBundleAssembler(tree).assemble_bundle(documents, ctx.build_id) SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle) + self._bundle = bundle items: list[AssetItem] = [] for node, path in self._nodes: diff --git a/api/core/sandbox/__init__.py b/api/core/sandbox/__init__.py index bf18a226c1..608e57546a 100644 --- a/api/core/sandbox/__init__.py +++ b/api/core/sandbox/__init__.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from .entities import AppAssets, DifyCli, SandboxProviderApiEntity, SandboxType from .initializer import ( AsyncSandboxInitializer, + SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer, ) @@ -44,6 +45,7 @@ __all__ = [ "Sandbox", "SandboxBashSession", "SandboxBuilder", + "SandboxInitializeContext", "SandboxInitializer", "SandboxManager", "SandboxProviderApiEntity", @@ -72,6 +74,7 @@ _LAZY_IMPORTS = { "Sandbox": ("core.sandbox.sandbox", "Sandbox"), "SandboxBashSession": ("core.sandbox.bash.session", "SandboxBashSession"), "SandboxBuilder": ("core.sandbox.builder", "SandboxBuilder"), + "SandboxInitializeContext": ("core.sandbox.initializer", "SandboxInitializeContext"), "SandboxInitializer": ("core.sandbox.initializer", "SandboxInitializer"), "SandboxManager": ("core.sandbox.manager", "SandboxManager"), "SandboxProviderApiEntity": ("core.sandbox.entities", "SandboxProviderApiEntity"), diff --git a/api/core/sandbox/builder.py b/api/core/sandbox/builder.py index 3d19158d45..394a4d9972 100644 --- a/api/core/sandbox/builder.py +++ b/api/core/sandbox/builder.py @@ -11,7 +11,7 @@ from core.entities.provider_entities import BasicProviderConfig from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from .entities.sandbox_type import SandboxType -from .initializer import AsyncSandboxInitializer, SandboxInitializer, SyncSandboxInitializer +from .initializer import AsyncSandboxInitializer, SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer from .sandbox import Sandbox if TYPE_CHECKING: @@ -125,10 +125,17 @@ class SandboxBuilder: assets_id=self._assets_id, ) + ctx = SandboxInitializeContext( + tenant_id=self._tenant_id, + app_id=self._app_id, + assets_id=self._assets_id, + user_id=self._user_id, + ) + # Run synchronous initializers before marking sandbox as ready. for init in self._initializers: if isinstance(init, SyncSandboxInitializer): - init.initialize(sandbox) + init.initialize(sandbox, ctx) # Run sandbox setup asynchronously so workflow execution can proceed. # Capture the Flask app before starting the thread for database access. @@ -143,7 +150,7 @@ class SandboxBuilder: if sandbox.is_cancelled(): return - init.initialize(sandbox) + init.initialize(sandbox, ctx) if sandbox.is_cancelled(): return sandbox.mount() diff --git a/api/core/sandbox/initializer/__init__.py b/api/core/sandbox/initializer/__init__.py index 6933b9f385..0d1f476cc4 100644 --- a/api/core/sandbox/initializer/__init__.py +++ b/api/core/sandbox/initializer/__init__.py @@ -1,7 +1,8 @@ -from .base import AsyncSandboxInitializer, SandboxInitializer, SyncSandboxInitializer +from .base import AsyncSandboxInitializer, SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer __all__ = [ "AsyncSandboxInitializer", + "SandboxInitializeContext", "SandboxInitializer", "SyncSandboxInitializer", ] diff --git a/api/core/sandbox/initializer/app_asset_attrs_initializer.py b/api/core/sandbox/initializer/app_asset_attrs_initializer.py index d4cb623ec5..ed4b51a01b 100644 --- a/api/core/sandbox/initializer/app_asset_attrs_initializer.py +++ b/api/core/sandbox/initializer/app_asset_attrs_initializer.py @@ -4,7 +4,7 @@ from core.app_assets.constants import AppAssetsAttrs from core.sandbox.sandbox import Sandbox from services.app_asset_package_service import AppAssetPackageService -from .base import SyncSandboxInitializer +from .base import SandboxInitializeContext, SyncSandboxInitializer logger = logging.getLogger(__name__) @@ -12,13 +12,8 @@ APP_ASSETS_DOWNLOAD_TIMEOUT = 60 * 10 class AppAssetAttrsInitializer(SyncSandboxInitializer): - def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None: - self._tenant_id = tenant_id - self._app_id = app_id - self._assets_id = assets_id - - def initialize(self, sandbox: Sandbox) -> None: + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: # Load published app assets and unzip the artifact bundle. - app_assets = AppAssetPackageService.get_tenant_app_assets(self._tenant_id, self._assets_id) + app_assets = AppAssetPackageService.get_tenant_app_assets(ctx.tenant_id, ctx.assets_id) sandbox.attrs.set(AppAssetsAttrs.FILE_TREE, app_assets.asset_tree) - sandbox.attrs.set(AppAssetsAttrs.APP_ASSETS_ID, self._assets_id) + sandbox.attrs.set(AppAssetsAttrs.APP_ASSETS_ID, ctx.assets_id) diff --git a/api/core/sandbox/initializer/app_assets_initializer.py b/api/core/sandbox/initializer/app_assets_initializer.py index c3876727cb..9c689595c0 100644 --- a/api/core/sandbox/initializer/app_assets_initializer.py +++ b/api/core/sandbox/initializer/app_assets_initializer.py @@ -5,7 +5,7 @@ from core.sandbox.sandbox import Sandbox from core.virtual_environment.__base.helpers import pipeline from ..entities import AppAssets -from .base import AsyncSandboxInitializer +from .base import AsyncSandboxInitializer, SandboxInitializeContext logger = logging.getLogger(__name__) @@ -13,18 +13,13 @@ APP_ASSETS_DOWNLOAD_TIMEOUT = 60 * 10 class AppAssetsInitializer(AsyncSandboxInitializer): - def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None: - self._tenant_id = tenant_id - self._app_id = app_id - self._assets_id = assets_id - - def initialize(self, sandbox: Sandbox) -> None: + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: from services.app_asset_service import AppAssetService # Load published app assets and unzip the artifact bundle. vm = sandbox.vm asset_storage = AppAssetService.get_storage() - key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id) + key = AssetPaths.build_zip(ctx.tenant_id, ctx.app_id, ctx.assets_id) download_url = asset_storage.get_download_url(key) ( @@ -54,6 +49,6 @@ class AppAssetsInitializer(AsyncSandboxInitializer): logger.info( "App assets initialized for app_id=%s, published_id=%s", - self._app_id, - self._assets_id, + ctx.app_id, + ctx.assets_id, ) diff --git a/api/core/sandbox/initializer/base.py b/api/core/sandbox/initializer/base.py index 93bd27b9b7..7021a814e5 100644 --- a/api/core/sandbox/initializer/base.py +++ b/api/core/sandbox/initializer/base.py @@ -1,11 +1,41 @@ +from __future__ import annotations + from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING from core.sandbox.sandbox import Sandbox +if TYPE_CHECKING: + from core.app_assets.entities.assets import AssetItem + + +@dataclass +class SandboxInitializeContext: + """Shared identity context passed to every ``SandboxInitializer``. + + Carries the common identity fields that virtually every initializer + needs, plus optional artefact slots that sync initializers populate + for async initializers to consume. + + Identity fields are immutable by convention; artefact slots are + written at most once during the sync phase and read during the + async phase. + """ + + tenant_id: str + app_id: str + assets_id: str + user_id: str + + # Populated by DraftAppAssetsInitializer (sync) for + # DraftAppAssetsDownloader (async) to download into the VM. + built_assets: list[AssetItem] | None = field(default=None) + class SandboxInitializer(ABC): @abstractmethod - def initialize(self, sandbox: Sandbox) -> None: ... + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: ... class SyncSandboxInitializer(SandboxInitializer): diff --git a/api/core/sandbox/initializer/dify_cli_initializer.py b/api/core/sandbox/initializer/dify_cli_initializer.py index 7e6dbf1ad1..2efabf1c14 100644 --- a/api/core/sandbox/initializer/dify_cli_initializer.py +++ b/api/core/sandbox/initializer/dify_cli_initializer.py @@ -7,36 +7,26 @@ from pathlib import Path from core.sandbox.sandbox import Sandbox from core.session.cli_api import CliApiSessionManager, CliContext +from core.skill.constants import SkillAttrs from core.skill.entities import ToolAccessPolicy -from core.skill.skill_manager import SkillManager from core.virtual_environment.__base.helpers import pipeline from ..bash.dify_cli import DifyCliConfig, DifyCliLocator from ..entities import DifyCli -from .base import SyncSandboxInitializer +from .base import AsyncSandboxInitializer, SandboxInitializeContext logger = logging.getLogger(__name__) -class DifyCliInitializer(SyncSandboxInitializer): - def __init__( - self, - tenant_id: str, - user_id: str, - app_id: str, - assets_id: str, - cli_root: str | Path | None = None, - ) -> None: - self._tenant_id = tenant_id - self._app_id = app_id - self._user_id = user_id - self._assets_id = assets_id - self._locator = DifyCliLocator(root=cli_root) +class DifyCliInitializer(AsyncSandboxInitializer): + _cli_api_session: object | None - self._tools = [] + def __init__(self, cli_root: str | Path | None = None) -> None: + self._locator = DifyCliLocator(root=cli_root) + self._tools: list[object] = [] self._cli_api_session = None - def initialize(self, sandbox: Sandbox) -> None: + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: vm = sandbox.vm # FIXME(Mairuis): should be more robust, effectively. binary = self._locator.resolve(vm.metadata.os, vm.metadata.arch) @@ -60,14 +50,14 @@ class DifyCliInitializer(SyncSandboxInitializer): logger.info("Dify CLI uploaded to sandbox, path=%s", DifyCli.PATH) - bundle = SkillManager.load_bundle(self._tenant_id, self._app_id, self._assets_id) + bundle = sandbox.attrs.get(SkillAttrs.BUNDLE) if bundle is None or bundle.get_tool_dependencies().is_empty(): - logger.info("No tools found in bundle for assets_id=%s", self._assets_id) + logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id) return self._cli_api_session = CliApiSessionManager().create( - tenant_id=self._tenant_id, - user_id=self._user_id, + tenant_id=ctx.tenant_id, + user_id=ctx.user_id, context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())), ) @@ -75,7 +65,7 @@ class DifyCliInitializer(SyncSandboxInitializer): ["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir" ).execute(raise_on_error=True) - config = DifyCliConfig.create(self._cli_api_session, self._tenant_id, bundle.get_tool_dependencies()) + config = DifyCliConfig.create(self._cli_api_session, ctx.tenant_id, bundle.get_tool_dependencies()) config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False) config_path = f"{DifyCli.GLOBAL_TOOLS_PATH}/{DifyCli.CONFIG_FILENAME}" vm.upload_file(config_path, BytesIO(config_json.encode("utf-8"))) diff --git a/api/core/sandbox/initializer/draft_app_assets_initializer.py b/api/core/sandbox/initializer/draft_app_assets_initializer.py index f8163ff6e4..68f4bfb4e9 100644 --- a/api/core/sandbox/initializer/draft_app_assets_initializer.py +++ b/api/core/sandbox/initializer/draft_app_assets_initializer.py @@ -1,4 +1,4 @@ -"""Async initializer that populates a draft sandbox with app asset files. +"""Synchronous initializer that compiles draft app assets. Unlike ``AppAssetsInitializer`` (which downloads a pre-built ZIP for published assets), this initializer runs the build pipeline on the fly @@ -6,11 +6,12 @@ so that ``.md`` skill documents are compiled and their resolved content is embedded directly into the download script — avoiding the S3 round-trip that was previously required for resolved keys. -Execution order guarantee: - This runs as an ``AsyncSandboxInitializer`` in the background thread. - By the time it finishes, ``SkillManager.save_bundle()`` has been - called (inside ``SkillBuilder.build()``), so subsequent initializers - like ``DifyCliInitializer`` can safely load the bundle from Redis/S3. +Execution order: + ``DraftAppAssetsInitializer`` (sync) compiles assets and publishes + the ``SkillBundle`` to ``sandbox.attrs`` in-memory, so the + downstream ``SkillInitializer`` can skip the Redis/S3 round-trip. + ``DraftAppAssetsDownloader`` (async) then pushes the compiled + artefacts into the sandbox VM in the background. """ import logging @@ -23,54 +24,68 @@ from core.app_assets.constants import AppAssetsAttrs from core.sandbox.entities import AppAssets from core.sandbox.sandbox import Sandbox from core.sandbox.services import AssetDownloadService +from core.skill import SkillAttrs from core.virtual_environment.__base.helpers import pipeline from services.app_asset_service import AppAssetService -from .base import SyncSandboxInitializer +from .base import AsyncSandboxInitializer, SandboxInitializeContext, SyncSandboxInitializer logger = logging.getLogger(__name__) -_TIMEOUT = 600 # 10 minutes - class DraftAppAssetsInitializer(SyncSandboxInitializer): - """Compile draft assets and push them into the sandbox VM. + """Compile draft assets and publish the ``SkillBundle`` to attrs. - ``.md`` (skill) files are compiled in-process and their resolved - content is embedded as base64 heredocs in the download script. - All other files are fetched from S3 via presigned URLs. + The build pipeline compiles ``.md`` skill files in-process. + The resulting ``SkillBundle`` is persisted to Redis/S3 (by + ``SkillBuilder``) **and** written to ``sandbox.attrs[BUNDLE]`` + so that ``SkillInitializer`` can read it without a round-trip. + Built asset items are stored on ``ctx.built_assets`` for the + async ``DraftAppAssetsDownloader`` to consume. """ - def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None: - self._tenant_id = tenant_id - self._app_id = app_id - self._assets_id = assets_id - - def initialize(self, sandbox: Sandbox) -> None: - vm = sandbox.vm + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: tree = sandbox.attrs.get(AppAssetsAttrs.FILE_TREE) if tree.empty(): return # --- 1. Run the build pipeline (SkillBuilder compiles .md inline) --- - accessor = AppAssetService.get_accessor(self._tenant_id, self._app_id) - build_pipeline = AssetBuildPipeline([SkillBuilder(accessor=accessor), FileBuilder()]) - ctx = BuildContext(tenant_id=self._tenant_id, app_id=self._app_id, build_id=self._assets_id) - built_assets = build_pipeline.build_all(tree, ctx) + accessor = AppAssetService.get_accessor(ctx.tenant_id, ctx.app_id) + skill_builder = SkillBuilder(accessor=accessor) + build_pipeline = AssetBuildPipeline([skill_builder, FileBuilder()]) + build_ctx = BuildContext(tenant_id=ctx.tenant_id, app_id=ctx.app_id, build_id=ctx.assets_id) + built_assets = build_pipeline.build_all(tree, build_ctx) + ctx.built_assets = built_assets - if not built_assets: + # Publish the in-memory bundle so SkillInitializer skips Redis/S3. + if skill_builder.bundle is not None: + sandbox.attrs.set(SkillAttrs.BUNDLE, skill_builder.bundle) + + +class DraftAppAssetsDownloader(AsyncSandboxInitializer): + """Download the compiled assets into the sandbox VM. + + The download script is generated by ``DraftAppAssetsInitializer`` and + includes inline base64 content for compiled skills, as well as + presigned URLs for other files. + """ + + _TIMEOUT = 600 # 10 minutes + + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: + if not ctx.built_assets: + logger.debug("No built assets found for assets_id=%s", ctx.assets_id) return - # --- 2. Convert to unified download items and execute --- - download_items = AppAssetService.to_download_items(built_assets) + download_items = AppAssetService.to_download_items(ctx.built_assets) script = AssetDownloadService.build_download_script(download_items, AppAssets.PATH) - pipeline(vm).add( + pipeline(sandbox.vm).add( ["sh", "-c", script], error_message="Failed to download draft assets", - ).execute(timeout=_TIMEOUT, raise_on_error=True) + ).execute(timeout=self._TIMEOUT, raise_on_error=True) logger.info( "Draft app assets initialized for app_id=%s, assets_id=%s", - self._app_id, - self._assets_id, + ctx.app_id, + ctx.assets_id, ) diff --git a/api/core/sandbox/initializer/skill_initializer.py b/api/core/sandbox/initializer/skill_initializer.py index 92375db78a..b6983d92c8 100644 --- a/api/core/sandbox/initializer/skill_initializer.py +++ b/api/core/sandbox/initializer/skill_initializer.py @@ -6,31 +6,29 @@ from core.sandbox.sandbox import Sandbox from core.skill import SkillAttrs from core.skill.skill_manager import SkillManager -from .base import SyncSandboxInitializer +from .base import SandboxInitializeContext, SyncSandboxInitializer logger = logging.getLogger(__name__) class SkillInitializer(SyncSandboxInitializer): - def __init__( - self, - tenant_id: str, - user_id: str, - app_id: str, - assets_id: str, - ) -> None: - self._tenant_id = tenant_id - self._app_id = app_id - self._user_id = user_id - self._assets_id = assets_id + """Ensure ``sandbox.attrs[BUNDLE]`` is populated for downstream consumers. - def initialize(self, sandbox: Sandbox) -> None: + In the draft path ``DraftAppAssetsInitializer`` already sets the + bundle on attrs from the in-memory build result, so this initializer + becomes a no-op. In the published path no prior initializer sets + it, so we fall back to ``SkillManager.load_bundle()`` (Redis/S3). + """ + + def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: + # Draft path: bundle already populated by DraftAppAssetsInitializer. + if sandbox.attrs.has(SkillAttrs.BUNDLE): + return + + # Published path: load from Redis/S3. bundle = SkillManager.load_bundle( - self._tenant_id, - self._app_id, - self._assets_id, - ) - sandbox.attrs.set( - SkillAttrs.BUNDLE, - bundle, + ctx.tenant_id, + ctx.app_id, + ctx.assets_id, ) + sandbox.attrs.set(SkillAttrs.BUNDLE, bundle) diff --git a/api/core/skill/entities/api_entities.py b/api/core/skill/entities/api_entities.py index 969aedaef9..aeb54d503f 100644 --- a/api/core/skill/entities/api_entities.py +++ b/api/core/skill/entities/api_entities.py @@ -4,14 +4,13 @@ from core.skill.entities.tool_dependencies import ToolDependency class NodeSkillInfo(BaseModel): - """Information about skills referenced by a workflow node.""" + """Information about skills referenced by a workflow node. + + Used by the whole-workflow skills endpoint to return per-node + tool dependency information. + """ node_id: str = Field(description="The node ID") tool_dependencies: list[ToolDependency] = Field( default_factory=list, description="Tool dependencies extracted from skill prompts" ) - - @staticmethod - def empty(node_id: str = "") -> "NodeSkillInfo": - """Create an empty NodeSkillInfo with no tool dependencies.""" - return NodeSkillInfo(node_id=node_id, tool_dependencies=[]) diff --git a/api/core/skill/skill_manager.py b/api/core/skill/skill_manager.py index 3145a3f31b..b480794f69 100644 --- a/api/core/skill/skill_manager.py +++ b/api/core/skill/skill_manager.py @@ -28,6 +28,6 @@ class SkillManager: @staticmethod def save_bundle(tenant_id: str, app_id: str, assets_id: str, bundle: SkillBundle) -> None: key = AssetPaths.skill_bundle(tenant_id, app_id, assets_id) - AppAssetService.get_storage().save(key, bundle.model_dump_json(indent=2).encode("utf-8")) + AppAssetService.get_storage().save(key, data=bundle.model_dump_json(indent=2).encode("utf-8")) cache_key = f"{_CACHE_PREFIX}:{tenant_id}:{app_id}:{assets_id}" redis_client.delete(cache_key) diff --git a/api/services/sandbox/sandbox_service.py b/api/services/sandbox/sandbox_service.py index 60b8877261..81c3b92c71 100644 --- a/api/services/sandbox/sandbox_service.py +++ b/api/services/sandbox/sandbox_service.py @@ -22,7 +22,7 @@ from core.sandbox.entities.providers import SandboxProviderEntity from core.sandbox.initializer.app_asset_attrs_initializer import AppAssetAttrsInitializer from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer -from core.sandbox.initializer.draft_app_assets_initializer import DraftAppAssetsInitializer +from core.sandbox.initializer.draft_app_assets_initializer import DraftAppAssetsDownloader, DraftAppAssetsInitializer from core.sandbox.initializer.skill_initializer import SkillInitializer from core.sandbox.sandbox import Sandbox from core.sandbox.storage.archive_storage import ArchiveSandboxStorage @@ -52,10 +52,10 @@ class SandboxService: .options(sandbox_provider.config) .user(user_id) .app(app_id) - .initializer(AppAssetAttrsInitializer(tenant_id, app_id, assets.id)) - .initializer(AppAssetsInitializer(tenant_id, app_id, assets.id)) - .initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id)) - .initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id)) + .initializer(AppAssetAttrsInitializer()) + .initializer(AppAssetsInitializer()) + .initializer(SkillInitializer()) + .initializer(DifyCliInitializer()) .storage(archive_storage, assets.id) .build() ) @@ -92,10 +92,11 @@ class SandboxService: .options(sandbox_provider.config) .user(user_id) .app(app_id) - .initializer(AppAssetAttrsInitializer(tenant_id, app_id, assets.id)) - .initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id)) - .initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id)) - .initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id)) + .initializer(AppAssetAttrsInitializer()) + .initializer(DraftAppAssetsInitializer()) + .initializer(DraftAppAssetsDownloader()) + .initializer(SkillInitializer()) + .initializer(DifyCliInitializer()) .storage(archive_storage, assets.id) .build() ) @@ -125,10 +126,11 @@ class SandboxService: .options(sandbox_provider.config) .user(user_id) .app(app_id) - .initializer(AppAssetAttrsInitializer(tenant_id, app_id, assets.id)) - .initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id)) - .initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id)) - .initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id)) + .initializer(AppAssetAttrsInitializer()) + .initializer(DraftAppAssetsInitializer()) + .initializer(DraftAppAssetsDownloader()) + .initializer(SkillInitializer()) + .initializer(DifyCliInitializer()) .storage(archive_storage, assets.id) .build() ) diff --git a/api/services/skill_service.py b/api/services/skill_service.py index 8e181d5cef..352f5a50ab 100644 --- a/api/services/skill_service.py +++ b/api/services/skill_service.py @@ -1,17 +1,33 @@ +"""Service for extracting tool dependencies from LLM node skill prompts. + +Two public entry points: + +- ``extract_tool_dependencies`` — takes raw node data from the client, + real-time builds a ``SkillBundle`` from current draft ``.md`` assets, + and resolves transitive tool dependencies. Used by the per-node POST + endpoint. +- ``get_workflow_skills`` — scans all LLM nodes in a persisted draft + workflow and returns per-node skill info. Uses a cached bundle. +""" + +from __future__ import annotations + +import json import logging from collections.abc import Mapping +from functools import reduce from typing import Any, cast +from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode from core.sandbox.entities.config import AppAssets -from core.skill.assembler import SkillDocumentAssembler +from core.skill.assembler import SkillBundleAssembler, SkillDocumentAssembler from core.skill.entities.api_entities import NodeSkillInfo +from core.skill.entities.skill_bundle import SkillBundle from core.skill.entities.skill_document import SkillDocument from core.skill.entities.skill_metadata import SkillMetadata from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency from core.skill.skill_manager import SkillManager -from core.workflow.entities.graph_config import NodeConfigData, NodeConfigDict from core.workflow.enums import NodeType -from models._workflow_exc import NodeNotFoundError from models.model import App from models.workflow import Workflow from services.app_asset_service import AppAssetService @@ -20,159 +36,193 @@ logger = logging.getLogger(__name__) class SkillService: - """ - Service for managing and retrieving skill information from workflows. - """ + """Service for managing and retrieving skill information from workflows.""" + + # ------------------------------------------------------------------ + # Per-node: client sends node data, server builds bundle in real-time + # ------------------------------------------------------------------ @staticmethod - def get_node_skill_info(app: App, workflow: Workflow, node_id: str, user_id: str) -> NodeSkillInfo: + def extract_tool_dependencies( + app: App, + node_data: Mapping[str, Any], + user_id: str, + ) -> list[ToolDependency]: + """Extract tool dependencies from an LLM node's skill prompts. + + Builds a fresh ``SkillBundle`` from current draft ``.md`` assets + every time — no cached bundle is used. The caller supplies the + full node ``data`` dict directly (not a ``node_id``). + + Returns an empty list when the node has no skill prompts or when + no draft assets exist. """ - Get skill information for a specific node in a workflow. + if node_data.get("type", "") != NodeType.LLM.value: + return [] - Args: - app: The app model - workflow: The workflow containing the node - node_id: The ID of the node to get skill info for - user_id: The user ID for asset access - - Returns: - NodeSkillInfo containing tool dependencies for the node - """ - node_config: NodeConfigDict = workflow.get_node_config_by_id(node_id) - if not node_config: - raise NodeNotFoundError(f"Node with ID {node_id} not found in workflow {workflow.id}") - node_data: NodeConfigData = node_config["data"] - node_type = node_data.get("type", "") - - # Only LLM nodes support skills currently - if node_type != NodeType.LLM.value: - return NodeSkillInfo(node_id=node_id) - - # Check if node has any skill prompts if not SkillService._has_skill(node_data): - return NodeSkillInfo(node_id=node_id) + return [] - tool_dependencies = SkillService._extract_tool_dependencies_with_compiler(app, node_data, user_id) + bundle = SkillService._build_bundle(app, user_id) + if bundle is None: + return [] - return NodeSkillInfo( - node_id=node_id, - tool_dependencies=tool_dependencies, - ) + return SkillService._resolve_prompt_dependencies(node_data, bundle) + + # ------------------------------------------------------------------ + # Whole-workflow: reads persisted draft + cached bundle + # ------------------------------------------------------------------ @staticmethod def get_workflow_skills(app: App, workflow: Workflow, user_id: str) -> list[NodeSkillInfo]: - """ - Get skill information for all nodes in a workflow that have skill references. + """Get skill information for all LLM nodes in a persisted workflow. - Args: - app: The app model - workflow: The workflow to scan for skills - user_id: The user ID for asset access - - Returns: - List of NodeSkillInfo for nodes that have skill references + Uses the cached ``SkillBundle`` (Redis / S3). This method is + kept for the whole-workflow GET endpoint. """ result: list[NodeSkillInfo] = [] - # Only scan LLM nodes since they're the only ones that support skills for node_id, node_data in workflow.walk_nodes(specific_node_type=NodeType.LLM): - has_skill = SkillService._has_skill(dict(node_data)) + if not SkillService._has_skill(dict(node_data)): + continue - if has_skill: - tool_dependencies = SkillService._extract_tool_dependencies_with_compiler(app, dict(node_data), user_id) - result.append( - NodeSkillInfo( - node_id=node_id, - tool_dependencies=tool_dependencies, - ) - ) + tool_dependencies = SkillService._extract_tool_dependencies_cached(app, dict(node_data), user_id) + result.append(NodeSkillInfo(node_id=node_id, tool_dependencies=tool_dependencies)) return result + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + @staticmethod def _has_skill(node_data: Mapping[str, Any]) -> bool: """Check if node has any skill prompts.""" prompt_template_raw = node_data.get("prompt_template", []) if isinstance(prompt_template_raw, list): - prompt_template = cast(list[object], prompt_template_raw) - for prompt_item in prompt_template: - if not isinstance(prompt_item, dict): - continue - prompt = cast(dict[str, Any], prompt_item) - if prompt.get("skill", False): + for prompt_item in cast(list[object], prompt_template_raw): + if isinstance(prompt_item, dict) and prompt_item.get("skill", False): return True return False @staticmethod - def _extract_tool_dependencies_with_compiler( - app: App, node_data: Mapping[str, Any], user_id: str - ) -> list[ToolDependency]: - """Extract tool dependencies using SkillDocumentAssembler. + def _build_bundle(app: App, user_id: str) -> SkillBundle | None: + """Real-time build a SkillBundle from current draft .md assets. - This method loads the SkillBundle and AppAssetFileTree, then uses - SkillDocumentAssembler.assemble_document() to properly extract tool dependencies - including transitive dependencies from referenced skill files. + Reads all ``.md`` nodes from the draft file tree, bulk-loads + their content from the DB cache, parses into ``SkillDocument`` + objects, and assembles a full bundle with transitive dependency + resolution. + + The bundle is **not** persisted — it is built fresh for each + request so the response always reflects the latest draft state. """ - # Get the draft assets to obtain assets_id and file_tree assets = AppAssetService.get_assets( tenant_id=app.tenant_id, app_id=app.id, user_id=user_id, is_draft=True, ) - if not assets: - logger.warning("No draft assets found for app_id=%s", app.id) - return [] + return None - assets_id = assets.id + file_tree: AppAssetFileTree = assets.asset_tree + if file_tree.empty(): + return SkillBundle(assets_id=assets.id, asset_tree=file_tree) - # Load the skill bundle - try: - bundle = SkillManager.load_bundle( - tenant_id=app.tenant_id, - app_id=app.id, - assets_id=assets_id, - ) - except Exception as e: - logger.debug("Failed to load skill bundle for app_id=%s: %s", app.id, e) - # Return empty if bundle doesn't exist (no skills compiled yet) - return [] + # Collect all .md file nodes from the tree. + md_nodes: list[AppAssetNode] = [n for n in file_tree.walk_files() if n.extension == "md"] + if not md_nodes: + return SkillBundle(assets_id=assets.id, asset_tree=file_tree) - # Compile each skill prompt and collect tool dependencies + # Bulk-load content from DB (with S3 fallback). + accessor = AppAssetService.get_accessor(app.tenant_id, app.id) + raw_contents = accessor.bulk_load(md_nodes) + + # Parse into SkillDocuments. + documents: dict[str, SkillDocument] = {} + for node in md_nodes: + raw = raw_contents.get(node.id) + if not raw: + continue + try: + data = {"skill_id": node.id, **json.loads(raw)} + documents[node.id] = SkillDocument.model_validate(data) + except (json.JSONDecodeError, TypeError, ValueError): + logger.warning("Skipping unparseable skill document node_id=%s", node.id) + continue + + return SkillBundleAssembler(file_tree).assemble_bundle(documents, assets.id) + + @staticmethod + def _resolve_prompt_dependencies( + node_data: Mapping[str, Any], + bundle: SkillBundle, + ) -> list[ToolDependency]: + """Resolve tool dependencies from skill prompts against a bundle.""" assembler = SkillDocumentAssembler(bundle) tool_deps_list: list[ToolDependencies] = [] prompt_template_raw = node_data.get("prompt_template", []) - if isinstance(prompt_template_raw, list): - prompt_template = cast(list[object], prompt_template_raw) - for prompt_item in prompt_template: - if not isinstance(prompt_item, dict): - continue - prompt = cast(dict[str, Any], prompt_item) - if prompt.get("skill", False): - text_raw = prompt.get("text", "") - text = text_raw if isinstance(text_raw, str) else str(text_raw) + if not isinstance(prompt_template_raw, list): + return [] - metadata_obj: object = prompt.get("metadata") - metadata = cast(dict[str, Any], metadata_obj) if isinstance(metadata_obj, dict) else {} + for prompt_item in cast(list[object], prompt_template_raw): + if not isinstance(prompt_item, dict): + continue + prompt = cast(dict[str, Any], prompt_item) + if not prompt.get("skill", False): + continue - skill_entry = assembler.assemble_document( - document=SkillDocument( - skill_id="anonymous", - content=text, - metadata=SkillMetadata.model_validate(metadata), - ), - base_path=AppAssets.PATH, - ) - tool_deps_list.append(skill_entry.dependance.tools) + text_raw = prompt.get("text", "") + text = text_raw if isinstance(text_raw, str) else str(text_raw) + + metadata_obj: object = prompt.get("metadata") + metadata = cast(dict[str, Any], metadata_obj) if isinstance(metadata_obj, dict) else {} + + skill_entry = assembler.assemble_document( + document=SkillDocument( + skill_id="anonymous", + content=text, + metadata=SkillMetadata.model_validate(metadata), + ), + base_path=AppAssets.PATH, + ) + tool_deps_list.append(skill_entry.dependance.tools) if not tool_deps_list: return [] - # Merge all tool dependencies - from functools import reduce - merged = reduce(lambda x, y: x.merge(y), tool_deps_list) - return merged.dependencies + + @staticmethod + def _extract_tool_dependencies_cached( + app: App, + node_data: Mapping[str, Any], + user_id: str, + ) -> list[ToolDependency]: + """Extract tool dependencies using a cached SkillBundle. + + Used by ``get_workflow_skills`` for the whole-workflow endpoint. + """ + assets = AppAssetService.get_assets( + tenant_id=app.tenant_id, + app_id=app.id, + user_id=user_id, + is_draft=True, + ) + if not assets: + return [] + + try: + bundle = SkillManager.load_bundle( + tenant_id=app.tenant_id, + app_id=app.id, + assets_id=assets.id, + ) + except Exception: + logger.debug("Failed to load cached skill bundle for app_id=%s", app.id, exc_info=True) + return [] + + return SkillService._resolve_prompt_dependencies(node_data, bundle)