mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
feat: add client credentials auth
This commit is contained in:
@ -18,6 +18,15 @@ from core.tools.utils.encryption import create_provider_encrypter
|
||||
if TYPE_CHECKING:
|
||||
from models.tools import MCPToolProvider
|
||||
|
||||
# Constants
|
||||
DEFAULT_GRANT_TYPE = "authorization_code"
|
||||
CLIENT_NAME = "Dify"
|
||||
CLIENT_URI = "https://github.com/langgenius/dify"
|
||||
DEFAULT_TOKEN_TYPE = "Bearer"
|
||||
DEFAULT_EXPIRES_IN = 3600
|
||||
MASK_CHAR = "*"
|
||||
MIN_UNMASK_LENGTH = 6
|
||||
|
||||
|
||||
class MCPProviderEntity(BaseModel):
|
||||
"""MCP Provider domain entity for business logic operations"""
|
||||
@ -78,13 +87,38 @@ class MCPProviderEntity(BaseModel):
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
# Get grant type from credentials
|
||||
credentials = self.decrypt_credentials()
|
||||
|
||||
# Try to get grant_type from different locations
|
||||
grant_type = credentials.get("grant_type", DEFAULT_GRANT_TYPE)
|
||||
|
||||
# For nested structure, check if client_information has grant_types
|
||||
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
|
||||
client_info = credentials["client_information"]
|
||||
# If grant_types is specified in client_information, use it to determine grant_type
|
||||
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
|
||||
if "client_credentials" in client_info["grant_types"]:
|
||||
grant_type = "client_credentials"
|
||||
elif "authorization_code" in client_info["grant_types"]:
|
||||
grant_type = "authorization_code"
|
||||
|
||||
# Configure based on grant type
|
||||
is_client_credentials = grant_type == "client_credentials"
|
||||
|
||||
grant_types = ["refresh_token"]
|
||||
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
|
||||
|
||||
response_types = [] if is_client_credentials else ["code"]
|
||||
redirect_uris = [] if is_client_credentials else [self.redirect_url]
|
||||
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=[self.redirect_url],
|
||||
redirect_uris=redirect_uris,
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
client_name="Dify",
|
||||
client_uri="https://github.com/langgenius/dify",
|
||||
grant_types=grant_types,
|
||||
response_types=response_types,
|
||||
client_name=CLIENT_NAME,
|
||||
client_uri=CLIENT_URI,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -100,7 +134,7 @@ class MCPProviderEntity(BaseModel):
|
||||
|
||||
def to_api_response(self, user_name: str | None = None) -> dict[str, Any]:
|
||||
"""Convert to API response format"""
|
||||
return {
|
||||
response = {
|
||||
"id": self.id,
|
||||
"author": user_name or "Anonymous",
|
||||
"name": self.name,
|
||||
@ -117,11 +151,50 @@ class MCPProviderEntity(BaseModel):
|
||||
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||
}
|
||||
|
||||
# Add masked credentials if they exist
|
||||
masked_creds = self.masked_credentials()
|
||||
if masked_creds:
|
||||
response.update(masked_creds)
|
||||
|
||||
return response
|
||||
|
||||
def retrieve_client_information(self) -> OAuthClientInformation | None:
|
||||
"""OAuth client information if available"""
|
||||
client_info = self.decrypt_credentials().get("client_information", {})
|
||||
if not client_info:
|
||||
credentials = self.decrypt_credentials()
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# Check if we have nested client_information structure
|
||||
if "client_information" in credentials:
|
||||
# Handle nested structure (Authorization Code flow)
|
||||
client_info_data = credentials["client_information"]
|
||||
if isinstance(client_info_data, dict):
|
||||
return OAuthClientInformation.model_validate(client_info_data)
|
||||
return None
|
||||
|
||||
# Handle flat structure (Client Credentials flow)
|
||||
if "client_id" not in credentials:
|
||||
return None
|
||||
|
||||
# Build client information from flat structure
|
||||
client_info = {
|
||||
"client_id": credentials.get("client_id", ""),
|
||||
"client_secret": credentials.get("client_secret", ""),
|
||||
"client_name": credentials.get("client_name", CLIENT_NAME),
|
||||
}
|
||||
|
||||
# Parse JSON fields if they exist
|
||||
json_fields = ["redirect_uris", "grant_types", "response_types"]
|
||||
for field in json_fields:
|
||||
if field in credentials:
|
||||
try:
|
||||
client_info[field] = json.loads(credentials[field])
|
||||
except:
|
||||
client_info[field] = []
|
||||
|
||||
if "scope" in credentials:
|
||||
client_info["scope"] = credentials["scope"]
|
||||
|
||||
return OAuthClientInformation.model_validate(client_info)
|
||||
|
||||
def retrieve_tokens(self) -> OAuthTokens | None:
|
||||
@ -131,8 +204,8 @@ class MCPProviderEntity(BaseModel):
|
||||
credentials = self.decrypt_credentials()
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", "Bearer"),
|
||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
||||
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
|
||||
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
@ -144,30 +217,77 @@ class MCPProviderEntity(BaseModel):
|
||||
return f"{base_url}/******"
|
||||
return base_url
|
||||
|
||||
def _mask_value(self, value: str) -> str:
|
||||
"""Mask a sensitive value for display"""
|
||||
if len(value) > MIN_UNMASK_LENGTH:
|
||||
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
|
||||
else:
|
||||
return MASK_CHAR * len(value)
|
||||
|
||||
def masked_headers(self) -> dict[str, str]:
|
||||
"""Masked headers for display"""
|
||||
masked: dict[str, str] = {}
|
||||
for key, value in self.decrypt_headers().items():
|
||||
if len(value) > 6:
|
||||
masked[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
|
||||
else:
|
||||
masked[key] = "*" * len(value)
|
||||
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
|
||||
|
||||
def masked_credentials(self) -> dict[str, str]:
|
||||
"""Masked credentials for display"""
|
||||
credentials = self.decrypt_credentials()
|
||||
if not credentials:
|
||||
return {}
|
||||
|
||||
masked = {}
|
||||
|
||||
# Check if we have nested client_information structure
|
||||
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
|
||||
client_info = credentials["client_information"]
|
||||
# Mask sensitive fields from nested structure
|
||||
if client_info.get("client_id"):
|
||||
masked["client_id"] = self._mask_value(client_info["client_id"])
|
||||
if client_info.get("client_secret"):
|
||||
masked["client_secret"] = self._mask_value(client_info["client_secret"])
|
||||
else:
|
||||
# Handle flat structure
|
||||
# Mask sensitive fields
|
||||
sensitive_fields = ["client_id", "client_secret"]
|
||||
for field in sensitive_fields:
|
||||
if credentials.get(field):
|
||||
masked[field] = self._mask_value(credentials[field])
|
||||
|
||||
# Include non-sensitive fields (check both flat and nested structures)
|
||||
if "grant_type" in credentials:
|
||||
masked["grant_type"] = credentials["grant_type"]
|
||||
if "scope" in credentials:
|
||||
masked["scope"] = credentials["scope"]
|
||||
|
||||
return masked
|
||||
|
||||
def decrypt_server_url(self) -> str:
|
||||
"""Decrypt server URL"""
|
||||
|
||||
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
||||
|
||||
def decrypt_headers(self) -> dict[str, Any]:
|
||||
"""Decrypt headers"""
|
||||
|
||||
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Generic method to decrypt dictionary fields"""
|
||||
try:
|
||||
if not self.headers:
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in self.headers]
|
||||
# Only decrypt fields that are actually encrypted
|
||||
# For nested structures, client_information is not encrypted as a whole
|
||||
encrypted_fields = []
|
||||
for key, value in data.items():
|
||||
# Skip nested objects - they are not encrypted
|
||||
if isinstance(value, dict):
|
||||
continue
|
||||
# Only process string values that might be encrypted
|
||||
if isinstance(value, str) and value:
|
||||
encrypted_fields.append(key)
|
||||
|
||||
if not encrypted_fields:
|
||||
return data
|
||||
|
||||
# Create dynamic config only for encrypted fields
|
||||
config = [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields
|
||||
]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -175,28 +295,21 @@ class MCPProviderEntity(BaseModel):
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
result = encrypter_instance.decrypt(self.headers)
|
||||
# Decrypt only the encrypted fields
|
||||
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
|
||||
|
||||
# Merge decrypted data with original data (preserving non-encrypted fields)
|
||||
result = data.copy()
|
||||
result.update(decrypted_data)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def decrypt_credentials(
|
||||
self,
|
||||
) -> dict[str, Any]:
|
||||
def decrypt_headers(self) -> dict[str, Any]:
|
||||
"""Decrypt headers"""
|
||||
return self._decrypt_dict(self.headers)
|
||||
|
||||
def decrypt_credentials(self) -> dict[str, Any]:
|
||||
"""Decrypt credentials"""
|
||||
try:
|
||||
if not self.credentials:
|
||||
return {}
|
||||
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=[
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key)
|
||||
for key in self.credentials
|
||||
],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.decrypt(self.credentials)
|
||||
except Exception:
|
||||
return {}
|
||||
return self._decrypt_dict(self.credentials)
|
||||
|
||||
@ -106,8 +106,8 @@ def handle_callback(state_key: str, authorization_code: str, mcp_service: "MCPTo
|
||||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
||||
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
|
||||
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
|
||||
if b_query:
|
||||
url_for_resource_discovery += f"?{b_query}"
|
||||
if b_fragment:
|
||||
@ -117,7 +117,10 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
response = httpx.get(url_for_resource_discovery, headers=headers)
|
||||
if 200 <= response.status_code < 300:
|
||||
body = response.json()
|
||||
if "authorization_server_url" in body:
|
||||
# Support both singular and plural forms
|
||||
if body.get("authorization_servers"):
|
||||
return True, body["authorization_servers"][0]
|
||||
elif body.get("authorization_server_url"):
|
||||
return True, body["authorization_server_url"][0]
|
||||
else:
|
||||
return False, ""
|
||||
@ -132,27 +135,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
|
||||
# First check if the server supports OAuth 2.0 Resource Discovery
|
||||
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
||||
if support_resource_discovery:
|
||||
url = oauth_discovery_url
|
||||
# The oauth_discovery_url is the authorization server base URL
|
||||
# Try OpenID Connect discovery first (more common), then OAuth 2.0
|
||||
urls_to_try = [
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
|
||||
]
|
||||
else:
|
||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
|
||||
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
response = httpx.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except httpx.RequestError as e:
|
||||
if isinstance(e, httpx.ConnectError):
|
||||
response = httpx.get(url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
|
||||
for url in urls_to_try:
|
||||
try:
|
||||
response = httpx.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
continue # Try next URL
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
raise
|
||||
except httpx.RequestError as e:
|
||||
if isinstance(e, httpx.ConnectError):
|
||||
response = httpx.get(url)
|
||||
if response.status_code == 404:
|
||||
continue # Try next URL
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
# For other errors, try next URL
|
||||
continue
|
||||
|
||||
return None # No metadata found
|
||||
|
||||
|
||||
def start_authorization(
|
||||
@ -276,6 +289,49 @@ def refresh_authorization(
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
def client_credentials_flow(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
client_information: OAuthClientInformation,
|
||||
scope: str | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Execute Client Credentials Flow to get access token."""
|
||||
grant_type = "client_credentials"
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
# Support both Basic Auth and body parameters for client authentication
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
data = {"grant_type": grant_type}
|
||||
|
||||
if scope:
|
||||
data["scope"] = scope
|
||||
|
||||
# If client_secret is provided, use Basic Auth (preferred method)
|
||||
if client_information.client_secret:
|
||||
credentials = f"{client_information.client_id}:{client_information.client_secret}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
else:
|
||||
# Fall back to including credentials in the body
|
||||
data["client_id"] = client_information.client_id
|
||||
if client_information.client_secret:
|
||||
data["client_secret"] = client_information.client_secret
|
||||
|
||||
response = httpx.post(token_url, headers=headers, data=data)
|
||||
if not response.is_success:
|
||||
raise ValueError(
|
||||
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
def register_client(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
@ -304,6 +360,7 @@ def auth(
|
||||
mcp_service: "MCPToolManageService",
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
grant_type: str = "authorization_code",
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
server_url = provider.decrypt_server_url()
|
||||
@ -314,9 +371,22 @@ def auth(
|
||||
client_information = provider.retrieve_client_information()
|
||||
redirect_url = provider.redirect_url
|
||||
|
||||
# Check if we should use client credentials flow
|
||||
credentials = provider.decrypt_credentials()
|
||||
stored_grant_type = credentials.get("grant_type", "authorization_code")
|
||||
|
||||
# Use stored grant type if available, otherwise use parameter
|
||||
effective_grant_type = stored_grant_type or grant_type
|
||||
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
|
||||
# For client credentials flow, we don't need to register client dynamically
|
||||
if effective_grant_type == "client_credentials":
|
||||
# Client should provide client_id and client_secret directly
|
||||
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
|
||||
|
||||
try:
|
||||
full_information = register_client(server_url, server_metadata, client_metadata)
|
||||
except httpx.RequestError as e:
|
||||
@ -329,7 +399,28 @@ def auth(
|
||||
|
||||
client_information = full_information
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
# Handle client credentials flow
|
||||
if effective_grant_type == "client_credentials":
|
||||
# Direct token request without user interaction
|
||||
try:
|
||||
scope = credentials.get("scope")
|
||||
tokens = client_credentials_flow(
|
||||
server_url,
|
||||
server_metadata,
|
||||
client_information,
|
||||
scope,
|
||||
)
|
||||
|
||||
# Save tokens and grant type
|
||||
token_data = tokens.model_dump()
|
||||
token_data["grant_type"] = "client_credentials"
|
||||
mcp_service.save_oauth_data(provider_id, tenant_id, token_data, "tokens")
|
||||
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Client credentials flow failed: {e}")
|
||||
|
||||
# Exchange authorization code for tokens (Authorization Code flow)
|
||||
if authorization_code is not None:
|
||||
if not state_param:
|
||||
raise ValueError("State parameter is required when exchanging authorization code")
|
||||
@ -377,7 +468,7 @@ def auth(
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||
|
||||
# Start new authorization flow
|
||||
# Start new authorization flow (only for authorization code flow)
|
||||
authorization_url, code_verifier = start_authorization(
|
||||
server_url,
|
||||
server_metadata,
|
||||
|
||||
@ -47,6 +47,11 @@ class ToolProviderApiEntity(BaseModel):
|
||||
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
|
||||
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
|
||||
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
|
||||
# MCP OAuth credentials
|
||||
client_id: str | None = Field(default=None, description="The masked client ID for OAuth")
|
||||
client_secret: str | None = Field(default=None, description="The masked client secret for OAuth")
|
||||
grant_type: str | None = Field(default=None, description="The OAuth grant type")
|
||||
scope: str | None = Field(default=None, description="The OAuth scope")
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
@ -72,6 +77,10 @@ class ToolProviderApiEntity(BaseModel):
|
||||
optional_fields.update(self.optional_field("timeout", self.timeout))
|
||||
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("client_id", self.client_id))
|
||||
optional_fields.update(self.optional_field("client_secret", self.client_secret))
|
||||
optional_fields.update(self.optional_field("grant_type", self.grant_type))
|
||||
optional_fields.update(self.optional_field("scope", self.scope))
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
|
||||
Reference in New Issue
Block a user