feat: implement MCP specification 2025-06-18 (#25766)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Novice
2025-10-27 17:07:51 +08:00
committed by GitHub
parent b6e0abadab
commit 0ded6303c1
33 changed files with 4863 additions and 1128 deletions

View File

@ -0,0 +1,328 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from pydantic import BaseModel
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
from core.file import helpers as file_helpers
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.encryption import create_provider_encrypter
if TYPE_CHECKING:
from models.tools import MCPToolProvider
# Constants
CLIENT_NAME = "Dify"
CLIENT_URI = "https://github.com/langgenius/dify"
DEFAULT_TOKEN_TYPE = "Bearer"
DEFAULT_EXPIRES_IN = 3600
MASK_CHAR = "*"
MIN_UNMASK_LENGTH = 6
class MCPSupportGrantType(StrEnum):
"""The supported grant types for MCP"""
AUTHORIZATION_CODE = "authorization_code"
CLIENT_CREDENTIALS = "client_credentials"
REFRESH_TOKEN = "refresh_token"
class MCPAuthentication(BaseModel):
client_id: str
client_secret: str | None = None
class MCPConfiguration(BaseModel):
timeout: float = 30
sse_read_timeout: float = 300
class MCPProviderEntity(BaseModel):
"""MCP Provider domain entity for business logic operations"""
# Basic identification
id: str
provider_id: str # server_identifier
name: str
tenant_id: str
user_id: str
# Server connection info
server_url: str # encrypted URL
headers: dict[str, str] # encrypted headers
timeout: float
sse_read_timeout: float
# Authentication related
authed: bool
credentials: dict[str, Any] # encrypted credentials
code_verifier: str | None = None # for OAuth
# Tools and display info
tools: list[dict[str, Any]] # parsed tools list
icon: str | dict[str, str] # parsed icon
# Timestamps
created_at: datetime
updated_at: datetime
@classmethod
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
"""Create entity from database model with decryption"""
return cls(
id=db_provider.id,
provider_id=db_provider.server_identifier,
name=db_provider.name,
tenant_id=db_provider.tenant_id,
user_id=db_provider.user_id,
server_url=db_provider.server_url,
headers=db_provider.headers,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
authed=db_provider.authed,
credentials=db_provider.credentials,
tools=db_provider.tool_dict,
icon=db_provider.icon or "",
created_at=db_provider.created_at,
updated_at=db_provider.updated_at,
)
@property
def redirect_url(self) -> str:
"""OAuth redirect URL"""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
# Get grant type from credentials
credentials = self.decrypt_credentials()
# Try to get grant_type from different locations
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
# For nested structure, check if client_information has grant_types
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
client_info = credentials["client_information"]
# If grant_types is specified in client_information, use it to determine grant_type
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
if "client_credentials" in client_info["grant_types"]:
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
elif "authorization_code" in client_info["grant_types"]:
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
# Configure based on grant type
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
grant_types = ["refresh_token"]
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
response_types = [] if is_client_credentials else ["code"]
redirect_uris = [] if is_client_credentials else [self.redirect_url]
return OAuthClientMetadata(
redirect_uris=redirect_uris,
token_endpoint_auth_method="none",
grant_types=grant_types,
response_types=response_types,
client_name=CLIENT_NAME,
client_uri=CLIENT_URI,
)
@property
def provider_icon(self) -> dict[str, str] | str:
"""Get provider icon, handling both dict and string formats"""
if isinstance(self.icon, dict):
return self.icon
try:
return json.loads(self.icon)
except (json.JSONDecodeError, TypeError):
# If not JSON, assume it's a file path
return file_helpers.get_signed_file_url(self.icon)
def to_api_response(self, user_name: str | None = None, include_sensitive: bool = True) -> dict[str, Any]:
"""Convert to API response format
Args:
user_name: User name to display
include_sensitive: If False, skip expensive decryption operations (for list view optimization)
"""
response = {
"id": self.id,
"author": user_name or "Anonymous",
"name": self.name,
"icon": self.provider_icon,
"type": ToolProviderType.MCP.value,
"is_team_authorization": self.authed,
"server_url": self.masked_server_url(),
"server_identifier": self.provider_id,
"updated_at": int(self.updated_at.timestamp()),
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
}
# Add configuration
response["configuration"] = {
"timeout": str(self.timeout),
"sse_read_timeout": str(self.sse_read_timeout),
}
# Skip expensive operations when sensitive data is not needed (e.g., list view)
if not include_sensitive:
response["masked_headers"] = {}
response["is_dynamic_registration"] = True
else:
# Add masked headers
response["masked_headers"] = self.masked_headers()
# Add authentication info if available
masked_creds = self.masked_credentials()
if masked_creds:
response["authentication"] = masked_creds
response["is_dynamic_registration"] = self.credentials.get("client_information", {}).get(
"is_dynamic_registration", True
)
return response
def retrieve_client_information(self) -> OAuthClientInformation | None:
"""OAuth client information if available"""
credentials = self.decrypt_credentials()
if not credentials:
return None
# Check if we have nested client_information structure
if "client_information" not in credentials:
return None
client_info_data = credentials["client_information"]
if isinstance(client_info_data, dict):
if "encrypted_client_secret" in client_info_data:
client_info_data["client_secret"] = encrypter.decrypt_token(
self.tenant_id, client_info_data["encrypted_client_secret"]
)
return OAuthClientInformation.model_validate(client_info_data)
return None
def retrieve_tokens(self) -> OAuthTokens | None:
"""OAuth tokens if available"""
if not self.credentials:
return None
credentials = self.decrypt_credentials()
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
refresh_token=credentials.get("refresh_token", ""),
)
def masked_server_url(self) -> str:
"""Masked server URL for display"""
parsed = urlparse(self.decrypt_server_url())
if parsed.path and parsed.path != "/":
masked = parsed._replace(path="/******")
return masked.geturl()
return parsed.geturl()
def _mask_value(self, value: str) -> str:
"""Mask a sensitive value for display"""
if len(value) > MIN_UNMASK_LENGTH:
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
else:
return MASK_CHAR * len(value)
def masked_headers(self) -> dict[str, str]:
"""Masked headers for display"""
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
def masked_credentials(self) -> dict[str, str]:
"""Masked credentials for display"""
credentials = self.decrypt_credentials()
if not credentials:
return {}
masked = {}
if "client_information" not in credentials or not isinstance(credentials["client_information"], dict):
return {}
client_info = credentials["client_information"]
# Mask sensitive fields from nested structure
if client_info.get("client_id"):
masked["client_id"] = self._mask_value(client_info["client_id"])
if client_info.get("encrypted_client_secret"):
masked["client_secret"] = self._mask_value(
encrypter.decrypt_token(self.tenant_id, client_info["encrypted_client_secret"])
)
if client_info.get("client_secret"):
masked["client_secret"] = self._mask_value(client_info["client_secret"])
return masked
def decrypt_server_url(self) -> str:
"""Decrypt server URL"""
return encrypter.decrypt_token(self.tenant_id, self.server_url)
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""Generic method to decrypt dictionary fields"""
if not data:
return {}
# Only decrypt fields that are actually encrypted
# For nested structures, client_information is not encrypted as a whole
encrypted_fields = []
for key, value in data.items():
# Skip nested objects - they are not encrypted
if isinstance(value, dict):
continue
# Only process string values that might be encrypted
if isinstance(value, str) and value:
encrypted_fields.append(key)
if not encrypted_fields:
return data
# Create dynamic config only for encrypted fields
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# Decrypt only the encrypted fields
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
# Merge decrypted data with original data (preserving non-encrypted fields)
result = data.copy()
result.update(decrypted_data)
return result
def decrypt_headers(self) -> dict[str, Any]:
"""Decrypt headers"""
return self._decrypt_dict(self.headers)
def decrypt_credentials(self) -> dict[str, Any]:
"""Decrypt credentials"""
return self._decrypt_dict(self.credentials)
def decrypt_authentication(self) -> dict[str, Any]:
"""Decrypt authentication"""
# Option 1: if headers is provided, use it and don't need to get token
headers = self.decrypt_headers()
# Option 2: Add OAuth token if authed and no headers provided
if not self.headers and self.authed:
token = self.retrieve_tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
return headers

View File

@ -6,11 +6,15 @@ import secrets
import urllib.parse
from urllib.parse import urljoin, urlparse
import httpx
from pydantic import BaseModel, ValidationError
from httpx import ConnectError, HTTPStatusError, RequestError
from pydantic import ValidationError
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
from core.helper import ssrf_proxy
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
from core.mcp.error import MCPRefreshTokenError
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
@ -19,21 +23,10 @@ from core.mcp.types import (
)
from extensions.ext_redis import redis_client
LATEST_PROTOCOL_VERSION = "1.0"
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str
def generate_pkce_challenge() -> tuple[str, str]:
"""Generate PKCE challenge and verifier."""
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
@ -80,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
"""
Handle the callback from the OAuth provider.
Returns:
A tuple of (callback_state, tokens) that can be used by the caller to save data.
"""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
@ -93,30 +91,32 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
provider.save_tokens(tokens)
return full_state_data
return full_state_data, tokens
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
"""Check if the server supports OAuth 2.0 Resource Discovery."""
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
if b_query:
url_for_resource_discovery += f"?{b_query}"
if b_fragment:
url_for_resource_discovery += f"#{b_fragment}"
try:
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
response = httpx.get(url_for_resource_discovery, headers=headers)
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
if 200 <= response.status_code < 300:
body = response.json()
if "authorization_server_url" in body:
# Support both singular and plural forms
if body.get("authorization_servers"):
return True, body["authorization_servers"][0]
elif body.get("authorization_server_url"):
return True, body["authorization_server_url"][0]
else:
return False, ""
return False, ""
except httpx.RequestError:
except RequestError:
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""
@ -126,27 +126,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
# First check if the server supports OAuth 2.0 Resource Discovery
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
if support_resource_discovery:
url = oauth_discovery_url
# The oauth_discovery_url is the authorization server base URL
# Try OpenID Connect discovery first (more common), then OAuth 2.0
urls_to_try = [
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
]
else:
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
try:
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
response = httpx.get(url, headers=headers)
if response.status_code == 404:
return None
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
except httpx.RequestError as e:
if isinstance(e, httpx.ConnectError):
response = httpx.get(url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
for url in urls_to_try:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 404:
return None
continue
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
raise
except (RequestError, HTTPStatusError) as e:
if isinstance(e, ConnectError):
response = ssrf_proxy.get(url)
if response.status_code == 404:
continue # Try next URL
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
# For other errors, try next URL
continue
return None # No metadata found
def start_authorization(
@ -213,7 +223,7 @@ def exchange_authorization(
redirect_uri: str,
) -> OAuthTokens:
"""Exchanges an authorization code for an access token."""
grant_type = "authorization_code"
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
if metadata:
token_url = metadata.token_endpoint
@ -233,7 +243,7 @@ def exchange_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = httpx.post(token_url, data=params)
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
@ -246,7 +256,7 @@ def refresh_authorization(
refresh_token: str,
) -> OAuthTokens:
"""Exchange a refresh token for an updated access token."""
grant_type = "refresh_token"
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
if metadata:
token_url = metadata.token_endpoint
@ -263,10 +273,55 @@ def refresh_authorization(
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = httpx.post(token_url, data=params)
try:
response = ssrf_proxy.post(token_url, data=params)
except ssrf_proxy.MaxRetriesExceededError as e:
raise MCPRefreshTokenError(e) from e
if not response.is_success:
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
raise MCPRefreshTokenError(response.text)
return OAuthTokens.model_validate(response.json())
def client_credentials_flow(
server_url: str,
metadata: OAuthMetadata | None,
client_information: OAuthClientInformation,
scope: str | None = None,
) -> OAuthTokens:
"""Execute Client Credentials Flow to get access token."""
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Support both Basic Auth and body parameters for client authentication
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {"grant_type": grant_type}
if scope:
data["scope"] = scope
# If client_secret is provided, use Basic Auth (preferred method)
if client_information.client_secret:
credentials = f"{client_information.client_id}:{client_information.client_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
else:
# Fall back to including credentials in the body
data["client_id"] = client_information.client_id
if client_information.client_secret:
data["client_secret"] = client_information.client_secret
response = ssrf_proxy.post(token_url, headers=headers, data=data)
if not response.is_success:
raise ValueError(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
)
return OAuthTokens.model_validate(response.json())
@ -283,7 +338,7 @@ def register_client(
else:
registration_url = urljoin(server_url, "/register")
response = httpx.post(
response = ssrf_proxy.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
@ -294,28 +349,111 @@ def register_client(
def auth(
provider: OAuthClientProvider,
server_url: str,
provider: MCPProviderEntity,
authorization_code: str | None = None,
state_param: str | None = None,
for_list: bool = False,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url)
) -> AuthResult:
"""
Orchestrates the full auth flow with a server using secure Redis state storage.
This function performs only network operations and returns actions that need
to be performed by the caller (such as saving data to database).
Args:
provider: The MCP provider entity
authorization_code: Optional authorization code from OAuth callback
state_param: Optional state parameter from OAuth callback
Returns:
AuthResult containing actions to be performed and response data
"""
actions: list[AuthAction] = []
server_url = provider.decrypt_server_url()
server_metadata = discover_oauth_metadata(server_url)
client_metadata = provider.client_metadata
provider_id = provider.id
tenant_id = provider.tenant_id
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
# Determine grant type based on server metadata
if not server_metadata:
raise ValueError("Failed to discover OAuth metadata from server")
supported_grant_types = server_metadata.grant_types_supported or []
# Convert to lowercase for comparison
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
# Determine which grant type to use
effective_grant_type = None
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
else:
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
# Get stored credentials
credentials = provider.decrypt_credentials()
# Handle client registration if needed
client_information = provider.client_information()
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
# For client credentials flow, we don't need to register client dynamically
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Client should provide client_id and client_secret directly
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
try:
full_information = register_client(server_url, metadata, provider.client_metadata)
except httpx.RequestError as e:
full_information = register_client(server_url, server_metadata, client_metadata)
except RequestError as e:
raise ValueError(f"Could not register OAuth client: {e}")
provider.save_client_information(full_information)
# Return action to save client information
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CLIENT_INFO,
data={"client_information": full_information.model_dump()},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
client_information = full_information
# Exchange authorization code for tokens
# Handle client credentials flow
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Direct token request without user interaction
try:
scope = credentials.get("scope")
tokens = client_credentials_flow(
server_url,
server_metadata,
client_information,
scope,
)
# Return action to save tokens and grant type
token_data = tokens.model_dump()
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=token_data,
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Client credentials flow failed: {e}")
# Exchange authorization code for tokens (Authorization Code flow)
if authorization_code is not None:
if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
@ -335,35 +473,69 @@ def auth(
tokens = exchange_authorization(
server_url,
metadata,
server_metadata,
client_information,
authorization_code,
code_verifier,
redirect_uri,
)
provider.save_tokens(tokens)
return {"result": "success"}
provider_tokens = provider.tokens()
# Return action to save tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
provider_tokens = provider.retrieve_tokens()
# Handle token refresh or new authorization
if provider_tokens and provider_tokens.refresh_token:
try:
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
provider.save_tokens(new_tokens)
return {"result": "success"}
except Exception as e:
new_tokens = refresh_authorization(
server_url, server_metadata, client_information, provider_tokens.refresh_token
)
# Return action to save new tokens
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_TOKENS,
data=new_tokens.model_dump(),
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"result": "success"})
except (RequestError, ValueError, KeyError) as e:
# RequestError: HTTP request failed
# ValueError: Invalid response data
# KeyError: Missing required fields in response
raise ValueError(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow
# Start new authorization flow (only for authorization code flow)
authorization_url, code_verifier = start_authorization(
server_url,
metadata,
server_metadata,
client_information,
provider.redirect_url,
provider.mcp_provider.id,
provider.mcp_provider.tenant_id,
redirect_url,
provider_id,
tenant_id,
)
provider.save_code_verifier(code_verifier)
return {"authorization_url": authorization_url}
# Return action to save code verifier
actions.append(
AuthAction(
action_type=AuthActionType.SAVE_CODE_VERIFIER,
data={"code_verifier": code_verifier},
provider_id=provider_id,
tenant_id=tenant_id,
)
)
return AuthResult(actions=actions, response={"authorization_url": authorization_url})

View File

@ -1,77 +0,0 @@
from configs import dify_config
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthTokens,
)
from models.tools import MCPToolProvider
from services.tools.mcp_tools_manage_service import MCPToolManageService
class OAuthClientProvider:
mcp_provider: MCPToolProvider
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
if for_list:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
else:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
return OAuthClientMetadata(
redirect_uris=[self.redirect_url],
token_endpoint_auth_method="none",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
)
def client_information(self) -> OAuthClientInformation | None:
"""Loads information about this OAuth client."""
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
if not client_information:
return None
return OAuthClientInformation.model_validate(client_information)
def save_client_information(self, client_information: OAuthClientInformationFull):
"""Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials(
self.mcp_provider,
{"client_information": client_information.model_dump()},
)
def tokens(self) -> OAuthTokens | None:
"""Loads any existing OAuth tokens for the current session."""
credentials = self.mcp_provider.decrypted_credentials
if not credentials:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"),
expires_in=int(credentials.get("expires_in", "3600") or 3600),
refresh_token=credentials.get("refresh_token", ""),
)
def save_tokens(self, tokens: OAuthTokens):
"""Stores new OAuth tokens for the current session."""
# update mcp provider credentials
token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str):
"""Saves a PKCE code verifier for the current session."""
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
def code_verifier(self) -> str:
"""Loads the PKCE code verifier for the current session."""
# get code verifier from mcp provider credentials
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))

191
api/core/mcp/auth_client.py Normal file
View File

@ -0,0 +1,191 @@
"""
MCP Client with Authentication Retry Support
This module provides an enhanced MCPClient that automatically handles
authentication failures and retries operations after refreshing tokens.
"""
import logging
from collections.abc import Callable
from typing import Any
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, Tool
from extensions.ext_database import db
logger = logging.getLogger(__name__)
class MCPClientWithAuthRetry(MCPClient):
"""
An enhanced MCPClient that provides automatic authentication retry.
This class extends MCPClient and intercepts MCPAuthError exceptions
to refresh authentication before retrying failed operations.
Note: This class uses lazy session creation - database sessions are only
created when authentication retry is actually needed, not on every request.
"""
def __init__(
self,
server_url: str,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
provider_entity: MCPProviderEntity | None = None,
authorization_code: str | None = None,
by_server_id: bool = False,
):
"""
Initialize the MCP client with auth retry capability.
Args:
server_url: The MCP server URL
headers: Optional headers for requests
timeout: Request timeout
sse_read_timeout: SSE read timeout
provider_entity: Provider entity for authentication
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
"""
super().__init__(server_url, headers, timeout, sse_read_timeout)
self.provider_entity = provider_entity
self.authorization_code = authorization_code
self.by_server_id = by_server_id
self._has_retried = False
def _handle_auth_error(self, error: MCPAuthError) -> None:
"""
Handle authentication error by refreshing tokens.
This method creates a short-lived database session only when authentication
retry is needed, minimizing database connection hold time.
Args:
error: The authentication error
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
if not self.provider_entity:
raise error
if self._has_retried:
raise error
self._has_retried = True
try:
# Create a temporary session only for auth retry
# This session is short-lived and only exists during the auth operation
from services.tools.mcp_tools_manage_service import MCPToolManageService
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# Perform authentication using the service's auth method
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
# Retrieve new tokens
self.provider_entity = mcp_service.get_provider_entity(
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
)
# Session is closed here, before we update headers
token = self.provider_entity.retrieve_tokens()
if not token:
raise MCPAuthError("Authentication failed - no token received")
# Update headers with new token
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
# Clear authorization code after first use
self.authorization_code = None
except MCPAuthError:
# Re-raise MCPAuthError as is
raise
except Exception as e:
# Catch all exceptions during auth retry
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
"""
Execute a function with authentication retry logic.
Args:
func: The function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
The result of the function call
Raises:
MCPAuthError: If authentication fails after retries
Any other exceptions from the function
"""
try:
return func(*args, **kwargs)
except MCPAuthError as e:
self._handle_auth_error(e)
# Re-initialize the connection with new headers
if self._initialized:
# Clean up existing connection
self._exit_stack.close()
self._session = None
self._initialized = False
# Re-initialize with new headers
self._initialize()
self._initialized = True
return func(*args, **kwargs)
finally:
# Reset retry flag after operation completes
self._has_retried = False
def __enter__(self):
"""Enter the context manager with retry support."""
def initialize_with_retry():
super(MCPClientWithAuthRetry, self).__enter__()
return self
return self._execute_with_retry(initialize_with_retry)
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server with auth retry.
Returns:
List of available tools
Raises:
MCPAuthError: If authentication fails after retries
"""
return self._execute_with_retry(super().list_tools)
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server with auth retry.
Args:
tool_name: Name of the tool to invoke
tool_args: Arguments for the tool
Returns:
Result of the tool invocation
Raises:
MCPAuthError: If authentication fails after retries
"""
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)

View File

View File

@ -46,7 +46,7 @@ class SSETransport:
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
):
"""Initialize the SSE transport.
@ -255,7 +255,7 @@ def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
"""
Client transport for SSE.
@ -276,31 +276,34 @@ def sse_client(
read_queue: ReadQueue | None = None
write_queue: WriteQueue | None = None
with ThreadPoolExecutor() as executor:
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
executor = ThreadPoolExecutor()
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
read_queue, write_queue = transport.connect(executor, client, event_source)
read_queue, write_queue = transport.connect(executor, client, event_source)
yield read_queue, write_queue
yield read_queue, write_queue
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):

View File

@ -434,45 +434,48 @@ def streamablehttp_client(
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
with ThreadPoolExecutor(max_workers=2) as executor:
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
executor = ThreadPoolExecutor(max_workers=2)
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
try:
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)

View File

@ -1,10 +1,13 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Generic, TypeVar
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
from pydantic import BaseModel
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
class AuthActionType(StrEnum):
"""Types of actions that can be performed during auth flow."""
SAVE_CLIENT_INFO = "save_client_info"
SAVE_TOKENS = "save_tokens"
SAVE_CODE_VERIFIER = "save_code_verifier"
START_AUTHORIZATION = "start_authorization"
SUCCESS = "success"
class AuthAction(BaseModel):
"""Represents an action that needs to be performed as a result of auth flow."""
action_type: AuthActionType
data: dict[str, Any]
provider_id: str | None = None
tenant_id: str | None = None
class AuthResult(BaseModel):
"""Result of auth function containing actions to be performed and response data."""
actions: list[AuthAction]
response: dict[str, str]
class OAuthCallbackState(BaseModel):
"""State data stored in Redis during OAuth callback flow."""
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str

View File

@ -8,3 +8,7 @@ class MCPConnectionError(MCPError):
class MCPAuthError(MCPConnectionError):
pass
class MCPRefreshTokenError(MCPError):
pass

View File

@ -7,9 +7,9 @@ from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.error import MCPConnectionError
from core.mcp.session.client_session import ClientSession
from core.mcp.types import Tool
from core.mcp.types import CallToolResult, Tool
logger = logging.getLogger(__name__)
@ -18,40 +18,18 @@ class MCPClient:
def __init__(
self,
server_url: str,
provider_id: str,
tenant_id: str,
authed: bool = True,
authorization_code: str | None = None,
for_list: bool = False,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
self.token = self.provider.tokens()
# Initialize session and client objects
self._session: ClientSession | None = None
self._streams_context: AbstractContextManager[Any] | None = None
self._session_context: ClientSession | None = None
self._exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
def __enter__(self):
@ -85,61 +63,42 @@ class MCPClient:
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")
def connect_server(
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
):
from core.mcp.auth.auth_flow import auth
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
"""
Connect to the MCP server using streamable http or sse.
Default to streamable http.
Args:
client_factory: The client factory to use(streamablehttp_client or sse_client).
method_name: The method name to use(mcp or sse).
"""
streams_context = client_factory(
url=self.server_url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(streams_context)
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
self._session.initialize()
return
except MCPAuthError:
if not self.authed:
raise
try:
auth(self.provider, self.server_url, self.authorization_code)
except Exception as e:
raise ValueError(f"Failed to authenticate: {e}")
self.token = self.provider.tokens()
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(session_context)
self._session.initialize()
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
if not self._initialized or not self._session:
"""List available tools from the MCP server"""
if not self._session:
raise ValueError("Session not initialized.")
response = self._session.list_tools()
tools = response.tools
return tools
return response.tools
def invoke_tool(self, tool_name: str, tool_args: dict):
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""Call a tool"""
if not self._initialized or not self._session:
if not self._session:
raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args)
@ -153,6 +112,4 @@ class MCPClient:
raise ValueError(f"Error during cleanup: {e}")
finally:
self._session = None
self._session_context = None
self._streams_context = None
self._initialized = False

View File

@ -201,11 +201,14 @@ class BaseSession(
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
except TimeoutError:
# If the receiver loop is still running after timeout, we'll force shutdown
pass
# Cancel the future to interrupt the receiver loop
self._receiver_future.cancel()
# Shutdown the executor
if self._executor:
self._executor.shutdown(wait=True)
# Use non-blocking shutdown to prevent hanging
# The receiver thread should have already exited due to the None message in the queue
self._executor.shutdown(wait=False)
def send_request(
self,

View File

@ -284,7 +284,7 @@ class ClientSession(
def complete(
self,
ref: types.ResourceReference | types.PromptReference,
ref: types.ResourceTemplateReference | types.PromptReference,
argument: dict[str, str],
) -> types.CompleteResult:
"""Send a completion/complete request."""

View File

@ -1,13 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import (
Annotated,
Any,
Generic,
Literal,
TypeAlias,
TypeVar,
)
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic.networks import AnyUrl, UrlConstraints
@ -33,6 +26,7 @@ for reference.
LATEST_PROTOCOL_VERSION = "2025-03-26"
# Server support 2024-11-05 to allow claude to use.
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
@ -55,14 +49,22 @@ class RequestParams(BaseModel):
meta: Meta | None = Field(alias="_meta", default=None)
class PaginatedRequestParams(RequestParams):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class NotificationParams(BaseModel):
class Meta(BaseModel):
model_config = ConfigDict(extra="allow")
meta: Meta | None = Field(alias="_meta", default=None)
"""
This parameter name is reserved by MCP to allow clients and servers to attach
additional metadata to their notifications.
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
@ -79,12 +81,11 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow")
class PaginatedRequest(Request[RequestParamsT, MethodT]):
cursor: Cursor | None = None
"""
An opaque token representing the current pagination position.
If provided, the server should return results starting after this cursor.
"""
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
"""Base class for paginated requests,
matching the schema's PaginatedRequest interface."""
params: PaginatedRequestParams | None = None
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
@ -98,13 +99,12 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
class Result(BaseModel):
"""Base class for JSON-RPC results."""
model_config = ConfigDict(extra="allow")
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
This result property is reserved by the protocol to allow clients and servers to
attach additional metadata to their responses.
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class PaginatedResult(Result):
@ -186,10 +186,26 @@ class EmptyResult(Result):
"""A response that indicates success but carries no data."""
class Implementation(BaseModel):
"""Describes the name and version of an MCP implementation."""
class BaseMetadata(BaseModel):
"""Base class for entities with name and optional title fields."""
name: str
"""The programmatic name of the entity."""
title: str | None = None
"""
Intended for UI and end-user contexts — optimized to be human-readable and easily understood,
even by those unfamiliar with domain-specific terminology.
If not provided, the name should be used for display (except for Tool,
where `annotations.title` should be given precedence over using `name`,
if present).
"""
class Implementation(BaseMetadata):
"""Describes the name and version of an MCP implementation."""
version: str
model_config = ConfigDict(extra="allow")
@ -203,7 +219,7 @@ class RootsCapability(BaseModel):
class SamplingCapability(BaseModel):
"""Capability for logging operations."""
"""Capability for sampling operations."""
model_config = ConfigDict(extra="allow")
@ -252,6 +268,12 @@ class LoggingCapability(BaseModel):
model_config = ConfigDict(extra="allow")
class CompletionsCapability(BaseModel):
"""Capability for completions operations."""
model_config = ConfigDict(extra="allow")
class ServerCapabilities(BaseModel):
"""Capabilities that a server may support."""
@ -265,6 +287,8 @@ class ServerCapabilities(BaseModel):
"""Present if the server offers any resources to read."""
tools: ToolsCapability | None = None
"""Present if the server offers any tools to call."""
completions: CompletionsCapability | None = None
"""Present if the server offers autocompletion suggestions for prompts and resources."""
model_config = ConfigDict(extra="allow")
@ -284,7 +308,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]])
to begin initialization.
"""
method: Literal["initialize"]
method: Literal["initialize"] = "initialize"
params: InitializeRequestParams
@ -305,7 +329,7 @@ class InitializedNotification(Notification[NotificationParams | None, Literal["n
finished.
"""
method: Literal["notifications/initialized"]
method: Literal["notifications/initialized"] = "notifications/initialized"
params: NotificationParams | None = None
@ -315,7 +339,7 @@ class PingRequest(Request[RequestParams | None, Literal["ping"]]):
still alive.
"""
method: Literal["ping"]
method: Literal["ping"] = "ping"
params: RequestParams | None = None
@ -334,6 +358,11 @@ class ProgressNotificationParams(NotificationParams):
"""
total: float | None = None
"""Total number of items to process (or total progress required), if known."""
message: str | None = None
"""
Message related to progress. This should provide relevant human readable
progress information.
"""
model_config = ConfigDict(extra="allow")
@ -343,15 +372,14 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not
long-running request.
"""
method: Literal["notifications/progress"]
method: Literal["notifications/progress"] = "notifications/progress"
params: ProgressNotificationParams
class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]):
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
"""Sent from the client to request a list of resources the server has."""
method: Literal["resources/list"]
params: RequestParams | None = None
method: Literal["resources/list"] = "resources/list"
class Annotations(BaseModel):
@ -360,13 +388,11 @@ class Annotations(BaseModel):
model_config = ConfigDict(extra="allow")
class Resource(BaseModel):
class Resource(BaseMetadata):
"""A known resource that the server is capable of reading."""
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
"""The URI of this resource."""
name: str
"""A human-readable name for this resource."""
description: str | None = None
"""A description of what this resource represents."""
mimeType: str | None = None
@ -379,10 +405,15 @@ class Resource(BaseModel):
This can be used by Hosts to display file sizes and estimate context window usage.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class ResourceTemplate(BaseModel):
class ResourceTemplate(BaseMetadata):
"""A template description for resources available on the server."""
uriTemplate: str
@ -390,8 +421,6 @@ class ResourceTemplate(BaseModel):
A URI template (according to RFC 6570) that can be used to construct resource
URIs.
"""
name: str
"""A human-readable name for the type of resource this template refers to."""
description: str | None = None
"""A human-readable description of what this template is for."""
mimeType: str | None = None
@ -400,6 +429,11 @@ class ResourceTemplate(BaseModel):
included if all resources matching this template have the same type.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -409,11 +443,10 @@ class ListResourcesResult(PaginatedResult):
resources: list[Resource]
class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]):
class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]):
"""Sent from the client to request a list of resource templates the server has."""
method: Literal["resources/templates/list"]
params: RequestParams | None = None
method: Literal["resources/templates/list"] = "resources/templates/list"
class ListResourceTemplatesResult(PaginatedResult):
@ -436,7 +469,7 @@ class ReadResourceRequestParams(RequestParams):
class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
"""Sent from the client to the server, to read a specific resource URI."""
method: Literal["resources/read"]
method: Literal["resources/read"] = "resources/read"
params: ReadResourceRequestParams
@ -447,6 +480,11 @@ class ResourceContents(BaseModel):
"""The URI of this resource."""
mimeType: str | None = None
"""The MIME type of this resource, if known."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -481,7 +519,7 @@ class ResourceListChangedNotification(
of resources it can read from has changed.
"""
method: Literal["notifications/resources/list_changed"]
method: Literal["notifications/resources/list_changed"] = "notifications/resources/list_changed"
params: NotificationParams | None = None
@ -502,7 +540,7 @@ class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscr
whenever a particular resource changes.
"""
method: Literal["resources/subscribe"]
method: Literal["resources/subscribe"] = "resources/subscribe"
params: SubscribeRequestParams
@ -520,7 +558,7 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un
the server.
"""
method: Literal["resources/unsubscribe"]
method: Literal["resources/unsubscribe"] = "resources/unsubscribe"
params: UnsubscribeRequestParams
@ -543,15 +581,14 @@ class ResourceUpdatedNotification(
changed and may need to be read again.
"""
method: Literal["notifications/resources/updated"]
method: Literal["notifications/resources/updated"] = "notifications/resources/updated"
params: ResourceUpdatedNotificationParams
class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]):
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
"""Sent from the client to request a list of prompts and prompt templates."""
method: Literal["prompts/list"]
params: RequestParams | None = None
method: Literal["prompts/list"] = "prompts/list"
class PromptArgument(BaseModel):
@ -566,15 +603,18 @@ class PromptArgument(BaseModel):
model_config = ConfigDict(extra="allow")
class Prompt(BaseModel):
class Prompt(BaseMetadata):
"""A prompt or prompt template that the server offers."""
name: str
"""The name of the prompt or prompt template."""
description: str | None = None
"""An optional description of what this prompt provides."""
arguments: list[PromptArgument] | None = None
"""A list of arguments to use for templating the prompt."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -597,7 +637,7 @@ class GetPromptRequestParams(RequestParams):
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
"""Used by the client to get a prompt provided by the server."""
method: Literal["prompts/get"]
method: Literal["prompts/get"] = "prompts/get"
params: GetPromptRequestParams
@ -608,6 +648,11 @@ class TextContent(BaseModel):
text: str
"""The text content of the message."""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -623,6 +668,31 @@ class ImageContent(BaseModel):
image types.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class AudioContent(BaseModel):
"""Audio content for a message."""
type: Literal["audio"]
data: str
"""The base64-encoded audio data."""
mimeType: str
"""
The MIME type of the audio. Different providers may support different
audio types.
"""
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -630,7 +700,7 @@ class SamplingMessage(BaseModel):
"""Describes a message issued to or received from an LLM API."""
role: Role
content: TextContent | ImageContent
content: TextContent | ImageContent | AudioContent
model_config = ConfigDict(extra="allow")
@ -645,14 +715,36 @@ class EmbeddedResource(BaseModel):
type: Literal["resource"]
resource: TextResourceContents | BlobResourceContents
annotations: Annotations | None = None
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
class ResourceLink(Resource):
"""
A resource that the server is capable of reading, included in a prompt or tool call result.
Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
"""
type: Literal["resource_link"]
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
"""A content block that can be used in prompts and tool results."""
Content: TypeAlias = ContentBlock
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
class PromptMessage(BaseModel):
"""Describes a message returned as part of a prompt."""
role: Role
content: TextContent | ImageContent | EmbeddedResource
content: ContentBlock
model_config = ConfigDict(extra="allow")
@ -672,15 +764,14 @@ class PromptListChangedNotification(
of prompts it offers has changed.
"""
method: Literal["notifications/prompts/list_changed"]
method: Literal["notifications/prompts/list_changed"] = "notifications/prompts/list_changed"
params: NotificationParams | None = None
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
"""Sent from the client to request a list of tools the server has."""
method: Literal["tools/list"]
params: RequestParams | None = None
method: Literal["tools/list"] = "tools/list"
class ToolAnnotations(BaseModel):
@ -731,17 +822,25 @@ class ToolAnnotations(BaseModel):
model_config = ConfigDict(extra="allow")
class Tool(BaseModel):
class Tool(BaseMetadata):
"""Definition for a tool the client can call."""
name: str
"""The name of the tool."""
description: str | None = None
"""A human-readable description of the tool."""
inputSchema: dict[str, Any]
"""A JSON Schema object defining the expected parameters for the tool."""
outputSchema: dict[str, Any] | None = None
"""
An optional JSON Schema object defining the structure of the tool's output
returned in the structuredContent field of a CallToolResult.
"""
annotations: ToolAnnotations | None = None
"""Optional additional tool information."""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -762,14 +861,16 @@ class CallToolRequestParams(RequestParams):
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
"""Used by the client to invoke a tool provided by the server."""
method: Literal["tools/call"]
method: Literal["tools/call"] = "tools/call"
params: CallToolRequestParams
class CallToolResult(Result):
"""The server's response to a tool call."""
content: list[TextContent | ImageContent | EmbeddedResource]
content: list[ContentBlock]
structuredContent: dict[str, Any] | None = None
"""An optional JSON object that represents the structured result of the tool call."""
isError: bool = False
@ -779,7 +880,7 @@ class ToolListChangedNotification(Notification[NotificationParams | None, Litera
of tools it offers has changed.
"""
method: Literal["notifications/tools/list_changed"]
method: Literal["notifications/tools/list_changed"] = "notifications/tools/list_changed"
params: NotificationParams | None = None
@ -797,7 +898,7 @@ class SetLevelRequestParams(RequestParams):
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
"""A request from the client to the server, to enable or adjust logging."""
method: Literal["logging/setLevel"]
method: Literal["logging/setLevel"] = "logging/setLevel"
params: SetLevelRequestParams
@ -808,7 +909,7 @@ class LoggingMessageNotificationParams(NotificationParams):
"""The severity of this log message."""
logger: str | None = None
"""An optional name of the logger issuing this message."""
data: Any = None
data: Any
"""
The data to be logged, such as a string message or an object. Any JSON serializable
type is allowed here.
@ -819,7 +920,7 @@ class LoggingMessageNotificationParams(NotificationParams):
class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
"""Notification of a log message passed from server to client."""
method: Literal["notifications/message"]
method: Literal["notifications/message"] = "notifications/message"
params: LoggingMessageNotificationParams
@ -914,7 +1015,7 @@ class CreateMessageRequestParams(RequestParams):
class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
"""A request from the server to sample an LLM via the client."""
method: Literal["sampling/createMessage"]
method: Literal["sampling/createMessage"] = "sampling/createMessage"
params: CreateMessageRequestParams
@ -925,14 +1026,14 @@ class CreateMessageResult(Result):
"""The client's response to a sampling/create_message request from the server."""
role: Role
content: TextContent | ImageContent
content: TextContent | ImageContent | AudioContent
model: str
"""The name of the model that generated the message."""
stopReason: StopReason | None = None
"""The reason why sampling stopped, if known."""
class ResourceReference(BaseModel):
class ResourceTemplateReference(BaseModel):
"""A reference to a resource or resource template definition."""
type: Literal["ref/resource"]
@ -960,18 +1061,28 @@ class CompletionArgument(BaseModel):
model_config = ConfigDict(extra="allow")
class CompletionContext(BaseModel):
"""Additional, optional context for completions."""
arguments: dict[str, str] | None = None
"""Previously-resolved variables in a URI template or prompt."""
model_config = ConfigDict(extra="allow")
class CompleteRequestParams(RequestParams):
"""Parameters for completion requests."""
ref: ResourceReference | PromptReference
ref: ResourceTemplateReference | PromptReference
argument: CompletionArgument
context: CompletionContext | None = None
"""Additional, optional context for completions"""
model_config = ConfigDict(extra="allow")
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
"""A request from the client to the server, to ask for completion options."""
method: Literal["completion/complete"]
method: Literal["completion/complete"] = "completion/complete"
params: CompleteRequestParams
@ -1010,7 +1121,7 @@ class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
structure or access specific locations that the client has permission to read from.
"""
method: Literal["roots/list"]
method: Literal["roots/list"] = "roots/list"
params: RequestParams | None = None
@ -1029,6 +1140,11 @@ class Root(BaseModel):
identifier for the root, which may be useful for display purposes or for
referencing the root in other parts of the application.
"""
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
"""
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
for notes on _meta usage.
"""
model_config = ConfigDict(extra="allow")
@ -1054,7 +1170,7 @@ class RootsListChangedNotification(
using the ListRootsRequest.
"""
method: Literal["notifications/roots/list_changed"]
method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed"
params: NotificationParams | None = None
@ -1074,7 +1190,7 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n
previously-issued request.
"""
method: Literal["notifications/cancelled"]
method: Literal["notifications/cancelled"] = "notifications/cancelled"
params: CancelledNotificationParams

View File

@ -217,3 +217,16 @@ class Tool(ABC):
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
)
def create_variable_message(
self, variable_name: str, variable_value: Any, stream: bool = False
) -> ToolInvokeMessage:
"""
create a variable message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.VARIABLE,
message=ToolInvokeMessage.VariableMessage(
variable_name=variable_name, variable_value=variable_value, stream=stream
),
)

View File

@ -4,6 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
@ -44,10 +45,14 @@ class ToolProviderApiEntity(BaseModel):
server_url: str | None = Field(default="", description="The server url of the tool")
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool")
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
@field_validator("tools", mode="before")
@classmethod
@ -70,8 +75,15 @@ class ToolProviderApiEntity(BaseModel):
if self.type == ToolProviderType.MCP:
optional_fields.update(self.optional_field("updated_at", self.updated_at))
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
optional_fields.update(self.optional_field("timeout", self.timeout))
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
optional_fields.update(
self.optional_field(
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
)
)
optional_fields.update(
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
)
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
return {

View File

@ -1,6 +1,6 @@
import json
from typing import Any, Self
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -52,18 +52,25 @@ class MCPToolProviderController(ToolProviderController):
"""
from db provider
"""
tools = []
tools_data = json.loads(db_provider.tools)
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
user = db_provider.load_user()
# Convert to entity first
provider_entity = db_provider.to_entity()
return cls.from_entity(provider_entity)
@classmethod
def from_entity(cls, entity: MCPProviderEntity) -> Self:
"""
create a MCPToolProviderController from a MCPProviderEntity
"""
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
tools = [
ToolEntity(
identity=ToolIdentity(
author=user.name if user else "Anonymous",
author="Anonymous", # Tool level author is not stored
name=remote_mcp_tool.name,
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
provider=db_provider.server_identifier,
icon=db_provider.icon,
provider=entity.provider_id,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
description=ToolDescription(
@ -72,31 +79,32 @@ class MCPToolProviderController(ToolProviderController):
),
llm=remote_mcp_tool.description or "",
),
output_schema=remote_mcp_tool.outputSchema or {},
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
)
for remote_mcp_tool in remote_mcp_tools
]
if not db_provider.icon:
if not entity.icon:
raise ValueError("Database provider icon is required")
return cls(
entity=ToolProviderEntityWithPlugin(
identity=ToolProviderIdentity(
author=user.name if user else "Anonymous",
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
author="Anonymous", # Provider level author is not stored in entity
name=entity.name,
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
description=I18nObject(en_US="", zh_Hans=""),
icon=db_provider.icon,
icon=entity.icon if isinstance(entity.icon, str) else "",
),
plugin_id=None,
credentials_schema=[],
tools=tools,
),
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers=db_provider.decrypted_headers or {},
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
provider_id=entity.provider_id,
tenant_id=entity.tenant_id,
server_url=entity.server_url,
headers=entity.headers,
timeout=entity.timeout,
sse_read_timeout=entity.sse_read_timeout,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):

View File

@ -3,12 +3,13 @@ import json
from collections.abc import Generator
from typing import Any
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
class MCPTool(Tool):
@ -44,40 +45,32 @@ class MCPTool(Tool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
from core.tools.errors import ToolInvokeError
try:
with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
raise ToolInvokeError("Please auth the tool first") from e
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
result = self.invoke_remote_mcp_tool(tool_parameters)
# handle dify tool output
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
# handle MCP structured output
if self.entity.output_schema and result.structuredContent:
for k, v in result.structuredContent.items():
yield self.create_variable_message(k, v)
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
"""Process text content and yield appropriate messages."""
try:
content_json = json.loads(content.text)
yield from self._process_json_content(content_json)
except json.JSONDecodeError:
yield self.create_text_message(content.text)
# Check if content looks like JSON before attempting to parse
text = content.text.strip()
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
try:
content_json = json.loads(text)
yield from self._process_json_content(content_json)
return
except json.JSONDecodeError:
pass
# If not JSON or parsing failed, treat as plain text
yield self.create_text_message(content.text)
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
"""Process JSON content based on its type."""
@ -126,3 +119,44 @@ class MCPTool(Tool):
for key, value in parameter.items()
if value is not None and not (isinstance(value, str) and value.strip() == "")
}
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
from sqlalchemy.orm import Session
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import MCPToolManageService
# Step 1: Load provider entity and credentials in a short-lived session
# This minimizes database connection hold time
with Session(db.engine, expire_on_commit=False) as session:
mcp_service = MCPToolManageService(session=session)
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
# Decrypt and prepare all credentials before closing session
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_headers()
# Try to get existing token and add to headers
if not headers:
tokens = provider_entity.retrieve_tokens()
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# Step 2: Session is now closed, perform network operations without holding database connection
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
try:
with MCPClientWithAuthRetry(
server_url=server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e

View File

@ -14,17 +14,32 @@ from sqlalchemy.orm import Session
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.runtime.variable_pool import VariablePool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.builtin_tool.tool import BuiltinTool
@ -40,21 +55,11 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.provider_ids import ToolProviderID
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
@ -719,7 +724,9 @@ class ToolManager:
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters:
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
for mcp_provider in mcp_providers:
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
@ -774,17 +781,12 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.where(
MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
if provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
controller = MCPToolProviderController.from_db(provider)
@ -922,16 +924,15 @@ class ToolManager:
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
try:
mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first()
)
if mcp_provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
return mcp_provider.provider_icon
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
try:
mcp_provider = mcp_service.get_provider_entity(
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
)
return mcp_provider.provider_icon
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}