refactor: async skill compile and context sharing

This commit is contained in:
Harry
2026-03-11 00:42:32 +08:00
parent d61be086ed
commit 0776e16fdc
14 changed files with 325 additions and 232 deletions

View File

@ -5,10 +5,10 @@ each into a ``SkillDocument``, assembles a ``SkillBundle`` (with
transitive tool/file dependency resolution), and returns ``AssetItem``
objects whose *content* field carries the resolved bytes in-process.
No S3 writes happen here — the only persistence is the ``SkillBundle``
saved via ``SkillManager`` (S3 + Redis cache invalidation) so that
downstream consumers (``SkillInitializer``, ``DifyCliInitializer``) can
load it later.
The assembled ``SkillBundle`` is persisted via ``SkillManager``
(S3 + Redis) **and** retained on the ``bundle`` property so that
callers (e.g. ``DraftAppAssetsInitializer``) can pass it directly to
``sandbox.attrs`` without a redundant Redis/S3 round-trip.
"""
import json
@ -29,10 +29,17 @@ logger = logging.getLogger(__name__)
class SkillBuilder:
_nodes: list[tuple[AppAssetNode, str]]
_accessor: CachedContentAccessor
_bundle: SkillBundle | None
def __init__(self, accessor: CachedContentAccessor) -> None:
self._nodes = []
self._accessor = accessor
self._bundle = None
@property
def bundle(self) -> SkillBundle | None:
"""The ``SkillBundle`` produced by the last ``build()`` call, or *None*."""
return self._bundle
def accept(self, node: AppAssetNode) -> bool:
return node.extension == "md"
@ -44,9 +51,9 @@ class SkillBuilder:
from core.skill.skill_manager import SkillManager
if not self._nodes:
SkillManager.save_bundle(
ctx.tenant_id, ctx.app_id, ctx.build_id, SkillBundle(assets_id=ctx.build_id, asset_tree=tree)
)
bundle = SkillBundle(assets_id=ctx.build_id, asset_tree=tree)
SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle)
self._bundle = bundle
return []
# Batch-load all skill draft content in one DB query (with S3 fallback on miss).
@ -69,6 +76,7 @@ class SkillBuilder:
bundle = SkillBundleAssembler(tree).assemble_bundle(documents, ctx.build_id)
SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle)
self._bundle = bundle
items: list[AssetItem] = []
for node, path in self._nodes:

View File

@ -16,6 +16,7 @@ if TYPE_CHECKING:
from .entities import AppAssets, DifyCli, SandboxProviderApiEntity, SandboxType
from .initializer import (
AsyncSandboxInitializer,
SandboxInitializeContext,
SandboxInitializer,
SyncSandboxInitializer,
)
@ -44,6 +45,7 @@ __all__ = [
"Sandbox",
"SandboxBashSession",
"SandboxBuilder",
"SandboxInitializeContext",
"SandboxInitializer",
"SandboxManager",
"SandboxProviderApiEntity",
@ -72,6 +74,7 @@ _LAZY_IMPORTS = {
"Sandbox": ("core.sandbox.sandbox", "Sandbox"),
"SandboxBashSession": ("core.sandbox.bash.session", "SandboxBashSession"),
"SandboxBuilder": ("core.sandbox.builder", "SandboxBuilder"),
"SandboxInitializeContext": ("core.sandbox.initializer", "SandboxInitializeContext"),
"SandboxInitializer": ("core.sandbox.initializer", "SandboxInitializer"),
"SandboxManager": ("core.sandbox.manager", "SandboxManager"),
"SandboxProviderApiEntity": ("core.sandbox.entities", "SandboxProviderApiEntity"),

View File

@ -11,7 +11,7 @@ from core.entities.provider_entities import BasicProviderConfig
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from .entities.sandbox_type import SandboxType
from .initializer import AsyncSandboxInitializer, SandboxInitializer, SyncSandboxInitializer
from .initializer import AsyncSandboxInitializer, SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer
from .sandbox import Sandbox
if TYPE_CHECKING:
@ -125,10 +125,17 @@ class SandboxBuilder:
assets_id=self._assets_id,
)
ctx = SandboxInitializeContext(
tenant_id=self._tenant_id,
app_id=self._app_id,
assets_id=self._assets_id,
user_id=self._user_id,
)
# Run synchronous initializers before marking sandbox as ready.
for init in self._initializers:
if isinstance(init, SyncSandboxInitializer):
init.initialize(sandbox)
init.initialize(sandbox, ctx)
# Run sandbox setup asynchronously so workflow execution can proceed.
# Capture the Flask app before starting the thread for database access.
@ -143,7 +150,7 @@ class SandboxBuilder:
if sandbox.is_cancelled():
return
init.initialize(sandbox)
init.initialize(sandbox, ctx)
if sandbox.is_cancelled():
return
sandbox.mount()

View File

@ -1,7 +1,8 @@
from .base import AsyncSandboxInitializer, SandboxInitializer, SyncSandboxInitializer
from .base import AsyncSandboxInitializer, SandboxInitializeContext, SandboxInitializer, SyncSandboxInitializer
__all__ = [
"AsyncSandboxInitializer",
"SandboxInitializeContext",
"SandboxInitializer",
"SyncSandboxInitializer",
]

View File

@ -4,7 +4,7 @@ from core.app_assets.constants import AppAssetsAttrs
from core.sandbox.sandbox import Sandbox
from services.app_asset_package_service import AppAssetPackageService
from .base import SyncSandboxInitializer
from .base import SandboxInitializeContext, SyncSandboxInitializer
logger = logging.getLogger(__name__)
@ -12,13 +12,8 @@ APP_ASSETS_DOWNLOAD_TIMEOUT = 60 * 10
class AppAssetAttrsInitializer(SyncSandboxInitializer):
def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None:
self._tenant_id = tenant_id
self._app_id = app_id
self._assets_id = assets_id
def initialize(self, sandbox: Sandbox) -> None:
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
# Load published app assets and unzip the artifact bundle.
app_assets = AppAssetPackageService.get_tenant_app_assets(self._tenant_id, self._assets_id)
app_assets = AppAssetPackageService.get_tenant_app_assets(ctx.tenant_id, ctx.assets_id)
sandbox.attrs.set(AppAssetsAttrs.FILE_TREE, app_assets.asset_tree)
sandbox.attrs.set(AppAssetsAttrs.APP_ASSETS_ID, self._assets_id)
sandbox.attrs.set(AppAssetsAttrs.APP_ASSETS_ID, ctx.assets_id)

View File

@ -5,7 +5,7 @@ from core.sandbox.sandbox import Sandbox
from core.virtual_environment.__base.helpers import pipeline
from ..entities import AppAssets
from .base import AsyncSandboxInitializer
from .base import AsyncSandboxInitializer, SandboxInitializeContext
logger = logging.getLogger(__name__)
@ -13,18 +13,13 @@ APP_ASSETS_DOWNLOAD_TIMEOUT = 60 * 10
class AppAssetsInitializer(AsyncSandboxInitializer):
def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None:
self._tenant_id = tenant_id
self._app_id = app_id
self._assets_id = assets_id
def initialize(self, sandbox: Sandbox) -> None:
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
from services.app_asset_service import AppAssetService
# Load published app assets and unzip the artifact bundle.
vm = sandbox.vm
asset_storage = AppAssetService.get_storage()
key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id)
key = AssetPaths.build_zip(ctx.tenant_id, ctx.app_id, ctx.assets_id)
download_url = asset_storage.get_download_url(key)
(
@ -54,6 +49,6 @@ class AppAssetsInitializer(AsyncSandboxInitializer):
logger.info(
"App assets initialized for app_id=%s, published_id=%s",
self._app_id,
self._assets_id,
ctx.app_id,
ctx.assets_id,
)

View File

@ -1,11 +1,41 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from core.sandbox.sandbox import Sandbox
if TYPE_CHECKING:
from core.app_assets.entities.assets import AssetItem
@dataclass
class SandboxInitializeContext:
"""Shared identity context passed to every ``SandboxInitializer``.
Carries the common identity fields that virtually every initializer
needs, plus optional artefact slots that sync initializers populate
for async initializers to consume.
Identity fields are immutable by convention; artefact slots are
written at most once during the sync phase and read during the
async phase.
"""
tenant_id: str
app_id: str
assets_id: str
user_id: str
# Populated by DraftAppAssetsInitializer (sync) for
# DraftAppAssetsDownloader (async) to download into the VM.
built_assets: list[AssetItem] | None = field(default=None)
class SandboxInitializer(ABC):
@abstractmethod
def initialize(self, sandbox: Sandbox) -> None: ...
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None: ...
class SyncSandboxInitializer(SandboxInitializer):

View File

@ -7,36 +7,26 @@ from pathlib import Path
from core.sandbox.sandbox import Sandbox
from core.session.cli_api import CliApiSessionManager, CliContext
from core.skill.constants import SkillAttrs
from core.skill.entities import ToolAccessPolicy
from core.skill.skill_manager import SkillManager
from core.virtual_environment.__base.helpers import pipeline
from ..bash.dify_cli import DifyCliConfig, DifyCliLocator
from ..entities import DifyCli
from .base import SyncSandboxInitializer
from .base import AsyncSandboxInitializer, SandboxInitializeContext
logger = logging.getLogger(__name__)
class DifyCliInitializer(SyncSandboxInitializer):
def __init__(
self,
tenant_id: str,
user_id: str,
app_id: str,
assets_id: str,
cli_root: str | Path | None = None,
) -> None:
self._tenant_id = tenant_id
self._app_id = app_id
self._user_id = user_id
self._assets_id = assets_id
self._locator = DifyCliLocator(root=cli_root)
class DifyCliInitializer(AsyncSandboxInitializer):
_cli_api_session: object | None
self._tools = []
def __init__(self, cli_root: str | Path | None = None) -> None:
self._locator = DifyCliLocator(root=cli_root)
self._tools: list[object] = []
self._cli_api_session = None
def initialize(self, sandbox: Sandbox) -> None:
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
vm = sandbox.vm
# FIXME(Mairuis): should be more robust, effectively.
binary = self._locator.resolve(vm.metadata.os, vm.metadata.arch)
@ -60,14 +50,14 @@ class DifyCliInitializer(SyncSandboxInitializer):
logger.info("Dify CLI uploaded to sandbox, path=%s", DifyCli.PATH)
bundle = SkillManager.load_bundle(self._tenant_id, self._app_id, self._assets_id)
bundle = sandbox.attrs.get(SkillAttrs.BUNDLE)
if bundle is None or bundle.get_tool_dependencies().is_empty():
logger.info("No tools found in bundle for assets_id=%s", self._assets_id)
logger.info("No tools found in bundle for assets_id=%s", ctx.assets_id)
return
self._cli_api_session = CliApiSessionManager().create(
tenant_id=self._tenant_id,
user_id=self._user_id,
tenant_id=ctx.tenant_id,
user_id=ctx.user_id,
context=CliContext(tool_access=ToolAccessPolicy.from_dependencies(bundle.get_tool_dependencies())),
)
@ -75,7 +65,7 @@ class DifyCliInitializer(SyncSandboxInitializer):
["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir"
).execute(raise_on_error=True)
config = DifyCliConfig.create(self._cli_api_session, self._tenant_id, bundle.get_tool_dependencies())
config = DifyCliConfig.create(self._cli_api_session, ctx.tenant_id, bundle.get_tool_dependencies())
config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False)
config_path = f"{DifyCli.GLOBAL_TOOLS_PATH}/{DifyCli.CONFIG_FILENAME}"
vm.upload_file(config_path, BytesIO(config_json.encode("utf-8")))

View File

@ -1,4 +1,4 @@
"""Async initializer that populates a draft sandbox with app asset files.
"""Synchronous initializer that compiles draft app assets.
Unlike ``AppAssetsInitializer`` (which downloads a pre-built ZIP for
published assets), this initializer runs the build pipeline on the fly
@ -6,11 +6,12 @@ so that ``.md`` skill documents are compiled and their resolved content
is embedded directly into the download script — avoiding the S3
round-trip that was previously required for resolved keys.
Execution order guarantee:
This runs as an ``AsyncSandboxInitializer`` in the background thread.
By the time it finishes, ``SkillManager.save_bundle()`` has been
called (inside ``SkillBuilder.build()``), so subsequent initializers
like ``DifyCliInitializer`` can safely load the bundle from Redis/S3.
Execution order:
``DraftAppAssetsInitializer`` (sync) compiles assets and publishes
the ``SkillBundle`` to ``sandbox.attrs`` in-memory, so the
downstream ``SkillInitializer`` can skip the Redis/S3 round-trip.
``DraftAppAssetsDownloader`` (async) then pushes the compiled
artefacts into the sandbox VM in the background.
"""
import logging
@ -23,54 +24,68 @@ from core.app_assets.constants import AppAssetsAttrs
from core.sandbox.entities import AppAssets
from core.sandbox.sandbox import Sandbox
from core.sandbox.services import AssetDownloadService
from core.skill import SkillAttrs
from core.virtual_environment.__base.helpers import pipeline
from services.app_asset_service import AppAssetService
from .base import SyncSandboxInitializer
from .base import AsyncSandboxInitializer, SandboxInitializeContext, SyncSandboxInitializer
logger = logging.getLogger(__name__)
_TIMEOUT = 600 # 10 minutes
class DraftAppAssetsInitializer(SyncSandboxInitializer):
"""Compile draft assets and push them into the sandbox VM.
"""Compile draft assets and publish the ``SkillBundle`` to attrs.
``.md`` (skill) files are compiled in-process and their resolved
content is embedded as base64 heredocs in the download script.
All other files are fetched from S3 via presigned URLs.
The build pipeline compiles ``.md`` skill files in-process.
The resulting ``SkillBundle`` is persisted to Redis/S3 (by
``SkillBuilder``) **and** written to ``sandbox.attrs[BUNDLE]``
so that ``SkillInitializer`` can read it without a round-trip.
Built asset items are stored on ``ctx.built_assets`` for the
async ``DraftAppAssetsDownloader`` to consume.
"""
def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None:
self._tenant_id = tenant_id
self._app_id = app_id
self._assets_id = assets_id
def initialize(self, sandbox: Sandbox) -> None:
vm = sandbox.vm
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
tree = sandbox.attrs.get(AppAssetsAttrs.FILE_TREE)
if tree.empty():
return
# --- 1. Run the build pipeline (SkillBuilder compiles .md inline) ---
accessor = AppAssetService.get_accessor(self._tenant_id, self._app_id)
build_pipeline = AssetBuildPipeline([SkillBuilder(accessor=accessor), FileBuilder()])
ctx = BuildContext(tenant_id=self._tenant_id, app_id=self._app_id, build_id=self._assets_id)
built_assets = build_pipeline.build_all(tree, ctx)
accessor = AppAssetService.get_accessor(ctx.tenant_id, ctx.app_id)
skill_builder = SkillBuilder(accessor=accessor)
build_pipeline = AssetBuildPipeline([skill_builder, FileBuilder()])
build_ctx = BuildContext(tenant_id=ctx.tenant_id, app_id=ctx.app_id, build_id=ctx.assets_id)
built_assets = build_pipeline.build_all(tree, build_ctx)
ctx.built_assets = built_assets
if not built_assets:
# Publish the in-memory bundle so SkillInitializer skips Redis/S3.
if skill_builder.bundle is not None:
sandbox.attrs.set(SkillAttrs.BUNDLE, skill_builder.bundle)
class DraftAppAssetsDownloader(AsyncSandboxInitializer):
"""Download the compiled assets into the sandbox VM.
The download script is generated by ``DraftAppAssetsInitializer`` and
includes inline base64 content for compiled skills, as well as
presigned URLs for other files.
"""
_TIMEOUT = 600 # 10 minutes
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
if not ctx.built_assets:
logger.debug("No built assets found for assets_id=%s", ctx.assets_id)
return
# --- 2. Convert to unified download items and execute ---
download_items = AppAssetService.to_download_items(built_assets)
download_items = AppAssetService.to_download_items(ctx.built_assets)
script = AssetDownloadService.build_download_script(download_items, AppAssets.PATH)
pipeline(vm).add(
pipeline(sandbox.vm).add(
["sh", "-c", script],
error_message="Failed to download draft assets",
).execute(timeout=_TIMEOUT, raise_on_error=True)
).execute(timeout=self._TIMEOUT, raise_on_error=True)
logger.info(
"Draft app assets initialized for app_id=%s, assets_id=%s",
self._app_id,
self._assets_id,
ctx.app_id,
ctx.assets_id,
)

View File

@ -6,31 +6,29 @@ from core.sandbox.sandbox import Sandbox
from core.skill import SkillAttrs
from core.skill.skill_manager import SkillManager
from .base import SyncSandboxInitializer
from .base import SandboxInitializeContext, SyncSandboxInitializer
logger = logging.getLogger(__name__)
class SkillInitializer(SyncSandboxInitializer):
def __init__(
self,
tenant_id: str,
user_id: str,
app_id: str,
assets_id: str,
) -> None:
self._tenant_id = tenant_id
self._app_id = app_id
self._user_id = user_id
self._assets_id = assets_id
"""Ensure ``sandbox.attrs[BUNDLE]`` is populated for downstream consumers.
def initialize(self, sandbox: Sandbox) -> None:
In the draft path ``DraftAppAssetsInitializer`` already sets the
bundle on attrs from the in-memory build result, so this initializer
becomes a no-op. In the published path no prior initializer sets
it, so we fall back to ``SkillManager.load_bundle()`` (Redis/S3).
"""
def initialize(self, sandbox: Sandbox, ctx: SandboxInitializeContext) -> None:
# Draft path: bundle already populated by DraftAppAssetsInitializer.
if sandbox.attrs.has(SkillAttrs.BUNDLE):
return
# Published path: load from Redis/S3.
bundle = SkillManager.load_bundle(
self._tenant_id,
self._app_id,
self._assets_id,
)
sandbox.attrs.set(
SkillAttrs.BUNDLE,
bundle,
ctx.tenant_id,
ctx.app_id,
ctx.assets_id,
)
sandbox.attrs.set(SkillAttrs.BUNDLE, bundle)

View File

@ -4,14 +4,13 @@ from core.skill.entities.tool_dependencies import ToolDependency
class NodeSkillInfo(BaseModel):
"""Information about skills referenced by a workflow node."""
"""Information about skills referenced by a workflow node.
Used by the whole-workflow skills endpoint to return per-node
tool dependency information.
"""
node_id: str = Field(description="The node ID")
tool_dependencies: list[ToolDependency] = Field(
default_factory=list, description="Tool dependencies extracted from skill prompts"
)
@staticmethod
def empty(node_id: str = "") -> "NodeSkillInfo":
"""Create an empty NodeSkillInfo with no tool dependencies."""
return NodeSkillInfo(node_id=node_id, tool_dependencies=[])

View File

@ -28,6 +28,6 @@ class SkillManager:
@staticmethod
def save_bundle(tenant_id: str, app_id: str, assets_id: str, bundle: SkillBundle) -> None:
key = AssetPaths.skill_bundle(tenant_id, app_id, assets_id)
AppAssetService.get_storage().save(key, bundle.model_dump_json(indent=2).encode("utf-8"))
AppAssetService.get_storage().save(key, data=bundle.model_dump_json(indent=2).encode("utf-8"))
cache_key = f"{_CACHE_PREFIX}:{tenant_id}:{app_id}:{assets_id}"
redis_client.delete(cache_key)

View File

@ -22,7 +22,7 @@ from core.sandbox.entities.providers import SandboxProviderEntity
from core.sandbox.initializer.app_asset_attrs_initializer import AppAssetAttrsInitializer
from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer
from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer
from core.sandbox.initializer.draft_app_assets_initializer import DraftAppAssetsInitializer
from core.sandbox.initializer.draft_app_assets_initializer import DraftAppAssetsDownloader, DraftAppAssetsInitializer
from core.sandbox.initializer.skill_initializer import SkillInitializer
from core.sandbox.sandbox import Sandbox
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
@ -52,10 +52,10 @@ class SandboxService:
.options(sandbox_provider.config)
.user(user_id)
.app(app_id)
.initializer(AppAssetAttrsInitializer(tenant_id, app_id, assets.id))
.initializer(AppAssetsInitializer(tenant_id, app_id, assets.id))
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
.initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id))
.initializer(AppAssetAttrsInitializer())
.initializer(AppAssetsInitializer())
.initializer(SkillInitializer())
.initializer(DifyCliInitializer())
.storage(archive_storage, assets.id)
.build()
)
@ -92,10 +92,11 @@ class SandboxService:
.options(sandbox_provider.config)
.user(user_id)
.app(app_id)
.initializer(AppAssetAttrsInitializer(tenant_id, app_id, assets.id))
.initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id))
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
.initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id))
.initializer(AppAssetAttrsInitializer())
.initializer(DraftAppAssetsInitializer())
.initializer(DraftAppAssetsDownloader())
.initializer(SkillInitializer())
.initializer(DifyCliInitializer())
.storage(archive_storage, assets.id)
.build()
)
@ -125,10 +126,11 @@ class SandboxService:
.options(sandbox_provider.config)
.user(user_id)
.app(app_id)
.initializer(AppAssetAttrsInitializer(tenant_id, app_id, assets.id))
.initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id))
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
.initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id))
.initializer(AppAssetAttrsInitializer())
.initializer(DraftAppAssetsInitializer())
.initializer(DraftAppAssetsDownloader())
.initializer(SkillInitializer())
.initializer(DifyCliInitializer())
.storage(archive_storage, assets.id)
.build()
)

View File

@ -1,17 +1,33 @@
"""Service for extracting tool dependencies from LLM node skill prompts.
Two public entry points:
- ``extract_tool_dependencies`` — takes raw node data from the client,
real-time builds a ``SkillBundle`` from current draft ``.md`` assets,
and resolves transitive tool dependencies. Used by the per-node POST
endpoint.
- ``get_workflow_skills`` — scans all LLM nodes in a persisted draft
workflow and returns per-node skill info. Uses a cached bundle.
"""
from __future__ import annotations
import json
import logging
from collections.abc import Mapping
from functools import reduce
from typing import Any, cast
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
from core.sandbox.entities.config import AppAssets
from core.skill.assembler import SkillDocumentAssembler
from core.skill.assembler import SkillBundleAssembler, SkillDocumentAssembler
from core.skill.entities.api_entities import NodeSkillInfo
from core.skill.entities.skill_bundle import SkillBundle
from core.skill.entities.skill_document import SkillDocument
from core.skill.entities.skill_metadata import SkillMetadata
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
from core.skill.skill_manager import SkillManager
from core.workflow.entities.graph_config import NodeConfigData, NodeConfigDict
from core.workflow.enums import NodeType
from models._workflow_exc import NodeNotFoundError
from models.model import App
from models.workflow import Workflow
from services.app_asset_service import AppAssetService
@ -20,159 +36,193 @@ logger = logging.getLogger(__name__)
class SkillService:
"""
Service for managing and retrieving skill information from workflows.
"""
"""Service for managing and retrieving skill information from workflows."""
# ------------------------------------------------------------------
# Per-node: client sends node data, server builds bundle in real-time
# ------------------------------------------------------------------
@staticmethod
def get_node_skill_info(app: App, workflow: Workflow, node_id: str, user_id: str) -> NodeSkillInfo:
def extract_tool_dependencies(
app: App,
node_data: Mapping[str, Any],
user_id: str,
) -> list[ToolDependency]:
"""Extract tool dependencies from an LLM node's skill prompts.
Builds a fresh ``SkillBundle`` from current draft ``.md`` assets
every time — no cached bundle is used. The caller supplies the
full node ``data`` dict directly (not a ``node_id``).
Returns an empty list when the node has no skill prompts or when
no draft assets exist.
"""
Get skill information for a specific node in a workflow.
if node_data.get("type", "") != NodeType.LLM.value:
return []
Args:
app: The app model
workflow: The workflow containing the node
node_id: The ID of the node to get skill info for
user_id: The user ID for asset access
Returns:
NodeSkillInfo containing tool dependencies for the node
"""
node_config: NodeConfigDict = workflow.get_node_config_by_id(node_id)
if not node_config:
raise NodeNotFoundError(f"Node with ID {node_id} not found in workflow {workflow.id}")
node_data: NodeConfigData = node_config["data"]
node_type = node_data.get("type", "")
# Only LLM nodes support skills currently
if node_type != NodeType.LLM.value:
return NodeSkillInfo(node_id=node_id)
# Check if node has any skill prompts
if not SkillService._has_skill(node_data):
return NodeSkillInfo(node_id=node_id)
return []
tool_dependencies = SkillService._extract_tool_dependencies_with_compiler(app, node_data, user_id)
bundle = SkillService._build_bundle(app, user_id)
if bundle is None:
return []
return NodeSkillInfo(
node_id=node_id,
tool_dependencies=tool_dependencies,
)
return SkillService._resolve_prompt_dependencies(node_data, bundle)
# ------------------------------------------------------------------
# Whole-workflow: reads persisted draft + cached bundle
# ------------------------------------------------------------------
@staticmethod
def get_workflow_skills(app: App, workflow: Workflow, user_id: str) -> list[NodeSkillInfo]:
"""
Get skill information for all nodes in a workflow that have skill references.
"""Get skill information for all LLM nodes in a persisted workflow.
Args:
app: The app model
workflow: The workflow to scan for skills
user_id: The user ID for asset access
Returns:
List of NodeSkillInfo for nodes that have skill references
Uses the cached ``SkillBundle`` (Redis / S3). This method is
kept for the whole-workflow GET endpoint.
"""
result: list[NodeSkillInfo] = []
# Only scan LLM nodes since they're the only ones that support skills
for node_id, node_data in workflow.walk_nodes(specific_node_type=NodeType.LLM):
has_skill = SkillService._has_skill(dict(node_data))
if not SkillService._has_skill(dict(node_data)):
continue
if has_skill:
tool_dependencies = SkillService._extract_tool_dependencies_with_compiler(app, dict(node_data), user_id)
result.append(
NodeSkillInfo(
node_id=node_id,
tool_dependencies=tool_dependencies,
)
)
tool_dependencies = SkillService._extract_tool_dependencies_cached(app, dict(node_data), user_id)
result.append(NodeSkillInfo(node_id=node_id, tool_dependencies=tool_dependencies))
return result
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _has_skill(node_data: Mapping[str, Any]) -> bool:
"""Check if node has any skill prompts."""
prompt_template_raw = node_data.get("prompt_template", [])
if isinstance(prompt_template_raw, list):
prompt_template = cast(list[object], prompt_template_raw)
for prompt_item in prompt_template:
if not isinstance(prompt_item, dict):
continue
prompt = cast(dict[str, Any], prompt_item)
if prompt.get("skill", False):
for prompt_item in cast(list[object], prompt_template_raw):
if isinstance(prompt_item, dict) and prompt_item.get("skill", False):
return True
return False
@staticmethod
def _extract_tool_dependencies_with_compiler(
app: App, node_data: Mapping[str, Any], user_id: str
) -> list[ToolDependency]:
"""Extract tool dependencies using SkillDocumentAssembler.
def _build_bundle(app: App, user_id: str) -> SkillBundle | None:
"""Real-time build a SkillBundle from current draft .md assets.
This method loads the SkillBundle and AppAssetFileTree, then uses
SkillDocumentAssembler.assemble_document() to properly extract tool dependencies
including transitive dependencies from referenced skill files.
Reads all ``.md`` nodes from the draft file tree, bulk-loads
their content from the DB cache, parses into ``SkillDocument``
objects, and assembles a full bundle with transitive dependency
resolution.
The bundle is **not** persisted — it is built fresh for each
request so the response always reflects the latest draft state.
"""
# Get the draft assets to obtain assets_id and file_tree
assets = AppAssetService.get_assets(
tenant_id=app.tenant_id,
app_id=app.id,
user_id=user_id,
is_draft=True,
)
if not assets:
logger.warning("No draft assets found for app_id=%s", app.id)
return []
return None
assets_id = assets.id
file_tree: AppAssetFileTree = assets.asset_tree
if file_tree.empty():
return SkillBundle(assets_id=assets.id, asset_tree=file_tree)
# Load the skill bundle
try:
bundle = SkillManager.load_bundle(
tenant_id=app.tenant_id,
app_id=app.id,
assets_id=assets_id,
)
except Exception as e:
logger.debug("Failed to load skill bundle for app_id=%s: %s", app.id, e)
# Return empty if bundle doesn't exist (no skills compiled yet)
return []
# Collect all .md file nodes from the tree.
md_nodes: list[AppAssetNode] = [n for n in file_tree.walk_files() if n.extension == "md"]
if not md_nodes:
return SkillBundle(assets_id=assets.id, asset_tree=file_tree)
# Compile each skill prompt and collect tool dependencies
# Bulk-load content from DB (with S3 fallback).
accessor = AppAssetService.get_accessor(app.tenant_id, app.id)
raw_contents = accessor.bulk_load(md_nodes)
# Parse into SkillDocuments.
documents: dict[str, SkillDocument] = {}
for node in md_nodes:
raw = raw_contents.get(node.id)
if not raw:
continue
try:
data = {"skill_id": node.id, **json.loads(raw)}
documents[node.id] = SkillDocument.model_validate(data)
except (json.JSONDecodeError, TypeError, ValueError):
logger.warning("Skipping unparseable skill document node_id=%s", node.id)
continue
return SkillBundleAssembler(file_tree).assemble_bundle(documents, assets.id)
@staticmethod
def _resolve_prompt_dependencies(
node_data: Mapping[str, Any],
bundle: SkillBundle,
) -> list[ToolDependency]:
"""Resolve tool dependencies from skill prompts against a bundle."""
assembler = SkillDocumentAssembler(bundle)
tool_deps_list: list[ToolDependencies] = []
prompt_template_raw = node_data.get("prompt_template", [])
if isinstance(prompt_template_raw, list):
prompt_template = cast(list[object], prompt_template_raw)
for prompt_item in prompt_template:
if not isinstance(prompt_item, dict):
continue
prompt = cast(dict[str, Any], prompt_item)
if prompt.get("skill", False):
text_raw = prompt.get("text", "")
text = text_raw if isinstance(text_raw, str) else str(text_raw)
if not isinstance(prompt_template_raw, list):
return []
metadata_obj: object = prompt.get("metadata")
metadata = cast(dict[str, Any], metadata_obj) if isinstance(metadata_obj, dict) else {}
for prompt_item in cast(list[object], prompt_template_raw):
if not isinstance(prompt_item, dict):
continue
prompt = cast(dict[str, Any], prompt_item)
if not prompt.get("skill", False):
continue
skill_entry = assembler.assemble_document(
document=SkillDocument(
skill_id="anonymous",
content=text,
metadata=SkillMetadata.model_validate(metadata),
),
base_path=AppAssets.PATH,
)
tool_deps_list.append(skill_entry.dependance.tools)
text_raw = prompt.get("text", "")
text = text_raw if isinstance(text_raw, str) else str(text_raw)
metadata_obj: object = prompt.get("metadata")
metadata = cast(dict[str, Any], metadata_obj) if isinstance(metadata_obj, dict) else {}
skill_entry = assembler.assemble_document(
document=SkillDocument(
skill_id="anonymous",
content=text,
metadata=SkillMetadata.model_validate(metadata),
),
base_path=AppAssets.PATH,
)
tool_deps_list.append(skill_entry.dependance.tools)
if not tool_deps_list:
return []
# Merge all tool dependencies
from functools import reduce
merged = reduce(lambda x, y: x.merge(y), tool_deps_list)
return merged.dependencies
@staticmethod
def _extract_tool_dependencies_cached(
app: App,
node_data: Mapping[str, Any],
user_id: str,
) -> list[ToolDependency]:
"""Extract tool dependencies using a cached SkillBundle.
Used by ``get_workflow_skills`` for the whole-workflow endpoint.
"""
assets = AppAssetService.get_assets(
tenant_id=app.tenant_id,
app_id=app.id,
user_id=user_id,
is_draft=True,
)
if not assets:
return []
try:
bundle = SkillManager.load_bundle(
tenant_id=app.tenant_id,
app_id=app.id,
assets_id=assets.id,
)
except Exception:
logger.debug("Failed to load cached skill bundle for app_id=%s", app.id, exc_info=True)
return []
return SkillService._resolve_prompt_dependencies(node_data, bundle)