refactor(mcp): clean the client service code

This commit is contained in:
Novice
2025-09-16 10:54:31 +08:00
parent f16151ea29
commit aed9955105
13 changed files with 858 additions and 530 deletions

View File

@ -7,6 +7,7 @@ from flask_restx import (
Resource,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -17,13 +18,13 @@ from controllers.console.wraps import (
setup_required,
)
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from extensions.ext_database import db
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService
@ -870,8 +871,9 @@ class ToolProviderMCPApi(Resource):
user = current_user
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=user.current_tenant_id,
server_url=args["server_url"],
name=args["name"],
@ -884,7 +886,8 @@ class ToolProviderMCPApi(Resource):
sse_read_timeout=args["sse_read_timeout"],
headers=args["headers"],
)
)
session.commit()
return jsonable_encoder(result)
@setup_required
@login_required
@ -907,20 +910,23 @@ class ToolProviderMCPApi(Resource):
pass
else:
raise ValueError("Server URL is not valid.")
MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
)
return {"result": "success"}
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_user.current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
)
session.commit()
return {"result": "success"}
@setup_required
@login_required
@ -929,8 +935,11 @@ class ToolProviderMCPApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
session.commit()
return {"result": "success"}
class ToolMCPAuthApi(Resource):
@ -944,45 +953,50 @@ class ToolMCPAuthApi(Resource):
args = parser.parse_args()
provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
# headers1: if headers is provided, use it and don't need to get token
headers = provider.decrypted_headers or {}
# headers2: Add OAuth token if authed and no headers provided
if not provider.decrypted_headers and provider.authed:
token = OAuthClientProvider(provider_id, tenant_id, for_list=True).tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
try:
# try to connect to MCP server with headers
with MCPClient(
provider.decrypted_server_url,
headers=headers,
timeout=provider.timeout,
sse_read_timeout=provider.sse_read_timeout,
):
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials=provider.decrypted_credentials,
authed=True,
)
return {"result": "success"}
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
db_provider = service.get_provider_by_id(provider_id, tenant_id)
if not db_provider:
raise ValueError("provider not found")
except MCPAuthError as e:
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
# Option 1: if headers is provided, use it and don't need to get token
headers = provider_entity.decrypt_headers()
# Option 2: Add OAuth token if authed and no headers provided
if not provider_entity.headers and provider_entity.authed:
token = provider_entity.retrieve_tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
try:
if provider.decrypted_headers:
raise ValueError(f"Failed to authenticate, please check your headers: {e}") from e
# if auth failed, try to auth with OAuth or exchange token
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
except Exception as e:
MCPToolManageService.clear_mcp_provider_credentials(mcp_provider=provider)
raise ValueError(f"Failed to authenticate, please try again: {e}") from e
except MCPError as e:
MCPToolManageService.clear_mcp_provider_credentials(mcp_provider=provider)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClientWithAuthRetry(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
provider_entity=provider_entity
if not provider_entity.headers
else None, # Only use auth retry if no custom headers
auth_callback=auth if not provider_entity.headers else None,
authorization_code=args.get("authorization_code"),
):
service.update_provider_credentials(
provider=db_provider,
credentials=provider_entity.credentials,
authed=True,
)
session.commit()
return {"result": "success"}
except MCPError as e:
service.clear_provider_credentials(provider=db_provider)
session.commit()
raise ValueError(f"Failed to connect to MCP server: {e}") from e
class ToolMCPDetailApi(Resource):
@ -991,8 +1005,10 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
user = current_user
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
provider = service.get_provider_by_id(provider_id, user.current_tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
class ToolMCPListAllApi(Resource):
@ -1003,9 +1019,11 @@ class ToolMCPListAllApi(Resource):
user = current_user
tenant_id = user.current_tenant_id
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
tools = service.list_providers(tenant_id=tenant_id)
return [tool.to_dict() for tool in tools]
return [tool.to_dict() for tool in tools]
class ToolMCPUpdateApi(Resource):
@ -1014,11 +1032,13 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required
def get(self, provider_id):
tenant_id = current_user.current_tenant_id
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
class ToolMCPCallbackApi(Resource):