mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
feat(skills): enhance skill retrieval by incorporating user context and app model in API endpoints
This commit is contained in:
@ -1,10 +1,16 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.sandbox.entities.config import AppAssets
|
||||
from core.skill.entities.api_entities import NodeSkillInfo
|
||||
from core.skill.entities.skill_metadata import ToolReference
|
||||
from core.skill.entities.tool_dependencies import ToolDependency
|
||||
from core.skill.entities.skill_document import SkillDocument
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
||||
from core.skill.skill_compiler import SkillCompiler
|
||||
from core.skill.skill_manager import SkillManager
|
||||
from core.workflow.enums import NodeType
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,13 +21,15 @@ class SkillService:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_node_skill_info(workflow: Workflow, node_id: str) -> NodeSkillInfo:
|
||||
def get_node_skill_info(app: App, workflow: Workflow, node_id: str, user_id: str) -> NodeSkillInfo:
|
||||
"""
|
||||
Get skill information for a specific node in a workflow.
|
||||
|
||||
Args:
|
||||
app: The app model
|
||||
workflow: The workflow containing the node
|
||||
node_id: The ID of the node to get skill info for
|
||||
user_id: The user ID for asset access
|
||||
|
||||
Returns:
|
||||
NodeSkillInfo containing tool dependencies for the node
|
||||
@ -34,7 +42,11 @@ class SkillService:
|
||||
if node_type != NodeType.LLM.value:
|
||||
return NodeSkillInfo(node_id=node_id)
|
||||
|
||||
tool_dependencies = SkillService._extract_tool_dependencies(node_data)
|
||||
# Check if node has any skill prompts
|
||||
if not SkillService._has_skill(node_data):
|
||||
return NodeSkillInfo(node_id=node_id)
|
||||
|
||||
tool_dependencies = SkillService._extract_tool_dependencies_with_compiler(app, node_data, user_id)
|
||||
|
||||
return NodeSkillInfo(
|
||||
node_id=node_id,
|
||||
@ -42,12 +54,14 @@ class SkillService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_workflow_skills(workflow: Workflow) -> list[NodeSkillInfo]:
|
||||
def get_workflow_skills(app: App, workflow: Workflow, user_id: str) -> list[NodeSkillInfo]:
|
||||
"""
|
||||
Get skill information for all nodes in a workflow that have skill references.
|
||||
|
||||
Args:
|
||||
app: The app model
|
||||
workflow: The workflow to scan for skills
|
||||
user_id: The user ID for asset access
|
||||
|
||||
Returns:
|
||||
List of NodeSkillInfo for nodes that have skill references
|
||||
@ -56,10 +70,10 @@ class SkillService:
|
||||
|
||||
# Only scan LLM nodes since they're the only ones that support skills
|
||||
for node_id, node_data in workflow.walk_nodes(specific_node_type=NodeType.LLM):
|
||||
has_skill = SkillService._has_skill(node_data)
|
||||
has_skill = SkillService._has_skill(dict(node_data))
|
||||
|
||||
if has_skill:
|
||||
tool_dependencies = SkillService._extract_tool_dependencies(node_data)
|
||||
tool_dependencies = SkillService._extract_tool_dependencies_with_compiler(app, dict(node_data), user_id)
|
||||
result.append(
|
||||
NodeSkillInfo(
|
||||
node_id=node_id,
|
||||
@ -70,7 +84,7 @@ class SkillService:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _has_skill(node_data: dict) -> bool:
|
||||
def _has_skill(node_data: dict[str, Any]) -> bool:
|
||||
"""Check if node has any skill prompts."""
|
||||
prompt_template = node_data.get("prompt_template", [])
|
||||
if isinstance(prompt_template, list):
|
||||
@ -80,29 +94,67 @@ class SkillService:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_dependencies(node_data: dict) -> list[ToolDependency]:
|
||||
"""Extract deduplicated tool dependencies from node data."""
|
||||
dependencies: dict[str, ToolDependency] = {}
|
||||
def _extract_tool_dependencies_with_compiler(
|
||||
app: App, node_data: dict[str, Any], user_id: str
|
||||
) -> list[ToolDependency]:
|
||||
"""Extract tool dependencies using SkillCompiler.
|
||||
|
||||
This method loads the SkillBundle and AppAssetFileTree, then uses
|
||||
SkillCompiler.compile_one() to properly extract tool dependencies
|
||||
including transitive dependencies from referenced skill files.
|
||||
"""
|
||||
# Get the draft assets to obtain assets_id and file_tree
|
||||
assets = AppAssetService.get_assets(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
user_id=user_id,
|
||||
is_draft=True,
|
||||
)
|
||||
|
||||
if not assets:
|
||||
logger.warning("No draft assets found for app_id=%s", app.id)
|
||||
return []
|
||||
|
||||
assets_id = assets.id
|
||||
file_tree = assets.asset_tree
|
||||
|
||||
# Load the skill bundle
|
||||
try:
|
||||
bundle = SkillManager.load_bundle(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
assets_id=assets_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to load skill bundle for app_id=%s: %s", app.id, e)
|
||||
# Return empty if bundle doesn't exist (no skills compiled yet)
|
||||
return []
|
||||
|
||||
# Compile each skill prompt and collect tool dependencies
|
||||
compiler = SkillCompiler()
|
||||
tool_deps_list: list[ToolDependencies] = []
|
||||
|
||||
prompt_template = node_data.get("prompt_template", [])
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if isinstance(prompt, dict) and prompt.get("skill", False):
|
||||
metadata_dict = prompt.get("metadata") or {}
|
||||
tools_dict = metadata_dict.get("tools", {})
|
||||
text: str = prompt.get("text", "")
|
||||
metadata: dict[str, Any] = prompt.get("metadata") or {}
|
||||
|
||||
for uuid, tool_data in tools_dict.items():
|
||||
if isinstance(tool_data, dict):
|
||||
try:
|
||||
ref = ToolReference.model_validate({"uuid": uuid, **tool_data})
|
||||
key = f"{ref.provider}.{ref.tool_name}"
|
||||
if key not in dependencies:
|
||||
dependencies[key] = ToolDependency(
|
||||
type=ref.type,
|
||||
provider=ref.provider,
|
||||
tool_name=ref.tool_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Skipping invalid tool reference: uuid=%s", uuid)
|
||||
skill_entry = compiler.compile_one(
|
||||
bundle=bundle,
|
||||
document=SkillDocument(skill_id="anonymous", content=text, metadata=metadata),
|
||||
file_tree=file_tree,
|
||||
base_path=AppAssets.PATH,
|
||||
)
|
||||
tool_deps_list.append(skill_entry.tools)
|
||||
|
||||
return list(dependencies.values())
|
||||
if not tool_deps_list:
|
||||
return []
|
||||
|
||||
# Merge all tool dependencies
|
||||
from functools import reduce
|
||||
|
||||
merged = reduce(lambda x, y: x.merge(y), tool_deps_list)
|
||||
|
||||
return merged.dependencies
|
||||
|
||||
Reference in New Issue
Block a user