diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 8735370f49..aef1afb235 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -97,19 +97,21 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st parsed = urlparse(base_url) base = f"{parsed.scheme}://{parsed.netloc}" path = parsed.path.rstrip("/") - # OAuth 2.0 Authorization Server Metadata at root(MCP-03-26) + # OAuth 2.0 Authorization Server Metadata at root (MCP-03-26) urls.append(f"{base}/.well-known/oauth-authorization-server") # OpenID Connect Discovery at root urls.append(f"{base}/.well-known/openid-configuration") - # OpenID Connect Discovery with path insertion - urls.append(f"{base}/.well-known/openid-configuration{path}") - # OpenID Connect Discovery path appending - urls.append(f"{base}{path}/.well-known/openid-configuration") + if path: + # OpenID Connect Discovery with path insertion + urls.append(f"{base}/.well-known/openid-configuration{path}") - # OAuth 2.0 Authorization Server Metadata with path insertion - urls.append(f"{base}/.well-known/oauth-authorization-server{path}") + # OpenID Connect Discovery path appending + urls.append(f"{base}{path}/.well-known/openid-configuration") + + # OAuth 2.0 Authorization Server Metadata with path insertion + urls.append(f"{base}/.well-known/oauth-authorization-server{path}") return urls diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 8c190762cf..6284c634d6 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -1308,18 +1308,17 @@ class TestMCPToolManageService: type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(), ] - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client: # Setup mock client mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - result = service._reconnect_provider( + result = MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) # Assert: Verify the expected outcomes @@ -1337,8 +1336,12 @@ class TestMCPToolManageService: assert tools_data[1]["name"] == "test_tool_2" # Verify mock interactions - provider_entity = mcp_provider.to_entity() - mock_mcp_client.assert_called_once() + mock_mcp_client.assert_called_once_with( + server_url="https://example.com/mcp", + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, + ) def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -1361,19 +1364,18 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise authentication error - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client: from core.mcp.error import MCPAuthError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act: Execute the method under test - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - result = service._reconnect_provider( + result = MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) # Assert: Verify the expected outcomes @@ -1404,18 +1406,17 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise connection error - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client: from core.mcp.error import MCPError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"): - service._reconnect_provider( + MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, )