refactor(mcp): clean the client service code

This commit is contained in:
Novice
2025-09-16 10:54:31 +08:00
parent f16151ea29
commit aed9955105
13 changed files with 858 additions and 530 deletions

View File

@ -0,0 +1,53 @@
"""
MCP OAuth Service - handles OAuth-related database operations
"""
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import OAuthClientInformationFull, OAuthTokens
from services.tools.mcp_tools_manage_service import MCPToolManageService
class MCPOAuthService:
"""Service for handling MCP OAuth operations"""
def __init__(self, session: Session):
self._session = session
self._mcp_service = MCPToolManageService(session=session)
def get_provider_entity(self, provider_id: str, tenant_id: str, by_server_id: bool = False) -> MCPProviderEntity:
"""Get provider entity by ID"""
if by_server_id:
db_provider = self._mcp_service.get_provider_by_server_identifier(provider_id, tenant_id)
else:
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
return db_provider.to_entity()
def save_client_information(
self, provider_id: str, tenant_id: str, client_information: OAuthClientInformationFull
) -> None:
"""Save OAuth client information"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
self._mcp_service.update_provider_credentials(
provider=db_provider,
credentials={"client_information": client_information.model_dump()},
)
def save_tokens(self, provider_id: str, tenant_id: str, tokens: OAuthTokens, authed: bool = True) -> None:
"""Save OAuth tokens"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
token_dict = tokens.model_dump()
self._mcp_service.update_provider_credentials(provider=db_provider, credentials=token_dict, authed=authed)
def save_code_verifier(self, provider_id: str, tenant_id: str, code_verifier: str) -> None:
"""Save PKCE code verifier"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
self._mcp_service.update_provider_credentials(
provider=db_provider, credentials={"code_verifier": code_verifier}
)
def clear_credentials(self, provider_id: str, tenant_id: str) -> None:
"""Clear provider credentials"""
db_provider = self._mcp_service.get_provider_by_id(provider_id, tenant_id)
self._mcp_service.clear_provider_credentials(provider=db_provider)

View File

@ -1,24 +1,27 @@
import hashlib
import json
import logging
from collections.abc import Callable
from datetime import datetime
from typing import Any
from typing import Any, Optional
from sqlalchemy import or_
from sqlalchemy import or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import OAuthTokens
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
@ -27,8 +30,10 @@ class MCPToolManageService:
Service class for managing mcp tools.
"""
@staticmethod
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
def __init__(self, session: Session):
self._session = session
def _encrypt_headers(self, headers: dict[str, str], tenant_id: str) -> dict[str, str]:
"""
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
@ -57,48 +62,53 @@ class MCPToolManageService:
return encrypter_instance.encrypt(headers)
@staticmethod
def _retrieve_remote_mcp_tools(server_url: str, headers: dict[str, str], timeout: float, sse_read_timeout: float):
with MCPClient(
def _retrieve_remote_mcp_tools(
self,
server_url: str,
headers: dict[str, str],
provider_entity: MCPProviderEntity,
auth_callback: Callable[[MCPProviderEntity, Optional[str]], dict[str, str]],
):
"""Retrieve tools from remote MCP server"""
with MCPClientWithAuthRetry(
server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth_callback,
) as mcp_client:
tools = mcp_client.list_tools()
return tools
@staticmethod
def _process_headers(headers: dict[str, str], tokens: OAuthTokens | None = None):
headers = headers or {}
def _process_headers(self, headers: dict[str, str], tokens: OAuthTokens | None = None) -> dict[str, str]:
"""Process headers and add OAuth token if available"""
headers = headers.copy() if headers else {}
if tokens:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
return headers
@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
.first()
)
if not res:
def get_provider_by_id(self, provider_id: str, tenant_id: str) -> MCPToolProvider:
"""Get MCP provider by ID"""
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
provider = self._session.scalar(stmt)
if not provider:
raise ValueError("MCP tool not found")
return res
return provider
@staticmethod
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
.first()
def get_provider_by_server_identifier(self, server_identifier: str, tenant_id: str) -> MCPToolProvider:
"""Get MCP provider by server identifier"""
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier
)
if not res:
provider = self._session.scalar(stmt)
if not provider:
raise ValueError("MCP tool not found")
return res
return provider
@staticmethod
def create_mcp_provider(
def create_provider(
self,
*,
tenant_id: str,
name: str,
server_url: str,
@ -111,19 +121,20 @@ class MCPToolManageService:
sse_read_timeout: float,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
"""Create a new MCP provider"""
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
db.session.query(MCPToolProvider)
.where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
.first()
# Check for existing provider
stmt = select(MCPToolProvider).where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
)
existing_provider = self._session.scalar(stmt)
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
@ -131,13 +142,17 @@ class MCPToolManageService:
raise ValueError(f"MCP tool {server_url} already exists")
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
# Encrypt server URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
# Encrypt headers
encrypted_headers = None
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
encrypted_headers = json.dumps(encrypted_headers_dict)
# Create provider
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
@ -152,91 +167,68 @@ class MCPToolManageService:
sse_read_timeout=sse_read_timeout,
encrypted_headers=encrypted_headers,
)
db.session.add(mcp_tool)
db.session.commit()
self._session.add(mcp_tool)
self._session.commit()
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
@staticmethod
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
mcp_providers = (
db.session.query(MCPToolProvider)
.where(MCPToolProvider.tenant_id == tenant_id)
.order_by(MCPToolProvider.name)
.all()
)
def list_providers(self, *, tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
"""List all MCP providers for a tenant"""
stmt = select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
mcp_providers = self._session.scalars(stmt).all()
return [
ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
for mcp_provider in mcp_providers
ToolTransformService.mcp_provider_to_user_provider(provider, for_list=for_list)
for provider in mcp_providers
]
@classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
def list_provider_tools(self, *, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
"""List tools from remote MCP server"""
from core.mcp.auth.auth_flow import auth
from core.mcp.auth.auth_provider import OAuthClientProvider
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
# Load provider and convert to entity
db_provider = self.get_provider_by_id(provider_id, tenant_id)
provider_entity = db_provider.to_entity()
# Handle authentication headers if authed
if not authed:
if not provider_entity.authed:
raise ValueError("Please auth the tool first")
provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
tokens = provider.tokens()
headers = cls._process_headers(headers, tokens)
tokens = provider_entity.retrieve_tokens()
headers = self._process_headers(provider_entity.headers, tokens)
server_url = provider_entity.decrypt_server_url()
try:
tools = cls._retrieve_remote_mcp_tools(server_url, headers, timeout, sse_read_timeout)
except MCPAuthError:
try:
auth(provider, server_url, None, None, False)
tokens = provider.tokens()
re_authed_headers = cls._process_headers(headers, tokens)
tools = cls._retrieve_remote_mcp_tools(server_url, re_authed_headers, timeout, sse_read_timeout)
except Exception:
raise ValueError("Please auth the tool first")
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity, auth)
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
try:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True
mcp_provider.updated_at = datetime.now()
db.session.commit()
except Exception:
db.session.rollback()
raise
# Update database record with new tools
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.commit()
user = mcp_provider.load_user()
return ToolProviderApiEntity(
id=mcp_provider.id,
name=mcp_provider.name,
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
type=ToolProviderType.MCP,
icon=mcp_provider.icon,
author=user.name if user else "Anonymous",
server_url=mcp_provider.masked_server_url,
updated_at=int(mcp_provider.updated_at.timestamp()),
description=I18nObject(en_US="", zh_Hans=""),
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
plugin_unique_identifier=mcp_provider.server_identifier,
# Create API response using entity
user = db_provider.load_user()
response = provider_entity.to_api_response(
user_name=user.name if user else None,
)
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(db_provider, tools)
response["plugin_unique_identifier"] = provider_entity.provider_id
@classmethod
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return ToolProviderApiEntity(**response)
db.session.delete(mcp_tool)
db.session.commit()
def delete_provider(self, *, tenant_id: str, provider_id: str) -> None:
"""Delete an MCP provider"""
mcp_tool = self.get_provider_by_id(provider_id, tenant_id)
self._session.delete(mcp_tool)
self._session.commit()
@classmethod
def update_mcp_provider(
cls,
def update_provider(
self,
*,
tenant_id: str,
provider_id: str,
name: str,
@ -248,21 +240,27 @@ class MCPToolManageService:
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
) -> None:
"""Update an MCP provider"""
mcp_provider = self.get_provider_by_id(provider_id, tenant_id)
reconnect_result = None
encrypted_server_url = None
server_url_hash = None
# Handle server URL update
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
if server_url_hash != mcp_provider.server_url_hash:
reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
reconnect_result = self._reconnect_provider(
server_url=server_url,
provider=mcp_provider,
)
try:
# Update basic fields
mcp_provider.updated_at = datetime.now()
mcp_provider.name = name
mcp_provider.icon = (
@ -270,6 +268,7 @@ class MCPToolManageService:
)
mcp_provider.server_identifier = server_identifier
# Update server URL if changed
if encrypted_server_url is not None and server_url_hash is not None:
mcp_provider.server_url = encrypted_server_url
mcp_provider.server_url_hash = server_url_hash
@ -279,6 +278,7 @@ class MCPToolManageService:
mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
# Update optional fields
if timeout is not None:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
@ -286,13 +286,15 @@ class MCPToolManageService:
if headers is not None:
# Encrypt headers
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
encrypted_headers_dict = self._encrypt_headers(headers, tenant_id)
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
else:
mcp_provider.encrypted_headers = None
db.session.commit()
self._session.commit()
except IntegrityError as e:
db.session.rollback()
self._session.rollback()
error_msg = str(e.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
@ -302,54 +304,55 @@ class MCPToolManageService:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
except Exception:
db.session.rollback()
self._session.rollback()
raise
@classmethod
def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
):
def update_provider_credentials(
self, *, provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
) -> None:
"""Update provider credentials"""
from core.tools.mcp_tool.provider import MCPToolProviderController
provider_controller = MCPToolProviderController.from_db(mcp_provider)
provider_controller = MCPToolProviderController.from_db(provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
tenant_id=provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.updated_at = datetime.now()
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
mcp_provider.authed = authed
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.updated_at = datetime.now()
provider.encrypted_credentials = json.dumps({**provider.credentials, **encrypted_credentials})
provider.authed = authed
if not authed:
mcp_provider.tools = "[]"
db.session.commit()
provider.tools = "[]"
@classmethod
def clear_mcp_provider_credentials(
cls,
mcp_provider: MCPToolProvider,
):
mcp_provider.tools = "[]"
mcp_provider.encrypted_credentials = "{}"
mcp_provider.updated_at = datetime.now()
mcp_provider.authed = False
db.session.commit()
self._session.commit()
@classmethod
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str) -> dict[str, Any]:
# Get the existing provider to access headers and timeout settings
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
def clear_provider_credentials(self, *, provider: MCPToolProvider) -> None:
"""Clear provider credentials"""
provider.tools = "[]"
provider.encrypted_credentials = "{}"
provider.updated_at = datetime.now()
provider.authed = False
self._session.commit()
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> dict[str, Any]:
"""Attempt to reconnect to MCP provider with new server URL"""
from core.mcp.auth.auth_flow import auth
provider_entity = provider.to_entity()
headers = provider_entity.headers
timeout = provider_entity.timeout
sse_read_timeout = provider_entity.sse_read_timeout
try:
with MCPClient(
with MCPClientWithAuthRetry(
server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth,
) as mcp_client:
tools = mcp_client.list_tools()
return {

View File

@ -221,27 +221,20 @@ class ToolTransformService:
@staticmethod
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
# Convert to entity and use its API response method
provider_entity = db_provider.to_entity()
user = db_provider.load_user()
return ToolProviderApiEntity(
id=db_provider.server_identifier if not for_list else db_provider.id,
author=user.name if user else "Anonymous",
name=db_provider.name,
icon=db_provider.provider_icon,
type=ToolProviderType.MCP,
is_team_authorization=db_provider.authed,
server_url=db_provider.masked_server_url,
tools=ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
),
updated_at=int(db_provider.updated_at.timestamp()),
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
server_identifier=db_provider.server_identifier,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
masked_headers=db_provider.masked_headers,
original_headers=db_provider.decrypted_headers,
response = provider_entity.to_api_response(user_name=user.name if user else None)
# Add additional fields specific to the transform
response["id"] = db_provider.server_identifier if not for_list else db_provider.id
response["tools"] = ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
)
response["server_identifier"] = db_provider.server_identifier
return ToolProviderApiEntity(**response)
@staticmethod
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
@ -403,7 +396,7 @@ class ToolTransformService:
)
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
def convert_mcp_schema_to_parameter(schema: dict[str, Any]) -> list["ToolParameter"]:
"""
Convert MCP JSON schema to tool parameters
@ -412,7 +405,7 @@ class ToolTransformService:
"""
def create_parameter(
name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
) -> ToolParameter:
"""Create a ToolParameter instance with given attributes"""
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
@ -427,7 +420,9 @@ class ToolTransformService:
**input_schema_dict,
)
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
def process_properties(
props: dict[str, dict[str, Any]], required: list[str], prefix: str = ""
) -> list[ToolParameter]:
"""Process properties recursively"""
TYPE_MAPPING = {"integer": "number", "float": "number"}
COMPLEX_TYPES = ["array", "object"]