mirror of
https://github.com/langgenius/dify.git
synced 2026-03-23 23:37:55 +08:00
feat(skill): tool switcher implementation
- Introduced a new regex pattern for tool groups to support multiple tool placeholders. - Updated the DefaultToolResolver to format outputs for specific built-in tools (bash, python). - Enhanced the SkillCompiler to filter out disabled tools in tool groups, ensuring only enabled tools are rendered. - Added tests to verify the correct behavior of tool group filtering and rendering.
This commit is contained in:
@ -29,6 +29,7 @@ class ToolReference(BaseModel):
|
||||
type: ToolProviderType
|
||||
provider: str
|
||||
tool_name: str
|
||||
enabled: bool = True
|
||||
credential_id: str | None = None
|
||||
configuration: ToolConfiguration | None = None
|
||||
|
||||
|
||||
@ -25,6 +25,11 @@ class ToolResolver(Protocol):
|
||||
@dataclass(frozen=True)
|
||||
class CompilerConfig:
|
||||
tool_pattern: re.Pattern[str] = re.compile(r"§\[tool\]\.\[.*?\]\.\[.*?\]\.\[(.*?)\]§")
|
||||
# Evolved format: a group of tool placeholders wrapped by "[...]".
|
||||
# Example: [§[tool].[provider].[name].[uuid-a]§, §[tool].[provider].[name].[uuid-b]§]
|
||||
tool_group_pattern: re.Pattern[str] = re.compile(
|
||||
r"\[\s*§\[tool\]\.\[[^\]]+\]\.\[[^\]]+\]\.\[[^\]]+\]§(?:\s*,\s*§\[tool\]\.\[[^\]]+\]\.\[[^\]]+\]\.\[[^\]]+\]§)*\s*\]"
|
||||
)
|
||||
file_pattern: re.Pattern[str] = re.compile(r"§\[file\]\.\[.*?\]\.\[(.*?)\]§")
|
||||
|
||||
|
||||
@ -51,6 +56,11 @@ class FileTreePathResolver:
|
||||
|
||||
class DefaultToolResolver:
|
||||
def resolve(self, tool_ref: ToolReference) -> str:
|
||||
# Keep outputs readable for the most common built-in tools.
|
||||
if tool_ref.provider == "sandbox" and tool_ref.tool_name == "bash":
|
||||
return f"[Bash Command: {tool_ref.tool_name}_{tool_ref.uuid}]"
|
||||
if tool_ref.provider == "sandbox" and tool_ref.tool_name == "python":
|
||||
return f"[Python Code: {tool_ref.tool_name}_{tool_ref.uuid}]"
|
||||
return f"[Executable: {tool_ref.tool_name}_{tool_ref.uuid} --help command]"
|
||||
|
||||
|
||||
@ -347,12 +357,33 @@ class SkillCompiler:
|
||||
|
||||
def replace_tool(match: re.Match[str]) -> str:
|
||||
tool_id = match.group(1)
|
||||
tool_ref = metadata.tools.get(tool_id)
|
||||
tool_ref: ToolReference | None = metadata.tools.get(tool_id)
|
||||
if not tool_ref:
|
||||
return f"[Tool not found: {tool_id}]"
|
||||
if not tool_ref.enabled:
|
||||
return ""
|
||||
return self._tool_resolver.resolve(tool_ref)
|
||||
|
||||
def replace_tool_group(match: re.Match[str]) -> str:
|
||||
group_text = match.group(0)
|
||||
enabled_renders: list[str] = []
|
||||
|
||||
for tool_match in self._config.tool_pattern.finditer(group_text):
|
||||
tool_id = tool_match.group(1)
|
||||
tool_ref: ToolReference | None = metadata.tools.get(tool_id)
|
||||
if not tool_ref:
|
||||
enabled_renders.append(f"[Tool not found: {tool_id}]")
|
||||
continue
|
||||
if not tool_ref.enabled:
|
||||
continue
|
||||
enabled_renders.append(self._tool_resolver.resolve(tool_ref))
|
||||
|
||||
if not enabled_renders:
|
||||
return ""
|
||||
return "[" + ", ".join(enabled_renders) + "]"
|
||||
|
||||
content = self._config.file_pattern.sub(replace_file, content)
|
||||
content = self._config.tool_group_pattern.sub(replace_tool_group, content)
|
||||
content = self._config.tool_pattern.sub(replace_tool, content)
|
||||
return content
|
||||
|
||||
@ -368,16 +399,18 @@ class SkillCompiler:
|
||||
if isinstance(meta, ToolReference):
|
||||
tools[uuid] = meta
|
||||
elif isinstance(meta, dict):
|
||||
tool_type_str = cast(str | None, meta.get("type"))
|
||||
meta_dict = cast(dict[str, Any], meta)
|
||||
tool_type_str = cast(str | None, meta_dict.get("type"))
|
||||
if tool_type_str:
|
||||
tools[uuid] = ToolReference(
|
||||
uuid=uuid,
|
||||
type=ToolProviderType.value_of(tool_type_str),
|
||||
provider=provider,
|
||||
tool_name=name,
|
||||
credential_id=cast(str | None, meta.get("credential_id")),
|
||||
configuration=ToolConfiguration.model_validate(meta.get("configuration", {}))
|
||||
if meta.get("configuration")
|
||||
enabled=cast(bool, meta_dict.get("enabled", True)),
|
||||
credential_id=cast(str | None, meta_dict.get("credential_id")),
|
||||
configuration=ToolConfiguration.model_validate(meta_dict.get("configuration", {}))
|
||||
if meta_dict.get("configuration")
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
@ -223,6 +223,57 @@ class TestSkillCompilerCompileOne:
|
||||
assert result.tools.dependencies[0].tool_name == "python"
|
||||
|
||||
|
||||
class TestSkillCompilerToolGroups:
|
||||
def test_compile_tool_group_filters_disabled(self):
|
||||
# given
|
||||
doc = SkillDocument(
|
||||
skill_id="skill-1",
|
||||
content="Tools:[§[tool].[sandbox].[bash].[tool-a]§, §[tool].[sandbox].[bash].[tool-b]§]",
|
||||
metadata={
|
||||
"tools": {
|
||||
"tool-a": {"type": ToolProviderType.BUILT_IN.value, "enabled": True},
|
||||
"tool-b": {"type": ToolProviderType.BUILT_IN.value, "enabled": False},
|
||||
}
|
||||
},
|
||||
)
|
||||
tree = create_file_tree(
|
||||
AppAssetNode.create_file("skill-1", "skill.md"),
|
||||
)
|
||||
compiler = SkillCompiler()
|
||||
|
||||
# when
|
||||
artifact_set = compiler.compile_all([doc], tree, "assets-1")
|
||||
|
||||
# then
|
||||
artifact = artifact_set.get("skill-1")
|
||||
assert artifact is not None
|
||||
assert artifact.content == "Tools:[[Bash Command: bash_tool-a]]"
|
||||
|
||||
def test_compile_tool_group_renders_nothing_when_all_disabled(self):
|
||||
# given
|
||||
doc = SkillDocument(
|
||||
skill_id="skill-1",
|
||||
content="Tools:[§[tool].[sandbox].[bash].[tool-b]§]",
|
||||
metadata={
|
||||
"tools": {
|
||||
"tool-b": {"type": ToolProviderType.BUILT_IN.value, "enabled": False},
|
||||
}
|
||||
},
|
||||
)
|
||||
tree = create_file_tree(
|
||||
AppAssetNode.create_file("skill-1", "skill.md"),
|
||||
)
|
||||
compiler = SkillCompiler()
|
||||
|
||||
# when
|
||||
artifact_set = compiler.compile_all([doc], tree, "assets-1")
|
||||
|
||||
# then
|
||||
artifact = artifact_set.get("skill-1")
|
||||
assert artifact is not None
|
||||
assert artifact.content == "Tools:"
|
||||
|
||||
|
||||
class TestSkillCompilerComplexGraph:
|
||||
def test_large_complex_dependency_graph(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user