mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 14:38:06 +08:00
refactor(mcp): clean the auth code
This commit is contained in:
@ -3,7 +3,6 @@ import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||
@ -125,71 +124,39 @@ class MCPTool(Tool):
|
||||
headers = self.headers.copy() if self.headers else {}
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
# Get provider entity to access tokens
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Get MCP service from invoke parameters or create new one
|
||||
provider_entity = None
|
||||
mcp_service = None
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
# Check if mcp_service is passed in tool_parameters
|
||||
if "_mcp_service" in tool_parameters:
|
||||
mcp_service = tool_parameters.pop("_mcp_service")
|
||||
if mcp_service:
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
headers = provider_entity.decrypt_headers()
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
# Step 1: Load provider entity and credentials in a short-lived session
|
||||
# This minimizes database connection hold time
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth if mcp_service else None,
|
||||
mcp_service=mcp_service,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except (ValueError, TypeError, KeyError) as e:
|
||||
# Catch specific exceptions that might occur during tool invocation
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
else:
|
||||
# Fallback to creating service with database session
|
||||
from sqlalchemy.orm import Session
|
||||
# Decrypt and prepare all credentials before closing session
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_headers()
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
headers = provider_entity.decrypt_headers()
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=provider_entity.decrypt_server_url() if provider_entity else self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
auth_callback=auth if mcp_service else None,
|
||||
mcp_service=mcp_service,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
Reference in New Issue
Block a user