mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
refactor(mcp): clean the client service code
This commit is contained in:
202
api/core/entities/mcp_provider.py
Normal file
202
api/core/entities/mcp_provider.py
Normal file
@ -0,0 +1,202 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.tools import MCPToolProvider
|
||||
|
||||
|
||||
class MCPProviderEntity(BaseModel):
|
||||
"""MCP Provider domain entity for business logic operations"""
|
||||
|
||||
# Basic identification
|
||||
id: str
|
||||
provider_id: str # server_identifier
|
||||
name: str
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
|
||||
# Server connection info
|
||||
server_url: str # encrypted URL
|
||||
headers: dict[str, str] # encrypted headers
|
||||
timeout: float
|
||||
sse_read_timeout: float
|
||||
|
||||
# Authentication related
|
||||
authed: bool
|
||||
credentials: dict[str, Any] # encrypted credentials
|
||||
code_verifier: Optional[str] = None # for OAuth
|
||||
|
||||
# Tools and display info
|
||||
tools: list[dict[str, Any]] # parsed tools list
|
||||
icon: str | dict[str, str] # parsed icon
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
|
||||
"""Create entity from database model with decryption"""
|
||||
|
||||
return cls(
|
||||
id=db_provider.id,
|
||||
provider_id=db_provider.server_identifier,
|
||||
name=db_provider.name,
|
||||
tenant_id=db_provider.tenant_id,
|
||||
user_id=db_provider.user_id,
|
||||
server_url=db_provider.server_url,
|
||||
headers=db_provider.headers,
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
authed=db_provider.authed,
|
||||
credentials=db_provider.credentials,
|
||||
tools=db_provider.tool_dict,
|
||||
icon=db_provider.icon or "",
|
||||
created_at=db_provider.created_at,
|
||||
updated_at=db_provider.updated_at,
|
||||
)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""OAuth redirect URL"""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=[self.redirect_url],
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
client_name="Dify",
|
||||
client_uri="https://github.com/langgenius/dify",
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_icon(self) -> dict[str, str] | str:
|
||||
"""Get provider icon, handling both dict and string formats"""
|
||||
if isinstance(self.icon, dict):
|
||||
return self.icon
|
||||
try:
|
||||
return json.loads(self.icon)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If not JSON, assume it's a file path
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
|
||||
def to_api_response(self, user_name: Optional[str] = None) -> dict[str, Any]:
|
||||
"""Convert to API response format"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": user_name or "Anonymous",
|
||||
"name": self.name,
|
||||
"icon": self.provider_icon,
|
||||
"type": ToolProviderType.MCP.value,
|
||||
"is_team_authorization": self.authed,
|
||||
"server_url": self.masked_server_url(),
|
||||
"server_identifier": self.provider_id,
|
||||
"timeout": self.timeout,
|
||||
"sse_read_timeout": self.sse_read_timeout,
|
||||
"masked_headers": self.masked_headers(),
|
||||
"updated_at": int(self.updated_at.timestamp()),
|
||||
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
|
||||
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||
}
|
||||
|
||||
def retrieve_client_information(self) -> Optional[OAuthClientInformation]:
|
||||
"""OAuth client information if available"""
|
||||
client_info = self.decrypt_credentials().get("client_information", {})
|
||||
if not client_info:
|
||||
return None
|
||||
return OAuthClientInformation.model_validate(client_info)
|
||||
|
||||
def retrieve_tokens(self) -> Optional[OAuthTokens]:
|
||||
"""OAuth tokens if available"""
|
||||
if not self.credentials:
|
||||
return None
|
||||
credentials = self.decrypt_credentials()
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", "Bearer"),
|
||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def masked_server_url(self) -> str:
|
||||
"""Masked server URL for display"""
|
||||
parsed = urlparse(self.decrypt_server_url())
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
if parsed.path and parsed.path != "/":
|
||||
return f"{base_url}/******"
|
||||
return base_url
|
||||
|
||||
def masked_headers(self) -> dict[str, str]:
|
||||
"""Masked headers for display"""
|
||||
masked: dict[str, str] = {}
|
||||
for key, value in self.decrypt_headers().items():
|
||||
if len(value) > 6:
|
||||
masked[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
|
||||
else:
|
||||
masked[key] = "*" * len(value)
|
||||
return masked
|
||||
|
||||
def decrypt_server_url(self) -> str:
|
||||
"""Decrypt server URL"""
|
||||
|
||||
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
||||
|
||||
def decrypt_headers(self) -> dict[str, Any]:
|
||||
"""Decrypt headers"""
|
||||
|
||||
try:
|
||||
if not self.headers:
|
||||
return {}
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in self.headers]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
result = encrypter_instance.decrypt(self.headers)
|
||||
return result
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def decrypt_credentials(
|
||||
self,
|
||||
) -> dict[str, Any]:
|
||||
"""Decrypt credentials"""
|
||||
try:
|
||||
if not self.credentials:
|
||||
return {}
|
||||
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=[
|
||||
BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key)
|
||||
for key in self.credentials
|
||||
],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.decrypt(self.credentials)
|
||||
except Exception:
|
||||
return {}
|
||||
Reference in New Issue
Block a user