mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
feat: add client credentials auth
This commit is contained in:
@ -9,6 +9,7 @@ from sqlalchemy import or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
@ -21,7 +22,12 @@ from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
||||
DEFAULT_GRANT_TYPE = "authorization_code"
|
||||
CLIENT_NAME = "Dify"
|
||||
EMPTY_TOOLS_JSON = "[]"
|
||||
EMPTY_CREDENTIALS_JSON = "{}"
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
@ -85,6 +91,10 @@ class MCPToolManageService:
|
||||
timeout: float,
|
||||
sse_read_timeout: float,
|
||||
headers: dict[str, str] | None = None,
|
||||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
grant_type: str = DEFAULT_GRANT_TYPE,
|
||||
scope: str | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Create a new MCP provider."""
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
@ -94,8 +104,14 @@ class MCPToolManageService:
|
||||
|
||||
# Encrypt sensitive data
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
encrypted_headers = self._prepare_encrypted_headers(headers, tenant_id) if headers else None
|
||||
|
||||
encrypted_headers = self._prepare_encrypted_dict(headers, tenant_id) if headers else None
|
||||
if client_id and client_secret:
|
||||
# Build the full credentials structure with encrypted client_id and client_secret
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
client_id, client_secret, grant_type, scope, tenant_id
|
||||
)
|
||||
else:
|
||||
encrypted_credentials = None
|
||||
# Create provider
|
||||
mcp_tool = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
@ -104,12 +120,13 @@ class MCPToolManageService:
|
||||
server_url_hash=server_url_hash,
|
||||
user_id=user_id,
|
||||
authed=False,
|
||||
tools="[]",
|
||||
tools=EMPTY_TOOLS_JSON,
|
||||
icon=self._prepare_icon(icon, icon_type, icon_background),
|
||||
server_identifier=server_identifier,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
encrypted_headers=encrypted_headers,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
)
|
||||
|
||||
self._session.add(mcp_tool)
|
||||
@ -131,6 +148,10 @@ class MCPToolManageService:
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
grant_type: str | None = None,
|
||||
scope: str | None = None,
|
||||
) -> None:
|
||||
"""Update an MCP provider."""
|
||||
mcp_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
@ -176,11 +197,31 @@ class MCPToolManageService:
|
||||
if headers:
|
||||
# Build headers preserving unchanged masked values
|
||||
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
|
||||
encrypted_headers_dict = self._prepare_encrypted_headers(final_headers, tenant_id)
|
||||
encrypted_headers_dict = self._prepare_encrypted_dict(final_headers, tenant_id)
|
||||
mcp_provider.encrypted_headers = encrypted_headers_dict
|
||||
else:
|
||||
# Clear headers if empty dict passed
|
||||
mcp_provider.encrypted_headers = None
|
||||
|
||||
# Update credentials if provided
|
||||
if client_id is not None and client_secret is not None:
|
||||
# Merge with existing credentials to handle masked values
|
||||
(
|
||||
final_client_id,
|
||||
final_client_secret,
|
||||
final_grant_type,
|
||||
final_scope,
|
||||
) = self._merge_credentials_with_masked(client_id, client_secret, grant_type, scope, mcp_provider)
|
||||
|
||||
# Use default grant_type if none found
|
||||
final_grant_type = final_grant_type or DEFAULT_GRANT_TYPE
|
||||
|
||||
# Build and encrypt new credentials
|
||||
encrypted_credentials = self._build_and_encrypt_credentials(
|
||||
final_client_id, final_client_secret, final_grant_type, final_scope, tenant_id
|
||||
)
|
||||
mcp_provider.encrypted_credentials = encrypted_credentials
|
||||
|
||||
self._session.commit()
|
||||
except IntegrityError as e:
|
||||
self._session.rollback()
|
||||
@ -271,7 +312,7 @@ class MCPToolManageService:
|
||||
if authed is not None:
|
||||
provider.authed = authed
|
||||
if not authed:
|
||||
provider.tools = "[]"
|
||||
provider.tools = EMPTY_TOOLS_JSON
|
||||
|
||||
self._session.commit()
|
||||
|
||||
@ -287,28 +328,15 @@ class MCPToolManageService:
|
||||
"""
|
||||
db_provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
credentials = {}
|
||||
authed = None
|
||||
# Determine if this makes the provider authenticated
|
||||
authed = data_type == "tokens" or (data_type == "mixed" and "access_token" in data) or None
|
||||
|
||||
if data_type == "tokens" or (data_type == "mixed" and "access_token" in data):
|
||||
# OAuth tokens
|
||||
credentials = data
|
||||
authed = True
|
||||
elif data_type == "client_info" or (data_type == "mixed" and "client_information" in data):
|
||||
# OAuth client information
|
||||
credentials = data
|
||||
elif data_type == "code_verifier" or (data_type == "mixed" and "code_verifier" in data):
|
||||
# PKCE code verifier
|
||||
credentials = data
|
||||
else:
|
||||
credentials = data
|
||||
|
||||
self.update_provider_credentials(provider=db_provider, credentials=credentials, authed=authed)
|
||||
self.update_provider_credentials(provider=db_provider, credentials=data, authed=authed)
|
||||
|
||||
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
|
||||
"""Clear all credentials for a provider."""
|
||||
provider.tools = "[]"
|
||||
provider.encrypted_credentials = "{}"
|
||||
provider.tools = EMPTY_TOOLS_JSON
|
||||
provider.encrypted_credentials = EMPTY_CREDENTIALS_JSON
|
||||
provider.updated_at = datetime.now()
|
||||
provider.authed = False
|
||||
self._session.commit()
|
||||
@ -341,13 +369,24 @@ class MCPToolManageService:
|
||||
return json.dumps({"content": icon, "background": icon_background})
|
||||
return icon
|
||||
|
||||
def _prepare_encrypted_headers(self, headers: dict[str, str], tenant_id: str) -> str:
|
||||
"""Encrypt headers and prepare for storage."""
|
||||
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> str:
|
||||
"""Encrypt specified fields in a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing data to encrypt
|
||||
secret_fields: List of field names to encrypt
|
||||
tenant_id: Tenant ID for encryption
|
||||
|
||||
Returns:
|
||||
JSON string of encrypted data
|
||||
"""
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
|
||||
# Create config for secret fields
|
||||
config = [
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=field) for field in secret_fields
|
||||
]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
@ -355,8 +394,13 @@ class MCPToolManageService:
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
encrypted_headers_dict = encrypter_instance.encrypt(headers)
|
||||
return json.dumps(encrypted_headers_dict)
|
||||
encrypted_data = encrypter_instance.encrypt(data)
|
||||
return json.dumps(encrypted_data)
|
||||
|
||||
def _prepare_encrypted_dict(self, headers: dict[str, str], tenant_id: str) -> str:
|
||||
"""Encrypt headers and prepare for storage."""
|
||||
# All headers are treated as secret
|
||||
return self._encrypt_dict_fields(headers, list(headers.keys()), tenant_id)
|
||||
|
||||
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
|
||||
"""Prepare headers with OAuth token if available."""
|
||||
@ -391,27 +435,18 @@ class MCPToolManageService:
|
||||
|
||||
provider_entity = provider.to_entity()
|
||||
headers = provider_entity.headers
|
||||
timeout = provider_entity.timeout
|
||||
sse_read_timeout = provider_entity.sse_read_timeout
|
||||
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=lambda p, s, c: auth(p, self, c),
|
||||
mcp_service=self,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return {
|
||||
"authed": True,
|
||||
"tools": json.dumps([tool.model_dump() for tool in tools]),
|
||||
"encrypted_credentials": "{}",
|
||||
}
|
||||
tools = self._retrieve_remote_mcp_tools(
|
||||
server_url, headers, provider_entity, lambda p, s, c: auth(p, self, c)
|
||||
)
|
||||
return {
|
||||
"authed": True,
|
||||
"tools": json.dumps([tool.model_dump() for tool in tools]),
|
||||
"encrypted_credentials": EMPTY_CREDENTIALS_JSON,
|
||||
}
|
||||
except MCPAuthError:
|
||||
return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
|
||||
return {"authed": False, "tools": EMPTY_TOOLS_JSON, "encrypted_credentials": EMPTY_CREDENTIALS_JSON}
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
@ -461,3 +496,76 @@ class MCPToolManageService:
|
||||
for key, value in incoming_headers.items()
|
||||
if key in existing_decrypted or value != existing_masked.get(key)
|
||||
}
|
||||
|
||||
def _merge_credentials_with_masked(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
grant_type: str | None,
|
||||
scope: str | None,
|
||||
mcp_provider: MCPToolProvider,
|
||||
) -> tuple[str, str, str | None, str | None]:
|
||||
"""Merge incoming credentials with existing ones, preserving unchanged masked values.
|
||||
|
||||
Args:
|
||||
client_id: Client ID from frontend (may be masked)
|
||||
client_secret: Client secret from frontend (may be masked)
|
||||
grant_type: Grant type from frontend
|
||||
scope: OAuth scope from frontend
|
||||
mcp_provider: The MCP provider instance
|
||||
|
||||
Returns:
|
||||
Tuple of (final_client_id, final_client_secret, grant_type, scope)
|
||||
"""
|
||||
mcp_provider_entity = mcp_provider.to_entity()
|
||||
existing_decrypted = mcp_provider_entity.decrypt_credentials()
|
||||
existing_masked = mcp_provider_entity.masked_credentials()
|
||||
|
||||
# Check if client_id is masked and unchanged
|
||||
final_client_id = client_id
|
||||
if existing_masked.get("client_id") and client_id == existing_masked["client_id"]:
|
||||
# Use existing decrypted value
|
||||
final_client_id = existing_decrypted.get("client_id", client_id)
|
||||
|
||||
# Check if client_secret is masked and unchanged
|
||||
final_client_secret = client_secret
|
||||
if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]:
|
||||
# Use existing decrypted value
|
||||
final_client_secret = existing_decrypted.get("client_secret", client_secret)
|
||||
|
||||
# Grant type and scope are not masked, use as is
|
||||
final_grant_type = grant_type if grant_type is not None else existing_decrypted.get("grant_type")
|
||||
final_scope = scope if scope is not None else existing_decrypted.get("scope")
|
||||
|
||||
return final_client_id, final_client_secret, final_grant_type, final_scope
|
||||
|
||||
def _build_and_encrypt_credentials(
|
||||
self, client_id: str, client_secret: str, grant_type: str, scope: str | None, tenant_id: str
|
||||
) -> str:
|
||||
"""Build credentials and encrypt sensitive fields."""
|
||||
# Create a flat structure with all credential data
|
||||
credentials_data = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"grant_type": grant_type,
|
||||
"client_name": CLIENT_NAME,
|
||||
}
|
||||
|
||||
if scope:
|
||||
credentials_data["scope"] = scope
|
||||
|
||||
# Add grant types and response types based on grant_type
|
||||
if grant_type == "client_credentials":
|
||||
credentials_data["grant_types"] = json.dumps(["client_credentials"])
|
||||
credentials_data["response_types"] = json.dumps([])
|
||||
credentials_data["redirect_uris"] = json.dumps([])
|
||||
else:
|
||||
credentials_data["grant_types"] = json.dumps(["authorization_code", "refresh_token"])
|
||||
credentials_data["response_types"] = json.dumps(["code"])
|
||||
credentials_data["redirect_uris"] = json.dumps(
|
||||
[f"{dify_config.CONSOLE_API_URL}/console/api/mcp/oauth/callback"]
|
||||
)
|
||||
|
||||
# Only client_id and client_secret need encryption
|
||||
secret_fields = ["client_id", "client_secret"]
|
||||
return self._encrypt_dict_fields(credentials_data, secret_fields, tenant_id)
|
||||
|
||||
Reference in New Issue
Block a user