mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-29 20:17:35 +08:00
### What problem does this PR solve? When multiple MCP servers expose tools with the same name, the agent currently registers those tools using their original MCP names. This can lead to two issues: - later MCP tools may overwrite earlier ones in the agent tool map - duplicate function names may be exposed to the LLM This PR fixes duplicate MCP tool-name handling by applying the same indexed naming strategy already used for native agent tools. Native tools are exposed with generated names such as `<tool_name>_<index>` to avoid collisions, and MCP tools now follow the same convention for consistency. Specifically, this PR: - assigns unique indexed function names to MCP tools exposed to the LLM - preserves each MCP tool's original server-side name in an `MCPToolBinding` - dispatches MCP calls using the original MCP tool name while keeping the indexed name in the agent tool map - allows MCP metadata conversion to override only the OpenAI function name without modifying the original MCP tool metadata ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ### Validation The validation was performed using two MCP servers. Both servers exposed a tool with the same name: `mcp0`. Both tools take no input parameters. **MCP Server One:** <img width="1780" height="625" alt="ONE" src="https://github.com/user-attachments/assets/801a2654-fc10-4b71-b31c-81841fd40c55" /> **MCP Server Two:** <img width="1777" height="624" alt="Second" src="https://github.com/user-attachments/assets/c095151d-7bdf-47c8-9bfe-6aaf4a01b944" /> **Before the fix:** When invoking `mcp0`, only the `mcp0` tool from the MCP server injected later could be called successfully. As shown below, both `mcp0` tools were present, but only the later-registered one was actually invokable. <img width="694" height="935" alt="Three" src="https://github.com/user-attachments/assets/3b9d7ab2-1765-492c-b8e0-bf05a69933ca" /> **After the fix:** Both `mcp0` tools can now be invoked correctly. <img width="737" height="1095" alt="F" src="https://github.com/user-attachments/assets/6e896627-2b7f-41bb-becc-daa0c73ff58f" /> <img width="730" height="1090" alt="six" src="https://github.com/user-attachments/assets/aba75593-26ae-4e3b-951d-b45ff177fd32" />
345 lines
14 KiB
Python
345 lines
14 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import asyncio
|
|
import logging
|
|
import threading
|
|
import weakref
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
|
from dataclasses import dataclass
|
|
from string import Template
|
|
from typing import Any, Literal, Protocol
|
|
|
|
from typing_extensions import override
|
|
|
|
from common.constants import MCPServerType
|
|
from mcp.client.session import ClientSession
|
|
from mcp.client.sse import sse_client
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
|
|
|
|
MCPTaskType = Literal["list_tools", "tool_call"]
|
|
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
|
|
|
|
|
|
class ToolCallSession(Protocol):
|
|
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: ...
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class MCPToolBinding:
|
|
session: ToolCallSession
|
|
original_name: str
|
|
|
|
|
|
class MCPToolCallSession(ToolCallSession):
|
|
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
|
|
|
|
def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None, custom_header = None) -> None:
|
|
self.__class__._ALL_INSTANCES.add(self)
|
|
|
|
self._custom_header = custom_header
|
|
self._mcp_server = mcp_server
|
|
self._server_variables = server_variables or {}
|
|
self._queue = asyncio.Queue()
|
|
self._close = False
|
|
|
|
self._event_loop = asyncio.new_event_loop()
|
|
self._thread_pool = ThreadPoolExecutor(max_workers=1)
|
|
self._thread_pool.submit(self._event_loop.run_forever)
|
|
|
|
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop)
|
|
|
|
async def _mcp_server_loop(self) -> None:
|
|
url = self._mcp_server.url.strip()
|
|
raw_headers: dict[str, str] = self._mcp_server.headers or {}
|
|
custom_header: dict[str, str] = self._custom_header or {}
|
|
headers: dict[str, str] = {}
|
|
|
|
for h, v in raw_headers.items():
|
|
nh = Template(h).safe_substitute(self._server_variables)
|
|
nv = Template(v).safe_substitute(self._server_variables)
|
|
if nh.strip() and nv.strip().strip("Bearer"):
|
|
headers[nh] = nv
|
|
|
|
for h, v in custom_header.items():
|
|
nh = Template(h).safe_substitute(custom_header)
|
|
nv = Template(v).safe_substitute(custom_header)
|
|
headers[nh] = nv
|
|
|
|
if self._mcp_server.server_type == MCPServerType.SSE:
|
|
# SSE transport
|
|
try:
|
|
async with sse_client(url, headers) as stream:
|
|
async with ClientSession(*stream) as client_session:
|
|
try:
|
|
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
|
logging.info("client_session initialized successfully")
|
|
await self._process_mcp_tasks(client_session)
|
|
except asyncio.TimeoutError:
|
|
msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
|
|
logging.error(msg)
|
|
await self._process_mcp_tasks(None, msg)
|
|
except asyncio.CancelledError:
|
|
logging.warning(f"SSE transport MCP session cancelled for server {self._mcp_server.id}")
|
|
return
|
|
except Exception:
|
|
msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
|
|
await self._process_mcp_tasks(None, msg)
|
|
|
|
elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
|
|
# Streamable HTTP transport
|
|
try:
|
|
async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
|
|
async with ClientSession(read_stream, write_stream) as client_session:
|
|
try:
|
|
await asyncio.wait_for(client_session.initialize(), timeout=5)
|
|
logging.info("client_session initialized successfully")
|
|
await self._process_mcp_tasks(client_session)
|
|
except asyncio.TimeoutError:
|
|
msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
|
|
logging.error(msg)
|
|
await self._process_mcp_tasks(None, msg)
|
|
except asyncio.CancelledError:
|
|
logging.warning(f"STREAMABLE_HTTP MCP session cancelled for server {self._mcp_server.id}")
|
|
return
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
|
|
await self._process_mcp_tasks(None, msg)
|
|
|
|
else:
|
|
await self._process_mcp_tasks(None,
|
|
f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
|
|
|
|
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
|
|
while not self._close:
|
|
try:
|
|
mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
|
|
|
|
r: Any = None
|
|
|
|
if not client_session or error_message:
|
|
r = ValueError(error_message)
|
|
try:
|
|
await result_queue.put(r)
|
|
except asyncio.CancelledError:
|
|
break
|
|
continue
|
|
|
|
try:
|
|
if mcp_task == "list_tools":
|
|
r = await client_session.list_tools()
|
|
elif mcp_task == "tool_call":
|
|
r = await client_session.call_tool(**arguments)
|
|
else:
|
|
r = ValueError(f"Unknown MCP task {mcp_task}")
|
|
except Exception as e:
|
|
r = e
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
try:
|
|
await result_queue.put(r)
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
async def _call_mcp_server(self, task_type: MCPTaskType, request_timeout: float | int = 8, **kwargs) -> Any:
|
|
if self._close:
|
|
raise ValueError("Session is closed")
|
|
|
|
results = asyncio.Queue()
|
|
await self._queue.put((task_type, kwargs, results))
|
|
|
|
try:
|
|
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=request_timeout)
|
|
if isinstance(result, Exception):
|
|
raise result
|
|
return result
|
|
except asyncio.TimeoutError:
|
|
raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {request_timeout}s")
|
|
except Exception:
|
|
raise
|
|
|
|
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], request_timeout: float | int = 10) -> str:
|
|
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments,
|
|
request_timeout=request_timeout)
|
|
|
|
if result.isError:
|
|
return f"MCP server error: {result.content}"
|
|
|
|
# For now, we only support text content
|
|
if not result.content:
|
|
return "MCP server returned empty content."
|
|
if isinstance(result.content[0], TextContent):
|
|
return result.content[0].text
|
|
else:
|
|
return f"Unsupported content type {type(result.content)}"
|
|
|
|
async def _get_tools_from_mcp_server(self, request_timeout: float | int = 8) -> list[Tool]:
|
|
try:
|
|
result: ListToolsResult = await self._call_mcp_server("list_tools", request_timeout=request_timeout)
|
|
return result.tools
|
|
except Exception:
|
|
raise
|
|
|
|
def get_tools(self, timeout: float | int = 10) -> list[Tool]:
|
|
if self._close:
|
|
raise ValueError("Session is closed")
|
|
|
|
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(request_timeout=timeout), self._event_loop)
|
|
try:
|
|
return future.result(timeout=timeout)
|
|
except FuturesTimeoutError:
|
|
msg = f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})"
|
|
logging.error(msg)
|
|
raise RuntimeError(msg)
|
|
except Exception:
|
|
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
|
|
raise
|
|
|
|
@override
|
|
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
|
|
if self._close:
|
|
return "Error: Session is closed"
|
|
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
self._call_mcp_tool(name, arguments, request_timeout=timeout),
|
|
self._event_loop,
|
|
)
|
|
try:
|
|
return future.result(timeout=timeout)
|
|
except FuturesTimeoutError:
|
|
logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id} (timeout={timeout})")
|
|
return f"Timeout calling tool '{name}' (timeout={timeout})."
|
|
except Exception as e:
|
|
logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
|
|
return f"Error calling tool '{name}': {e}."
|
|
|
|
async def close(self) -> None:
|
|
if self._close:
|
|
return
|
|
|
|
self._close = True
|
|
|
|
while not self._queue.empty():
|
|
try:
|
|
_, _, result_queue = self._queue.get_nowait()
|
|
try:
|
|
await result_queue.put(asyncio.CancelledError("Session is closing"))
|
|
except Exception:
|
|
pass
|
|
except asyncio.QueueEmpty:
|
|
break
|
|
except Exception:
|
|
break
|
|
|
|
try:
|
|
self._event_loop.call_soon_threadsafe(self._event_loop.stop)
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
self._thread_pool.shutdown(wait=True)
|
|
except Exception:
|
|
pass
|
|
|
|
self.__class__._ALL_INSTANCES.discard(self)
|
|
|
|
def close_sync(self, timeout: float | int = 5) -> None:
|
|
if not self._event_loop.is_running():
|
|
logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
|
|
return
|
|
|
|
try:
|
|
future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
|
|
try:
|
|
future.result(timeout=timeout)
|
|
except FuturesTimeoutError:
|
|
logging.error(f"Timeout while closing session for server {self._mcp_server.id} (timeout={timeout})")
|
|
except Exception:
|
|
logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")
|
|
except Exception:
|
|
logging.exception(f"Exception while scheduling close for server {self._mcp_server.id}")
|
|
|
|
|
|
def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
|
|
logging.info(f"Want to clean up {len(sessions)} MCP sessions")
|
|
|
|
async def _gather_and_stop() -> None:
|
|
try:
|
|
await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True)
|
|
except Exception:
|
|
logging.exception("Exception during MCP session cleanup")
|
|
finally:
|
|
try:
|
|
loop.call_soon_threadsafe(loop.stop)
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
loop = asyncio.new_event_loop()
|
|
thread = threading.Thread(target=loop.run_forever, daemon=True)
|
|
thread.start()
|
|
|
|
asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result()
|
|
thread.join()
|
|
except Exception:
|
|
logging.exception("Exception during MCP session cleanup thread management")
|
|
|
|
logging.info(
|
|
f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
|
|
|
|
|
|
def shutdown_all_mcp_sessions():
|
|
"""Gracefully shutdown all active MCPToolCallSession instances."""
|
|
sessions = list(MCPToolCallSession._ALL_INSTANCES)
|
|
if not sessions:
|
|
logging.info("No MCPToolCallSession instances to close.")
|
|
return
|
|
|
|
logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...")
|
|
close_multiple_mcp_toolcall_sessions(sessions)
|
|
logging.info("All MCPToolCallSession instances have been closed.")
|
|
|
|
|
|
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict, function_name: str | None = None) -> dict[str, Any]:
|
|
if isinstance(mcp_tool, dict):
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": function_name or mcp_tool["name"],
|
|
"description": mcp_tool["description"],
|
|
"parameters": mcp_tool["inputSchema"],
|
|
},
|
|
}
|
|
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": function_name or mcp_tool.name,
|
|
"description": mcp_tool.description,
|
|
"parameters": mcp_tool.inputSchema,
|
|
},
|
|
}
|