From aa4526266f3a6ed6cc264a661eb81c19e1dcb8d8 Mon Sep 17 00:00:00 2001 From: buua436 Date: Thu, 23 Apr 2026 12:51:27 +0800 Subject: [PATCH] Refa: migrate MCP APIs to RESTful api (#14317) ### What problem does this PR solve? migrate MCP APIs to RESTful api ### Type of change - [x] Refactoring --- api/apps/restful_apis/mcp_api.py | 331 +++++++++++++++++ .../test_mcp_server_app_unit.py | 348 +++++------------- web/src/hooks/use-mcp-request.ts | 27 +- web/src/interfaces/database/mcp.ts | 7 +- web/src/services/mcp-server-service.ts | 68 +--- web/src/utils/api.ts | 17 +- 6 files changed, 481 insertions(+), 317 deletions(-) create mode 100644 api/apps/restful_apis/mcp_api.py diff --git a/api/apps/restful_apis/mcp_api.py b/api/apps/restful_apis/mcp_api.py new file mode 100644 index 0000000000..ec384f6074 --- /dev/null +++ b/api/apps/restful_apis/mcp_api.py @@ -0,0 +1,331 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from quart import Response, request + +from api.apps import current_user, login_required +from api.db.db_models import MCPServer +from api.db.services.mcp_server_service import MCPServerService +from api.db.services.user_service import TenantService +from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request +from api.utils.web_utils import get_float, safe_json_parse +from common.constants import VALID_MCP_SERVER_TYPES +from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions +from common.misc_utils import get_uuid, thread_pool_exec + + +def _get_mcp_ids_from_args() -> list[str]: + mcp_ids = request.args.getlist("mcp_ids") + if mcp_ids: + return [mcp_id for item in mcp_ids for mcp_id in item.split(",") if mcp_id] + mcp_ids = request.args.get("mcp_id", "") + return [mcp_id for mcp_id in mcp_ids.split(",") if mcp_id] + + +def _export_mcp_servers(mcp_ids: list[str]) -> dict | None: + exported_servers = {} + for mcp_id in mcp_ids: + e, mcp_server = MCPServerService.get_by_id(mcp_id) + if e and mcp_server.tenant_id == current_user.id: + server_key = mcp_server.name + exported_servers[server_key] = { + "type": mcp_server.server_type, + "url": mcp_server.url, + "name": mcp_server.name, + "authorization_token": mcp_server.variables.get("authorization_token", ""), + "tools": mcp_server.variables.get("tools", {}), + } + + if not exported_servers: + return None + + return {"mcpServers": exported_servers} + + +@manager.route("/mcp/servers", methods=["GET"]) # noqa: F821 +@login_required +async def list_mcp() -> Response: + keywords = request.args.get("keywords", "") + page_number = int(request.args.get("page", 0)) + items_per_page = int(request.args.get("page_size", 0)) + orderby = request.args.get("orderby", "create_time") + if request.args.get("desc", "true").lower() == "false": + desc = False + else: + desc = True + + mcp_ids = _get_mcp_ids_from_args() + try: + servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or [] + total = len(servers) + + if page_number and items_per_page: + servers = servers[(page_number - 1) * items_per_page : page_number * items_per_page] + + return get_json_result(data={"mcp_servers": servers, "total": total}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/mcp/servers/", methods=["GET"]) # noqa: F821 +@login_required +def detail(mcp_id: str) -> Response: + try: + if request.args.get("mode") == "download": + exported_servers = _export_mcp_servers([mcp_id]) + if exported_servers is None: + return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") + return get_json_result(data=exported_servers) + + mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id) + + if mcp_server is None: + return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") + + return get_json_result(data=mcp_server.to_dict()) + except Exception as e: + return server_error_response(e) + + +@manager.route("/mcp/servers", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("name", "url", "server_type") +async def create() -> Response: + req = await get_request_json() + + server_type = req.get("server_type", "") + if server_type not in VALID_MCP_SERVER_TYPES: + return get_data_error_result(message="Unsupported MCP server type.") + + server_name = req.get("name", "") + if not server_name or len(server_name.encode("utf-8")) > 255: + return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.") + + e, _ = MCPServerService.get_by_name_and_tenant(name=server_name, tenant_id=current_user.id) + if e: + return get_data_error_result(message="Duplicated MCP server name.") + + url = req.get("url", "") + if not url: + return get_data_error_result(message="Invalid url.") + + headers = safe_json_parse(req.get("headers", {})) + req["headers"] = headers + variables = safe_json_parse(req.get("variables", {})) + variables.pop("tools", None) + + timeout = get_float(req, "timeout", 10) + + try: + req["id"] = get_uuid() + req["tenant_id"] = current_user.id + + e, _ = TenantService.get_by_id(current_user.id) + if not e: + return get_data_error_result(message="Tenant not found.") + + mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) + server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) + if err_message: + return get_data_error_result(message=err_message) + + tools = server_tools[server_name] + tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool} + variables["tools"] = tools + req["variables"] = variables + + if not MCPServerService.insert(**req): + return get_data_error_result(message="Failed to create MCP server.") + + return get_json_result(data=req) + except Exception as e: + return server_error_response(e) + + +@manager.route("/mcp/servers/", methods=["PUT"]) # noqa: F821 +@login_required +async def update(mcp_id: str) -> Response: + req = await get_request_json() + + e, mcp_server = MCPServerService.get_by_id(mcp_id) + if not e or mcp_server.tenant_id != current_user.id: + return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") + + server_type = req.get("server_type", mcp_server.server_type) + if server_type and server_type not in VALID_MCP_SERVER_TYPES: + return get_data_error_result(message="Unsupported MCP server type.") + server_name = req.get("name", mcp_server.name) + if server_name and len(server_name.encode("utf-8")) > 255: + return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.") + url = req.get("url", mcp_server.url) + if not url: + return get_data_error_result(message="Invalid url.") + + headers = safe_json_parse(req.get("headers", mcp_server.headers)) + req["headers"] = headers + + variables = safe_json_parse(req.get("variables", mcp_server.variables)) + variables.pop("tools", None) + + timeout = get_float(req, "timeout", 10) + + try: + req["tenant_id"] = current_user.id + req["id"] = mcp_id + + mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) + server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) + if err_message: + return get_data_error_result(message=err_message) + + tools = server_tools[server_name] + tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool} + variables["tools"] = tools + req["variables"] = variables + + if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req): + return get_data_error_result(message="Failed to updated MCP server.") + + e, updated_mcp = MCPServerService.get_by_id(req["id"]) + if not e: + return get_data_error_result(message="Failed to fetch updated MCP server.") + + return get_json_result(data=updated_mcp.to_dict()) + except Exception as e: + return server_error_response(e) + + +@manager.route("/mcp/servers/", methods=["DELETE"]) # noqa: F821 +@login_required +async def rm(mcp_id: str) -> Response: + try: + e, mcp_server = MCPServerService.get_by_id(mcp_id) + if not e or mcp_server.tenant_id != current_user.id: + return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") + if not MCPServerService.delete_by_ids([mcp_id]): + return get_data_error_result(message=f"Failed to delete MCP servers {[mcp_id]}") + + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) + + +@manager.route("/mcp/servers/import", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("mcpServers") +async def import_multiple() -> Response: + req = await get_request_json() + servers = req.get("mcpServers", {}) + if not servers: + return get_data_error_result(message="No MCP servers provided.") + + timeout = get_float(req, "timeout", 10) + + results = [] + try: + for server_name, config in servers.items(): + if not all(key in config for key in {"type", "url"}): + results.append({"server": server_name, "success": False, "message": "Missing required fields (type or url)"}) + continue + + if not server_name or len(server_name.encode("utf-8")) > 255: + results.append({"server": server_name, "success": False, "message": f"Invalid MCP name or length is {len(server_name)} which is large than 255."}) + continue + + base_name = server_name + new_name = base_name + counter = 0 + + while True: + e, _ = MCPServerService.get_by_name_and_tenant(name=new_name, tenant_id=current_user.id) + if not e: + break + new_name = f"{base_name}_{counter}" + counter += 1 + + create_data = { + "id": get_uuid(), + "tenant_id": current_user.id, + "name": new_name, + "url": config["url"], + "server_type": config["type"], + "variables": {"authorization_token": config.get("authorization_token", "")}, + } + + headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {} + variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}} + mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers) + server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) + if err_message: + results.append({"server": base_name, "success": False, "message": err_message}) + continue + + tools = server_tools[new_name] + tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool} + create_data["variables"]["tools"] = tools + + if MCPServerService.insert(**create_data): + result = {"server": server_name, "success": True, "action": "created", "id": create_data["id"], "new_name": new_name} + if new_name != base_name: + result["message"] = f"Renamed from '{base_name}' to '{new_name}' avoid duplication" + results.append(result) + else: + results.append({"server": server_name, "success": False, "message": "Failed to create MCP server."}) + + return get_json_result(data={"results": results}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/mcp/servers//test", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("url", "server_type") +async def test_mcp(mcp_id: str) -> Response: + req = await get_request_json() + + url = req.get("url", "") + if not url: + return get_data_error_result(message="Invalid MCP url.") + + server_type = req.get("server_type", "") + if server_type not in VALID_MCP_SERVER_TYPES: + return get_data_error_result(message="Unsupported MCP server type.") + + timeout = get_float(req, "timeout", 10) + headers = safe_json_parse(req.get("headers", {})) + variables = safe_json_parse(req.get("variables", {})) + + mcp_server = MCPServer(id=mcp_id, server_type=server_type, url=url, headers=headers, variables=variables) + + result = [] + try: + tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) + + try: + tools = await thread_pool_exec(tool_call_session.get_tools, timeout) + except Exception as e: + return get_data_error_result(message=f"Test MCP error: {e}") + finally: + await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session]) + + for tool in tools: + tool_dict = tool.model_dump() + tool_dict["enabled"] = True + result.append(tool_dict) + + return get_json_result(data=result) + except Exception as e: + return server_error_response(e) diff --git a/test/testcases/test_web_api/test_mcp_server_app/test_mcp_server_app_unit.py b/test/testcases/test_web_api/test_mcp_server_app/test_mcp_server_app_unit.py index 9aad0e34eb..ac8a580c38 100644 --- a/test/testcases/test_web_api/test_mcp_server_app/test_mcp_server_app_unit.py +++ b/test/testcases/test_web_api/test_mcp_server_app/test_mcp_server_app_unit.py @@ -33,6 +33,14 @@ class _DummyManager: return decorator +class _Args(dict): + def getlist(self, key): + value = self.get(key, []) + if isinstance(value, list): + return value + return [value] + + class _Field: def __init__(self, name): self.name = name @@ -142,13 +150,22 @@ def set_tenant_info(): return None -def _load_mcp_server_app(monkeypatch): +def _load_mcp_api(monkeypatch): repo_root = Path(__file__).resolve().parents[4] + quart_mod = ModuleType("quart") + quart_mod.Response = object + quart_mod.request = SimpleNamespace(args=_Args({})) + monkeypatch.setitem(sys.modules, "quart", quart_mod) + common_pkg = ModuleType("common") common_pkg.__path__ = [str(repo_root / "common")] monkeypatch.setitem(sys.modules, "common", common_pkg) + constants_mod = ModuleType("common.constants") + constants_mod.VALID_MCP_SERVER_TYPES = {"sse", "streamable-http"} + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + apps_mod = ModuleType("api.apps") apps_mod.current_user = SimpleNamespace(id="tenant_1") apps_mod.login_required = lambda func: func @@ -230,8 +247,8 @@ def _load_mcp_server_app(monkeypatch): web_utils_mod.safe_json_parse = _safe_json_parse monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod) - module_name = "test_mcp_server_app_unit_module" - module_path = repo_root / "api" / "apps" / "mcp_server_app.py" + module_name = "test_mcp_api_unit_module" + module_path = repo_root / "api" / "apps" / "restful_apis" / "mcp_api.py" spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) module.manager = _DummyManager() @@ -242,12 +259,12 @@ def _load_mcp_server_app(monkeypatch): @pytest.mark.p2 def test_list_mcp_desc_pagination_and_exception(monkeypatch): - module = _load_mcp_server_app(monkeypatch) + module = _load_mcp_api(monkeypatch) monkeypatch.setattr( module, "request", - SimpleNamespace(args={"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"}), + SimpleNamespace(args=_Args({"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"})), ) _set_request_json(monkeypatch, module, {"mcp_ids": []}) monkeypatch.setattr(module.MCPServerService, "get_servers", lambda *_args, **_kwargs: [{"id": "a"}, {"id": "b"}]) @@ -257,7 +274,7 @@ def test_list_mcp_desc_pagination_and_exception(monkeypatch): assert res["data"]["total"] == 2 assert res["data"]["mcp_servers"] == [{"id": "b"}] - monkeypatch.setattr(module, "request", SimpleNamespace(args={})) + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({}))) _set_request_json(monkeypatch, module, {"mcp_ids": []}) def _raise_list(*_args, **_kwargs): @@ -271,19 +288,20 @@ def test_list_mcp_desc_pagination_and_exception(monkeypatch): @pytest.mark.p2 def test_detail_not_found_success_and_exception(monkeypatch): - module = _load_mcp_server_app(monkeypatch) - monkeypatch.setattr(module, "request", SimpleNamespace(args={"mcp_id": "mcp-1"})) + module = _load_mcp_api(monkeypatch) + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({}))) monkeypatch.setattr(module.MCPServerService, "get_or_none", lambda **_kwargs: None) - res = module.detail() - assert res["code"] == module.RetCode.NOT_FOUND + res = module.detail("mcp-1") + assert res["code"] == 102 + assert "Cannot find MCP server mcp-1 for user tenant_1" in res["message"] monkeypatch.setattr( module.MCPServerService, "get_or_none", lambda **_kwargs: _DummyMCPServer(id="mcp-1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1"), ) - res = module.detail() + res = module.detail("mcp-1") assert res["code"] == 0 assert res["data"]["id"] == "mcp-1" @@ -291,14 +309,14 @@ def test_detail_not_found_success_and_exception(monkeypatch): raise RuntimeError("detail explode") monkeypatch.setattr(module.MCPServerService, "get_or_none", _raise_detail) - res = module.detail() + res = module.detail("mcp-1") assert res["code"] == 100 assert "detail explode" in res["message"] @pytest.mark.p2 def test_create_validation_guards(monkeypatch): - module = _load_mcp_server_app(monkeypatch) + module = _load_mcp_api(monkeypatch) monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None)) @@ -323,7 +341,7 @@ def test_create_validation_guards(monkeypatch): @pytest.mark.p2 def test_create_service_paths(monkeypatch): - module = _load_mcp_server_app(monkeypatch) + module = _load_mcp_api(monkeypatch) base_payload = { "name": "srv", @@ -350,8 +368,8 @@ def test_create_service_paths(monkeypatch): monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error) res = _run(module.create.__wrapped__()) - assert res["code"] == "tools error" - assert "Sorry! Data missing!" in res["message"] + assert res["code"] == 102 + assert "tools error" in res["message"] _set_request_json(monkeypatch, module, dict(base_payload)) @@ -361,8 +379,8 @@ def test_create_service_paths(monkeypatch): monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok) monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: False) res = _run(module.create.__wrapped__()) - assert res["code"] == "Failed to create MCP server." - assert "Sorry! Data missing!" in res["message"] + assert res["code"] == 102 + assert "Failed to create MCP server" in res["message"] _set_request_json(monkeypatch, module, dict(base_payload)) monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: True) @@ -385,13 +403,13 @@ def test_create_service_paths(monkeypatch): @pytest.mark.p2 def test_update_validation_guards(monkeypatch): - module = _load_mcp_server_app(monkeypatch) + module = _load_mcp_api(monkeypatch) existing = _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="tenant_1", variables={}, headers={}) _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"}) monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None)) - res = _run(module.update.__wrapped__()) + res = _run(module.update("mcp-1")) assert "Cannot find MCP server" in res["message"] _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"}) @@ -400,26 +418,26 @@ def test_update_validation_guards(monkeypatch): "get_by_id", lambda _mcp_id: (True, _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="other", variables={}, headers={})), ) - res = _run(module.update.__wrapped__()) + res = _run(module.update("mcp-1")) assert "Cannot find MCP server" in res["message"] _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "server_type": "invalid"}) monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing)) - res = _run(module.update.__wrapped__()) + res = _run(module.update("mcp-1")) assert "Unsupported MCP server type" in res["message"] _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "name": "a" * 256}) - res = _run(module.update.__wrapped__()) + res = _run(module.update("mcp-1")) assert "Invalid MCP name" in res["message"] _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": ""}) - res = _run(module.update.__wrapped__()) + res = _run(module.update("mcp-1")) assert "Invalid url" in res["message"] @pytest.mark.p2 def test_update_service_paths(monkeypatch): - module = _load_mcp_server_app(monkeypatch) + module = _load_mcp_api(monkeypatch) existing = _DummyMCPServer( id="mcp-1", @@ -457,9 +475,9 @@ def test_update_service_paths(monkeypatch): return None, "update tools error" monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error) - res = _run(module.update.__wrapped__()) - assert res["code"] == "update tools error" - assert "Sorry! Data missing!" in res["message"] + res = _run(module.update("mcp-1")) + assert res["code"] == 102 + assert "update tools error" in res["message"] _set_request_json(monkeypatch, module, dict(base_payload)) @@ -468,7 +486,7 @@ def test_update_service_paths(monkeypatch): monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok) monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False) - res = _run(module.update.__wrapped__()) + res = _run(module.update("mcp-1")) assert "Failed to updated MCP server" in res["message"] _set_request_json(monkeypatch, module, dict(base_payload)) @@ -482,7 +500,7 @@ def test_update_service_paths(monkeypatch): _get_by_id_fetch_fail.calls = 0 monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_fetch_fail) - res = _run(module.update.__wrapped__()) + res = _run(module.update("mcp-1")) assert "Failed to fetch updated MCP server" in res["message"] _set_request_json(monkeypatch, module, dict(base_payload)) @@ -495,7 +513,7 @@ def test_update_service_paths(monkeypatch): _get_by_id_success.calls = 0 monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_success) - res = _run(module.update.__wrapped__()) + res = _run(module.update("mcp-1")) assert res["code"] == 0 assert res["data"]["id"] == "mcp-1" @@ -506,23 +524,25 @@ def test_update_service_paths(monkeypatch): raise RuntimeError("update explode") monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises) - res = _run(module.update.__wrapped__()) + res = _run(module.update("mcp-1")) assert res["code"] == 100 assert "update explode" in res["message"] @pytest.mark.p2 def test_rm_failure_success_and_exception(monkeypatch): - module = _load_mcp_server_app(monkeypatch) + module = _load_mcp_api(monkeypatch) + server = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={}) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server)) _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]}) monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: False) - res = _run(module.rm.__wrapped__()) + res = _run(module.rm("id1")) assert "Failed to delete MCP servers" in res["message"] _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]}) monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: True) - res = _run(module.rm.__wrapped__()) + res = _run(module.rm("id1")) assert res["code"] == 0 assert res["data"] is True @@ -532,14 +552,14 @@ def test_rm_failure_success_and_exception(monkeypatch): raise RuntimeError("rm explode") monkeypatch.setattr(module.MCPServerService, "delete_by_ids", _raise_rm) - res = _run(module.rm.__wrapped__()) + res = _run(module.rm("id1")) assert res["code"] == 100 assert "rm explode" in res["message"] @pytest.mark.p2 def test_import_multiple_missing_servers_and_exception(monkeypatch): - module = _load_mcp_server_app(monkeypatch) + module = _load_mcp_api(monkeypatch) _set_request_json(monkeypatch, module, {"mcpServers": {}}) res = _run(module.import_multiple.__wrapped__()) @@ -558,7 +578,7 @@ def test_import_multiple_missing_servers_and_exception(monkeypatch): @pytest.mark.p2 def test_import_multiple_mixed_results(monkeypatch): - module = _load_mcp_server_app(monkeypatch) + module = _load_mcp_api(monkeypatch) payload = { "mcpServers": { @@ -614,244 +634,72 @@ def test_import_multiple_mixed_results(monkeypatch): @pytest.mark.p2 -def test_export_multiple_missing_ids_success_and_exception(monkeypatch): - module = _load_mcp_server_app(monkeypatch) +def test_detail_download_success_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({"mode": "download"}))) - _set_request_json(monkeypatch, module, {"mcp_ids": []}) - res = _run(module.export_multiple.__wrapped__()) - assert "No MCP server IDs provided" in res["message"] - - _set_request_json(monkeypatch, module, {"mcp_ids": ["id1", "id2", "id3"]}) - - def _get_by_id(mcp_id): - if mcp_id == "id1": - return True, _DummyMCPServer( + monkeypatch.setattr( + module.MCPServerService, + "get_by_id", + lambda _mcp_id: ( + True, + _DummyMCPServer( id="id1", name="srv-one", url="http://one", server_type="sse", tenant_id="tenant_1", variables={"authorization_token": "tok", "tools": {"tool_a": {"enabled": True}}}, - ) - if mcp_id == "id2": - return True, _DummyMCPServer( + ), + ), + ) + res = module.detail("id1") + assert res["code"] == 0 + assert list(res["data"]["mcpServers"].keys()) == ["srv-one"] + + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None)) + res = module.detail("missing") + assert res["code"] == 102 + assert "Cannot find MCP server missing for user tenant_1" in res["message"] + + monkeypatch.setattr( + module.MCPServerService, + "get_by_id", + lambda _mcp_id: ( + True, + _DummyMCPServer( id="id2", name="srv-two", url="http://two", server_type="sse", tenant_id="other", variables={}, - ) - return False, None - - monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id) - res = _run(module.export_multiple.__wrapped__()) - assert res["code"] == 0 - assert list(res["data"]["mcpServers"].keys()) == ["srv-one"] - - _set_request_json(monkeypatch, module, {"mcp_ids": ["id1"]}) + ), + ), + ) + res = module.detail("id2") + assert res["code"] == 102 + assert "Cannot find MCP server id2 for user tenant_1" in res["message"] def _raise_export(_mcp_id): raise RuntimeError("export explode") monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_export) - res = _run(module.export_multiple.__wrapped__()) + res = module.detail("id1") assert res["code"] == 100 assert "export explode" in res["message"] -@pytest.mark.p2 -def test_list_tools_missing_ids_success_inner_error_outer_error_and_finally_cleanup(monkeypatch): - module = _load_mcp_server_app(monkeypatch) - - _set_request_json(monkeypatch, module, {"mcp_ids": []}) - res = _run(module.list_tools.__wrapped__()) - assert "No MCP server IDs provided" in res["message"] - - server = _DummyMCPServer( - id="id1", - name="srv-tools", - url="http://tools", - server_type="sse", - tenant_id="tenant_1", - variables={"tools": {"tool_a": {"enabled": False}}}, - ) - - _set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"}) - monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server)) - - close_calls = [] - - async def _thread_pool_exec_success(func, *args): - if func is module.close_multiple_mcp_toolcall_sessions: - close_calls.append(args[0]) - return None - return func(*args) - - monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success) - res = _run(module.list_tools.__wrapped__()) - assert res["code"] == 0 - assert res["data"]["id1"][0]["name"] == "tool_a" - assert res["data"]["id1"][0]["enabled"] is False - assert res["data"]["id1"][1]["enabled"] is True - assert close_calls and len(close_calls[-1]) == 1 - - _set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"}) - close_calls_inner = [] - - async def _thread_pool_exec_inner_error(func, *args): - if func is module.close_multiple_mcp_toolcall_sessions: - close_calls_inner.append(args[0]) - return None - raise RuntimeError("inner tools explode") - - monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error) - res = _run(module.list_tools.__wrapped__()) - assert res["code"] == 102 - assert "MCP list tools error" in res["message"] - assert close_calls_inner and len(close_calls_inner[-1]) == 1 - - _set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"}) - close_calls_outer = [] - - def _raise_get_by_id(_mcp_id): - raise RuntimeError("outer explode") - - monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_get_by_id) - - async def _thread_pool_exec_outer(func, *args): - if func is module.close_multiple_mcp_toolcall_sessions: - close_calls_outer.append(args[0]) - return None - return func(*args) - - monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_outer) - res = _run(module.list_tools.__wrapped__()) - assert res["code"] == 100 - assert "outer explode" in res["message"] - assert close_calls_outer - - -@pytest.mark.p2 -def test_test_tool_missing_mcp_id(monkeypatch): - module = _load_mcp_server_app(monkeypatch) - - _set_request_json(monkeypatch, module, {"mcp_id": "", "tool_name": "tool_a", "arguments": {"x": 1}}) - res = _run(module.test_tool.__wrapped__()) - assert "No MCP server ID provided" in res["message"] - - -@pytest.mark.p2 -def test_test_tool_route_matrix_unit(monkeypatch): - module = _load_mcp_server_app(monkeypatch) - - _set_request_json(monkeypatch, module, {"mcp_id": "", "tool_name": "tool_a", "arguments": {"x": 1}}) - res = _run(module.test_tool.__wrapped__()) - assert "No MCP server ID provided" in res["message"] - - _set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "", "arguments": {"x": 1}}) - res = _run(module.test_tool.__wrapped__()) - assert "Require provide tool name and arguments" in res["message"] - - _set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "tool_a", "arguments": {}}) - res = _run(module.test_tool.__wrapped__()) - assert "Require provide tool name and arguments" in res["message"] - - _set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "tool_a", "arguments": {"x": 1}}) - monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None)) - res = _run(module.test_tool.__wrapped__()) - assert "Cannot find MCP server id1 for user tenant_1" in res["message"] - - server_other = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="other", variables={}) - monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_other)) - res = _run(module.test_tool.__wrapped__()) - assert "Cannot find MCP server id1 for user tenant_1" in res["message"] - - server_ok = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={}) - monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_ok)) - close_calls = [] - - async def _thread_pool_exec_success(func, *args): - if func is module.close_multiple_mcp_toolcall_sessions: - close_calls.append(args[0]) - return None - return func(*args) - - monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success) - res = _run(module.test_tool.__wrapped__()) - assert res["code"] == 0 - assert res["data"] == "ok" - assert close_calls and len(close_calls[-1]) == 1 - - async def _thread_pool_exec_raise(func, *args): - if func is module.close_multiple_mcp_toolcall_sessions: - return None - raise RuntimeError("tool call explode") - - monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_raise) - res = _run(module.test_tool.__wrapped__()) - assert res["code"] == 100 - assert "tool call explode" in res["message"] - - -@pytest.mark.p2 -def test_cache_tool_route_matrix_unit(monkeypatch): - module = _load_mcp_server_app(monkeypatch) - - _set_request_json(monkeypatch, module, {"mcp_id": "", "tools": [{"name": "tool_a"}]}) - res = _run(module.cache_tool.__wrapped__()) - assert "No MCP server ID provided" in res["message"] - - _set_request_json(monkeypatch, module, {"mcp_id": "id1", "tools": [{"name": "tool_a"}]}) - monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None)) - res = _run(module.cache_tool.__wrapped__()) - assert "Cannot find MCP server id1 for user tenant_1" in res["message"] - - server_other = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="other", variables={}) - monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_other)) - res = _run(module.cache_tool.__wrapped__()) - assert "Cannot find MCP server id1 for user tenant_1" in res["message"] - - server_fail = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={}) - monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_fail)) - monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False) - res = _run(module.cache_tool.__wrapped__()) - assert "Failed to updated MCP server" in res["message"] - - server_ok = _DummyMCPServer( - id="id1", - name="srv", - url="http://a", - server_type="sse", - tenant_id="tenant_1", - variables={"tools": {"old_tool": {"name": "old_tool"}}}, - ) - monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_ok)) - monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True) - _set_request_json( - monkeypatch, - module, - { - "mcp_id": "id1", - "tools": [{"name": "tool_a", "enabled": True}, {"bad": 1}, "x", {"name": "tool_b", "enabled": False}], - }, - ) - res = _run(module.cache_tool.__wrapped__()) - assert res["code"] == 0 - assert sorted(res["data"].keys()) == ["tool_a", "tool_b"] - assert server_ok.variables["tools"]["tool_b"]["enabled"] is False - - @pytest.mark.p2 def test_test_mcp_route_matrix_unit(monkeypatch): - module = _load_mcp_server_app(monkeypatch) + module = _load_mcp_api(monkeypatch) _set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"}) - res = _run(module.test_mcp.__wrapped__()) + res = _run(module.test_mcp("mcp-1")) assert "Invalid MCP url" in res["message"] _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"}) - res = _run(module.test_mcp.__wrapped__()) + res = _run(module.test_mcp("mcp-1")) assert "Unsupported MCP server type" in res["message"] close_calls = [] @@ -866,7 +714,7 @@ def test_test_mcp_route_matrix_unit(monkeypatch): monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error) _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"}) - res = _run(module.test_mcp.__wrapped__()) + res = _run(module.test_mcp("mcp-1")) assert res["code"] == 102 assert "Test MCP error: get tools explode" in res["message"] assert close_calls and len(close_calls[-1]) == 1 @@ -881,7 +729,7 @@ def test_test_mcp_route_matrix_unit(monkeypatch): monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success) _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"}) - res = _run(module.test_mcp.__wrapped__()) + res = _run(module.test_mcp("mcp-1")) assert res["code"] == 0 assert res["data"][0]["name"] == "tool_a" assert all(tool["enabled"] is True for tool in res["data"]) @@ -892,6 +740,6 @@ def test_test_mcp_route_matrix_unit(monkeypatch): monkeypatch.setattr(module, "MCPToolCallSession", _raise_session) _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"}) - res = _run(module.test_mcp.__wrapped__()) + res = _run(module.test_mcp("mcp-1")) assert res["code"] == 100 assert "session explode" in res["message"] diff --git a/web/src/hooks/use-mcp-request.ts b/web/src/hooks/use-mcp-request.ts index f76811802d..051bab5987 100644 --- a/web/src/hooks/use-mcp-request.ts +++ b/web/src/hooks/use-mcp-request.ts @@ -141,8 +141,12 @@ export const useDeleteMcpServer = () => { } = useMutation({ mutationKey: [McpApiAction.DeleteMcpServer], mutationFn: async (ids: string[]) => { - const { data = {} } = await mcpServerService.delete({ mcp_ids: ids }); - if (data.code === 0) { + const results = await Promise.all( + ids.map((id) => mcpServerService.delete({ mcp_id: id })), + ); + const failed = results.find(({ data = {} }) => data.code !== 0); + const data = failed?.data ?? { code: 0, data: true }; + if (!failed) { message.success(i18n.t(`message.deleted`)); queryClient.invalidateQueries({ @@ -188,8 +192,23 @@ export const useExportMcpServer = () => { } = useMutation, Error, string[]>({ mutationKey: [McpApiAction.ExportMcpServer], mutationFn: async (ids) => { - const { data = {} } = await mcpServerService.export({ mcp_ids: ids }); - if (data.code === 0) { + const results = await Promise.all( + ids.map((id) => mcpServerService.export({ mcp_id: id })), + ); + const failed = results.find(({ data = {} }) => data.code !== 0); + const data = (failed?.data ?? { + code: 0, + data: results.reduce( + (acc, result) => ({ + mcpServers: { + ...acc.mcpServers, + ...(result.data?.data?.mcpServers ?? {}), + }, + }), + { mcpServers: {} }, + ), + }) as ResponseType; + if (!failed) { message.success(i18n.t(`message.operated`)); } return data; diff --git a/web/src/interfaces/database/mcp.ts b/web/src/interfaces/database/mcp.ts index 143cf8cb48..d489dfaec5 100644 --- a/web/src/interfaces/database/mcp.ts +++ b/web/src/interfaces/database/mcp.ts @@ -43,12 +43,7 @@ interface ISymbol { } export interface IExportedMcpServers { - mcpServers: McpServers; -} - -interface McpServers { - fetch_2: IExportedMcpServer; - github_1: IExportedMcpServer; + mcpServers: Record; } export interface IExportedMcpServer { diff --git a/web/src/services/mcp-server-service.ts b/web/src/services/mcp-server-service.ts index fbdf232fb2..d0a49d2c74 100644 --- a/web/src/services/mcp-server-service.ts +++ b/web/src/services/mcp-server-service.ts @@ -1,57 +1,27 @@ import { IPaginationRequestBody } from '@/interfaces/request/base'; import api from '@/utils/api'; -import registerServer from '@/utils/register-server'; import request from '@/utils/request'; -const { - listMcpServer, - createMcpServer, - updateMcpServer, - deleteMcpServer, - getMcpServer, - importMcpServer, - exportMcpServer, - testMcpServer, -} = api; - -const methods = { - list: { - url: listMcpServer, - method: 'post', - }, - get: { - url: getMcpServer, - method: 'get', - }, - create: { - url: createMcpServer, - method: 'post', - }, - update: { - url: updateMcpServer, - method: 'post', - }, - delete: { - url: deleteMcpServer, - method: 'post', - }, - import: { - url: importMcpServer, - method: 'post', - }, - export: { - url: exportMcpServer, - method: 'post', - }, - test: { - url: testMcpServer, - method: 'post', - }, -} as const; - -const mcpServerService = registerServer(methods, request); +const mcpServerService = { + get: (params: { mcp_id: string }) => + request.get(api.getMcpServer(params.mcp_id), { + params: { mode: 'preview' }, + }), + create: (params?: Record) => + request.post(api.createMcpServer, { data: params }), + update: ({ mcp_id, ...params }: Record) => + request.put(api.updateMcpServer(mcp_id), { data: params }), + delete: ({ mcp_id }: { mcp_id: string }) => + request.delete(api.deleteMcpServer(mcp_id)), + import: (params?: Record) => + request.post(api.importMcpServer, { data: params }), + export: ({ mcp_id }: { mcp_id: string }) => + request.get(api.exportMcpServer(mcp_id)), + test: (params: Record) => + request.post(api.testMcpServer(params.name || 'preview'), { data: params }), +}; export default mcpServerService; export const listMcpServers = (params?: IPaginationRequestBody, body?: any) => - request.post(api.listMcpServer, { data: body || {}, params }); + request.get(api.listMcpServer, { params: { ...params, ...(body || {}) } }); diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index 982a24871e..691ae9e7bd 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -220,14 +220,15 @@ export default { `${webAPI}/canvas/${canvasId}/completion`, // mcp server - listMcpServer: `${webAPI}/mcp_server/list`, - getMcpServer: `${webAPI}/mcp_server/detail`, - createMcpServer: `${webAPI}/mcp_server/create`, - updateMcpServer: `${webAPI}/mcp_server/update`, - deleteMcpServer: `${webAPI}/mcp_server/rm`, - importMcpServer: `${webAPI}/mcp_server/import`, - exportMcpServer: `${webAPI}/mcp_server/export`, - testMcpServer: `${webAPI}/mcp_server/test_mcp`, + listMcpServer: `${restAPIv1}/mcp/servers`, + getMcpServer: (id: string) => `${restAPIv1}/mcp/servers/${id}`, + createMcpServer: `${restAPIv1}/mcp/servers`, + updateMcpServer: (id: string) => `${restAPIv1}/mcp/servers/${id}`, + deleteMcpServer: (id: string) => `${restAPIv1}/mcp/servers/${id}`, + importMcpServer: `${restAPIv1}/mcp/servers/import`, + exportMcpServer: (id: string) => + `${restAPIv1}/mcp/servers/${id}?mode=download`, + testMcpServer: (id: string) => `${restAPIv1}/mcp/servers/${id}/test`, // next-search createSearch: `${restAPIv1}/searches`,