From 60fe5e7f00a6e0c2b195676505023fbfb3e3ec5b Mon Sep 17 00:00:00 2001 From: mahammadasim <135003320+mahammadasim@users.noreply.github.com> Date: Thu, 12 Mar 2026 09:14:56 +0530 Subject: [PATCH] test: added for core logging and core mcp (#32478) Co-authored-by: rajatagarwal-oss --- .../unit_tests/core/logging/test_filters.py | 178 +++ .../core/mcp/auth/test_auth_flow.py | 564 ++++++++ .../unit_tests/core/mcp/client/test_sse.py | 472 +++++++ .../core/mcp/client/test_streamable_http.py | 1195 ++++++++++++++++- .../core/mcp/session/test_base_session.py | 617 +++++++++ .../core/mcp/session/test_client_session.py | 576 ++++++++ .../unit_tests/core/mcp/test_mcp_client.py | 262 +++- 7 files changed, 3859 insertions(+), 5 deletions(-) create mode 100644 api/tests/unit_tests/core/mcp/session/test_base_session.py create mode 100644 api/tests/unit_tests/core/mcp/session/test_client_session.py diff --git a/api/tests/unit_tests/core/logging/test_filters.py b/api/tests/unit_tests/core/logging/test_filters.py index 7c2767266f..a8b186ac8a 100644 --- a/api/tests/unit_tests/core/logging/test_filters.py +++ b/api/tests/unit_tests/core/logging/test_filters.py @@ -82,6 +82,68 @@ class TestTraceContextFilter: assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2" assert log_record.span_id == "051581bf3bb55c45" + def test_otel_context_invalid_trace_id(self, log_record): + from core.logging.filters import TraceContextFilter + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.trace_id = 0 + mock_context.is_valid = True + mock_span.get_span_context.return_value = mock_context + + # Use mocks for base context to ensure we can test the fallback + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span), + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("core.logging.filters.get_trace_id", return_value=""), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "" + + def test_otel_context_invalid_span_id(self, log_record): + from core.logging.filters import TraceContextFilter + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2 + mock_context.span_id = 0 + mock_context.is_valid = True + mock_span.get_span_context.return_value = mock_context + + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span), + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2" + assert log_record.span_id == "" + + def test_otel_context_span_none(self, log_record): + from core.logging.filters import TraceContextFilter + + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=None), + mock.patch("core.logging.filters.get_trace_id", return_value=""), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "" + + def test_otel_context_exception(self, log_record): + from core.logging.filters import TraceContextFilter + + # Trigger exception in OTEL block + with ( + mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception), + mock.patch("core.logging.filters.get_trace_id", return_value=""), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "" + class TestIdentityContextFilter: def test_sets_empty_identity_without_request_context(self, log_record): @@ -114,3 +176,119 @@ class TestIdentityContextFilter: result = filter.filter(log_record) assert result is True assert log_record.tenant_id == "" + + def test_sets_empty_identity_unauthenticated(self, log_record): + from core.logging.filters import IdentityContextFilter + + mock_user = mock.MagicMock() + mock_user.is_authenticated = False + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + assert log_record.user_id == "" + + def test_sets_identity_for_account(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockAccount: + pass + + mock_user = MockAccount() + mock_user.id = "account_id" + mock_user.current_tenant_id = "tenant_id" + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.Account", MockAccount), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "tenant_id" + assert log_record.user_id == "account_id" + assert log_record.user_type == "account" + + def test_sets_identity_for_account_no_tenant(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockAccount: + pass + + mock_user = MockAccount() + mock_user.id = "account_id" + mock_user.current_tenant_id = None + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.Account", MockAccount), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "" + assert log_record.user_id == "account_id" + assert log_record.user_type == "account" + + def test_sets_identity_for_end_user(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockEndUser: + pass + + class AnotherClass: + pass + + mock_user = MockEndUser() + mock_user.id = "end_user_id" + mock_user.tenant_id = "tenant_id" + mock_user.type = "custom_type" + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.model.EndUser", MockEndUser), + mock.patch("models.Account", AnotherClass), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "tenant_id" + assert log_record.user_id == "end_user_id" + assert log_record.user_type == "custom_type" + + def test_sets_identity_for_end_user_default_type(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockEndUser: + pass + + class AnotherClass: + pass + + mock_user = MockEndUser() + mock_user.id = "end_user_id" + mock_user.tenant_id = "tenant_id" + mock_user.type = None + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.model.EndUser", MockEndUser), + mock.patch("models.Account", AnotherClass), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "tenant_id" + assert log_record.user_id == "end_user_id" + assert log_record.user_type == "end_user" diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index 60f37b6de0..abf3c60fe0 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -1,27 +1,39 @@ """Unit tests for MCP OAuth authentication flow.""" +import json from unittest.mock import Mock, patch +import httpx import pytest +from pydantic import ValidationError from core.entities.mcp_provider import MCPProviderEntity +from core.helper import ssrf_proxy from core.mcp.auth.auth_flow import ( OAUTH_STATE_EXPIRY_SECONDS, OAUTH_STATE_REDIS_KEY_PREFIX, OAuthCallbackState, _create_secure_redis_state, + _parse_token_response, _retrieve_redis_state, auth, + build_oauth_authorization_server_metadata_discovery_urls, + build_protected_resource_metadata_discovery_urls, check_support_resource_discovery, + client_credentials_flow, + discover_oauth_authorization_server_metadata, discover_oauth_metadata, + discover_protected_resource_metadata, exchange_authorization, generate_pkce_challenge, + get_effective_scope, handle_callback, refresh_authorization, register_client, start_authorization, ) from core.mcp.entities import AuthActionType, AuthResult +from core.mcp.error import MCPRefreshTokenError from core.mcp.types import ( LATEST_PROTOCOL_VERSION, OAuthClientInformation, @@ -764,3 +776,555 @@ class TestAuthOrchestration: auth(mock_provider, authorization_code="auth-code") assert "Existing OAuth client information is required" in str(exc_info.value) + + def test_generate_pkce_challenge(self): + verifier, challenge = generate_pkce_challenge() + assert verifier + assert challenge + assert "=" not in verifier + assert "=" not in challenge + + def test_build_protected_resource_metadata_discovery_urls(self): + # Case 1: WWW-Auth URL provided + urls = build_protected_resource_metadata_discovery_urls( + "https://auth.example.com/prm", "https://api.example.com" + ) + assert "https://auth.example.com/prm" in urls + assert "https://api.example.com/.well-known/oauth-protected-resource" in urls + + # Case 2: No WWW-Auth URL, with path + urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com/v1") + assert "https://api.example.com/.well-known/oauth-protected-resource/v1" in urls + assert "https://api.example.com/.well-known/oauth-protected-resource" in urls + + # Case 3: No path + urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com") + assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"] + + def test_build_oauth_authorization_server_metadata_discovery_urls(self): + # Case 1: with auth_server_url + urls = build_oauth_authorization_server_metadata_discovery_urls( + "https://auth.example.com", "https://api.example.com" + ) + assert "https://auth.example.com/.well-known/oauth-authorization-server" in urls + assert "https://auth.example.com/.well-known/openid-configuration" in urls + + # Case 2: with path + urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://api.example.com/tenant") + assert "https://api.example.com/.well-known/oauth-authorization-server/tenant" in urls + assert "https://api.example.com/tenant/.well-known/openid-configuration" in urls + + @patch("core.helper.ssrf_proxy.get") + def test_discover_protected_resource_metadata(self, mock_get): + # Success + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "resource": "https://api.example.com", + "authorization_servers": ["https://auth"], + } + mock_get.return_value = mock_response + result = discover_protected_resource_metadata(None, "https://api.example.com") + assert result is not None + assert result.resource == "https://api.example.com" + + # 404 then Success + res404 = Mock() + res404.status_code = 404 + mock_get.side_effect = [res404, mock_response] + result = discover_protected_resource_metadata(None, "https://api.example.com/path") + assert result is not None + assert result.resource == "https://api.example.com" + + # Error handling + mock_get.side_effect = httpx.RequestError("Error") + result = discover_protected_resource_metadata(None, "https://api.example.com") + assert result is None + + @patch("core.helper.ssrf_proxy.get") + def test_discover_oauth_authorization_server_metadata(self, mock_get): + # Success + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "authorization_endpoint": "https://auth.example.com/auth", + "token_endpoint": "https://auth.example.com/token", + "response_types_supported": ["code"], + } + mock_get.return_value = mock_response + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") + assert result is not None + assert result.authorization_endpoint == "https://auth.example.com/auth" + + # 404 + res404 = Mock() + res404.status_code = 404 + mock_get.side_effect = [res404, mock_response] + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com/tenant") + assert result is not None + assert result.authorization_endpoint == "https://auth.example.com/auth" + + # ValidationError + mock_response.json.return_value = {"invalid": "data"} + mock_get.side_effect = None + mock_get.return_value = mock_response + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") + assert result is None + + def test_get_effective_scope(self): + prm = ProtectedResourceMetadata( + resource="https://api.example.com", + authorization_servers=["https://auth"], + scopes_supported=["read", "write"], + ) + asm = OAuthMetadata( + authorization_endpoint="https://auth.example.com/auth", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + scopes_supported=["openid", "profile"], + ) + + # 1. WWW-Auth priority + assert get_effective_scope("scope1", prm, asm, "client") == "scope1" + # 2. PRM priority + assert get_effective_scope(None, prm, asm, "client") == "read write" + # 3. ASM priority + assert get_effective_scope(None, None, asm, "client") == "openid profile" + # 4. Client configured + assert get_effective_scope(None, None, None, "client") == "client" + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_redis_state_management(self, mock_redis): + state_data = OAuthCallbackState( + provider_id="p1", + tenant_id="t1", + server_url="https://api", + metadata=None, + client_information=OAuthClientInformation(client_id="c1"), + code_verifier="cv", + redirect_uri="https://re", + ) + + # Create + state_key = _create_secure_redis_state(state_data) + assert state_key + mock_redis.setex.assert_called_once() + + # Retrieve Success + mock_redis.get.return_value = state_data.model_dump_json() + retrieved = _retrieve_redis_state(state_key) + assert retrieved.provider_id == "p1" + mock_redis.delete.assert_called_once() + + # Retrieve Failure - Not found + mock_redis.get.return_value = None + with pytest.raises(ValueError, match="expired or does not exist"): + _retrieve_redis_state("absent") + + # Retrieve Failure - Invalid JSON + mock_redis.get.return_value = "invalid" + with pytest.raises(ValueError, match="Invalid state parameter"): + _retrieve_redis_state("invalid") + + @patch("core.mcp.auth.auth_flow._retrieve_redis_state") + @patch("core.mcp.auth.auth_flow.exchange_authorization") + def test_handle_callback(self, mock_exchange, mock_retrieve): + state = Mock(spec=OAuthCallbackState) + state.server_url = "https://api" + state.metadata = None + state.client_information = Mock() + state.code_verifier = "cv" + state.redirect_uri = "https://re" + mock_retrieve.return_value = state + + tokens = Mock(spec=OAuthTokens) + mock_exchange.return_value = tokens + + s, t = handle_callback("key", "code") + assert s == state + assert t == tokens + + @patch("core.helper.ssrf_proxy.get") + def test_check_support_resource_discovery(self, mock_get): + # Case 1: authorization_servers (plural) + res = Mock() + res.status_code = 200 + res.json.return_value = {"authorization_servers": ["https://auth1"]} + mock_get.return_value = res + supported, url = check_support_resource_discovery("https://api") + assert supported is True + assert url == "https://auth1" + + # Case 2: authorization_server_url (singular alias) + res.json.return_value = {"authorization_server_url": ["https://auth2"]} + supported, url = check_support_resource_discovery("https://api") + assert supported is True + assert url == "https://auth2" + + # Case 3: Missing fields + res.json.return_value = {"nothing": []} + supported, url = check_support_resource_discovery("https://api") + assert supported is False + + # Case 4: 404 + res.status_code = 404 + supported, url = check_support_resource_discovery("https://api") + assert supported is False + + # Case 5: RequestError + mock_get.side_effect = httpx.RequestError("Error") + supported, url = check_support_resource_discovery("https://api") + assert supported is False + + def test_discover_oauth_metadata(self): + with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm: + with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm: + mock_prm.return_value = ProtectedResourceMetadata( + resource="https://api", authorization_servers=["https://auth"] + ) + mock_asm.return_value = Mock(spec=OAuthMetadata) + + asm, prm, hint = discover_oauth_metadata("https://api") + assert asm == mock_asm.return_value + assert prm == mock_prm.return_value + mock_asm.assert_called_with("https://auth", "https://api", None) + + def test_start_authorization(self): + metadata = OAuthMetadata( + authorization_endpoint="https://auth/authorize", + token_endpoint="https://auth/token", + response_types_supported=["code"], + ) + client_info = OAuthClientInformation(client_id="c1") + + with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create: + mock_create.return_value = "state-key" + + # Success with scope + url, verifier = start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1", "read") + assert "scope=read" in url + assert "state=state-key" in url + + # Success without metadata + url, verifier = start_authorization("https://api", None, client_info, "https://re", "p1", "t1") + assert "https://api/authorize" in url + + # Failure: incompatible auth server + metadata.response_types_supported = ["implicit"] + with pytest.raises(ValueError, match="Incompatible auth server"): + start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1") + + def test_parse_token_response(self): + # Case 1: JSON + res = Mock() + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at", "token_type": "Bearer"} + tokens = _parse_token_response(res) + assert tokens.access_token == "at" + + # Case 2: Form-urlencoded + res.headers = {"content-type": "application/x-www-form-urlencoded"} + res.text = "access_token=at2&token_type=Bearer" + tokens = _parse_token_response(res) + assert tokens.access_token == "at2" + + # Case 3: No content-type, but JSON + res.headers = {} + res.json.return_value = {"access_token": "at3", "token_type": "Bearer"} + tokens = _parse_token_response(res) + assert tokens.access_token == "at3" + + # Case 4: No content-type, not JSON, but Form + res.json.side_effect = json.JSONDecodeError("msg", "doc", 0) + res.text = "access_token=at4&token_type=Bearer" + tokens = _parse_token_response(res) + assert tokens.access_token == "at4" + + # Case 5: Validation Error fallback + res.json.side_effect = ValidationError.from_exception_data("error", []) + res.text = "access_token=at5&token_type=Bearer" + tokens = _parse_token_response(res) + assert tokens.access_token == "at5" + + @patch("core.helper.ssrf_proxy.post") + def test_exchange_authorization(self, mock_post): + client_info = OAuthClientInformation(client_id="c1", client_secret="s1") + metadata = OAuthMetadata( + authorization_endpoint="https://auth/authorize", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + + # Success + res = Mock() + res.is_success = True + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at", "token_type": "Bearer"} + mock_post.return_value = res + + tokens = exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re") + assert tokens.access_token == "at" + + # Failure: Unsupported grant type + metadata.grant_types_supported = ["client_credentials"] + with pytest.raises(ValueError, match="Incompatible auth server"): + exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re") + + # Failure: HTTP error + metadata.grant_types_supported = ["authorization_code"] + res.is_success = False + res.status_code = 400 + with pytest.raises(ValueError, match="Token exchange failed"): + exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re") + + @patch("core.helper.ssrf_proxy.post") + def test_refresh_authorization(self, mock_post): + # Case 1: with client_secret + client_info = OAuthClientInformation(client_id="c1", client_secret="s1") + + # Success + res = Mock() + res.is_success = True + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at_new", "token_type": "Bearer"} + mock_post.return_value = res + + tokens = refresh_authorization("https://api", None, client_info, "rt") + assert tokens.access_token == "at_new" + assert mock_post.call_args[1]["data"]["client_secret"] == "s1" + + # Failure: MaxRetriesExceededError + mock_post.side_effect = ssrf_proxy.MaxRetriesExceededError("Too many retries") + with pytest.raises(MCPRefreshTokenError): + refresh_authorization("https://api", None, client_info, "rt") + + # Failure: HTTP error + mock_post.side_effect = None + res.is_success = False + res.text = "error_msg" + with pytest.raises(MCPRefreshTokenError, match="error_msg"): + refresh_authorization("https://api", None, client_info, "rt") + + # Failure: Incompatible metadata + metadata = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + with pytest.raises(ValueError, match="Incompatible auth server"): + refresh_authorization("https://api", metadata, client_info, "rt") + + @patch("core.helper.ssrf_proxy.post") + def test_client_credentials_flow(self, mock_post): + client_info = OAuthClientInformation(client_id="c1", client_secret="s1") + + # Success with secret + res = Mock() + res.is_success = True + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at_cc", "token_type": "Bearer"} + mock_post.return_value = res + + tokens = client_credentials_flow("https://api", None, client_info, "read") + assert tokens.access_token == "at_cc" + args, kwargs = mock_post.call_args + assert "Authorization" in kwargs["headers"] + + # Success without secret + client_info_no_secret = OAuthClientInformation(client_id="c2") + tokens = client_credentials_flow("https://api", None, client_info_no_secret) + args, kwargs = mock_post.call_args + assert kwargs["data"]["client_id"] == "c2" + + # Failure: Incompatible metadata + metadata = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + with pytest.raises(ValueError, match="Incompatible auth server"): + client_credentials_flow("https://api", metadata, client_info) + + # Failure: HTTP error + res.is_success = False + res.status_code = 401 + res.text = "Unauthorized" + with pytest.raises(ValueError, match="Client credentials token request failed"): + client_credentials_flow("https://api", None, client_info) + + @patch("core.helper.ssrf_proxy.post") + def test_register_client(self, mock_post): + # Case 1: Success with metadata + metadata = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + registration_endpoint="https://auth/register", + response_types_supported=["code"], + ) + client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://re"]) + + res = Mock() + res.is_success = True + res.json.return_value = { + "client_id": "c_new", + "client_secret": "s_new", + "client_name": "Dify", + "redirect_uris": ["https://re"], + } + mock_post.return_value = res + + info = register_client("https://api", metadata, client_metadata) + assert info.client_id == "c_new" + + # Case 2: Success without metadata + info = register_client("https://api", None, client_metadata) + assert mock_post.call_args[0][0] == "https://api/register" + + # Case 3: Metadata provided but no endpoint + metadata.registration_endpoint = None + with pytest.raises(ValueError, match="does not support dynamic client registration"): + register_client("https://api", metadata, client_metadata) + + # Failure: HTTP + res.is_success = False + res.raise_for_status = Mock() + res.status_code = 400 + # If is_success is false, it should call raise_for_status + register_client("https://api", None, client_metadata) + res.raise_for_status.assert_called_once() + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_failures(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + + # Case 1: No server metadata + mock_discover.return_value = (None, None, None) + with pytest.raises(ValueError, match="Failed to discover OAuth metadata"): + auth(provider) + + # Case 2: No client info, exchange code provided + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + ) + mock_discover.return_value = (asm, None, None) + provider.retrieve_client_information.return_value = None + with pytest.raises(ValueError, match="Existing OAuth client information is required"): + auth(provider, authorization_code="code") + + # Case 3: CLIENT_CREDENTIALS but client must provide info + asm.grant_types_supported = ["client_credentials"] + with pytest.raises(ValueError, match="requires client_id and client_secret"): + auth(provider) + + # Case 4: Client registration fails + asm.grant_types_supported = ["authorization_code"] + with patch("core.mcp.auth.auth_flow.register_client") as mock_reg: + mock_reg.side_effect = httpx.RequestError("Reg failed") + with pytest.raises(ValueError, match="Could not register OAuth client"): + auth(provider) + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_client_credentials(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1", client_secret="s1") + provider.decrypt_credentials.return_value = {"scope": "read"} + + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["client_credentials"], + ) + mock_discover.return_value = (asm, None, None) + + with patch("core.mcp.auth.auth_flow.client_credentials_flow") as mock_cc: + mock_cc.return_value = OAuthTokens(access_token="at_cc", token_type="Bearer") + + result = auth(provider) + assert result.response == {"result": "success"} + assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS + assert result.actions[0].data["grant_type"] == "client_credentials" + + # Failure in CC flow + mock_cc.side_effect = ValueError("CC Failed") + with pytest.raises(ValueError, match="Client credentials flow failed"): + auth(provider) + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_authorization_code(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1") + provider.decrypt_credentials.return_value = {} + + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + mock_discover.return_value = (asm, None, None) + + # Case 1: Exchange code + with patch("core.mcp.auth.auth_flow._retrieve_redis_state") as mock_retrieve: + state = Mock(spec=OAuthCallbackState) + state.code_verifier = "cv" + state.redirect_uri = "https://re" + mock_retrieve.return_value = state + + with patch("core.mcp.auth.auth_flow.exchange_authorization") as mock_exchange: + mock_exchange.return_value = OAuthTokens(access_token="at_code", token_type="Bearer") + + # Success + result = auth(provider, authorization_code="code", state_param="sp") + assert result.response == {"result": "success"} + + # Missing state_param + with pytest.raises(ValueError, match="State parameter is required"): + auth(provider, authorization_code="code") + + # Missing verifier in state + state.code_verifier = None + with pytest.raises(ValueError, match="Missing code_verifier"): + auth(provider, authorization_code="code", state_param="sp") + + # Invalid state + mock_retrieve.side_effect = ValueError("Invalid") + with pytest.raises(ValueError, match="Invalid state parameter"): + auth(provider, authorization_code="code", state_param="sp") + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_refresh_failure(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1") + provider.decrypt_credentials.return_value = {} + provider.retrieve_tokens.return_value = OAuthTokens(access_token="at", token_type="Bearer", refresh_token="rt") + + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + mock_discover.return_value = (asm, None, None) + + with patch("core.mcp.auth.auth_flow.refresh_authorization") as mock_refresh: + mock_refresh.side_effect = ValueError("Refresh Failed") + with pytest.raises(ValueError, match="Could not refresh OAuth tokens"): + auth(provider) diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py index 490a647025..e6eeb6cd59 100644 --- a/api/tests/unit_tests/core/mcp/client/test_sse.py +++ b/api/tests/unit_tests/core/mcp/client/test_sse.py @@ -322,3 +322,475 @@ def test_sse_client_concurrent_access(): assert len(received_messages) == 10 for i in range(10): assert f"message_{i}" in received_messages + + +class TestStatusClasses: + """Tests for _StatusReady and _StatusError data containers.""" + + def test_status_ready_stores_endpoint(self): + from core.mcp.client.sse_client import _StatusReady + + status = _StatusReady("http://example.com/messages/") + assert status.endpoint_url == "http://example.com/messages/" + + def test_status_error_stores_exception(self): + from core.mcp.client.sse_client import _StatusError + + exc = ValueError("bad endpoint") + status = _StatusError(exc) + assert status.exc is exc + + +class TestSSETransportInit: + """Tests for SSETransport default and explicit init values.""" + + def test_defaults(self): + from core.mcp.client.sse_client import SSETransport + + t = SSETransport("http://example.com/sse") + assert t.url == "http://example.com/sse" + assert t.headers == {} + assert t.timeout == 5.0 + assert t.sse_read_timeout == 60.0 + assert t.endpoint_url is None + assert t.event_source is None + + def test_explicit_headers_not_mutated(self): + from core.mcp.client.sse_client import SSETransport + + hdrs = {"X-Foo": "bar"} + t = SSETransport("http://example.com/sse", headers=hdrs) + assert t.headers is hdrs + + +class TestHandleEndpointEvent: + """Tests for SSETransport._handle_endpoint_event covering the invalid-origin branch.""" + + def test_invalid_origin_puts_status_error(self): + from core.mcp.client.sse_client import SSETransport, _StatusError + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + # Provide a full URL with a different origin so urljoin keeps it as-is + transport._handle_endpoint_event("http://evil.com/messages/", status_queue) + + result = status_queue.get_nowait() + assert isinstance(result, _StatusError) + assert "does not match" in str(result.exc) + + def test_valid_origin_puts_status_ready(self): + from core.mcp.client.sse_client import SSETransport, _StatusReady + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + transport._handle_endpoint_event("/messages/?session_id=abc", status_queue) + + result = status_queue.get_nowait() + assert isinstance(result, _StatusReady) + assert "example.com" in result.endpoint_url + + +class TestHandleSSEEvent: + """Tests for SSETransport._handle_sse_event covering all match branches.""" + + def _make_sse(self, event_type: str, data: str): + sse = Mock() + sse.event = event_type + sse.data = data + return sse + + def test_message_event_dispatched(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + valid_msg = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + transport._handle_sse_event(self._make_sse("message", valid_msg), read_queue, status_queue) + + item = read_queue.get_nowait() + assert hasattr(item, "message") + + def test_unknown_event_logs_warning_and_does_nothing(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + transport._handle_sse_event(self._make_sse("ping", "{}"), read_queue, status_queue) + + assert read_queue.empty() + assert status_queue.empty() + + +class TestSSEReader: + """Tests for SSETransport.sse_reader exception branches.""" + + def test_read_error_closes_cleanly(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + event_source = Mock() + event_source.iter_sse.side_effect = httpx.ReadError("connection reset") + + transport.sse_reader(event_source, read_queue, status_queue) + + # Finally block always puts None as sentinel + sentinel = read_queue.get_nowait() + assert sentinel is None + + def test_generic_exception_puts_exc_then_none(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + boom = RuntimeError("unexpected!") + event_source = Mock() + event_source.iter_sse.side_effect = boom + + transport.sse_reader(event_source, read_queue, status_queue) + + exc_item = read_queue.get_nowait() + assert exc_item is boom + + sentinel = read_queue.get_nowait() + assert sentinel is None + + +class TestSendMessage: + """Tests for SSETransport._send_message.""" + + def _make_session_message(self): + msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + msg = types.JSONRPCMessage.model_validate_json(msg_json) + return types.SessionMessage(msg) + + def test_sends_post_and_raises_for_status(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + + mock_response = Mock() + mock_response.status_code = 200 + mock_client = Mock() + mock_client.post.return_value = mock_response + + session_msg = self._make_session_message() + transport._send_message(mock_client, "http://example.com/messages/", session_msg) + + mock_client.post.assert_called_once() + mock_response.raise_for_status.assert_called_once() + + +class TestPostWriter: + """Tests for SSETransport.post_writer exception branches.""" + + def _make_session_message(self): + msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + msg = types.JSONRPCMessage.model_validate_json(msg_json) + return types.SessionMessage(msg) + + def test_none_message_exits_loop(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + write_queue.put(None) # Signal shutdown immediately + + mock_client = Mock() + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # Should put final None sentinel + sentinel = write_queue.get_nowait() + assert sentinel is None + + def test_exception_in_message_put_back_to_queue(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + exc = ValueError("some error") + write_queue.put(exc) # Exception goes in first + write_queue.put(None) # Then shutdown signal + + mock_client = Mock() + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # The exception should be re-queued, then None from loop exit, then None from finally + item1 = write_queue.get_nowait() + assert isinstance(item1, Exception) + + def test_read_error_shuts_down_cleanly(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + session_msg = self._make_session_message() + write_queue.put(session_msg) + + mock_response = Mock() + mock_response.status_code = 200 + mock_client = Mock() + mock_client.post.side_effect = httpx.ReadError("connection dropped") + + # post_writer calls _send_message which calls client.post → ReadError propagates + # The ReadError is raised inside _send_message → propagates out of the while loop + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # finally always puts None + sentinel = write_queue.get_nowait() + assert sentinel is None + + def test_generic_exception_puts_exc_in_queue(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + session_msg = self._make_session_message() + write_queue.put(session_msg) + + mock_client = Mock() + boom = RuntimeError("boom") + mock_client.post.side_effect = boom + + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + exc_item = write_queue.get_nowait() + assert isinstance(exc_item, Exception) + + sentinel = write_queue.get_nowait() + assert sentinel is None + + def test_queue_empty_timeout_continues_loop(self): + """Cover the 'except queue.Empty: continue' branch (line 188) in post_writer.""" + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + mock_client = Mock() + + # Patch queue.Queue.get so it raises Empty first, then returns None (shutdown) + call_count = {"n": 0} + original_get = write_queue.get + + def patched_get(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + raise queue.Empty + + write_queue.get = patched_get # type: ignore[method-assign] + + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # finally always puts None sentinel + sentinel = write_queue.get_nowait() + assert sentinel is None + assert call_count["n"] >= 2 # Empty on first, None on second (and possibly more retries) + + +class TestWaitForEndpoint: + """Tests for SSETransport._wait_for_endpoint edge cases.""" + + def test_raises_on_empty_queue(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() # empty + + with pytest.raises(ValueError, match="failed to get endpoint URL"): + transport._wait_for_endpoint(status_queue) + + def test_raises_status_error_exception(self): + from core.mcp.client.sse_client import SSETransport, _StatusError + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + exc = ValueError("malicious endpoint") + status_queue.put(_StatusError(exc)) + + with pytest.raises(ValueError, match="malicious endpoint"): + transport._wait_for_endpoint(status_queue) + + def test_raises_on_unknown_status_type(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + # Put an object that is neither _StatusReady nor _StatusError + status_queue.put("unexpected_value") + + with pytest.raises(ValueError, match="failed to get endpoint URL"): + transport._wait_for_endpoint(status_queue) + + +class TestSSEClientRuntimeError: + """Test sse_client context manager handles RuntimeError on close().""" + + def test_runtime_error_on_close_is_suppressed(self): + """Ensure RuntimeError raised by event_source.response.close() is caught.""" + test_url = "http://test.example/sse" + + class MockSSEEvent: + def __init__(self, event_type: str, data: str): + self.event = event_type + self.data = data + + endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123") + + with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sc: + mock_client = Mock() + mock_cf.return_value.__enter__.return_value = mock_client + + mock_es = Mock() + mock_es.response.raise_for_status.return_value = None + mock_es.iter_sse.return_value = [endpoint_event] + # Make close() raise RuntimeError to exercise line 307-308 + mock_es.response.close.side_effect = RuntimeError("already closed") + mock_sc.return_value.__enter__.return_value = mock_es + + # Should NOT raise even though close() raises RuntimeError + with contextlib.suppress(Exception): + with sse_client(test_url) as (rq, wq): + pass + + +class TestStandaloneSendMessage: + """Tests for the module-level send_message() function.""" + + def _make_session_message(self): + msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + msg = types.JSONRPCMessage.model_validate_json(msg_json) + return types.SessionMessage(msg) + + def test_send_message_success(self): + from core.mcp.client.sse_client import send_message + + mock_response = Mock() + mock_response.status_code = 200 + mock_http_client = Mock() + mock_http_client.post.return_value = mock_response + + session_msg = self._make_session_message() + send_message(mock_http_client, "http://example.com/messages/", session_msg) + + mock_http_client.post.assert_called_once() + mock_response.raise_for_status.assert_called_once() + + def test_send_message_raises_on_http_error(self): + from core.mcp.client.sse_client import send_message + + mock_http_client = Mock() + mock_http_client.post.side_effect = httpx.ConnectError("refused") + + session_msg = self._make_session_message() + + with pytest.raises(httpx.ConnectError): + send_message(mock_http_client, "http://example.com/messages/", session_msg) + + def test_send_message_raises_for_status_failure(self): + from core.mcp.client.sse_client import send_message + + mock_response = Mock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not Found", request=Mock(), response=Mock(status_code=404) + ) + mock_http_client = Mock() + mock_http_client.post.return_value = mock_response + + session_msg = self._make_session_message() + + with pytest.raises(httpx.HTTPStatusError): + send_message(mock_http_client, "http://example.com/messages/", session_msg) + + +class TestReadMessages: + """Tests for the module-level read_messages() generator.""" + + def _make_mock_sse_event(self, event_type: str, data: str): + ev = Mock() + ev.event = event_type + ev.data = data + return ev + + def test_valid_message_event_yields_session_message(self): + from core.mcp.client.sse_client import read_messages + + valid_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + mock_sse_event = self._make_mock_sse_event("message", valid_json) + + mock_client = Mock() + mock_client.events.return_value = [mock_sse_event] + + results = list(read_messages(mock_client)) + assert len(results) == 1 + assert hasattr(results[0], "message") + + def test_invalid_json_yields_exception(self): + from core.mcp.client.sse_client import read_messages + + mock_sse_event = self._make_mock_sse_event("message", "{not valid json}") + + mock_client = Mock() + mock_client.events.return_value = [mock_sse_event] + + results = list(read_messages(mock_client)) + assert len(results) == 1 + assert isinstance(results[0], Exception) + + def test_non_message_event_is_skipped(self): + from core.mcp.client.sse_client import read_messages + + mock_sse_event = self._make_mock_sse_event("endpoint", "/messages/") + + mock_client = Mock() + mock_client.events.return_value = [mock_sse_event] + + results = list(read_messages(mock_client)) + # Non-message events produce no output + assert results == [] + + def test_outer_exception_yields_exc(self): + from core.mcp.client.sse_client import read_messages + + boom = RuntimeError("stream broken") + mock_client = Mock() + mock_client.events.side_effect = boom + + results = list(read_messages(mock_client)) + assert len(results) == 1 + assert results[0] is boom + + def test_multiple_events_mixed(self): + from core.mcp.client.sse_client import read_messages + + valid_json = '{"jsonrpc": "2.0", "id": 2, "result": {}}' + events = [ + self._make_mock_sse_event("endpoint", "/messages/"), + self._make_mock_sse_event("message", valid_json), + self._make_mock_sse_event("message", "{bad json}"), + ] + + mock_client = Mock() + mock_client.events.return_value = events + + results = list(read_messages(mock_client)) + # endpoint is skipped; 1 valid SessionMessage + 1 Exception + assert len(results) == 2 + assert hasattr(results[0], "message") + assert isinstance(results[1], Exception) diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py index 9a30a35a49..81f8da9a62 100644 --- a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py @@ -4,14 +4,39 @@ Tests for the StreamableHTTP client transport. Contains tests for only the client side of the StreamableHTTP transport. """ +import json import queue import threading import time +from contextlib import contextmanager +from datetime import timedelta from typing import Any -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch + +import httpx +import pytest +from httpx_sse import ServerSentEvent from core.mcp import types -from core.mcp.client.streamable_client import streamablehttp_client +from core.mcp.client.streamable_client import ( + LAST_EVENT_ID, + MCP_SESSION_ID, + RequestContext, + ResumptionError, + StreamableHTTPError, + StreamableHTTPTransport, + streamablehttp_client, +) +from core.mcp.types import ( + ClientMessageMetadata, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + SessionMessage, +) # Test constants SERVER_NAME = "test_streamable_http_server" @@ -448,3 +473,1169 @@ def test_streamablehttp_client_resumption_token_handling(): assert write_queue is not None except Exception: pass # Expected due to mocking + + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def _make_request_msg(method: str = "ping", req_id: int = 1) -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=req_id, method=method)) + + +def _make_response_msg(req_id: int = 1, result: dict | None = None) -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=req_id, result=result or {})) + + +def _make_error_msg(req_id: int = 1, code: int = -32600) -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=code, message="err"))) + + +def _make_notification_msg(method: str = "notifications/initialized") -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCNotification(jsonrpc="2.0", method=method)) + + +def _make_sse_mock(event: str = "message", data: str = "", sse_id: str = "") -> ServerSentEvent: + # Use real ServerSentEvent since StreamableHTTPTransport requires its structure + return ServerSentEvent(event=event, data=data, id=sse_id, retry=None) + + +def _new_transport(url: str = "http://example.com/mcp", **kwargs) -> StreamableHTTPTransport: + return StreamableHTTPTransport(url, **kwargs) + + +# ── StreamableHTTPTransport.__init__ ───────────────────────────────────────── + + +class TestStreamableHTTPTransportInit: + def test_defaults(self): + t = _new_transport() + assert t.url == "http://example.com/mcp" + assert t.headers == {} + assert t.timeout == 30 + assert t.sse_read_timeout == 300 + assert t.session_id is None + assert t.stop_event is not None + assert t._active_responses == [] + + def test_timedelta_timeout_and_sse_read_timeout(self): + t = _new_transport(timeout=timedelta(seconds=10), sse_read_timeout=timedelta(seconds=120)) + assert t.timeout == 10.0 + assert t.sse_read_timeout == 120.0 + + def test_custom_headers_merged_into_request_headers(self): + t = _new_transport(headers={"Authorization": "Bearer tok"}) + assert t.request_headers["Authorization"] == "Bearer tok" + assert "Accept" in t.request_headers + assert "content-type" in t.request_headers + + +# ── _update_headers_with_session ───────────────────────────────────────────── + + +class TestUpdateHeadersWithSession: + def test_no_session_id_returns_copy_without_session_header(self): + t = _new_transport() + t.session_id = None + result = t._update_headers_with_session({"X-Foo": "bar"}) + assert result == {"X-Foo": "bar"} + assert MCP_SESSION_ID not in result + + def test_with_session_id_adds_header(self): + t = _new_transport() + t.session_id = "sess-abc" + result = t._update_headers_with_session({"X-Foo": "bar"}) + assert result[MCP_SESSION_ID] == "sess-abc" + assert result["X-Foo"] == "bar" + + +# ── _register_response / _unregister_response / close_active_responses ──────── + + +class TestResponseRegistry: + def test_register_and_unregister(self): + t = _new_transport() + resp = MagicMock(spec=httpx.Response) + t._register_response(resp) + assert resp in t._active_responses + t._unregister_response(resp) + assert resp not in t._active_responses + + def test_unregister_not_registered_does_not_raise(self): + t = _new_transport() + resp = MagicMock(spec=httpx.Response) + t._unregister_response(resp) # Should swallow ValueError silently + + def test_close_active_responses_calls_close(self): + t = _new_transport() + resp1 = MagicMock(spec=httpx.Response) + resp2 = MagicMock(spec=httpx.Response) + t._register_response(resp1) + t._register_response(resp2) + t.close_active_responses() + resp1.close.assert_called_once() + resp2.close.assert_called_once() + assert t._active_responses == [] + + def test_close_active_responses_swallows_runtime_error(self): + t = _new_transport() + resp = MagicMock(spec=httpx.Response) + resp.close.side_effect = RuntimeError("already closed") + t._register_response(resp) + t.close_active_responses() # Should not raise + + +# ── _is_initialization_request / _is_initialized_notification ──────────────── + + +class TestMessageClassifiers: + def test_is_initialization_request_true(self): + t = _new_transport() + assert t._is_initialization_request(_make_request_msg("initialize")) is True + + def test_is_initialization_request_false_other_method(self): + t = _new_transport() + assert t._is_initialization_request(_make_request_msg("tools/list")) is False + + def test_is_initialization_request_false_not_request(self): + t = _new_transport() + assert t._is_initialization_request(_make_response_msg()) is False + + def test_is_initialized_notification_true(self): + t = _new_transport() + assert t._is_initialized_notification(_make_notification_msg("notifications/initialized")) is True + + def test_is_initialized_notification_false_other_method(self): + t = _new_transport() + assert t._is_initialized_notification(_make_notification_msg("notifications/cancelled")) is False + + def test_is_initialized_notification_false_not_notification(self): + t = _new_transport() + assert t._is_initialized_notification(_make_request_msg("notifications/initialized")) is False + + +# ── _maybe_extract_session_id_from_response ─────────────────────────────────── + + +class TestMaybeExtractSessionIdNew: + def test_extracts_session_id_when_present(self): + t = _new_transport() + resp = MagicMock() + resp.headers = {MCP_SESSION_ID: "new-session-99"} + t._maybe_extract_session_id_from_response(resp) + assert t.session_id == "new-session-99" + + def test_no_session_id_header_leaves_none(self): + t = _new_transport() + resp = MagicMock() + resp.headers = MagicMock() + resp.headers.get = MagicMock(return_value=None) + t._maybe_extract_session_id_from_response(resp) + assert t.session_id is None + + +# ── _handle_sse_event ───────────────────────────────────────────────────────── + + +class TestHandleSseEventNew: + def test_message_event_response_returns_true(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("message", json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}})) + assert t._handle_sse_event(sse, q) is True + assert isinstance(q.get_nowait(), SessionMessage) + + def test_message_event_error_returns_true(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "bad"}}) + sse = _make_sse_mock("message", data) + assert t._handle_sse_event(sse, q) is True + + def test_message_event_notification_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "method": "notifications/something"}) + sse = _make_sse_mock("message", data) + assert t._handle_sse_event(sse, q) is False + assert isinstance(q.get_nowait(), SessionMessage) + + def test_message_event_empty_data_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("message", " ") + assert t._handle_sse_event(sse, q) is False + assert q.empty() + + def test_message_event_invalid_json_puts_exception(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("message", "{bad json}") + assert t._handle_sse_event(sse, q) is False + assert isinstance(q.get_nowait(), Exception) + + def test_message_event_replaces_original_request_id(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="") + t._handle_sse_event(sse, q, original_request_id=999) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert item.message.root.id == 999 + + def test_message_event_calls_resumption_callback_when_sse_id_present(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="token-abc") + callback = MagicMock() + t._handle_sse_event(sse, q, resumption_callback=callback) + callback.assert_called_once_with("token-abc") + + def test_message_event_no_callback_when_no_sse_id(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="") + callback = MagicMock() + t._handle_sse_event(sse, q, resumption_callback=callback) + callback.assert_not_called() + + def test_ping_event_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("ping", "") + assert t._handle_sse_event(sse, q) is False + assert q.empty() + + def test_unknown_event_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("custom_event", "{}") + assert t._handle_sse_event(sse, q) is False + assert q.empty() + + +# ── handle_get_stream ───────────────────────────────────────────────────────── + + +class TestHandleGetStreamNew: + def test_skips_when_no_session_id(self): + t = _new_transport() + t.session_id = None + q: queue.Queue = queue.Queue() + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + t.handle_get_stream(MagicMock(), q) + mock_connect.assert_not_called() + + def test_handles_messages_via_sse(self): + t = _new_transport() + t.session_id = "sess-1" + q: queue.Queue = queue.Queue() + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t.handle_get_stream(MagicMock(), q) + + assert isinstance(q.get_nowait(), SessionMessage) + + def test_stops_when_stop_event_set(self): + t = _new_transport() + t.session_id = "sess-1" + t.stop_event.set() + q: queue.Queue = queue.Queue() + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t.handle_get_stream(MagicMock(), q) + + assert q.empty() + + def test_exception_when_not_stopped_is_logged(self): + t = _new_transport() + t.session_id = "sess-1" + q: queue.Queue = queue.Queue() + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.side_effect = Exception("connection error") + t.handle_get_stream(MagicMock(), q) # Should not raise + + def test_exception_when_stopped_is_suppressed(self): + t = _new_transport() + t.session_id = "sess-1" + t.stop_event.set() + q: queue.Queue = queue.Queue() + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.side_effect = Exception("connection error") + t.handle_get_stream(MagicMock(), q) # Should not raise or log + + +# ── _handle_resumption_request ──────────────────────────────────────────────── + + +class TestHandleResumptionRequestNew: + def _make_ctx(self, transport, q, resumption_token="token-123", message=None) -> RequestContext: + if message is None: + message = _make_request_msg("tools/list", req_id=42) + session_msg = SessionMessage(message) + metadata = None + if resumption_token: + metadata = MagicMock(spec=ClientMessageMetadata) + metadata.resumption_token = resumption_token + metadata.on_resumption_token_update = MagicMock() + return RequestContext( + client=MagicMock(), + headers=transport.request_headers, + session_id=transport.session_id, + session_message=session_msg, + metadata=metadata, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + def test_raises_resumption_error_without_token(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + metadata = MagicMock(spec=ClientMessageMetadata) + metadata.resumption_token = None + ctx = RequestContext( + client=MagicMock(), + headers=t.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=metadata, + server_to_client_queue=q, + sse_read_timeout=60, + ) + with pytest.raises(ResumptionError): + t._handle_resumption_request(ctx) + + def test_raises_resumption_error_without_metadata(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = RequestContext( + client=MagicMock(), + headers=t.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=60, + ) + with pytest.raises(ResumptionError): + t._handle_resumption_request(ctx) + + def test_sets_last_event_id_header(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q, resumption_token="resume-999") + + captured_headers: dict = {} + data = json.dumps({"jsonrpc": "2.0", "id": 42, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + def fake_connect(url, headers, **kwargs): + captured_headers.update(headers) + + @contextmanager + def _ctx(): + yield mock_event_source + + return _ctx() + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect", side_effect=fake_connect): + t._handle_resumption_request(ctx) + + assert captured_headers.get(LAST_EVENT_ID) == "resume-999" + + def test_stops_when_response_complete(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q, message=_make_request_msg("tools/list", 42)) + + data1 = json.dumps({"jsonrpc": "2.0", "id": 42, "result": {}}) + data2 = json.dumps({"jsonrpc": "2.0", "id": 43, "result": {}}) + sse1 = _make_sse_mock("message", data1) + sse2 = _make_sse_mock("message", data2) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [sse1, sse2] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t._handle_resumption_request(ctx) + + # Only the first event was processed (loop breaks on completion) + assert q.qsize() == 1 + + def test_stops_when_stop_event_set(self): + t = _new_transport() + t.stop_event.set() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t._handle_resumption_request(ctx) + + assert q.empty() + + +# ── _handle_post_request ────────────────────────────────────────────────────── + + +class TestHandlePostRequestNew: + def _make_ctx(self, transport, q, message=None) -> RequestContext: + if message is None: + message = _make_request_msg("tools/list", 1) + return RequestContext( + client=MagicMock(), + headers=transport.request_headers, + session_id=transport.session_id, + session_message=SessionMessage(message), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + def _stream_ctx(self, mock_response): + @contextmanager + def _stream(*args, **kwargs): + yield mock_response + + return _stream + + def test_202_returns_immediately_no_queue(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + mock_resp = MagicMock() + mock_resp.status_code = 202 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + assert q.empty() + + def test_204_returns_immediately_no_queue(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + mock_resp = MagicMock() + mock_resp.status_code = 204 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + assert q.empty() + + def test_404_sends_session_terminated_error_for_request(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_request_msg("tools/list", 77) + ctx = self._make_ctx(t, q, message=msg) + mock_resp = MagicMock() + mock_resp.status_code = 404 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert isinstance(item.message.root, JSONRPCError) + assert item.message.root.id == 77 + + def test_404_for_notification_no_error_sent(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_notification_msg("some/notification") + ctx = self._make_ctx(t, q, message=msg) + mock_resp = MagicMock() + mock_resp.status_code = 404 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + assert q.empty() + + def test_json_response_puts_session_message(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + response_data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}).encode() + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = response_data + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert isinstance(q.get_nowait(), SessionMessage) + + def test_json_response_invalid_json_puts_exception(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = b"{bad json!" + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert isinstance(q.get_nowait(), Exception) + + def test_unexpected_content_type_puts_value_error(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "text/plain"} + mock_resp.raise_for_status.return_value = None + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + item = q.get_nowait() + assert isinstance(item, ValueError) + assert "Unexpected content type" in str(item) + + def test_initialization_request_extracts_session_id(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_request_msg("initialize", 1) + ctx = self._make_ctx(t, q, message=msg) + + response_data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = MagicMock() + headers_dict = {"content-type": "application/json", MCP_SESSION_ID: "new-sid"} + mock_resp.headers.__getitem__ = lambda self, k: headers_dict[k] + mock_resp.headers.get = lambda k, default=None: headers_dict.get(k, default) + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = response_data + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert t.session_id == "new-sid" + + def test_notification_skips_response_processing(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_notification_msg("notifications/something") + ctx = self._make_ctx(t, q, message=msg) + + response_data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = response_data + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert q.empty() + + def test_sse_response_handles_stream(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "text/event-stream"} + mock_resp.raise_for_status.return_value = None + ctx.client.stream = self._stream_ctx(mock_resp) + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [mock_sse_event] + MockEventSource.return_value = mock_es_instance + t._handle_post_request(ctx) + + assert isinstance(q.get_nowait(), SessionMessage) + + +# ── _handle_json_response ───────────────────────────────────────────────────── + + +class TestHandleJsonResponseNew: + def test_valid_json_puts_session_message(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + mock_response = MagicMock() + mock_response.read.return_value = data + t._handle_json_response(mock_response, q) + assert isinstance(q.get_nowait(), SessionMessage) + + def test_invalid_json_puts_exception(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + mock_response = MagicMock() + mock_response.read.return_value = b"{ invalid }" + t._handle_json_response(mock_response, q) + assert isinstance(q.get_nowait(), Exception) + + +# ── _handle_sse_response ────────────────────────────────────────────────────── + + +class TestHandleSseResponseNew: + def _ctx(self, transport, q) -> RequestContext: + return RequestContext( + client=MagicMock(), + headers=transport.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + def test_processes_sse_events(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [mock_sse_event] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + assert isinstance(q.get_nowait(), SessionMessage) + + def test_stops_when_stop_event_set(self): + t = _new_transport() + t.stop_event.set() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [mock_sse_event] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + assert q.empty() + + def test_stops_when_complete(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + + data1 = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + data2 = json.dumps({"jsonrpc": "2.0", "id": 2, "result": {}}) + sse1 = _make_sse_mock("message", data1) + sse2 = _make_sse_mock("message", data2) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [sse1, sse2] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + assert q.qsize() == 1 # Only the first completion item + + def test_exception_outside_stop_puts_to_queue(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + MockEventSource.side_effect = RuntimeError("EventSource error") + t._handle_sse_response(mock_response, ctx) + + assert isinstance(q.get_nowait(), Exception) + + def test_exception_suppressed_when_stopped(self): + t = _new_transport() + t.stop_event.set() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + MockEventSource.side_effect = RuntimeError("EventSource error") + t._handle_sse_response(mock_response, ctx) + + assert q.empty() + + def test_with_metadata_resumption_callback(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + metadata = MagicMock(spec=ClientMessageMetadata) + callback = MagicMock() + metadata.on_resumption_token_update = callback + + ctx = RequestContext( + client=MagicMock(), + headers=t.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=metadata, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="resume-token") + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [sse] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + callback.assert_called_once_with("resume-token") + + +# ── _handle_unexpected_content_type ────────────────────────────────────────── + + +class TestHandleUnexpectedContentTypeNew: + def test_puts_value_error_with_message(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + t._handle_unexpected_content_type("text/html", q) + item = q.get_nowait() + assert isinstance(item, ValueError) + assert "text/html" in str(item) + + +# ── _send_session_terminated_error ──────────────────────────────────────────── + + +class TestSendSessionTerminatedErrorNew: + def test_puts_jsonrpc_error(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + t._send_session_terminated_error(q, 42) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert isinstance(item.message.root, JSONRPCError) + assert item.message.root.id == 42 + assert item.message.root.error.code == 32600 + assert "terminated" in item.message.root.error.message.lower() + + +# ── post_writer ─────────────────────────────────────────────────────────────── + + +class TestPostWriterNew: + def test_none_message_exits_loop(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + c2s.put(None) + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + def test_stop_event_exits_loop(self): + t = _new_transport() + t.stop_event.set() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + def test_initialized_notification_calls_start_get_stream(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + start_get_stream = MagicMock() + + notif_msg = _make_notification_msg("notifications/initialized") + c2s.put(SessionMessage(notif_msg)) + c2s.put(None) + + with patch.object(t, "_handle_post_request"): + t.post_writer(MagicMock(), c2s, s2c, start_get_stream) + + start_get_stream.assert_called_once() + + def test_resumption_message_calls_handle_resumption_request(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + start_get_stream = MagicMock() + + msg = SessionMessage(_make_request_msg("tools/list", 10)) + metadata = MagicMock(spec=ClientMessageMetadata) + metadata.resumption_token = "resume-abc" + msg.metadata = metadata + c2s.put(msg) + c2s.put(None) + + with patch.object(t, "_handle_resumption_request") as mock_resumption: + t.post_writer(MagicMock(), c2s, s2c, start_get_stream) + + mock_resumption.assert_called_once() + + def test_regular_message_calls_handle_post_request(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + c2s.put(msg) + c2s.put(None) + + with patch.object(t, "_handle_post_request") as mock_post: + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + mock_post.assert_called_once() + + def test_exception_in_handler_put_to_s2c_when_not_stopped(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + c2s.put(msg) + c2s.put(None) + + boom = RuntimeError("oops") + with patch.object(t, "_handle_post_request", side_effect=boom): + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + item = s2c.get_nowait() + assert item is boom + + def test_exception_suppressed_when_stopped(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + c2s.put(msg) + c2s.put(None) + t.stop_event.set() + + boom = RuntimeError("oops") + with patch.object(t, "_handle_post_request", side_effect=boom): + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + assert s2c.empty() + + def test_queue_empty_timeout_continues_loop(self): + """Cover the 'except queue.Empty: continue' branch in post_writer.""" + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + call_count = {"n": 0} + + original_get = c2s.get + + def patched_get(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + raise queue.Empty + + c2s.get = patched_get # type: ignore[method-assign] + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + assert call_count["n"] >= 2 + + def test_non_client_metadata_treated_as_none(self): + """session_message.metadata that's not ClientMessageMetadata → metadata is None.""" + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + msg.metadata = "not-a-client-metadata" + c2s.put(msg) + c2s.put(None) + + with patch.object(t, "_handle_post_request") as mock_post: + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + ctx = mock_post.call_args[0][0] + assert ctx.metadata is None + + +# ── terminate_session ───────────────────────────────────────────────────────── + + +class TestTerminateSessionNew: + def test_no_session_id_skips(self): + t = _new_transport() + t.session_id = None + mock_client = MagicMock() + t.terminate_session(mock_client) + mock_client.delete.assert_not_called() + + def test_200_response_is_success(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.delete.return_value = mock_response + t.terminate_session(mock_client) + mock_client.delete.assert_called_once() + + def test_405_does_not_raise(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 405 + mock_client.delete.return_value = mock_response + t.terminate_session(mock_client) # Should not raise + + def test_non_200_logs_warning_does_not_raise(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 500 + mock_client.delete.return_value = mock_response + t.terminate_session(mock_client) # Should not raise + + def test_exception_is_swallowed(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_client.delete.side_effect = httpx.ConnectError("refused") + t.terminate_session(mock_client) # Should not raise + + +# ── get_session_id ──────────────────────────────────────────────────────────── + + +class TestGetSessionIdNew: + def test_returns_none_when_no_session(self): + t = _new_transport() + assert t.get_session_id() is None + + def test_returns_session_id_when_set(self): + t = _new_transport() + t.session_id = "my-session" + assert t.get_session_id() == "my-session" + + +# ── streamablehttp_client context manager ───────────────────────────────────── + + +class TestStreamablehttpClientContextManagerNew: + def test_yields_queues_and_callback(self): + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp") as (s2c, c2s, get_sid): + assert s2c is not None + assert c2s is not None + assert callable(get_sid) + + def test_terminate_on_close_false_does_not_delete(self): + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp", terminate_on_close=False) as (s2c, c2s, get_sid): + pass + mock_client.delete.assert_not_called() + + def test_queue_cleanup_on_outer_exception(self): + """Verify cleanup in finally block runs even when create_ssrf raises.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_cf.side_effect = RuntimeError("connection failed") + + with pytest.raises(RuntimeError): + with streamablehttp_client("http://example.com/mcp"): + pass # pragma: no cover + + def test_timedelta_args_accepted(self): + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client( + "http://example.com/mcp", + timeout=timedelta(seconds=15), + sse_read_timeout=timedelta(seconds=60), + ) as (s2c, c2s, get_sid): + assert callable(get_sid) + + def test_start_get_stream_submits_to_executor(self): + """When context starts, post_writer is submitted to executor.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + submitted_calls = [] + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + + def capture_submit(fn, *args, **kwargs): + submitted_calls.append((fn, args)) + + mock_executor.submit.side_effect = capture_submit + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp") as (s2c, c2s, get_sid): + pass + + # post_writer was submitted + assert len(submitted_calls) >= 1 + + def test_cleanup_puts_none_sentinels_to_queues(self): + """After context exit, None sentinels are put into both queues.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp") as (s2c, c2s, get_sid): + pass + + # After context exit, None sentinel should be in c2s queue from cleanup + val = c2s.get_nowait() + assert val is None + + def test_terminate_called_when_session_id_set(self): + """When session_id is set and terminate_on_close=True, terminate_session is called.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + mock_delete_resp = MagicMock() + mock_delete_resp.status_code = 200 + mock_client.delete.return_value = mock_delete_resp + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with patch("core.mcp.client.streamable_client.StreamableHTTPTransport") as MockTransport: + mock_transport = MockTransport.return_value + mock_transport.request_headers = { + "Accept": "application/json, text/event-stream", + "content-type": "application/json", + } + mock_transport.timeout = 30 + mock_transport.sse_read_timeout = 300 + mock_transport.session_id = "active-session" + mock_transport.stop_event = MagicMock() + mock_transport.get_session_id = MagicMock(return_value="active-session") + + with streamablehttp_client("http://example.com/mcp", terminate_on_close=True) as ( + s2c, + c2s, + get_sid, + ): + pass + + mock_transport.terminate_session.assert_called_once_with(mock_client) + + +# ── Exception hierarchy ─────────────────────────────────────────────────────── + + +class TestExceptionHierarchyNew: + def test_streamable_http_error_is_exception(self): + err = StreamableHTTPError("test") + assert isinstance(err, Exception) + + def test_resumption_error_is_streamable_http_error(self): + err = ResumptionError("test") + assert isinstance(err, StreamableHTTPError) + assert isinstance(err, Exception) + + +# ── RequestContext dataclass ────────────────────────────────────────────────── + + +class TestRequestContextNew: + def test_creation(self): + import queue + + q: queue.Queue = queue.Queue() + ctx = RequestContext( + client=MagicMock(), + headers={"X-Test": "val"}, + session_id="sid", + session_message=SessionMessage(_make_request_msg()), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=30.0, + ) + assert ctx.session_id == "sid" + assert ctx.sse_read_timeout == 30.0 + assert ctx.metadata is None diff --git a/api/tests/unit_tests/core/mcp/session/test_base_session.py b/api/tests/unit_tests/core/mcp/session/test_base_session.py new file mode 100644 index 0000000000..1dd916bcf1 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/session/test_base_session.py @@ -0,0 +1,617 @@ +import queue +import time +from concurrent.futures import Future, ThreadPoolExecutor +from datetime import timedelta +from typing import Union +from unittest.mock import MagicMock, patch + +import pytest +from httpx import HTTPStatusError, Request, Response +from pydantic import BaseModel, ConfigDict, RootModel + +from core.mcp.error import MCPAuthError, MCPConnectionError +from core.mcp.session.base_session import BaseSession, RequestResponder +from core.mcp.types import ( + CancelledNotification, + ClientNotification, + ClientRequest, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCResponse, + Notification, + RequestParams, + SessionMessage, +) +from core.mcp.types import ( + Request as MCPRequest, +) + + +class MockRequestParams(RequestParams): + name: str = "default" + model_config = ConfigDict(extra="allow") + + +class MockRequest(MCPRequest[MockRequestParams, str]): + method: str = "test/request" + params: MockRequestParams = MockRequestParams() + + +class MockResult(BaseModel): + result: str + + +class MockNotificationParams(BaseModel): + message: str + + +class MockNotification(Notification[MockNotificationParams, str]): + method: str = "test/notification" + params: MockNotificationParams + + +class ReceiveRequest(RootModel[Union[MockRequest, ClientRequest]]): + pass + + +class ReceiveNotification(RootModel[Union[CancelledNotification, MockNotification, JSONRPCNotification]]): + pass + + +class MockSession(BaseSession[MockRequest, MockNotification, MockResult, ReceiveRequest, ReceiveNotification]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.received_requests = [] + self.received_notifications = [] + self.handled_incoming = [] + + def _received_request(self, responder): + self.received_requests.append(responder) + + def _received_notification(self, notification): + self.received_notifications.append(notification) + + def _handle_incoming(self, item): + self.handled_incoming.append(item) + + +@pytest.fixture +def streams(): + return queue.Queue(), queue.Queue() + + +@pytest.mark.timeout(5) +def test_request_responder_respond(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + on_complete = MagicMock() + request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test"))) + + responder = RequestResponder( + request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete + ) + + with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"): + responder.respond(MockResult(result="ok")) + + with responder as r: + r.respond(MockResult(result="ok")) + with pytest.raises(AssertionError, match="Request already responded to"): + r.respond(MockResult(result="error")) + + assert responder.completed is True + on_complete.assert_called_once_with(responder) + + msg = write_stream.get_nowait() + assert isinstance(msg.message.root, JSONRPCResponse) + assert msg.message.root.result == {"result": "ok"} + + +@pytest.mark.timeout(5) +def test_request_responder_cancel(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + on_complete = MagicMock() + request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test"))) + + responder = RequestResponder( + request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete + ) + + with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"): + responder.cancel() + + with responder as r: + r.cancel() + + assert responder.completed is True + on_complete.assert_called_once_with(responder) + + msg = write_stream.get_nowait() + assert isinstance(msg.message.root, JSONRPCError) + assert msg.message.root.error.message == "Request cancelled" + + +@pytest.mark.timeout(10) +def test_base_session_lifecycle(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session as s: + assert isinstance(s, MockSession) + assert s._executor is not None + assert s._receiver_future is not None + + session._receiver_future.result(timeout=5.0) + assert session._receiver_future.done() + + +@pytest.mark.timeout(5) +def test_send_request_success(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_response(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "hello world"}) + read_stream.put(SessionMessage(message=JSONRPCMessage(response))) + except Exception: + pass + + import threading + + t = threading.Thread(target=mock_response, daemon=True) + t.start() + + with session: + result = session.send_request(request, MockResult) + assert result.result == "hello world" + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_retry_loop_coverage(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_delayed_response(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + time.sleep(0.2) + response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "slow"}) + read_stream.put(SessionMessage(message=JSONRPCMessage(response))) + except: + pass + + import threading + + t = threading.Thread(target=mock_delayed_response, daemon=True) + t.start() + + with session: + result = session.send_request(request, MockResult, request_read_timeout_seconds=timedelta(seconds=0.1)) + assert result.result == "slow" + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_jsonrpc_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=-32000, message="Error")) + read_stream.put(SessionMessage(message=JSONRPCMessage(error))) + except: + pass + + import threading + + t = threading.Thread(target=mock_error, daemon=True) + t.start() + + with session: + with pytest.raises(MCPConnectionError) as exc: + session.send_request(request, MockResult) + assert exc.value.args[0].message == "Error" + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_auth_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=401, message="Unauthorized")) + read_stream.put(SessionMessage(message=JSONRPCMessage(error))) + except: + pass + + import threading + + t = threading.Thread(target=mock_error, daemon=True) + t.start() + + with session: + with pytest.raises(MCPAuthError): + session.send_request(request, MockResult) + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_http_status_error_coverage(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_direct_http_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + # To cover line 263 in base_session.py, we MUST put non-401 HTTPStatusError + # DIRECTLY into response_streams, as _receive_loop would convert it to JSONRPCError. + response = Response(status_code=403, request=Request("GET", "http://test")) + error = HTTPStatusError("Forbidden", request=response.request, response=response) + session._response_streams[req_id].put(error) + except: + pass + + import threading + + t = threading.Thread(target=mock_direct_http_error, daemon=True) + t.start() + + # We still need the session for request ID generation and queue setup + with session: + with pytest.raises(MCPConnectionError) as exc: + session.send_request(request, MockResult) + assert exc.value.args[0].code == 403 + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_http_status_auth_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + response = Response(status_code=401, request=Request("GET", "http://test")) + error = HTTPStatusError("Unauthorized", request=response.request, response=response) + read_stream.put(error) + except: + pass + + import threading + + t = threading.Thread(target=mock_error, daemon=True) + t.start() + + with session: + with pytest.raises(MCPAuthError): + session.send_request(request, MockResult) + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_notification(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + notification = MockNotification(method="notify", params=MockNotificationParams(message="hi")) + + session.send_notification(notification, related_request_id="rel-1") + + msg = write_stream.get_nowait() + assert isinstance(msg.message.root, JSONRPCNotification) + assert msg.message.root.method == "notify" + assert msg.message.root.params == {"message": "hi"} + assert msg.metadata.related_request_id == "rel-1" + + +@pytest.mark.timeout(10) +def test_receive_loop_request(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + req_payload = {"jsonrpc": "2.0", "id": 1, "method": "test/request", "params": {"name": "test"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload))) + + for _ in range(30): + if session.received_requests: + break + time.sleep(0.1) + + assert len(session.received_requests) == 1 + responder = session.received_requests[0] + assert responder.request_id == 1 + assert responder.request.root.method == "test/request" + + +@pytest.mark.timeout(10) +def test_receive_loop_notification(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + notif_payload = {"jsonrpc": "2.0", "method": "test/notification", "params": {"message": "hello"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload))) + + for _ in range(30): + if session.received_notifications: + break + time.sleep(0.1) + + assert len(session.received_notifications) == 1 + assert isinstance(session.received_notifications[0].root, MockNotification) + assert session.received_notifications[0].root.method == "test/notification" + + +@pytest.mark.timeout(15) +def test_receive_loop_cancel_notification(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ClientNotification) + + with session: + req_payload = {"jsonrpc": "2.0", "id": "req-1", "method": "test/request", "params": {"name": "test"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload))) + + for _ in range(30): + if "req-1" in session._in_flight: + break + time.sleep(0.1) + + assert "req-1" in session._in_flight + responder = session._in_flight["req-1"] + + with responder: + cancel_payload = {"jsonrpc": "2.0", "method": "notifications/cancelled", "params": {"requestId": "req-1"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(cancel_payload))) + + for _ in range(30): + if responder.completed: + break + time.sleep(0.1) + + assert responder.completed is True + msg = write_stream.get(timeout=2) + assert isinstance(msg.message.root, JSONRPCError) + assert msg.message.root.id == "req-1" + + +@pytest.mark.timeout(10) +def test_receive_loop_exception(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + read_stream.put(Exception("Unexpected error")) + for _ in range(30): + if any(isinstance(x, Exception) for x in session.handled_incoming): + break + time.sleep(0.1) + + assert any(isinstance(x, Exception) and str(x) == "Unexpected error" for x in session.handled_incoming) + + +@pytest.mark.timeout(10) +def test_receive_loop_http_status_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + session._request_id = 1 + resp_queue = queue.Queue() + session._response_streams[0] = resp_queue + + response = Response(status_code=401, request=Request("GET", "http://test")) + # Using 401 specifically as _receive_loop preserves it + error = HTTPStatusError("Unauthorized", request=response.request, response=response) + read_stream.put(error) + + got = resp_queue.get(timeout=2) + assert isinstance(got, HTTPStatusError) + + +@pytest.mark.timeout(10) +def test_receive_loop_http_status_error_non_401(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + session._request_id = 1 + resp_queue = queue.Queue() + session._response_streams[0] = resp_queue + + response = Response(status_code=500, request=Request("GET", "http://test")) + error = HTTPStatusError("Server Error", request=response.request, response=response) + read_stream.put(error) + + got = resp_queue.get(timeout=2) + assert isinstance(got, JSONRPCError) + assert got.error.code == 500 + + +@pytest.mark.timeout(5) +def test_check_receiver_status_fail(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + executor = ThreadPoolExecutor(max_workers=1) + + def raise_err(): + raise RuntimeError("Receiver failed") + + future = executor.submit(raise_err) + session._receiver_future = future + + try: + future.result() + except: + pass + + with pytest.raises(RuntimeError, match="Receiver failed"): + session.check_receiver_status() + executor.shutdown() + + +@pytest.mark.timeout(10) +def test_receive_loop_unknown_request_id(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + resp = JSONRPCResponse(jsonrpc="2.0", id=999, result={"ok": True}) + read_stream.put(SessionMessage(message=JSONRPCMessage(resp))) + + for _ in range(30): + if any(isinstance(x, RuntimeError) and "Server Error" in str(x) for x in session.handled_incoming): + break + time.sleep(0.1) + + assert any("Server Error" in str(x) for x in session.handled_incoming) + + +@pytest.mark.timeout(10) +def test_receive_loop_http_error_unknown_id(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + response = Response(status_code=401, request=Request("GET", "http://test")) + error = HTTPStatusError("Unauthorized", request=response.request, response=response) + read_stream.put(error) + + for _ in range(30): + if any(isinstance(x, RuntimeError) and "unknown request ID" in str(x) for x in session.handled_incoming): + break + time.sleep(0.1) + + assert any("unknown request ID" in str(x) for x in session.handled_incoming) + + +@pytest.mark.timeout(10) +def test_receive_loop_validation_error_notification(streams): + from core.mcp.session.base_session import logger + + with patch.object(logger, "warning") as mock_warning: + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, RootModel[MockNotification]) + + with session: + notif_payload = {"jsonrpc": "2.0", "method": "bad", "params": {"some": "data"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload))) + time.sleep(1.0) + + assert mock_warning.called + + +@pytest.mark.timeout(5) +def test_send_request_none_response(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_none(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + session._response_streams[req_id].put(None) + except: + pass + + import threading + + t = threading.Thread(target=mock_none, daemon=True) + t.start() + + with session: + with pytest.raises(MCPConnectionError) as exc: + session.send_request(request, MockResult) + assert exc.value.args[0].message == "No response received" + t.join(timeout=1) + + +@pytest.mark.timeout(15) +def test_session_exit_timeout(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + mock_future = MagicMock(spec=Future) + mock_future.result.side_effect = TimeoutError() + mock_future.done.return_value = False + + session._receiver_future = mock_future + session._executor = MagicMock(spec=ThreadPoolExecutor) + + session.__exit__(None, None, None) + + mock_future.cancel.assert_called_once() + session._executor.shutdown.assert_called_once_with(wait=False) + + +@pytest.mark.timeout(10) +def test_receive_loop_fatal_exception(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with patch.object(read_stream, "get", side_effect=RuntimeError("Fatal loop error")): + with patch("core.mcp.session.base_session.logger") as mock_logger: + with pytest.raises(RuntimeError, match="Fatal loop error"): + with session: + pass + mock_logger.exception.assert_called_with("Error in message processing loop") + + +@pytest.mark.timeout(5) +def test_receive_loop_empty_coverage(streams): + with patch("core.mcp.session.base_session.DEFAULT_RESPONSE_READ_TIMEOUT", 0.1): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + with session: + time.sleep(0.3) + + +@pytest.mark.timeout(2) +def test_base_methods_noop(streams): + read_stream, write_stream = streams + session = BaseSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + session._received_request(MagicMock()) + session._received_notification(MagicMock()) + session.send_progress_notification("token", 0.5) + session._handle_incoming(MagicMock()) + + +@pytest.mark.timeout(5) +def test_send_request_session_timeout_retry_6(streams): + read_stream, write_stream = streams + session = MockSession( + read_stream, write_stream, ReceiveRequest, ReceiveNotification, read_timeout_seconds=timedelta(seconds=0.1) + ) + + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + with patch.object(session, "check_receiver_status", side_effect=[None, RuntimeError("timeout_broken")]): + with pytest.raises(RuntimeError, match="timeout_broken"): + session.send_request(request, MockResult) diff --git a/api/tests/unit_tests/core/mcp/session/test_client_session.py b/api/tests/unit_tests/core/mcp/session/test_client_session.py new file mode 100644 index 0000000000..c7b9d3cfa9 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/session/test_client_session.py @@ -0,0 +1,576 @@ +import queue +from unittest.mock import MagicMock + +import pytest +from pydantic import AnyUrl + +from core.mcp import types +from core.mcp.session.base_session import RequestResponder, SessionMessage +from core.mcp.session.client_session import ( + ClientSession, + _default_list_roots_callback, + _default_logging_callback, + _default_message_handler, + _default_sampling_callback, +) + + +@pytest.fixture +def streams(): + return queue.Queue(), queue.Queue() + + +def test_client_session_init(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + assert session._client_info.name == "Dify" + assert session._sampling_callback == _default_sampling_callback + assert session._list_roots_callback == _default_list_roots_callback + assert session._logging_callback == _default_logging_callback + assert session._message_handler == _default_message_handler + + +def test_client_session_init_custom(streams): + read_stream, write_stream = streams + sampling_cb = MagicMock() + list_roots_cb = MagicMock() + logging_cb = MagicMock() + msg_handler = MagicMock() + client_info = types.Implementation(name="Custom", version="1.0") + + session = ClientSession( + read_stream, + write_stream, + sampling_callback=sampling_cb, + list_roots_callback=list_roots_cb, + logging_callback=logging_cb, + message_handler=msg_handler, + client_info=client_info, + ) + + assert session._client_info == client_info + assert session._sampling_callback == sampling_cb + assert session._list_roots_callback == list_roots_cb + assert session._logging_callback == logging_cb + assert session._message_handler == msg_handler + + +def test_initialize_success(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + expected_result = types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ServerCapabilities(), + serverInfo=types.Implementation(name="test-server", version="1.0"), + ) + + def mock_server(): + # Handle initialize request + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result=expected_result.model_dump()) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + # Expect initialized notification + notif = write_stream.get(timeout=2) + assert notif.message.root.method == "notifications/initialized" + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.initialize() + assert result.protocolVersion == types.LATEST_PROTOCOL_VERSION + assert result.serverInfo.name == "test-server" + + t.join(timeout=1) + + +def test_initialize_custom_capabilities(streams): + read_stream, write_stream = streams + session = ClientSession( + read_stream, write_stream, sampling_callback=lambda c, p: None, list_roots_callback=lambda c: None + ) + + def mock_server(): + msg = write_stream.get(timeout=2) + params = msg.message.root.params + # Check that capabilities are set because we provided custom callbacks + assert params["capabilities"]["sampling"] is not None + assert params["capabilities"]["roots"]["listChanged"] is True + + req_id = msg.message.root.id + resp = types.JSONRPCResponse( + jsonrpc="2.0", + id=req_id, + result={ + "protocolVersion": types.LATEST_PROTOCOL_VERSION, + "capabilities": {}, + "serverInfo": {"name": "test", "version": "1.0"}, + }, + ) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + write_stream.get(timeout=2) # initialized notif + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.initialize() + t.join(timeout=1) + + +def test_initialize_unsupported_version(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + resp = types.JSONRPCResponse( + jsonrpc="2.0", + id=req_id, + result={ + "protocolVersion": "0.0.1", # Unsupported + "capabilities": {}, + "serverInfo": {"name": "test", "version": "1.0"}, + }, + ) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + with pytest.raises(RuntimeError, match="Unsupported protocol version"): + session.initialize() + t.join(timeout=1) + + +def test_send_ping(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "ping" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.send_ping() + t.join(timeout=1) + + +def test_send_progress_notification(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + session.send_progress_notification(progress_token="token", progress=50.0, total=100.0) + + msg = write_stream.get_nowait() + assert msg.message.root.method == "notifications/progress" + assert msg.message.root.params["progressToken"] == "token" + assert msg.message.root.params["progress"] == 50.0 + assert msg.message.root.params["total"] == 100.0 + + +def test_set_logging_level(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "logging/setLevel" + assert msg.message.root.params["level"] == "debug" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.set_logging_level("debug") + t.join(timeout=1) + + +def test_list_resources(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"resources": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_resources() + assert result.resources == [] + t.join(timeout=1) + + +def test_list_resource_templates(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/templates/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"resourceTemplates": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_resource_templates() + assert result.resourceTemplates == [] + t.join(timeout=1) + + +def test_read_resource(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + uri = AnyUrl("file:///test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/read" + assert msg.message.root.params["uri"] == str(uri) + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"contents": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.read_resource(uri) + assert result.contents == [] + t.join(timeout=1) + + +def test_subscribe_resource(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + uri = AnyUrl("file:///test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/subscribe" + assert msg.message.root.params["uri"] == str(uri) + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.subscribe_resource(uri) + t.join(timeout=1) + + +def test_unsubscribe_resource(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + uri = AnyUrl("file:///test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/unsubscribe" + assert msg.message.root.params["uri"] == str(uri) + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.unsubscribe_resource(uri) + t.join(timeout=1) + + +def test_call_tool(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "tools/call" + assert msg.message.root.params["name"] == "test-tool" + assert msg.message.root.params["arguments"] == {"arg": 1} + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"content": [], "isError": False}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.call_tool("test-tool", arguments={"arg": 1}) + assert result.isError is False + t.join(timeout=1) + + +def test_list_prompts(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "prompts/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"prompts": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_prompts() + assert result.prompts == [] + t.join(timeout=1) + + +def test_get_prompt(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "prompts/get" + assert msg.message.root.params["name"] == "test-prompt" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"messages": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.get_prompt("test-prompt") + assert result.messages == [] + t.join(timeout=1) + + +def test_complete(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + ref = types.PromptReference(type="ref/prompt", name="test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "completion/complete" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"completion": {"values": [], "hasMore": False}}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.complete(ref, argument={"name": "val", "value": "x"}) + assert result.completion.hasMore is False + t.join(timeout=1) + + +def test_list_tools(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "tools/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"tools": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_tools() + assert result.tools == [] + t.join(timeout=1) + + +def test_send_roots_list_changed(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + session.send_roots_list_changed() + + msg = write_stream.get_nowait() + assert msg.message.root.method == "notifications/roots/list_changed" + + +def test_received_request_sampling(streams): + read_stream, write_stream = streams + sampling_cb = MagicMock( + return_value=types.CreateMessageResult( + role="assistant", content=types.TextContent(type="text", text="hello"), model="gpt-4" + ) + ) + session = ClientSession(read_stream, write_stream, sampling_callback=sampling_cb) + + req = types.ServerRequest( + root=types.CreateMessageRequest( + method="sampling/createMessage", params=types.CreateMessageRequestParams(messages=[], maxTokens=100) + ) + ) + + responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock()) + + session._received_request(responder) + + msg = write_stream.get_nowait() + assert msg.message.root.result["model"] == "gpt-4" + sampling_cb.assert_called_once() + + +def test_received_request_list_roots(streams): + read_stream, write_stream = streams + list_roots_cb = MagicMock(return_value=types.ListRootsResult(roots=[])) + session = ClientSession(read_stream, write_stream, list_roots_callback=list_roots_cb) + + req = types.ServerRequest(root=types.ListRootsRequest(method="roots/list")) + + responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock()) + + session._received_request(responder) + + msg = write_stream.get_nowait() + assert msg.message.root.result["roots"] == [] + list_roots_cb.assert_called_once() + + +def test_received_request_ping(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + req = types.ServerRequest(root=types.PingRequest(method="ping")) + + responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock()) + + session._received_request(responder) + + msg = write_stream.get_nowait() + assert msg.message.root.result == {} + + +def test_handle_incoming(streams): + read_stream, write_stream = streams + msg_handler = MagicMock() + session = ClientSession(read_stream, write_stream, message_handler=msg_handler) + + item = MagicMock() + session._handle_incoming(item) + msg_handler.assert_called_once_with(item) + + +def test_received_notification_logging(streams): + read_stream, write_stream = streams + logging_cb = MagicMock() + session = ClientSession(read_stream, write_stream, logging_callback=logging_cb) + + notif = types.ServerNotification( + root=types.LoggingMessageNotification( + method="notifications/message", + params=types.LoggingMessageNotificationParams(level="info", data={"msg": "test"}), + ) + ) + + session._received_notification(notif) + logging_cb.assert_called_once() + assert logging_cb.call_args[0][0].level == "info" + + +def test_default_message_handler(): + # Exception case + with pytest.raises(ValueError, match="test error"): + _default_message_handler(Exception("test error")) + + # Notification case - should do nothing + _default_message_handler(MagicMock(spec=types.ServerNotification)) + + # RequestResponder case - should do nothing + _default_message_handler(MagicMock(spec=RequestResponder)) + + +def test_default_sampling_callback(): + ctx = MagicMock() + params = MagicMock() + res = _default_sampling_callback(ctx, params) + assert res.code == types.INVALID_REQUEST + assert "not supported" in res.message + + +def test_default_list_roots_callback(): + ctx = MagicMock() + res = _default_list_roots_callback(ctx) + assert res.code == types.INVALID_REQUEST + assert "not supported" in res.message + + +def test_default_logging_callback(): + params = MagicMock() + _default_logging_callback(params) # Should do nothing + + +def test_received_notification_unknown(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + # Use a notification type that is NOT LoggingMessageNotification + notif = types.ServerNotification( + root=types.ResourceListChangedNotification(method="notifications/resources/list_changed") + ) + + session._received_notification(notif) + # Should just pass (case _:) diff --git a/api/tests/unit_tests/core/mcp/test_mcp_client.py b/api/tests/unit_tests/core/mcp/test_mcp_client.py index c0420d3371..c245b4a77e 100644 --- a/api/tests/unit_tests/core/mcp/test_mcp_client.py +++ b/api/tests/unit_tests/core/mcp/test_mcp_client.py @@ -2,13 +2,16 @@ from contextlib import ExitStack from types import TracebackType -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest +from sqlalchemy.orm import Session -from core.mcp.error import MCPConnectionError +from core.entities.mcp_provider import MCPProviderEntity +from core.mcp.auth_client import MCPClientWithAuthRetry +from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.mcp_client import MCPClient -from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations +from core.mcp.types import CallToolResult, ListToolsResult, OAuthTokens, TextContent, Tool, ToolAnnotations class TestMCPClient: @@ -380,3 +383,256 @@ class TestMCPClient: timeout=30.0, sse_read_timeout=60.0, ) + + +class TestMCPClientWithAuthRetry: + """Test suite for MCPClientWithAuthRetry.""" + + @pytest.fixture + def mock_provider(self): + provider = MagicMock(spec=MCPProviderEntity) + provider.id = "test-provider-id" + provider.tenant_id = "test-tenant-id" + provider.retrieve_tokens.return_value = OAuthTokens( + access_token="new-token", + token_type="Bearer", + expires_in=3600, + refresh_token="refresh-token", + ) + return provider + + @pytest.fixture + def auth_client(self, mock_provider): + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + headers={"Authorization": "Bearer old-token"}, + provider_entity=mock_provider, + authorization_code="test-code", + by_server_id=True, + ) + return client + + def test_init(self, mock_provider): + """Test initialization.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + headers={"Authorization": "Bearer test"}, + timeout=30.0, + provider_entity=mock_provider, + authorization_code="initial-code", + by_server_id=True, + ) + + assert client.server_url == "http://test.example.com" + assert client.headers == {"Authorization": "Bearer test"} + assert client.timeout == 30.0 + assert client.provider_entity == mock_provider + assert client.authorization_code == "initial-code" + assert client.by_server_id is True + assert client._has_retried is False + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_success( + self, mock_service_class, mock_session_class, mock_db, auth_client, mock_provider + ): + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_service = mock_service_class.return_value + new_provider = MagicMock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = OAuthTokens( + access_token="new-access-token", + token_type="Bearer", + expires_in=3600, + refresh_token="new-refresh-token", + ) + mock_service.get_provider_entity.return_value = new_provider + + # MCPAuthError parses resource_metadata and scope from www_authenticate_header + www_auth = 'Bearer resource_metadata="http://meta", scope="read"' + error = MCPAuthError("Auth failed", www_authenticate_header=www_auth) + + auth_client._handle_auth_error(error) + + # Verify service calls - error.resource_metadata_url and error.scope_hint are parsed from header + mock_service.auth_with_actions.assert_called_once_with( + mock_provider, + "test-code", + resource_metadata_url="http://meta", + scope_hint="read", + ) + mock_service.get_provider_entity.assert_called_once_with( + mock_provider.id, mock_provider.tenant_id, by_server_id=True + ) + + # Verify client updates + assert auth_client.headers["Authorization"] == "Bearer new-access-token" + assert auth_client.authorization_code is None + assert auth_client._has_retried is True + assert auth_client.provider_entity == new_provider + + def test_handle_auth_error_no_provider(self, auth_client): + """Test auth error handling when no provider entity is set.""" + auth_client.provider_entity = None + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert exc_info.value == error + + def test_handle_auth_error_already_retried(self, auth_client): + """Test auth error handling when already retried.""" + auth_client._has_retried = True + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert exc_info.value == error + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_no_token( + self, mock_service_class, mock_session_class, mock_db, auth_client, mock_provider + ): + """Test auth error handling when no token is received.""" + mock_session_class.return_value.__enter__.return_value = MagicMock() + mock_service = mock_service_class.return_value + + new_provider = MagicMock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = None + mock_service.get_provider_entity.return_value = new_provider + + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert "Authentication failed - no token received" in str(exc_info.value) + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_generic_exception(self, mock_service_class, mock_session_class, mock_db, auth_client): + """Test auth error handling when a generic exception occurs.""" + mock_session_class.side_effect = Exception("DB error") + + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert "Authentication retry failed: DB error" in str(exc_info.value) + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_mcp_auth_error_propagation( + self, mock_service_class, mock_session_class, mock_db, auth_client + ): + """Test that MCPAuthError during refresh is propagated as is.""" + mock_session_class.return_value.__enter__.return_value = MagicMock() + mock_service = mock_service_class.return_value + mock_service.auth_with_actions.side_effect = MCPAuthError("Refresh failed") + + error = MCPAuthError("Initial auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert "Refresh failed" in str(exc_info.value) + + def test_execute_with_retry_success_first_try(self, auth_client): + """Test execution success on first try.""" + mock_func = MagicMock(return_value="success") + + result = auth_client._execute_with_retry(mock_func, "arg1", kwarg1="val1") + + assert result == "success" + mock_func.assert_called_once_with("arg1", kwarg1="val1") + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_handle_auth_error") + @patch.object(MCPClientWithAuthRetry, "_initialize") + def test_execute_with_retry_success_on_retry_initialized(self, mock_initialize, mock_handle_auth, auth_client): + """Test execution success on retry after auth error when client was already initialized.""" + mock_func = MagicMock() + mock_func.side_effect = [MCPAuthError("Auth failed"), "success"] + + auth_client._initialized = True + auth_client._exit_stack = MagicMock() + + result = auth_client._execute_with_retry(mock_func, "arg") + + assert result == "success" + assert mock_func.call_count == 2 + mock_handle_auth.assert_called_once() + mock_initialize.assert_called_once() + auth_client._exit_stack.close.assert_called_once() + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_handle_auth_error") + @patch.object(MCPClientWithAuthRetry, "_initialize") + def test_execute_with_retry_success_on_retry_not_initialized(self, mock_initialize, mock_handle_auth, auth_client): + """Test retry when client was NOT initialized (skips cleanup/re-init).""" + mock_func = MagicMock() + mock_func.side_effect = [MCPAuthError("Auth failed"), "result"] + + auth_client._initialized = False + + result = auth_client._execute_with_retry(mock_func, "arg") + + assert result == "result" + assert mock_func.call_count == 2 + mock_handle_auth.assert_called_once() + mock_initialize.assert_not_called() + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_handle_auth_error") + def test_execute_with_retry_failure_on_retry(self, mock_handle_auth, auth_client): + """Test execution failure even after retry.""" + mock_func = MagicMock() + mock_func.side_effect = [MCPAuthError("First fail"), MCPAuthError("Second fail")] + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._execute_with_retry(mock_func, "arg") + + assert "Second fail" in str(exc_info.value) + assert mock_func.call_count == 2 + mock_handle_auth.assert_called_once() + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_execute_with_retry") + def test_auth_client_context_manager_enter(self, mock_execute_retry, auth_client): + """Test context manager __enter__.""" + auth_client.__enter__() + + mock_execute_retry.assert_called_once() + func = mock_execute_retry.call_args[0][0] + + with patch("core.mcp.mcp_client.MCPClient.__enter__") as mock_base_enter: + result = func() + assert result == auth_client + mock_base_enter.assert_called_once() + + @patch.object(MCPClientWithAuthRetry, "_execute_with_retry") + def test_auth_client_list_tools(self, mock_execute_retry, auth_client): + """Test list_tools with retry.""" + auth_client.list_tools() + + mock_execute_retry.assert_called_once() + assert mock_execute_retry.call_args[0][0].__name__ == "list_tools" + + @patch.object(MCPClientWithAuthRetry, "_execute_with_retry") + def test_auth_client_invoke_tool(self, mock_execute_retry, auth_client): + """Test invoke_tool with retry.""" + auth_client.invoke_tool("test-tool", {"arg": "val"}) + + mock_execute_retry.assert_called_once() + assert mock_execute_retry.call_args[0][0].__name__ == "invoke_tool" + assert mock_execute_retry.call_args[0][1] == "test-tool" + assert mock_execute_retry.call_args[0][2] == {"arg": "val"}