feat: add unit test

This commit is contained in:
Novice
2025-09-16 16:18:41 +08:00
parent f137af4ec5
commit e2fd3f2983
16 changed files with 2945 additions and 68 deletions

View File

@ -1,13 +1,12 @@
"""
MCP Client with Authentication Retry Support
This module provides a wrapper around MCPClient that automatically handles
This module provides an enhanced MCPClient that automatically handles
authentication failures and retries operations after refreshing tokens.
"""
import logging
from collections.abc import Callable
from types import TracebackType
from typing import TYPE_CHECKING, Any, Optional
from core.entities.mcp_provider import MCPProviderEntity
@ -21,12 +20,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class MCPClientWithAuthRetry:
class MCPClientWithAuthRetry(MCPClient):
"""
A wrapper around MCPClient that provides automatic authentication retry.
An enhanced MCPClient that provides automatic authentication retry.
This class intercepts MCPAuthError exceptions and attempts to refresh
authentication before retrying the failed operation.
This class extends MCPClient and intercepts MCPAuthError exceptions
to refresh authentication before retrying failed operations.
"""
def __init__(
@ -53,27 +52,17 @@ class MCPClientWithAuthRetry:
provider_entity: Provider entity for authentication
auth_callback: Authentication callback function
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
mcp_service: MCP service instance
"""
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
super().__init__(server_url, headers, timeout, sse_read_timeout)
self.provider_entity = provider_entity
self.auth_callback = auth_callback
self.authorization_code = authorization_code
self._has_retried = False
self._client: MCPClient | None = None
self.by_server_id = by_server_id
self.mcp_service = mcp_service
def _create_client(self) -> MCPClient:
"""Create a new MCPClient instance with current headers."""
return MCPClient(
server_url=self.server_url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
self._has_retried = False
def _handle_auth_error(self, error: MCPAuthError) -> None:
"""
@ -134,38 +123,35 @@ class MCPClientWithAuthRetry:
return func(*args, **kwargs)
except MCPAuthError as e:
self._handle_auth_error(e)
# Recreate client with new headers
if self._client:
self._client.cleanup()
self._client = self._create_client()
self._client.__enter__()
# Re-initialize the connection with new headers
if self._initialized:
# Clean up existing connection
self._exit_stack.close()
self._session = None
self._initialized = False
# Re-initialize with new headers
self._initialize()
self._initialized = True
return func(*args, **kwargs)
finally:
# Reset retry flag after operation completes
self._has_retried = False
def __enter__(self):
"""Enter the context manager."""
self._client = self._create_client()
"""Enter the context manager with retry support."""
# Try to initialize with retry
def initialize():
if self._client is None:
raise ValueError("Client not created")
self._client.__enter__()
def initialize_with_retry():
super(MCPClientWithAuthRetry, self).__enter__()
return self
return self._execute_with_retry(initialize)
def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None):
"""Exit the context manager."""
if self._client:
self._client.__exit__(exc_type, exc_value, traceback)
self._client = None
return self._execute_with_retry(initialize_with_retry)
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server.
List available tools from the MCP server with auth retry.
Returns:
List of available tools
@ -173,13 +159,11 @@ class MCPClientWithAuthRetry:
Raises:
MCPAuthError: If authentication fails after retries
"""
if not self._client:
raise ValueError("Client not initialized. Use within a context manager.")
return self._execute_with_retry(self._client.list_tools)
return self._execute_with_retry(super().list_tools)
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server.
Invoke a tool on the MCP server with auth retry.
Args:
tool_name: Name of the tool to invoke
@ -191,12 +175,4 @@ class MCPClientWithAuthRetry:
Raises:
MCPAuthError: If authentication fails after retries
"""
if not self._client:
raise ValueError("Client not initialized. Use within a context manager.")
return self._execute_with_retry(self._client.invoke_tool, tool_name, tool_args)
def cleanup(self):
"""Clean up resources."""
if self._client:
self._client.cleanup()
self._client = None
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)

View File

@ -0,0 +1 @@

View File

@ -46,7 +46,7 @@ class SSETransport:
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
):
"""Initialize the SSE transport.
@ -255,7 +255,7 @@ def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
sse_read_timeout: float = 1 * 60,
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
"""
Client transport for SSE.

View File

@ -30,7 +30,7 @@ DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = Annotated[int, Field(strict=True)] | str
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
AnyFunction: TypeAlias = Callable[..., Any]

View File

@ -162,7 +162,6 @@ class MCPTool(Tool):
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
auth_callback=auth if mcp_service else None,
by_server_id=True,
mcp_service=mcp_service,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)