mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
feat: implement MCP specification 2025-06-18 (#25766)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
328
api/core/entities/mcp_provider.py
Normal file
328
api/core/entities/mcp_provider.py
Normal file
@ -0,0 +1,328 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.tools import MCPToolProvider
|
||||
|
||||
# Constants
|
||||
CLIENT_NAME = "Dify"
|
||||
CLIENT_URI = "https://github.com/langgenius/dify"
|
||||
DEFAULT_TOKEN_TYPE = "Bearer"
|
||||
DEFAULT_EXPIRES_IN = 3600
|
||||
MASK_CHAR = "*"
|
||||
MIN_UNMASK_LENGTH = 6
|
||||
|
||||
|
||||
class MCPSupportGrantType(StrEnum):
|
||||
"""The supported grant types for MCP"""
|
||||
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
class MCPAuthentication(BaseModel):
|
||||
client_id: str
|
||||
client_secret: str | None = None
|
||||
|
||||
|
||||
class MCPConfiguration(BaseModel):
|
||||
timeout: float = 30
|
||||
sse_read_timeout: float = 300
|
||||
|
||||
|
||||
class MCPProviderEntity(BaseModel):
|
||||
"""MCP Provider domain entity for business logic operations"""
|
||||
|
||||
# Basic identification
|
||||
id: str
|
||||
provider_id: str # server_identifier
|
||||
name: str
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
|
||||
# Server connection info
|
||||
server_url: str # encrypted URL
|
||||
headers: dict[str, str] # encrypted headers
|
||||
timeout: float
|
||||
sse_read_timeout: float
|
||||
|
||||
# Authentication related
|
||||
authed: bool
|
||||
credentials: dict[str, Any] # encrypted credentials
|
||||
code_verifier: str | None = None # for OAuth
|
||||
|
||||
# Tools and display info
|
||||
tools: list[dict[str, Any]] # parsed tools list
|
||||
icon: str | dict[str, str] # parsed icon
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
|
||||
"""Create entity from database model with decryption"""
|
||||
|
||||
return cls(
|
||||
id=db_provider.id,
|
||||
provider_id=db_provider.server_identifier,
|
||||
name=db_provider.name,
|
||||
tenant_id=db_provider.tenant_id,
|
||||
user_id=db_provider.user_id,
|
||||
server_url=db_provider.server_url,
|
||||
headers=db_provider.headers,
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
authed=db_provider.authed,
|
||||
credentials=db_provider.credentials,
|
||||
tools=db_provider.tool_dict,
|
||||
icon=db_provider.icon or "",
|
||||
created_at=db_provider.created_at,
|
||||
updated_at=db_provider.updated_at,
|
||||
)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""OAuth redirect URL"""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
# Get grant type from credentials
|
||||
credentials = self.decrypt_credentials()
|
||||
|
||||
# Try to get grant_type from different locations
|
||||
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
|
||||
|
||||
# For nested structure, check if client_information has grant_types
|
||||
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
|
||||
client_info = credentials["client_information"]
|
||||
# If grant_types is specified in client_information, use it to determine grant_type
|
||||
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
|
||||
if "client_credentials" in client_info["grant_types"]:
|
||||
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
|
||||
elif "authorization_code" in client_info["grant_types"]:
|
||||
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
|
||||
|
||||
# Configure based on grant type
|
||||
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
|
||||
|
||||
grant_types = ["refresh_token"]
|
||||
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
|
||||
|
||||
response_types = [] if is_client_credentials else ["code"]
|
||||
redirect_uris = [] if is_client_credentials else [self.redirect_url]
|
||||
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=redirect_uris,
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=grant_types,
|
||||
response_types=response_types,
|
||||
client_name=CLIENT_NAME,
|
||||
client_uri=CLIENT_URI,
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_icon(self) -> dict[str, str] | str:
|
||||
"""Get provider icon, handling both dict and string formats"""
|
||||
if isinstance(self.icon, dict):
|
||||
return self.icon
|
||||
try:
|
||||
return json.loads(self.icon)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If not JSON, assume it's a file path
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
|
||||
def to_api_response(self, user_name: str | None = None, include_sensitive: bool = True) -> dict[str, Any]:
|
||||
"""Convert to API response format
|
||||
|
||||
Args:
|
||||
user_name: User name to display
|
||||
include_sensitive: If False, skip expensive decryption operations (for list view optimization)
|
||||
"""
|
||||
response = {
|
||||
"id": self.id,
|
||||
"author": user_name or "Anonymous",
|
||||
"name": self.name,
|
||||
"icon": self.provider_icon,
|
||||
"type": ToolProviderType.MCP.value,
|
||||
"is_team_authorization": self.authed,
|
||||
"server_url": self.masked_server_url(),
|
||||
"server_identifier": self.provider_id,
|
||||
"updated_at": int(self.updated_at.timestamp()),
|
||||
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
|
||||
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||
}
|
||||
|
||||
# Add configuration
|
||||
response["configuration"] = {
|
||||
"timeout": str(self.timeout),
|
||||
"sse_read_timeout": str(self.sse_read_timeout),
|
||||
}
|
||||
|
||||
# Skip expensive operations when sensitive data is not needed (e.g., list view)
|
||||
if not include_sensitive:
|
||||
response["masked_headers"] = {}
|
||||
response["is_dynamic_registration"] = True
|
||||
else:
|
||||
# Add masked headers
|
||||
response["masked_headers"] = self.masked_headers()
|
||||
|
||||
# Add authentication info if available
|
||||
masked_creds = self.masked_credentials()
|
||||
if masked_creds:
|
||||
response["authentication"] = masked_creds
|
||||
response["is_dynamic_registration"] = self.credentials.get("client_information", {}).get(
|
||||
"is_dynamic_registration", True
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def retrieve_client_information(self) -> OAuthClientInformation | None:
|
||||
"""OAuth client information if available"""
|
||||
credentials = self.decrypt_credentials()
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# Check if we have nested client_information structure
|
||||
if "client_information" not in credentials:
|
||||
return None
|
||||
client_info_data = credentials["client_information"]
|
||||
if isinstance(client_info_data, dict):
|
||||
if "encrypted_client_secret" in client_info_data:
|
||||
client_info_data["client_secret"] = encrypter.decrypt_token(
|
||||
self.tenant_id, client_info_data["encrypted_client_secret"]
|
||||
)
|
||||
return OAuthClientInformation.model_validate(client_info_data)
|
||||
return None
|
||||
|
||||
def retrieve_tokens(self) -> OAuthTokens | None:
|
||||
"""OAuth tokens if available"""
|
||||
if not self.credentials:
|
||||
return None
|
||||
credentials = self.decrypt_credentials()
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
|
||||
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def masked_server_url(self) -> str:
|
||||
"""Masked server URL for display"""
|
||||
parsed = urlparse(self.decrypt_server_url())
|
||||
if parsed.path and parsed.path != "/":
|
||||
masked = parsed._replace(path="/******")
|
||||
return masked.geturl()
|
||||
return parsed.geturl()
|
||||
|
||||
def _mask_value(self, value: str) -> str:
|
||||
"""Mask a sensitive value for display"""
|
||||
if len(value) > MIN_UNMASK_LENGTH:
|
||||
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
|
||||
else:
|
||||
return MASK_CHAR * len(value)
|
||||
|
||||
def masked_headers(self) -> dict[str, str]:
|
||||
"""Masked headers for display"""
|
||||
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
|
||||
|
||||
def masked_credentials(self) -> dict[str, str]:
|
||||
"""Masked credentials for display"""
|
||||
credentials = self.decrypt_credentials()
|
||||
if not credentials:
|
||||
return {}
|
||||
|
||||
masked = {}
|
||||
|
||||
if "client_information" not in credentials or not isinstance(credentials["client_information"], dict):
|
||||
return {}
|
||||
client_info = credentials["client_information"]
|
||||
# Mask sensitive fields from nested structure
|
||||
if client_info.get("client_id"):
|
||||
masked["client_id"] = self._mask_value(client_info["client_id"])
|
||||
if client_info.get("encrypted_client_secret"):
|
||||
masked["client_secret"] = self._mask_value(
|
||||
encrypter.decrypt_token(self.tenant_id, client_info["encrypted_client_secret"])
|
||||
)
|
||||
if client_info.get("client_secret"):
|
||||
masked["client_secret"] = self._mask_value(client_info["client_secret"])
|
||||
return masked
|
||||
|
||||
def decrypt_server_url(self) -> str:
|
||||
"""Decrypt server URL"""
|
||||
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
||||
|
||||
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Generic method to decrypt dictionary fields"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
# Only decrypt fields that are actually encrypted
|
||||
# For nested structures, client_information is not encrypted as a whole
|
||||
encrypted_fields = []
|
||||
for key, value in data.items():
|
||||
# Skip nested objects - they are not encrypted
|
||||
if isinstance(value, dict):
|
||||
continue
|
||||
# Only process string values that might be encrypted
|
||||
if isinstance(value, str) and value:
|
||||
encrypted_fields.append(key)
|
||||
|
||||
if not encrypted_fields:
|
||||
return data
|
||||
|
||||
# Create dynamic config only for encrypted fields
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Decrypt only the encrypted fields
|
||||
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
|
||||
|
||||
# Merge decrypted data with original data (preserving non-encrypted fields)
|
||||
result = data.copy()
|
||||
result.update(decrypted_data)
|
||||
|
||||
return result
|
||||
|
||||
def decrypt_headers(self) -> dict[str, Any]:
|
||||
"""Decrypt headers"""
|
||||
return self._decrypt_dict(self.headers)
|
||||
|
||||
def decrypt_credentials(self) -> dict[str, Any]:
|
||||
"""Decrypt credentials"""
|
||||
return self._decrypt_dict(self.credentials)
|
||||
|
||||
def decrypt_authentication(self) -> dict[str, Any]:
|
||||
"""Decrypt authentication"""
|
||||
# Option 1: if headers is provided, use it and don't need to get token
|
||||
headers = self.decrypt_headers()
|
||||
|
||||
# Option 2: Add OAuth token if authed and no headers provided
|
||||
if not self.headers and self.authed:
|
||||
token = self.retrieve_tokens()
|
||||
if token:
|
||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
return headers
|
||||
@ -6,11 +6,15 @@ import secrets
|
||||
import urllib.parse
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
|
||||
from core.mcp.error import MCPRefreshTokenError
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
@ -19,21 +23,10 @@ from core.mcp.types import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_pkce_challenge() -> tuple[str, str]:
|
||||
"""Generate PKCE challenge and verifier."""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||
@ -80,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
|
||||
"""
|
||||
Handle the callback from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
A tuple of (callback_state, tokens) that can be used by the caller to save data.
|
||||
"""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
@ -93,30 +91,32 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
||||
full_state_data.code_verifier,
|
||||
full_state_data.redirect_uri,
|
||||
)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
||||
provider.save_tokens(tokens)
|
||||
return full_state_data
|
||||
|
||||
return full_state_data, tokens
|
||||
|
||||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
||||
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
|
||||
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
|
||||
if b_query:
|
||||
url_for_resource_discovery += f"?{b_query}"
|
||||
if b_fragment:
|
||||
url_for_resource_discovery += f"#{b_fragment}"
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
response = httpx.get(url_for_resource_discovery, headers=headers)
|
||||
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
|
||||
if 200 <= response.status_code < 300:
|
||||
body = response.json()
|
||||
if "authorization_server_url" in body:
|
||||
# Support both singular and plural forms
|
||||
if body.get("authorization_servers"):
|
||||
return True, body["authorization_servers"][0]
|
||||
elif body.get("authorization_server_url"):
|
||||
return True, body["authorization_server_url"][0]
|
||||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except httpx.RequestError:
|
||||
except RequestError:
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
@ -126,27 +126,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
|
||||
# First check if the server supports OAuth 2.0 Resource Discovery
|
||||
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
||||
if support_resource_discovery:
|
||||
url = oauth_discovery_url
|
||||
# The oauth_discovery_url is the authorization server base URL
|
||||
# Try OpenID Connect discovery first (more common), then OAuth 2.0
|
||||
urls_to_try = [
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
|
||||
]
|
||||
else:
|
||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
|
||||
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
response = httpx.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except httpx.RequestError as e:
|
||||
if isinstance(e, httpx.ConnectError):
|
||||
response = httpx.get(url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
|
||||
for url in urls_to_try:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
continue
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
response.raise_for_status()
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
raise
|
||||
except (RequestError, HTTPStatusError) as e:
|
||||
if isinstance(e, ConnectError):
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 404:
|
||||
continue # Try next URL
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
# For other errors, try next URL
|
||||
continue
|
||||
|
||||
return None # No metadata found
|
||||
|
||||
|
||||
def start_authorization(
|
||||
@ -213,7 +223,7 @@ def exchange_authorization(
|
||||
redirect_uri: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchanges an authorization code for an access token."""
|
||||
grant_type = "authorization_code"
|
||||
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
@ -233,7 +243,7 @@ def exchange_authorization(
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = httpx.post(token_url, data=params)
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
@ -246,7 +256,7 @@ def refresh_authorization(
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange a refresh token for an updated access token."""
|
||||
grant_type = "refresh_token"
|
||||
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
@ -263,10 +273,55 @@ def refresh_authorization(
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = httpx.post(token_url, data=params)
|
||||
try:
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||
raise MCPRefreshTokenError(e) from e
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
||||
raise MCPRefreshTokenError(response.text)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
def client_credentials_flow(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
client_information: OAuthClientInformation,
|
||||
scope: str | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Execute Client Credentials Flow to get access token."""
|
||||
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
# Support both Basic Auth and body parameters for client authentication
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
data = {"grant_type": grant_type}
|
||||
|
||||
if scope:
|
||||
data["scope"] = scope
|
||||
|
||||
# If client_secret is provided, use Basic Auth (preferred method)
|
||||
if client_information.client_secret:
|
||||
credentials = f"{client_information.client_id}:{client_information.client_secret}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
else:
|
||||
# Fall back to including credentials in the body
|
||||
data["client_id"] = client_information.client_id
|
||||
if client_information.client_secret:
|
||||
data["client_secret"] = client_information.client_secret
|
||||
|
||||
response = ssrf_proxy.post(token_url, headers=headers, data=data)
|
||||
if not response.is_success:
|
||||
raise ValueError(
|
||||
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
@ -283,7 +338,7 @@ def register_client(
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
|
||||
response = httpx.post(
|
||||
response = ssrf_proxy.post(
|
||||
registration_url,
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
@ -294,28 +349,111 @@ def register_client(
|
||||
|
||||
|
||||
def auth(
|
||||
provider: OAuthClientProvider,
|
||||
server_url: str,
|
||||
provider: MCPProviderEntity,
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
for_list: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
metadata = discover_oauth_metadata(server_url)
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||
|
||||
This function performs only network operations and returns actions that need
|
||||
to be performed by the caller (such as saving data to database).
|
||||
|
||||
Args:
|
||||
provider: The MCP provider entity
|
||||
authorization_code: Optional authorization code from OAuth callback
|
||||
state_param: Optional state parameter from OAuth callback
|
||||
|
||||
Returns:
|
||||
AuthResult containing actions to be performed and response data
|
||||
"""
|
||||
actions: list[AuthAction] = []
|
||||
server_url = provider.decrypt_server_url()
|
||||
server_metadata = discover_oauth_metadata(server_url)
|
||||
client_metadata = provider.client_metadata
|
||||
provider_id = provider.id
|
||||
tenant_id = provider.tenant_id
|
||||
client_information = provider.retrieve_client_information()
|
||||
redirect_url = provider.redirect_url
|
||||
|
||||
# Determine grant type based on server metadata
|
||||
if not server_metadata:
|
||||
raise ValueError("Failed to discover OAuth metadata from server")
|
||||
|
||||
supported_grant_types = server_metadata.grant_types_supported or []
|
||||
|
||||
# Convert to lowercase for comparison
|
||||
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
|
||||
|
||||
# Determine which grant type to use
|
||||
effective_grant_type = None
|
||||
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
|
||||
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||
else:
|
||||
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
# Get stored credentials
|
||||
credentials = provider.decrypt_credentials()
|
||||
|
||||
# Handle client registration if needed
|
||||
client_information = provider.client_information()
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
|
||||
# For client credentials flow, we don't need to register client dynamically
|
||||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Client should provide client_id and client_secret directly
|
||||
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
|
||||
|
||||
try:
|
||||
full_information = register_client(server_url, metadata, provider.client_metadata)
|
||||
except httpx.RequestError as e:
|
||||
full_information = register_client(server_url, server_metadata, client_metadata)
|
||||
except RequestError as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
provider.save_client_information(full_information)
|
||||
|
||||
# Return action to save client information
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CLIENT_INFO,
|
||||
data={"client_information": full_information.model_dump()},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
client_information = full_information
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
# Handle client credentials flow
|
||||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Direct token request without user interaction
|
||||
try:
|
||||
scope = credentials.get("scope")
|
||||
tokens = client_credentials_flow(
|
||||
server_url,
|
||||
server_metadata,
|
||||
client_information,
|
||||
scope,
|
||||
)
|
||||
|
||||
# Return action to save tokens and grant type
|
||||
token_data = tokens.model_dump()
|
||||
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=token_data,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
# KeyError: Missing required fields in response
|
||||
raise ValueError(f"Client credentials flow failed: {e}")
|
||||
|
||||
# Exchange authorization code for tokens (Authorization Code flow)
|
||||
if authorization_code is not None:
|
||||
if not state_param:
|
||||
raise ValueError("State parameter is required when exchanging authorization code")
|
||||
@ -335,35 +473,69 @@ def auth(
|
||||
|
||||
tokens = exchange_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
server_metadata,
|
||||
client_information,
|
||||
authorization_code,
|
||||
code_verifier,
|
||||
redirect_uri,
|
||||
)
|
||||
provider.save_tokens(tokens)
|
||||
return {"result": "success"}
|
||||
|
||||
provider_tokens = provider.tokens()
|
||||
# Return action to save tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
|
||||
provider_tokens = provider.retrieve_tokens()
|
||||
|
||||
# Handle token refresh or new authorization
|
||||
if provider_tokens and provider_tokens.refresh_token:
|
||||
try:
|
||||
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
|
||||
provider.save_tokens(new_tokens)
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
new_tokens = refresh_authorization(
|
||||
server_url, server_metadata, client_information, provider_tokens.refresh_token
|
||||
)
|
||||
|
||||
# Return action to save new tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=new_tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
# KeyError: Missing required fields in response
|
||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||
|
||||
# Start new authorization flow
|
||||
# Start new authorization flow (only for authorization code flow)
|
||||
authorization_url, code_verifier = start_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
server_metadata,
|
||||
client_information,
|
||||
provider.redirect_url,
|
||||
provider.mcp_provider.id,
|
||||
provider.mcp_provider.tenant_id,
|
||||
redirect_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
provider.save_code_verifier(code_verifier)
|
||||
return {"authorization_url": authorization_url}
|
||||
# Return action to save code verifier
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CODE_VERIFIER,
|
||||
data={"code_verifier": code_verifier},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"authorization_url": authorization_url})
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
from configs import dify_config
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
mcp_provider: MCPToolProvider
|
||||
|
||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
||||
if for_list:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
else:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""The URL to redirect the user agent to after authorization."""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=[self.redirect_url],
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
client_name="Dify",
|
||||
client_uri="https://github.com/langgenius/dify",
|
||||
)
|
||||
|
||||
def client_information(self) -> OAuthClientInformation | None:
|
||||
"""Loads information about this OAuth client."""
|
||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
||||
if not client_information:
|
||||
return None
|
||||
return OAuthClientInformation.model_validate(client_information)
|
||||
|
||||
def save_client_information(self, client_information: OAuthClientInformationFull):
|
||||
"""Saves client information after dynamic registration."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.mcp_provider,
|
||||
{"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def tokens(self) -> OAuthTokens | None:
|
||||
"""Loads any existing OAuth tokens for the current session."""
|
||||
credentials = self.mcp_provider.decrypted_credentials
|
||||
if not credentials:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", "Bearer"),
|
||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def save_tokens(self, tokens: OAuthTokens):
|
||||
"""Stores new OAuth tokens for the current session."""
|
||||
# update mcp provider credentials
|
||||
token_dict = tokens.model_dump()
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
||||
|
||||
def save_code_verifier(self, code_verifier: str):
|
||||
"""Saves a PKCE code verifier for the current session."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
||||
|
||||
def code_verifier(self) -> str:
|
||||
"""Loads the PKCE code verifier for the current session."""
|
||||
# get code verifier from mcp provider credentials
|
||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
||||
191
api/core/mcp/auth_client.py
Normal file
191
api/core/mcp/auth_client.py
Normal file
@ -0,0 +1,191 @@
|
||||
"""
|
||||
MCP Client with Authentication Retry Support
|
||||
|
||||
This module provides an enhanced MCPClient that automatically handles
|
||||
authentication failures and retries operations after refreshing tokens.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.error import MCPAuthError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClientWithAuthRetry(MCPClient):
|
||||
"""
|
||||
An enhanced MCPClient that provides automatic authentication retry.
|
||||
|
||||
This class extends MCPClient and intercepts MCPAuthError exceptions
|
||||
to refresh authentication before retrying failed operations.
|
||||
|
||||
Note: This class uses lazy session creation - database sessions are only
|
||||
created when authentication retry is actually needed, not on every request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
provider_entity: MCPProviderEntity | None = None,
|
||||
authorization_code: str | None = None,
|
||||
by_server_id: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the MCP client with auth retry capability.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL
|
||||
headers: Optional headers for requests
|
||||
timeout: Request timeout
|
||||
sse_read_timeout: SSE read timeout
|
||||
provider_entity: Provider entity for authentication
|
||||
authorization_code: Optional authorization code for initial auth
|
||||
by_server_id: Whether to look up provider by server ID
|
||||
"""
|
||||
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
||||
|
||||
self.provider_entity = provider_entity
|
||||
self.authorization_code = authorization_code
|
||||
self.by_server_id = by_server_id
|
||||
self._has_retried = False
|
||||
|
||||
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||
"""
|
||||
Handle authentication error by refreshing tokens.
|
||||
|
||||
This method creates a short-lived database session only when authentication
|
||||
retry is needed, minimizing database connection hold time.
|
||||
|
||||
Args:
|
||||
error: The authentication error
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails or max retries reached
|
||||
"""
|
||||
if not self.provider_entity:
|
||||
raise error
|
||||
if self._has_retried:
|
||||
raise error
|
||||
|
||||
self._has_retried = True
|
||||
|
||||
try:
|
||||
# Create a temporary session only for auth retry
|
||||
# This session is short-lived and only exists during the auth operation
|
||||
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
# Perform authentication using the service's auth method
|
||||
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = mcp_service.get_provider_entity(
|
||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||
)
|
||||
|
||||
# Session is closed here, before we update headers
|
||||
token = self.provider_entity.retrieve_tokens()
|
||||
if not token:
|
||||
raise MCPAuthError("Authentication failed - no token received")
|
||||
|
||||
# Update headers with new token
|
||||
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
|
||||
# Clear authorization code after first use
|
||||
self.authorization_code = None
|
||||
|
||||
except MCPAuthError:
|
||||
# Re-raise MCPAuthError as is
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch all exceptions during auth retry
|
||||
logger.exception("Authentication retry failed")
|
||||
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||
|
||||
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute a function with authentication retry logic.
|
||||
|
||||
Args:
|
||||
func: The function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The result of the function call
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
Any other exceptions from the function
|
||||
"""
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except MCPAuthError as e:
|
||||
self._handle_auth_error(e)
|
||||
|
||||
# Re-initialize the connection with new headers
|
||||
if self._initialized:
|
||||
# Clean up existing connection
|
||||
self._exit_stack.close()
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
|
||||
# Re-initialize with new headers
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
# Reset retry flag after operation completes
|
||||
self._has_retried = False
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager with retry support."""
|
||||
|
||||
def initialize_with_retry():
|
||||
super(MCPClientWithAuthRetry, self).__enter__()
|
||||
return self
|
||||
|
||||
return self._execute_with_retry(initialize_with_retry)
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""
|
||||
List available tools from the MCP server with auth retry.
|
||||
|
||||
Returns:
|
||||
List of available tools
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
"""
|
||||
return self._execute_with_retry(super().list_tools)
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""
|
||||
Invoke a tool on the MCP server with auth retry.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to invoke
|
||||
tool_args: Arguments for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool invocation
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
"""
|
||||
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)
|
||||
0
api/core/mcp/auth_client_comparison.md
Normal file
0
api/core/mcp/auth_client_comparison.md
Normal file
@ -46,7 +46,7 @@ class SSETransport:
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
sse_read_timeout: float = 1 * 60,
|
||||
):
|
||||
"""Initialize the SSE transport.
|
||||
|
||||
@ -255,7 +255,7 @@ def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
sse_read_timeout: float = 1 * 60,
|
||||
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
||||
"""
|
||||
Client transport for SSE.
|
||||
@ -276,31 +276,34 @@ def sse_client(
|
||||
read_queue: ReadQueue | None = None
|
||||
write_queue: WriteQueue | None = None
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
executor = ThreadPoolExecutor()
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
|
||||
yield read_queue, write_queue
|
||||
yield read_queue, write_queue
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
|
||||
# Shutdown executor without waiting to prevent hanging
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
|
||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
|
||||
|
||||
@ -434,45 +434,48 @@ def streamablehttp_client(
|
||||
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
||||
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream():
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
executor = ThreadPoolExecutor(max_workers=2)
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream():
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
|
||||
# Shutdown executor without waiting to prevent hanging
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
||||
from pydantic import BaseModel
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
|
||||
|
||||
class AuthActionType(StrEnum):
|
||||
"""Types of actions that can be performed during auth flow."""
|
||||
|
||||
SAVE_CLIENT_INFO = "save_client_info"
|
||||
SAVE_TOKENS = "save_tokens"
|
||||
SAVE_CODE_VERIFIER = "save_code_verifier"
|
||||
START_AUTHORIZATION = "start_authorization"
|
||||
SUCCESS = "success"
|
||||
|
||||
|
||||
class AuthAction(BaseModel):
|
||||
"""Represents an action that needs to be performed as a result of auth flow."""
|
||||
|
||||
action_type: AuthActionType
|
||||
data: dict[str, Any]
|
||||
provider_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class AuthResult(BaseModel):
|
||||
"""Result of auth function containing actions to be performed and response data."""
|
||||
|
||||
actions: list[AuthAction]
|
||||
response: dict[str, str]
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
"""State data stored in Redis during OAuth callback flow."""
|
||||
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
@ -8,3 +8,7 @@ class MCPConnectionError(MCPError):
|
||||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class MCPRefreshTokenError(MCPError):
|
||||
pass
|
||||
|
||||
@ -7,9 +7,9 @@ from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.session.client_session import ClientSession
|
||||
from core.mcp.types import Tool
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -18,40 +18,18 @@ class MCPClient:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
authed: bool = True,
|
||||
authorization_code: str | None = None,
|
||||
for_list: bool = False,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.client_type = "streamable"
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
# Authentication info
|
||||
self.authed = authed
|
||||
self.authorization_code = authorization_code
|
||||
if authed:
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
|
||||
self.token = self.provider.tokens()
|
||||
|
||||
# Initialize session and client objects
|
||||
self._session: ClientSession | None = None
|
||||
self._streams_context: AbstractContextManager[Any] | None = None
|
||||
self._session_context: ClientSession | None = None
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
# Whether the client has been initialized
|
||||
self._initialized = False
|
||||
|
||||
def __enter__(self):
|
||||
@ -85,61 +63,42 @@ class MCPClient:
|
||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
def connect_server(
|
||||
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
|
||||
):
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
|
||||
"""
|
||||
Connect to the MCP server using streamable http or sse.
|
||||
Default to streamable http.
|
||||
Args:
|
||||
client_factory: The client factory to use(streamablehttp_client or sse_client).
|
||||
method_name: The method name to use(mcp or sse).
|
||||
"""
|
||||
streams_context = client_factory(
|
||||
url=self.server_url,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
headers = (
|
||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||
if self.authed and self.token
|
||||
else self.headers
|
||||
)
|
||||
self._streams_context = client_factory(
|
||||
url=self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
if not self._streams_context:
|
||||
raise MCPConnectionError("Failed to create connection context")
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._exit_stack.enter_context(streams_context)
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._exit_stack.enter_context(self._streams_context)
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(self._session_context)
|
||||
self._session.initialize()
|
||||
return
|
||||
|
||||
except MCPAuthError:
|
||||
if not self.authed:
|
||||
raise
|
||||
try:
|
||||
auth(self.provider, self.server_url, self.authorization_code)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to authenticate: {e}")
|
||||
self.token = self.provider.tokens()
|
||||
if first_try:
|
||||
return self.connect_server(client_factory, method_name, first_try=False)
|
||||
session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(session_context)
|
||||
self._session.initialize()
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
# List available tools to verify connection
|
||||
if not self._initialized or not self._session:
|
||||
"""List available tools from the MCP server"""
|
||||
if not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
response = self._session.list_tools()
|
||||
tools = response.tools
|
||||
return tools
|
||||
return response.tools
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict):
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""Call a tool"""
|
||||
if not self._initialized or not self._session:
|
||||
if not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
return self._session.call_tool(tool_name, tool_args)
|
||||
|
||||
@ -153,6 +112,4 @@ class MCPClient:
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
finally:
|
||||
self._session = None
|
||||
self._session_context = None
|
||||
self._streams_context = None
|
||||
self._initialized = False
|
||||
|
||||
@ -201,11 +201,14 @@ class BaseSession(
|
||||
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
|
||||
except TimeoutError:
|
||||
# If the receiver loop is still running after timeout, we'll force shutdown
|
||||
pass
|
||||
# Cancel the future to interrupt the receiver loop
|
||||
self._receiver_future.cancel()
|
||||
|
||||
# Shutdown the executor
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=True)
|
||||
# Use non-blocking shutdown to prevent hanging
|
||||
# The receiver thread should have already exited due to the None message in the queue
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
|
||||
@ -284,7 +284,7 @@ class ClientSession(
|
||||
|
||||
def complete(
|
||||
self,
|
||||
ref: types.ResourceReference | types.PromptReference,
|
||||
ref: types.ResourceTemplateReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
|
||||
@ -1,13 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
||||
from pydantic.networks import AnyUrl, UrlConstraints
|
||||
@ -33,6 +26,7 @@ for reference.
|
||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||
# Server support 2024-11-05 to allow claude to use.
|
||||
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
|
||||
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
|
||||
ProgressToken = str | int
|
||||
Cursor = str
|
||||
Role = Literal["user", "assistant"]
|
||||
@ -55,14 +49,22 @@ class RequestParams(BaseModel):
|
||||
meta: Meta | None = Field(alias="_meta", default=None)
|
||||
|
||||
|
||||
class PaginatedRequestParams(RequestParams):
|
||||
cursor: Cursor | None = None
|
||||
"""
|
||||
An opaque token representing the current pagination position.
|
||||
If provided, the server should return results starting after this cursor.
|
||||
"""
|
||||
|
||||
|
||||
class NotificationParams(BaseModel):
|
||||
class Meta(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
meta: Meta | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
This parameter name is reserved by MCP to allow clients and servers to attach
|
||||
additional metadata to their notifications.
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
|
||||
|
||||
@ -79,12 +81,11 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedRequest(Request[RequestParamsT, MethodT]):
|
||||
cursor: Cursor | None = None
|
||||
"""
|
||||
An opaque token representing the current pagination position.
|
||||
If provided, the server should return results starting after this cursor.
|
||||
"""
|
||||
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
|
||||
"""Base class for paginated requests,
|
||||
matching the schema's PaginatedRequest interface."""
|
||||
|
||||
params: PaginatedRequestParams | None = None
|
||||
|
||||
|
||||
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
||||
@ -98,13 +99,12 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
||||
class Result(BaseModel):
|
||||
"""Base class for JSON-RPC results."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
This result property is reserved by the protocol to allow clients and servers to
|
||||
attach additional metadata to their responses.
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedResult(Result):
|
||||
@ -186,10 +186,26 @@ class EmptyResult(Result):
|
||||
"""A response that indicates success but carries no data."""
|
||||
|
||||
|
||||
class Implementation(BaseModel):
|
||||
"""Describes the name and version of an MCP implementation."""
|
||||
class BaseMetadata(BaseModel):
|
||||
"""Base class for entities with name and optional title fields."""
|
||||
|
||||
name: str
|
||||
"""The programmatic name of the entity."""
|
||||
|
||||
title: str | None = None
|
||||
"""
|
||||
Intended for UI and end-user contexts — optimized to be human-readable and easily understood,
|
||||
even by those unfamiliar with domain-specific terminology.
|
||||
|
||||
If not provided, the name should be used for display (except for Tool,
|
||||
where `annotations.title` should be given precedence over using `name`,
|
||||
if present).
|
||||
"""
|
||||
|
||||
|
||||
class Implementation(BaseMetadata):
|
||||
"""Describes the name and version of an MCP implementation."""
|
||||
|
||||
version: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@ -203,7 +219,7 @@ class RootsCapability(BaseModel):
|
||||
|
||||
|
||||
class SamplingCapability(BaseModel):
|
||||
"""Capability for logging operations."""
|
||||
"""Capability for sampling operations."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@ -252,6 +268,12 @@ class LoggingCapability(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompletionsCapability(BaseModel):
|
||||
"""Capability for completions operations."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ServerCapabilities(BaseModel):
|
||||
"""Capabilities that a server may support."""
|
||||
|
||||
@ -265,6 +287,8 @@ class ServerCapabilities(BaseModel):
|
||||
"""Present if the server offers any resources to read."""
|
||||
tools: ToolsCapability | None = None
|
||||
"""Present if the server offers any tools to call."""
|
||||
completions: CompletionsCapability | None = None
|
||||
"""Present if the server offers autocompletion suggestions for prompts and resources."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -284,7 +308,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]])
|
||||
to begin initialization.
|
||||
"""
|
||||
|
||||
method: Literal["initialize"]
|
||||
method: Literal["initialize"] = "initialize"
|
||||
params: InitializeRequestParams
|
||||
|
||||
|
||||
@ -305,7 +329,7 @@ class InitializedNotification(Notification[NotificationParams | None, Literal["n
|
||||
finished.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/initialized"]
|
||||
method: Literal["notifications/initialized"] = "notifications/initialized"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@ -315,7 +339,7 @@ class PingRequest(Request[RequestParams | None, Literal["ping"]]):
|
||||
still alive.
|
||||
"""
|
||||
|
||||
method: Literal["ping"]
|
||||
method: Literal["ping"] = "ping"
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
@ -334,6 +358,11 @@ class ProgressNotificationParams(NotificationParams):
|
||||
"""
|
||||
total: float | None = None
|
||||
"""Total number of items to process (or total progress required), if known."""
|
||||
message: str | None = None
|
||||
"""
|
||||
Message related to progress. This should provide relevant human readable
|
||||
progress information.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -343,15 +372,14 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not
|
||||
long-running request.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/progress"]
|
||||
method: Literal["notifications/progress"] = "notifications/progress"
|
||||
params: ProgressNotificationParams
|
||||
|
||||
|
||||
class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]):
|
||||
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
|
||||
"""Sent from the client to request a list of resources the server has."""
|
||||
|
||||
method: Literal["resources/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["resources/list"] = "resources/list"
|
||||
|
||||
|
||||
class Annotations(BaseModel):
|
||||
@ -360,13 +388,11 @@ class Annotations(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
class Resource(BaseMetadata):
|
||||
"""A known resource that the server is capable of reading."""
|
||||
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
||||
"""The URI of this resource."""
|
||||
name: str
|
||||
"""A human-readable name for this resource."""
|
||||
description: str | None = None
|
||||
"""A description of what this resource represents."""
|
||||
mimeType: str | None = None
|
||||
@ -379,10 +405,15 @@ class Resource(BaseModel):
|
||||
This can be used by Hosts to display file sizes and estimate context window usage.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ResourceTemplate(BaseModel):
|
||||
class ResourceTemplate(BaseMetadata):
|
||||
"""A template description for resources available on the server."""
|
||||
|
||||
uriTemplate: str
|
||||
@ -390,8 +421,6 @@ class ResourceTemplate(BaseModel):
|
||||
A URI template (according to RFC 6570) that can be used to construct resource
|
||||
URIs.
|
||||
"""
|
||||
name: str
|
||||
"""A human-readable name for the type of resource this template refers to."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of what this template is for."""
|
||||
mimeType: str | None = None
|
||||
@ -400,6 +429,11 @@ class ResourceTemplate(BaseModel):
|
||||
included if all resources matching this template have the same type.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -409,11 +443,10 @@ class ListResourcesResult(PaginatedResult):
|
||||
resources: list[Resource]
|
||||
|
||||
|
||||
class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]):
|
||||
class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]):
|
||||
"""Sent from the client to request a list of resource templates the server has."""
|
||||
|
||||
method: Literal["resources/templates/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["resources/templates/list"] = "resources/templates/list"
|
||||
|
||||
|
||||
class ListResourceTemplatesResult(PaginatedResult):
|
||||
@ -436,7 +469,7 @@ class ReadResourceRequestParams(RequestParams):
|
||||
class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
|
||||
"""Sent from the client to the server, to read a specific resource URI."""
|
||||
|
||||
method: Literal["resources/read"]
|
||||
method: Literal["resources/read"] = "resources/read"
|
||||
params: ReadResourceRequestParams
|
||||
|
||||
|
||||
@ -447,6 +480,11 @@ class ResourceContents(BaseModel):
|
||||
"""The URI of this resource."""
|
||||
mimeType: str | None = None
|
||||
"""The MIME type of this resource, if known."""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -481,7 +519,7 @@ class ResourceListChangedNotification(
|
||||
of resources it can read from has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/resources/list_changed"]
|
||||
method: Literal["notifications/resources/list_changed"] = "notifications/resources/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@ -502,7 +540,7 @@ class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscr
|
||||
whenever a particular resource changes.
|
||||
"""
|
||||
|
||||
method: Literal["resources/subscribe"]
|
||||
method: Literal["resources/subscribe"] = "resources/subscribe"
|
||||
params: SubscribeRequestParams
|
||||
|
||||
|
||||
@ -520,7 +558,7 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un
|
||||
the server.
|
||||
"""
|
||||
|
||||
method: Literal["resources/unsubscribe"]
|
||||
method: Literal["resources/unsubscribe"] = "resources/unsubscribe"
|
||||
params: UnsubscribeRequestParams
|
||||
|
||||
|
||||
@ -543,15 +581,14 @@ class ResourceUpdatedNotification(
|
||||
changed and may need to be read again.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/resources/updated"]
|
||||
method: Literal["notifications/resources/updated"] = "notifications/resources/updated"
|
||||
params: ResourceUpdatedNotificationParams
|
||||
|
||||
|
||||
class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]):
|
||||
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
|
||||
"""Sent from the client to request a list of prompts and prompt templates."""
|
||||
|
||||
method: Literal["prompts/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["prompts/list"] = "prompts/list"
|
||||
|
||||
|
||||
class PromptArgument(BaseModel):
|
||||
@ -566,15 +603,18 @@ class PromptArgument(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
class Prompt(BaseMetadata):
|
||||
"""A prompt or prompt template that the server offers."""
|
||||
|
||||
name: str
|
||||
"""The name of the prompt or prompt template."""
|
||||
description: str | None = None
|
||||
"""An optional description of what this prompt provides."""
|
||||
arguments: list[PromptArgument] | None = None
|
||||
"""A list of arguments to use for templating the prompt."""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -597,7 +637,7 @@ class GetPromptRequestParams(RequestParams):
|
||||
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
|
||||
"""Used by the client to get a prompt provided by the server."""
|
||||
|
||||
method: Literal["prompts/get"]
|
||||
method: Literal["prompts/get"] = "prompts/get"
|
||||
params: GetPromptRequestParams
|
||||
|
||||
|
||||
@ -608,6 +648,11 @@ class TextContent(BaseModel):
|
||||
text: str
|
||||
"""The text content of the message."""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -623,6 +668,31 @@ class ImageContent(BaseModel):
|
||||
image types.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AudioContent(BaseModel):
|
||||
"""Audio content for a message."""
|
||||
|
||||
type: Literal["audio"]
|
||||
data: str
|
||||
"""The base64-encoded audio data."""
|
||||
mimeType: str
|
||||
"""
|
||||
The MIME type of the audio. Different providers may support different
|
||||
audio types.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -630,7 +700,7 @@ class SamplingMessage(BaseModel):
|
||||
"""Describes a message issued to or received from an LLM API."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent
|
||||
content: TextContent | ImageContent | AudioContent
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -645,14 +715,36 @@ class EmbeddedResource(BaseModel):
|
||||
type: Literal["resource"]
|
||||
resource: TextResourceContents | BlobResourceContents
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ResourceLink(Resource):
|
||||
"""
|
||||
A resource that the server is capable of reading, included in a prompt or tool call result.
|
||||
|
||||
Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
|
||||
"""
|
||||
|
||||
type: Literal["resource_link"]
|
||||
|
||||
|
||||
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
|
||||
"""A content block that can be used in prompts and tool results."""
|
||||
|
||||
Content: TypeAlias = ContentBlock
|
||||
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
"""Describes a message returned as part of a prompt."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent | EmbeddedResource
|
||||
content: ContentBlock
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -672,15 +764,14 @@ class PromptListChangedNotification(
|
||||
of prompts it offers has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/prompts/list_changed"]
|
||||
method: Literal["notifications/prompts/list_changed"] = "notifications/prompts/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
|
||||
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
|
||||
"""Sent from the client to request a list of tools the server has."""
|
||||
|
||||
method: Literal["tools/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["tools/list"] = "tools/list"
|
||||
|
||||
|
||||
class ToolAnnotations(BaseModel):
|
||||
@ -731,17 +822,25 @@ class ToolAnnotations(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
class Tool(BaseMetadata):
|
||||
"""Definition for a tool the client can call."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of the tool."""
|
||||
inputSchema: dict[str, Any]
|
||||
"""A JSON Schema object defining the expected parameters for the tool."""
|
||||
outputSchema: dict[str, Any] | None = None
|
||||
"""
|
||||
An optional JSON Schema object defining the structure of the tool's output
|
||||
returned in the structuredContent field of a CallToolResult.
|
||||
"""
|
||||
annotations: ToolAnnotations | None = None
|
||||
"""Optional additional tool information."""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -762,14 +861,16 @@ class CallToolRequestParams(RequestParams):
|
||||
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
|
||||
"""Used by the client to invoke a tool provided by the server."""
|
||||
|
||||
method: Literal["tools/call"]
|
||||
method: Literal["tools/call"] = "tools/call"
|
||||
params: CallToolRequestParams
|
||||
|
||||
|
||||
class CallToolResult(Result):
|
||||
"""The server's response to a tool call."""
|
||||
|
||||
content: list[TextContent | ImageContent | EmbeddedResource]
|
||||
content: list[ContentBlock]
|
||||
structuredContent: dict[str, Any] | None = None
|
||||
"""An optional JSON object that represents the structured result of the tool call."""
|
||||
isError: bool = False
|
||||
|
||||
|
||||
@ -779,7 +880,7 @@ class ToolListChangedNotification(Notification[NotificationParams | None, Litera
|
||||
of tools it offers has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/tools/list_changed"]
|
||||
method: Literal["notifications/tools/list_changed"] = "notifications/tools/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@ -797,7 +898,7 @@ class SetLevelRequestParams(RequestParams):
|
||||
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
|
||||
"""A request from the client to the server, to enable or adjust logging."""
|
||||
|
||||
method: Literal["logging/setLevel"]
|
||||
method: Literal["logging/setLevel"] = "logging/setLevel"
|
||||
params: SetLevelRequestParams
|
||||
|
||||
|
||||
@ -808,7 +909,7 @@ class LoggingMessageNotificationParams(NotificationParams):
|
||||
"""The severity of this log message."""
|
||||
logger: str | None = None
|
||||
"""An optional name of the logger issuing this message."""
|
||||
data: Any = None
|
||||
data: Any
|
||||
"""
|
||||
The data to be logged, such as a string message or an object. Any JSON serializable
|
||||
type is allowed here.
|
||||
@ -819,7 +920,7 @@ class LoggingMessageNotificationParams(NotificationParams):
|
||||
class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
|
||||
"""Notification of a log message passed from server to client."""
|
||||
|
||||
method: Literal["notifications/message"]
|
||||
method: Literal["notifications/message"] = "notifications/message"
|
||||
params: LoggingMessageNotificationParams
|
||||
|
||||
|
||||
@ -914,7 +1015,7 @@ class CreateMessageRequestParams(RequestParams):
|
||||
class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
|
||||
"""A request from the server to sample an LLM via the client."""
|
||||
|
||||
method: Literal["sampling/createMessage"]
|
||||
method: Literal["sampling/createMessage"] = "sampling/createMessage"
|
||||
params: CreateMessageRequestParams
|
||||
|
||||
|
||||
@ -925,14 +1026,14 @@ class CreateMessageResult(Result):
|
||||
"""The client's response to a sampling/create_message request from the server."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent
|
||||
content: TextContent | ImageContent | AudioContent
|
||||
model: str
|
||||
"""The name of the model that generated the message."""
|
||||
stopReason: StopReason | None = None
|
||||
"""The reason why sampling stopped, if known."""
|
||||
|
||||
|
||||
class ResourceReference(BaseModel):
|
||||
class ResourceTemplateReference(BaseModel):
|
||||
"""A reference to a resource or resource template definition."""
|
||||
|
||||
type: Literal["ref/resource"]
|
||||
@ -960,18 +1061,28 @@ class CompletionArgument(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompletionContext(BaseModel):
|
||||
"""Additional, optional context for completions."""
|
||||
|
||||
arguments: dict[str, str] | None = None
|
||||
"""Previously-resolved variables in a URI template or prompt."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompleteRequestParams(RequestParams):
|
||||
"""Parameters for completion requests."""
|
||||
|
||||
ref: ResourceReference | PromptReference
|
||||
ref: ResourceTemplateReference | PromptReference
|
||||
argument: CompletionArgument
|
||||
context: CompletionContext | None = None
|
||||
"""Additional, optional context for completions"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
|
||||
"""A request from the client to the server, to ask for completion options."""
|
||||
|
||||
method: Literal["completion/complete"]
|
||||
method: Literal["completion/complete"] = "completion/complete"
|
||||
params: CompleteRequestParams
|
||||
|
||||
|
||||
@ -1010,7 +1121,7 @@ class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
|
||||
structure or access specific locations that the client has permission to read from.
|
||||
"""
|
||||
|
||||
method: Literal["roots/list"]
|
||||
method: Literal["roots/list"] = "roots/list"
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
@ -1029,6 +1140,11 @@ class Root(BaseModel):
|
||||
identifier for the root, which may be useful for display purposes or for
|
||||
referencing the root in other parts of the application.
|
||||
"""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -1054,7 +1170,7 @@ class RootsListChangedNotification(
|
||||
using the ListRootsRequest.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/roots/list_changed"]
|
||||
method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@ -1074,7 +1190,7 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n
|
||||
previously-issued request.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/cancelled"]
|
||||
method: Literal["notifications/cancelled"] = "notifications/cancelled"
|
||||
params: CancelledNotificationParams
|
||||
|
||||
|
||||
|
||||
@ -217,3 +217,16 @@ class Tool(ABC):
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
|
||||
)
|
||||
|
||||
def create_variable_message(
|
||||
self, variable_name: str, variable_value: Any, stream: bool = False
|
||||
) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a variable message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.VARIABLE,
|
||||
message=ToolInvokeMessage.VariableMessage(
|
||||
variable_name=variable_name, variable_value=variable_value, stream=stream
|
||||
),
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -44,10 +45,14 @@ class ToolProviderApiEntity(BaseModel):
|
||||
server_url: str | None = Field(default="", description="The server url of the tool")
|
||||
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
|
||||
timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool")
|
||||
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
|
||||
|
||||
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
|
||||
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
|
||||
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
|
||||
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
|
||||
configuration: MCPConfiguration | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
@ -70,8 +75,15 @@ class ToolProviderApiEntity(BaseModel):
|
||||
if self.type == ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(self.optional_field("timeout", self.timeout))
|
||||
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||
)
|
||||
)
|
||||
optional_fields.update(
|
||||
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
|
||||
)
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
return {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from typing import Any, Self
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@ -52,18 +52,25 @@ class MCPToolProviderController(ToolProviderController):
|
||||
"""
|
||||
from db provider
|
||||
"""
|
||||
tools = []
|
||||
tools_data = json.loads(db_provider.tools)
|
||||
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
|
||||
user = db_provider.load_user()
|
||||
# Convert to entity first
|
||||
provider_entity = db_provider.to_entity()
|
||||
return cls.from_entity(provider_entity)
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, entity: MCPProviderEntity) -> Self:
|
||||
"""
|
||||
create a MCPToolProviderController from a MCPProviderEntity
|
||||
"""
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
|
||||
|
||||
tools = [
|
||||
ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else "Anonymous",
|
||||
author="Anonymous", # Tool level author is not stored
|
||||
name=remote_mcp_tool.name,
|
||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||
provider=db_provider.server_identifier,
|
||||
icon=db_provider.icon,
|
||||
provider=entity.provider_id,
|
||||
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||
),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||
description=ToolDescription(
|
||||
@ -72,31 +79,32 @@ class MCPToolProviderController(ToolProviderController):
|
||||
),
|
||||
llm=remote_mcp_tool.description or "",
|
||||
),
|
||||
output_schema=remote_mcp_tool.outputSchema or {},
|
||||
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
||||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
]
|
||||
if not db_provider.icon:
|
||||
if not entity.icon:
|
||||
raise ValueError("Database provider icon is required")
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user.name if user else "Anonymous",
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
author="Anonymous", # Provider level author is not stored in entity
|
||||
name=entity.name,
|
||||
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
icon=db_provider.icon,
|
||||
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||
),
|
||||
plugin_id=None,
|
||||
credentials_schema=[],
|
||||
tools=tools,
|
||||
),
|
||||
provider_id=db_provider.server_identifier or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
server_url=db_provider.decrypted_server_url,
|
||||
headers=db_provider.decrypted_headers or {},
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
provider_id=entity.provider_id,
|
||||
tenant_id=entity.tenant_id,
|
||||
server_url=entity.server_url,
|
||||
headers=entity.headers,
|
||||
timeout=entity.timeout,
|
||||
sse_read_timeout=entity.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
|
||||
@ -3,12 +3,13 @@ import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import ImageContent, TextContent
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
@ -44,40 +45,32 @@ class MCPTool(Tool):
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
self.server_url,
|
||||
self.provider_id,
|
||||
self.tenant_id,
|
||||
authed=True,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPAuthError as e:
|
||||
raise ToolInvokeError("Please auth the tool first") from e
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
# handle dify tool output
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
yield from self._process_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self._process_image_content(content)
|
||||
# handle MCP structured output
|
||||
if self.entity.output_schema and result.structuredContent:
|
||||
for k, v in result.structuredContent.items():
|
||||
yield self.create_variable_message(k, v)
|
||||
|
||||
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process text content and yield appropriate messages."""
|
||||
try:
|
||||
content_json = json.loads(content.text)
|
||||
yield from self._process_json_content(content_json)
|
||||
except json.JSONDecodeError:
|
||||
yield self.create_text_message(content.text)
|
||||
# Check if content looks like JSON before attempting to parse
|
||||
text = content.text.strip()
|
||||
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
|
||||
try:
|
||||
content_json = json.loads(text)
|
||||
yield from self._process_json_content(content_json)
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If not JSON or parsing failed, treat as plain text
|
||||
yield self.create_text_message(content.text)
|
||||
|
||||
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process JSON content based on its type."""
|
||||
@ -126,3 +119,44 @@ class MCPTool(Tool):
|
||||
for key, value in parameter.items()
|
||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||
}
|
||||
|
||||
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
|
||||
headers = self.headers.copy() if self.headers else {}
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
# Step 1: Load provider entity and credentials in a short-lived session
|
||||
# This minimizes database connection hold time
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
|
||||
# Decrypt and prepare all credentials before closing session
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_headers()
|
||||
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
@ -14,17 +14,32 @@ from sqlalchemy.orm import Session
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
@ -40,21 +55,11 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -719,7 +724,9 @@ class ToolManager:
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
if "mcp" in filters:
|
||||
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
|
||||
for mcp_provider in mcp_providers:
|
||||
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
||||
|
||||
@ -774,17 +781,12 @@ class ToolManager:
|
||||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(
|
||||
MCPToolProvider.server_identifier == provider_id,
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
try:
|
||||
provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
|
||||
except ValueError:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
controller = MCPToolProviderController.from_db(provider)
|
||||
|
||||
@ -922,16 +924,15 @@ class ToolManager:
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
|
||||
try:
|
||||
mcp_provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mcp_provider is None:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
return mcp_provider.provider_icon
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
try:
|
||||
mcp_provider = mcp_service.get_provider_entity(
|
||||
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
|
||||
)
|
||||
return mcp_provider.provider_icon
|
||||
except ValueError:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user