mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
feat: mcp client init
This commit is contained in:
134
api/services/tools/mcp_tools_mange_service.py
Normal file
134
api/services/tools/mcp_tools_mange_service.py
Normal file
@ -0,0 +1,134 @@
|
||||
import json
|
||||
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from extensions.ext_database import db
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
"""
|
||||
Service class for managing mcp tools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider | None:
|
||||
return (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(
|
||||
MCPToolProvider.id == provider_id,
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_mcp_provider(
|
||||
tenant_id: str, name: str, server_url: str, user_id: str, icon: str, icon_type: str, icon_background: str
|
||||
) -> dict:
|
||||
if (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.name == name)
|
||||
.first()
|
||||
):
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
mcp_tool = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
server_url=server_url,
|
||||
user_id=user_id,
|
||||
authed=False,
|
||||
tools="[]",
|
||||
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
|
||||
)
|
||||
db.session.add(mcp_tool)
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def retrieve_mcp_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
mcp_providers = db.session.query(MCPToolProvider).filter(MCPToolProvider.tenant_id == tenant_id).all()
|
||||
return [ToolTransformService.mcp_provider_to_user_provider(mcp_provider) for mcp_provider in mcp_providers]
|
||||
|
||||
@classmethod
|
||||
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
mcp_provider.authed = True
|
||||
db.session.commit()
|
||||
return ToolProviderApiEntity(
|
||||
id=mcp_provider.id,
|
||||
name=mcp_provider.name,
|
||||
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
|
||||
type=ToolProviderType.MCP,
|
||||
icon=mcp_provider.icon,
|
||||
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
|
||||
server_url=mcp_provider.server_url,
|
||||
updated_at=mcp_provider.updated_at,
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def retrieve_mcp_provider(cls, tenant_id: str, provider_id: str):
|
||||
provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
return ToolTransformService.mcp_provider_to_user_provider(provider).to_dict()
|
||||
|
||||
@classmethod
|
||||
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
|
||||
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_tool is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
db.session.delete(mcp_tool)
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def update_mcp_provider(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: str,
|
||||
name: str,
|
||||
server_url: str,
|
||||
icon: str,
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
encrypted_credentials: dict,
|
||||
):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
mcp_provider.name = name
|
||||
mcp_provider.server_url = server_url
|
||||
mcp_provider.icon = (
|
||||
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
|
||||
)
|
||||
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **encrypted_credentials})
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
|
||||
mcp_provider.authed = authed
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_mcp_token(cls, provider_id: str, tenant_id: str):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_provider is None:
|
||||
raise ValueError("MCP provider not found")
|
||||
return mcp_provider.credentials.get("access_token", None)
|
||||
@ -5,6 +5,7 @@ from typing import Optional, Union, cast
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@ -21,7 +22,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -187,6 +188,38 @@ class ToolTransformService:
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mcp_provider_to_user_provider(db_provider: MCPToolProvider) -> ToolProviderApiEntity:
|
||||
return ToolProviderApiEntity(
|
||||
id=db_provider.id,
|
||||
author=db_provider.user.name if db_provider.user else "Anonymous",
|
||||
name=db_provider.name,
|
||||
icon=db_provider.provider_icon,
|
||||
type=ToolProviderType.MCP,
|
||||
is_team_authorization=db_provider.authed,
|
||||
server_url=db_provider.server_url,
|
||||
tools=ToolTransformService.mcp_tool_to_user_tool(
|
||||
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||
),
|
||||
updated_at=db_provider.updated_at,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
|
||||
return [
|
||||
ToolApiEntity(
|
||||
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
|
||||
name=tool.name,
|
||||
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
|
||||
description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
|
||||
labels=[],
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def api_provider_to_user_provider(
|
||||
cls,
|
||||
@ -304,3 +337,59 @@ class ToolTransformService:
|
||||
parameters=tool.parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
|
||||
"""
|
||||
Convert MCP JSON schema to tool parameters
|
||||
|
||||
:param schema: JSON schema dictionary
|
||||
:return: list of ToolParameter instances
|
||||
"""
|
||||
|
||||
def create_parameter(name: str, description: str, param_type: str, required: bool) -> ToolParameter:
|
||||
"""Create a ToolParameter instance with given attributes"""
|
||||
return ToolParameter(
|
||||
name=name,
|
||||
llm_description=description,
|
||||
label=I18nObject(en_US=name),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
required=required,
|
||||
type=ToolParameter.ToolParameterType(param_type),
|
||||
human_description=I18nObject(en_US=description),
|
||||
)
|
||||
|
||||
def process_array(name: str, description: str, items: dict, required: bool) -> list[ToolParameter]:
|
||||
"""Process array type properties"""
|
||||
item_type = items.get("type", "string")
|
||||
if item_type == "object" and "properties" in items:
|
||||
return process_properties(items["properties"], items.get("required", []), f"{name}[0]")
|
||||
|
||||
return [create_parameter(name, description, item_type, required)]
|
||||
|
||||
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
|
||||
"""Process properties recursively"""
|
||||
parameters = []
|
||||
for name, prop in props.items():
|
||||
current_name = f"{prefix}.{name}" if prefix else name
|
||||
current_description = prop.get("description", "")
|
||||
prop_type = prop.get("type", "string")
|
||||
|
||||
if isinstance(prop_type, list):
|
||||
prop_type = prop_type[0]
|
||||
if prop_type == "integer":
|
||||
prop_type = "number"
|
||||
if prop_type == "array":
|
||||
parameters.extend(
|
||||
process_array(current_name, current_description, prop.get("items", {}), name in required)
|
||||
)
|
||||
elif prop_type == "object" and "properties" in prop:
|
||||
parameters.extend(process_properties(prop["properties"], prop.get("required", []), current_name))
|
||||
else:
|
||||
parameters.append(create_parameter(current_name, current_description, prop_type, name in required))
|
||||
|
||||
return parameters
|
||||
|
||||
if schema.get("type") == "object" and "properties" in schema:
|
||||
return process_properties(schema["properties"], schema.get("required", []))
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user