mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 14:38:06 +08:00
refactor(mcp): clean the client service code
This commit is contained in:
@ -7,6 +7,7 @@ from flask_restx import (
|
||||
Resource,
|
||||
reqparse,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@ -17,13 +18,13 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
@ -870,8 +871,9 @@ class ToolProviderMCPApi(Resource):
|
||||
user = current_user
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
return jsonable_encoder(
|
||||
MCPToolManageService.create_mcp_provider(
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=user.current_tenant_id,
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
@ -884,7 +886,8 @@ class ToolProviderMCPApi(Resource):
|
||||
sse_read_timeout=args["sse_read_timeout"],
|
||||
headers=args["headers"],
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -907,20 +910,23 @@ class ToolProviderMCPApi(Resource):
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
MCPToolManageService.update_mcp_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_id=args["provider_id"],
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args.get("timeout"),
|
||||
sse_read_timeout=args.get("sse_read_timeout"),
|
||||
headers=args.get("headers"),
|
||||
)
|
||||
return {"result": "success"}
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_id=args["provider_id"],
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args.get("timeout"),
|
||||
sse_read_timeout=args.get("sse_read_timeout"),
|
||||
headers=args.get("headers"),
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -929,8 +935,11 @@ class ToolProviderMCPApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ToolMCPAuthApi(Resource):
|
||||
@ -944,45 +953,50 @@ class ToolMCPAuthApi(Resource):
|
||||
args = parser.parse_args()
|
||||
provider_id = args["provider_id"]
|
||||
tenant_id = current_user.current_tenant_id
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if not provider:
|
||||
raise ValueError("provider not found")
|
||||
# headers1: if headers is provided, use it and don't need to get token
|
||||
headers = provider.decrypted_headers or {}
|
||||
|
||||
# headers2: Add OAuth token if authed and no headers provided
|
||||
if not provider.decrypted_headers and provider.authed:
|
||||
token = OAuthClientProvider(provider_id, tenant_id, for_list=True).tokens()
|
||||
if token:
|
||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
try:
|
||||
# try to connect to MCP server with headers
|
||||
with MCPClient(
|
||||
provider.decrypted_server_url,
|
||||
headers=headers,
|
||||
timeout=provider.timeout,
|
||||
sse_read_timeout=provider.sse_read_timeout,
|
||||
):
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
mcp_provider=provider,
|
||||
credentials=provider.decrypted_credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider_by_id(provider_id, tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
except MCPAuthError as e:
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
|
||||
# Option 1: if headers is provided, use it and don't need to get token
|
||||
headers = provider_entity.decrypt_headers()
|
||||
|
||||
# Option 2: Add OAuth token if authed and no headers provided
|
||||
if not provider_entity.headers and provider_entity.authed:
|
||||
token = provider_entity.retrieve_tokens()
|
||||
if token:
|
||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
try:
|
||||
if provider.decrypted_headers:
|
||||
raise ValueError(f"Failed to authenticate, please check your headers: {e}") from e
|
||||
# if auth failed, try to auth with OAuth or exchange token
|
||||
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
|
||||
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
|
||||
except Exception as e:
|
||||
MCPToolManageService.clear_mcp_provider_credentials(mcp_provider=provider)
|
||||
raise ValueError(f"Failed to authenticate, please try again: {e}") from e
|
||||
except MCPError as e:
|
||||
MCPToolManageService.clear_mcp_provider_credentials(mcp_provider=provider)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
provider_entity=provider_entity
|
||||
if not provider_entity.headers
|
||||
else None, # Only use auth retry if no custom headers
|
||||
auth_callback=auth if not provider_entity.headers else None,
|
||||
authorization_code=args.get("authorization_code"),
|
||||
):
|
||||
service.update_provider_credentials(
|
||||
provider=db_provider,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
except MCPError as e:
|
||||
service.clear_provider_credentials(provider=db_provider)
|
||||
session.commit()
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
class ToolMCPDetailApi(Resource):
|
||||
@ -991,8 +1005,10 @@ class ToolMCPDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
user = current_user
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider_by_id(provider_id, user.current_tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
||||
|
||||
class ToolMCPListAllApi(Resource):
|
||||
@ -1003,9 +1019,11 @@ class ToolMCPListAllApi(Resource):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_providers(tenant_id=tenant_id)
|
||||
|
||||
return [tool.to_dict() for tool in tools]
|
||||
return [tool.to_dict() for tool in tools]
|
||||
|
||||
|
||||
class ToolMCPUpdateApi(Resource):
|
||||
@ -1014,11 +1032,13 @@ class ToolMCPUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return jsonable_encoder(tools)
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_provider_tools(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return jsonable_encoder(tools)
|
||||
|
||||
|
||||
class ToolMCPCallbackApi(Resource):
|
||||
|
||||
Reference in New Issue
Block a user