mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 07:28:05 +08:00
refactor(mcp): clean the client service code
This commit is contained in:
53
api/services/tools/mcp_oauth_service.py
Normal file
53
api/services/tools/mcp_oauth_service.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
MCP OAuth Service - handles OAuth-related database operations
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.types import OAuthClientInformationFull, OAuthTokens
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
|
||||
class MCPOAuthService:
|
||||
"""Service for handling MCP OAuth operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
self._mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
|
||||
"""Get provider entity by ID"""
|
||||
if by_server_id:
|
||||
db_provider = self._mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
|
||||
else:
|
||||
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||
return db_provider.to_entity()
|
||||
|
||||
def save_client_information(
|
||||
self, provider_id: str, tenant_id: str, client_information: OAuthClientInformationFull
|
||||
) -> None:
|
||||
"""Save OAuth client information"""
|
||||
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||
self._mcp_service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials={"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def save_tokens(self, provider_id: str, tenant_id: str, tokens: OAuthTokens, authed: bool = True) -> None:
|
||||
"""Save OAuth tokens"""
|
||||
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||
token_dict = tokens.model_dump()
|
||||
self._mcp_service.update_provider_credentials(provider=db_provider, credentials=token_dict, authed=authed)
|
||||
|
||||
def save_code_verifier(self, provider_id: str, tenant_id: str, code_verifier: str) -> None:
|
||||
"""Save PKCE code verifier"""
|
||||
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||
self._mcp_service.update_provider_credentials(
|
||||
provider=db_provider, credentials={"code_verifier": code_verifier}
|
||||
)
|
||||
|
||||
def clear_credentials(self, provider_id: str, tenant_id: str) -> None:
|
||||
"""Clear provider credentials"""
|
||||
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
|
||||
self._mcp_service.clear_provider_credentials(provider=db_provider)
|
||||
@ -1,24 +1,27 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import OAuthTokens
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
||||
|
||||
|
||||
@ -27,8 +30,10 @@ class MCPToolManageService:
|
||||
Service class for managing mcp tools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
def _encrypt_headers(self, headers: dict[str, str], tenant_id: str) -> dict[str, str]:
|
||||
"""
|
||||
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
|
||||
|
||||
@ -57,48 +62,53 @@ class MCPToolManageService:
|
||||
|
||||
return encrypter_instance.encrypt(headers)
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_remote_mcp_tools(server_url: str, headers: dict[str, str], timeout: float, sse_read_timeout: float):
|
||||
with MCPClient(
|
||||
def _retrieve_remote_mcp_tools(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
provider_entity: MCPProviderEntity,
|
||||
auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]],
|
||||
):
|
||||
"""Retrieve tools from remote MCP server"""
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth_callback,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return tools
|
||||
|
||||
@staticmethod
|
||||
def _process_headers(headers: dict[str, str], tokens: OAuthTokens | None = None):
|
||||
headers = headers or {}
|
||||
def _process_headers(self, headers: dict[str, str], tokens: OAuthTokens | None = None) -> dict[str, str]:
|
||||
"""Process headers and add OAuth token if available"""
|
||||
headers = headers.copy() if headers else {}
|
||||
if tokens:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
if not res:
|
||||
def get_provider_by_id(self, provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||
"""Get MCP provider by ID"""
|
||||
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
|
||||
provider = self._session.scalar(stmt)
|
||||
if not provider:
|
||||
raise ValueError("MCP tool not found")
|
||||
return res
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
|
||||
.first()
|
||||
def get_provider_by_server_identifier(self, server_identifier: str, tenant_id: str) -> MCPToolProvider:
|
||||
"""Get MCP provider by server identifier"""
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
|
||||
)
|
||||
if not res:
|
||||
provider = self._session.scalar(stmt)
|
||||
if not provider:
|
||||
raise ValueError("MCP tool not found")
|
||||
return res
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def create_mcp_provider(
|
||||
def create_provider(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
name: str,
|
||||
server_url: str,
|
||||
@ -111,19 +121,20 @@ class MCPToolManageService:
|
||||
sse_read_timeout: float,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Create a new MCP provider"""
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
existing_provider = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
or_(
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.server_url_hash == server_url_hash,
|
||||
MCPToolProvider.server_identifier == server_identifier,
|
||||
),
|
||||
)
|
||||
.first()
|
||||
|
||||
# Check for existing provider
|
||||
stmt = select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
or_(
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.server_url_hash == server_url_hash,
|
||||
MCPToolProvider.server_identifier == server_identifier,
|
||||
),
|
||||
)
|
||||
existing_provider = self._session.scalar(stmt)
|
||||
|
||||
if existing_provider:
|
||||
if existing_provider.name == name:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
@ -131,13 +142,17 @@ class MCPToolManageService:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
if existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
|
||||
# Encrypt server URL
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
|
||||
# Encrypt headers
|
||||
encrypted_headers = None
|
||||
if headers:
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
||||
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
|
||||
encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
|
||||
# Create provider
|
||||
mcp_tool = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
@ -152,91 +167,68 @@ class MCPToolManageService:
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
encrypted_headers=encrypted_headers,
|
||||
)
|
||||
db.session.add(mcp_tool)
|
||||
db.session.commit()
|
||||
|
||||
self._session.add(mcp_tool)
|
||||
self._session.commit()
|
||||
|
||||
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
|
||||
@staticmethod
|
||||
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
||||
mcp_providers = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id)
|
||||
.order_by(MCPToolProvider.name)
|
||||
.all()
|
||||
)
|
||||
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
||||
"""List all MCP providers for a tenant"""
|
||||
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||
|
||||
mcp_providers = self._session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
|
||||
for mcp_provider in mcp_providers
|
||||
ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list)
|
||||
for provider in mcp_providers
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
|
||||
"""List tools from remote MCP server"""
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
server_url = mcp_provider.decrypted_server_url
|
||||
authed = mcp_provider.authed
|
||||
headers = mcp_provider.decrypted_headers
|
||||
timeout = mcp_provider.timeout
|
||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
||||
# Load provider and convert to entity
|
||||
db_provider = self.get_provider_by_id(provider_id, tenant_id)
|
||||
provider_entity = db_provider.to_entity()
|
||||
|
||||
# Handle authentication headers if authed
|
||||
if not authed:
|
||||
if not provider_entity.authed:
|
||||
raise ValueError("Please auth the tool first")
|
||||
|
||||
provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
|
||||
tokens = provider.tokens()
|
||||
headers = cls._process_headers(headers, tokens)
|
||||
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
headers = self._process_headers(provider_entity.headers, tokens)
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
try:
|
||||
tools = cls._retrieve_remote_mcp_tools(server_url, headers, timeout, sse_read_timeout)
|
||||
except MCPAuthError:
|
||||
try:
|
||||
auth(provider, server_url, None, None, False)
|
||||
tokens = provider.tokens()
|
||||
re_authed_headers = cls._process_headers(headers, tokens)
|
||||
tools = cls._retrieve_remote_mcp_tools(server_url, re_authed_headers, timeout, sse_read_timeout)
|
||||
except Exception:
|
||||
raise ValueError("Please auth the tool first")
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity, auth)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
try:
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
mcp_provider.authed = True
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
raise
|
||||
# Update database record with new tools
|
||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
db_provider.authed = True
|
||||
db_provider.updated_at = datetime.now()
|
||||
self._session.commit()
|
||||
|
||||
user = mcp_provider.load_user()
|
||||
return ToolProviderApiEntity(
|
||||
id=mcp_provider.id,
|
||||
name=mcp_provider.name,
|
||||
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
|
||||
type=ToolProviderType.MCP,
|
||||
icon=mcp_provider.icon,
|
||||
author=user.name if user else "Anonymous",
|
||||
server_url=mcp_provider.masked_server_url,
|
||||
updated_at=int(mcp_provider.updated_at.timestamp()),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
||||
plugin_unique_identifier=mcp_provider.server_identifier,
|
||||
# Create API response using entity
|
||||
user = db_provider.load_user()
|
||||
response = provider_entity.to_api_response(
|
||||
user_name=user.name if user else None,
|
||||
)
|
||||
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
|
||||
response["plugin_unique_identifier"] = provider_entity.provider_id
|
||||
|
||||
@classmethod
|
||||
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
|
||||
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
return ToolProviderApiEntity(**response)
|
||||
|
||||
db.session.delete(mcp_tool)
|
||||
db.session.commit()
|
||||
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
|
||||
"""Delete an MCP provider"""
|
||||
mcp_tool = self.get_provider_by_id(provider_id, tenant_id)
|
||||
self._session.delete(mcp_tool)
|
||||
self._session.commit()
|
||||
|
||||
@classmethod
|
||||
def update_mcp_provider(
|
||||
cls,
|
||||
def update_provider(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
provider_id: str,
|
||||
name: str,
|
||||
@ -248,21 +240,27 @@ class MCPToolManageService:
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
) -> None:
|
||||
"""Update an MCP provider"""
|
||||
mcp_provider = self.get_provider_by_id(provider_id, tenant_id)
|
||||
|
||||
reconnect_result = None
|
||||
encrypted_server_url = None
|
||||
server_url_hash = None
|
||||
|
||||
# Handle server URL update
|
||||
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
|
||||
if server_url_hash != mcp_provider.server_url_hash:
|
||||
reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
|
||||
reconnect_result = self._reconnect_provider(
|
||||
server_url=server_url,
|
||||
provider=mcp_provider,
|
||||
)
|
||||
|
||||
try:
|
||||
# Update basic fields
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
mcp_provider.name = name
|
||||
mcp_provider.icon = (
|
||||
@ -270,6 +268,7 @@ class MCPToolManageService:
|
||||
)
|
||||
mcp_provider.server_identifier = server_identifier
|
||||
|
||||
# Update server URL if changed
|
||||
if encrypted_server_url is not None and server_url_hash is not None:
|
||||
mcp_provider.server_url = encrypted_server_url
|
||||
mcp_provider.server_url_hash = server_url_hash
|
||||
@ -279,6 +278,7 @@ class MCPToolManageService:
|
||||
mcp_provider.tools = reconnect_result["tools"]
|
||||
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
|
||||
|
||||
# Update optional fields
|
||||
if timeout is not None:
|
||||
mcp_provider.timeout = timeout
|
||||
if sse_read_timeout is not None:
|
||||
@ -286,13 +286,15 @@ class MCPToolManageService:
|
||||
if headers is not None:
|
||||
# Encrypt headers
|
||||
if headers:
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
||||
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
|
||||
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
else:
|
||||
mcp_provider.encrypted_headers = None
|
||||
db.session.commit()
|
||||
|
||||
self._session.commit()
|
||||
|
||||
except IntegrityError as e:
|
||||
db.session.rollback()
|
||||
self._session.rollback()
|
||||
error_msg = str(e.orig)
|
||||
if "unique_mcp_provider_name" in error_msg:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
@ -302,54 +304,55 @@ class MCPToolManageService:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
raise
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
self._session.rollback()
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def update_mcp_provider_credentials(
|
||||
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
||||
):
|
||||
def update_provider_credentials(
|
||||
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
||||
) -> None:
|
||||
"""Update provider credentials"""
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
|
||||
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
||||
provider_controller = MCPToolProviderController.from_db(provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=mcp_provider.tenant_id,
|
||||
tenant_id=provider.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_config_cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
|
||||
mcp_provider.authed = authed
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
provider.updated_at = datetime.now()
|
||||
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
|
||||
provider.authed = authed
|
||||
if not authed:
|
||||
mcp_provider.tools = "[]"
|
||||
db.session.commit()
|
||||
provider.tools = "[]"
|
||||
|
||||
@classmethod
|
||||
def clear_mcp_provider_credentials(
|
||||
cls,
|
||||
mcp_provider: MCPToolProvider,
|
||||
):
|
||||
mcp_provider.tools = "[]"
|
||||
mcp_provider.encrypted_credentials = "{}"
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
mcp_provider.authed = False
|
||||
db.session.commit()
|
||||
self._session.commit()
|
||||
|
||||
@classmethod
|
||||
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str) -> dict[str, Any]:
|
||||
# Get the existing provider to access headers and timeout settings
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
headers = mcp_provider.decrypted_headers
|
||||
timeout = mcp_provider.timeout
|
||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
||||
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
|
||||
"""Clear provider credentials"""
|
||||
provider.tools = "[]"
|
||||
provider.encrypted_credentials = "{}"
|
||||
provider.updated_at = datetime.now()
|
||||
provider.authed = False
|
||||
self._session.commit()
|
||||
|
||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]:
|
||||
"""Attempt to reconnect to MCP provider with new server URL"""
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
|
||||
provider_entity = provider.to_entity()
|
||||
headers = provider_entity.headers
|
||||
timeout = provider_entity.timeout
|
||||
sse_read_timeout = provider_entity.sse_read_timeout
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return {
|
||||
|
||||
@ -221,27 +221,20 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
|
||||
# Convert to entity and use its API response method
|
||||
provider_entity = db_provider.to_entity()
|
||||
user = db_provider.load_user()
|
||||
return ToolProviderApiEntity(
|
||||
id=db_provider.server_identifier if not for_list else db_provider.id,
|
||||
author=user.name if user else "Anonymous",
|
||||
name=db_provider.name,
|
||||
icon=db_provider.provider_icon,
|
||||
type=ToolProviderType.MCP,
|
||||
is_team_authorization=db_provider.authed,
|
||||
server_url=db_provider.masked_server_url,
|
||||
tools=ToolTransformService.mcp_tool_to_user_tool(
|
||||
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||
),
|
||||
updated_at=int(db_provider.updated_at.timestamp()),
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
server_identifier=db_provider.server_identifier,
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
masked_headers=db_provider.masked_headers,
|
||||
original_headers=db_provider.decrypted_headers,
|
||||
|
||||
response = provider_entity.to_api_response(user_name=user.name if user else None)
|
||||
|
||||
# Add additional fields specific to the transform
|
||||
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
|
||||
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(
|
||||
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||
)
|
||||
response["server_identifier"] = db_provider.server_identifier
|
||||
|
||||
return ToolProviderApiEntity(**response)
|
||||
|
||||
@staticmethod
|
||||
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
|
||||
@ -403,7 +396,7 @@ class ToolTransformService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
|
||||
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
|
||||
"""
|
||||
Convert MCP JSON schema to tool parameters
|
||||
|
||||
@ -412,7 +405,7 @@ class ToolTransformService:
|
||||
"""
|
||||
|
||||
def create_parameter(
|
||||
name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
|
||||
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
|
||||
) -> ToolParameter:
|
||||
"""Create a ToolParameter instance with given attributes"""
|
||||
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
|
||||
@ -427,7 +420,9 @@ class ToolTransformService:
|
||||
**input_schema_dict,
|
||||
)
|
||||
|
||||
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
|
||||
def process_properties(
|
||||
props: dict[str, dict[str, Any]], required: list[str], prefix: str = ""
|
||||
) -> list[ToolParameter]:
|
||||
"""Process properties recursively"""
|
||||
TYPE_MAPPING = {"integer": "number", "float": "number"}
|
||||
COMPLEX_TYPES = ["array", "object"]
|
||||
|
||||
Reference in New Issue
Block a user