mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
refactor(mcp): clean the client code
This commit is contained in:
@ -3,12 +3,14 @@ import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import ImageContent, TextContent
|
||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
@ -44,26 +46,7 @@ class MCPTool(Tool):
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
self.server_url,
|
||||
self.provider_id,
|
||||
self.tenant_id,
|
||||
authed=True,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPAuthError as e:
|
||||
raise ToolInvokeError("Please auth the tool first") from e
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
# handle dify tool output
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
@ -95,7 +78,7 @@ class MCPTool(Tool):
|
||||
|
||||
def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process a list of JSON items."""
|
||||
if any(not isinstance(item, dict[str, Any]) for item in json_list):
|
||||
if any(not isinstance(item, dict) for item in json_list):
|
||||
# If the list contains any non-dict item, treat the entire list as a text message.
|
||||
yield self.create_text_message(str(json_list))
|
||||
return
|
||||
@ -130,3 +113,65 @@ class MCPTool(Tool):
|
||||
for key, value in parameter.items()
|
||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||
}
|
||||
|
||||
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
|
||||
headers = self.headers.copy() if self.headers else {}
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
# Initialize auth provider
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
provider = None
|
||||
|
||||
try:
|
||||
provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=False)
|
||||
except Exception as e:
|
||||
# If provider initialization fails, continue without auth
|
||||
pass
|
||||
|
||||
# Try to get existing token and add to headers
|
||||
if provider:
|
||||
try:
|
||||
token = provider.tokens()
|
||||
if token:
|
||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
except Exception:
|
||||
# If token retrieval fails, continue without auth header
|
||||
pass
|
||||
|
||||
# Define a helper function to invoke the tool
|
||||
def _invoke_with_client(client_headers: dict[str, str]) -> CallToolResult:
|
||||
with MCPClient(
|
||||
self.server_url,
|
||||
headers=client_headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
|
||||
try:
|
||||
# First attempt with current headers
|
||||
return _invoke_with_client(headers)
|
||||
except MCPAuthError as e:
|
||||
# Authentication required - try to authenticate
|
||||
if not provider:
|
||||
raise ToolInvokeError("Authentication required but no auth provider available") from e
|
||||
|
||||
try:
|
||||
# Perform authentication flow
|
||||
auth(provider, self.server_url, None, None, False)
|
||||
token = provider.tokens()
|
||||
if not token:
|
||||
raise ToolInvokeError("Authentication failed - no token received")
|
||||
|
||||
# Update headers with new token while preserving other headers
|
||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
|
||||
# Retry with authenticated headers
|
||||
return _invoke_with_client(headers)
|
||||
except MCPAuthError as auth_error:
|
||||
raise ToolInvokeError("Authentication failed") from auth_error
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
Reference in New Issue
Block a user