feat: add client credentials auth

This commit is contained in:
Novice
2025-10-09 17:54:46 +08:00
parent 3592240d14
commit 740f970041
10 changed files with 609 additions and 142 deletions

View File

@ -9,6 +9,7 @@ from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.mcp_provider import MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
@ -21,7 +22,12 @@ from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
# Constants
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
DEFAULT_GRANT_TYPE = "authorization_code"
CLIENT_NAME = "Dify"
EMPTY_TOOLS_JSON = "[]"
EMPTY_CREDENTIALS_JSON = "{}"
class MCPToolManageService:
@ -85,6 +91,10 @@ class MCPToolManageService:
timeout: float,
sse_read_timeout: float,
headers: dict[str, str] | None = None,
client_id: str | None = None,
client_secret: str | None = None,
grant_type: str = DEFAULT_GRANT_TYPE,
scope: str | None = None,
) -> ToolProviderApiEntity:
"""Create a new MCP provider."""
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
@ -94,8 +104,14 @@ class MCPToolManageService:
# Encrypt sensitive data
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
encrypted_headers = self._prepare_encrypted_headers(headers, tenant_id) if headers else None
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
if client_id and client_secret:
# Build the full credentials structure with encrypted client_id and client_secret
encrypted_credentials = self._build_and_encrypt_credentials(
client_id, client_secret, grant_type, scope, tenant_id
)
else:
encrypted_credentials = None
# Create provider
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
@ -104,12 +120,13 @@ class MCPToolManageService:
server_url_hash=server_url_hash,
user_id=user_id,
authed=False,
tools="[]",
tools=EMPTY_TOOLS_JSON,
icon=self._prepare_icon(icon, icon_type, icon_background),
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
encrypted_headers=encrypted_headers,
encrypted_credentials=encrypted_credentials,
)
self._session.add(mcp_tool)
@ -131,6 +148,10 @@ class MCPToolManageService:
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
client_id: str | None = None,
client_secret: str | None = None,
grant_type: str | None = None,
scope: str | None = None,
) -> None:
"""Update an MCP provider."""
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
@ -176,11 +197,31 @@ class MCPToolManageService:
if headers:
# Build headers preserving unchanged masked values
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
encrypted_headers_dict = self._prepare_encrypted_headers(final_headers, tenant_id)
encrypted_headers_dict = self._prepare_encrypted_dict(final_headers, tenant_id)
mcp_provider.encrypted_headers = encrypted_headers_dict
else:
# Clear headers if empty dict passed
mcp_provider.encrypted_headers = None
# Update credentials if provided
if client_id is not None and client_secret is not None:
# Merge with existing credentials to handle masked values
(
final_client_id,
final_client_secret,
final_grant_type,
final_scope,
) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, scope, mcp_provider)
# Use default grant_type if none found
final_grant_type = final_grant_type or DEFAULT_GRANT_TYPE
# Build and encrypt new credentials
encrypted_credentials = self._build_and_encrypt_credentials(
final_client_id, final_client_secret, final_grant_type, final_scope, tenant_id
)
mcp_provider.encrypted_credentials = encrypted_credentials
self._session.commit()
except IntegrityError as e:
self._session.rollback()
@ -271,7 +312,7 @@ class MCPToolManageService:
if authed is not None:
provider.authed = authed
if not authed:
provider.tools = "[]"
provider.tools = EMPTY_TOOLS_JSON
self._session.commit()
@ -287,28 +328,15 @@ class MCPToolManageService:
"""
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
credentials = {}
authed = None
# Determine if this makes the provider authenticated
authed = data_type == "tokens" or (data_type == "mixed" and "access_token" in data) or None
if data_type == "tokens" or (data_type == "mixed" and "access_token" in data):
# OAuth tokens
credentials = data
authed = True
elif data_type == "client_info" or (data_type == "mixed" and "client_information" in data):
# OAuth client information
credentials = data
elif data_type == "code_verifier" or (data_type == "mixed" and "code_verifier" in data):
# PKCE code verifier
credentials = data
else:
credentials = data
self.update_provider_credentials(provider=db_provider, credentials=credentials, authed=authed)
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
"""Clear all credentials for a provider."""
provider.tools = "[]"
provider.encrypted_credentials = "{}"
provider.tools = EMPTY_TOOLS_JSON
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
provider.updated_at = datetime.now()
provider.authed = False
self._session.commit()
@ -341,13 +369,24 @@ class MCPToolManageService:
return json.dumps({"content": icon, "background": icon_background})
return icon
def _prepare_encrypted_headers(self, headers: dict[str, str], tenant_id: str) -> str:
"""Encrypt headers and prepare for storage."""
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> str:
"""Encrypt specified fields in a dictionary.
Args:
data: Dictionary containing data to encrypt
secret_fields: List of field names to encrypt
tenant_id: Tenant ID for encryption
Returns:
JSON string of encrypted data
"""
from core.entities.provider_entities import BasicProviderConfig
from core.tools.utils.encryption import create_provider_encrypter
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
# Create config for secret fields
config = [
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields
]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
@ -355,8 +394,13 @@ class MCPToolManageService:
cache=NoOpProviderCredentialCache(),
)
encrypted_headers_dict = encrypter_instance.encrypt(headers)
return json.dumps(encrypted_headers_dict)
encrypted_data = encrypter_instance.encrypt(data)
return json.dumps(encrypted_data)
def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str:
"""Encrypt headers and prepare for storage."""
# All headers are treated as secret
return self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id)
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
"""Prepare headers with OAuth token if available."""
@ -391,27 +435,18 @@ class MCPToolManageService:
provider_entity = provider.to_entity()
headers = provider_entity.headers
timeout = provider_entity.timeout
sse_read_timeout = provider_entity.sse_read_timeout
try:
with MCPClientWithAuthRetry(
server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
provider_entity=provider_entity,
auth_callback=lambda p, s, c: auth(p, self, c),
mcp_service=self,
) as mcp_client:
tools = mcp_client.list_tools()
return {
"authed": True,
"tools": json.dumps([tool.model_dump() for tool in tools]),
"encrypted_credentials": "{}",
}
tools = self._retrieve_remote_mcp_tools(
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
)
return {
"authed": True,
"tools": json.dumps([tool.model_dump() for tool in tools]),
"encrypted_credentials": EMPTY_CREDENTIALS_JSON,
}
except MCPAuthError:
return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
return {"authed": False, "tools": EMPTY_TOOLS_JSON, "encrypted_credentials": EMPTY_CREDENTIALS_JSON}
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
@ -461,3 +496,76 @@ class MCPToolManageService:
for key, value in incoming_headers.items()
if key in existing_decrypted or value != existing_masked.get(key)
}
def _merge_credentials_with_masked(
self,
client_id: str,
client_secret: str,
grant_type: str | None,
scope: str | None,
mcp_provider: MCPToolProvider,
) -> tuple[str, str, str | None, str | None]:
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
Args:
client_id: Client ID from frontend (may be masked)
client_secret: Client secret from frontend (may be masked)
grant_type: Grant type from frontend
scope: OAuth scope from frontend
mcp_provider: The MCP provider instance
Returns:
Tuple of (final_client_id, final_client_secret, grant_type, scope)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_credentials()
existing_masked = mcp_provider_entity.masked_credentials()
# Check if client_id is masked and unchanged
final_client_id = client_id
if existing_masked.get("client_id") and client_id == existing_masked["client_id"]:
# Use existing decrypted value
final_client_id = existing_decrypted.get("client_id", client_id)
# Check if client_secret is masked and unchanged
final_client_secret = client_secret
if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]:
# Use existing decrypted value
final_client_secret = existing_decrypted.get("client_secret", client_secret)
# Grant type and scope are not masked, use as is
final_grant_type = grant_type if grant_type is not None else existing_decrypted.get("grant_type")
final_scope = scope if scope is not None else existing_decrypted.get("scope")
return final_client_id, final_client_secret, final_grant_type, final_scope
def _build_and_encrypt_credentials(
self, client_id: str, client_secret: str, grant_type: str, scope: str | None, tenant_id: str
) -> str:
"""Build credentials and encrypt sensitive fields."""
# Create a flat structure with all credential data
credentials_data = {
"client_id": client_id,
"client_secret": client_secret,
"grant_type": grant_type,
"client_name": CLIENT_NAME,
}
if scope:
credentials_data["scope"] = scope
# Add grant types and response types based on grant_type
if grant_type == "client_credentials":
credentials_data["grant_types"] = json.dumps(["client_credentials"])
credentials_data["response_types"] = json.dumps([])
credentials_data["redirect_uris"] = json.dumps([])
else:
credentials_data["grant_types"] = json.dumps(["authorization_code", "refresh_token"])
credentials_data["response_types"] = json.dumps(["code"])
credentials_data["redirect_uris"] = json.dumps(
[f"{dify_config.CONSOLE_API_URL}/console/api/mcp/oauth/callback"]
)
# Only client_id and client_secret need encryption
secret_fields = ["client_id", "client_secret"]
return self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)