test: added for core logging and core mcp (#32478)

Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
This commit is contained in:
mahammadasim
2026-03-12 09:14:56 +05:30
committed by GitHub
parent 245f6b824d
commit 60fe5e7f00
7 changed files with 3859 additions and 5 deletions

View File

@ -82,6 +82,68 @@ class TestTraceContextFilter:
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
assert log_record.span_id == "051581bf3bb55c45"
def test_otel_context_invalid_trace_id(self, log_record):
from core.logging.filters import TraceContextFilter
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.trace_id = 0
mock_context.is_valid = True
mock_span.get_span_context.return_value = mock_context
# Use mocks for base context to ensure we can test the fallback
with (
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("core.logging.filters.get_trace_id", return_value=""),
):
filter = TraceContextFilter()
filter.filter(log_record)
assert log_record.trace_id == ""
def test_otel_context_invalid_span_id(self, log_record):
from core.logging.filters import TraceContextFilter
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
mock_context.span_id = 0
mock_context.is_valid = True
mock_span.get_span_context.return_value = mock_context
with (
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
filter = TraceContextFilter()
filter.filter(log_record)
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
assert log_record.span_id == ""
def test_otel_context_span_none(self, log_record):
from core.logging.filters import TraceContextFilter
with (
mock.patch("opentelemetry.trace.get_current_span", return_value=None),
mock.patch("core.logging.filters.get_trace_id", return_value=""),
):
filter = TraceContextFilter()
filter.filter(log_record)
assert log_record.trace_id == ""
def test_otel_context_exception(self, log_record):
from core.logging.filters import TraceContextFilter
# Trigger exception in OTEL block
with (
mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception),
mock.patch("core.logging.filters.get_trace_id", return_value=""),
):
filter = TraceContextFilter()
filter.filter(log_record)
assert log_record.trace_id == ""
class TestIdentityContextFilter:
def test_sets_empty_identity_without_request_context(self, log_record):
@ -114,3 +176,119 @@ class TestIdentityContextFilter:
result = filter.filter(log_record)
assert result is True
assert log_record.tenant_id == ""
def test_sets_empty_identity_unauthenticated(self, log_record):
from core.logging.filters import IdentityContextFilter
mock_user = mock.MagicMock()
mock_user.is_authenticated = False
with (
mock.patch("flask.has_request_context", return_value=True),
mock.patch("flask_login.current_user", mock_user),
):
filter = IdentityContextFilter()
filter.filter(log_record)
assert log_record.user_id == ""
def test_sets_identity_for_account(self, log_record):
from core.logging.filters import IdentityContextFilter
class MockAccount:
pass
mock_user = MockAccount()
mock_user.id = "account_id"
mock_user.current_tenant_id = "tenant_id"
mock_user.is_authenticated = True
with (
mock.patch("flask.has_request_context", return_value=True),
mock.patch("models.Account", MockAccount),
mock.patch("flask_login.current_user", mock_user),
):
filter = IdentityContextFilter()
filter.filter(log_record)
assert log_record.tenant_id == "tenant_id"
assert log_record.user_id == "account_id"
assert log_record.user_type == "account"
def test_sets_identity_for_account_no_tenant(self, log_record):
from core.logging.filters import IdentityContextFilter
class MockAccount:
pass
mock_user = MockAccount()
mock_user.id = "account_id"
mock_user.current_tenant_id = None
mock_user.is_authenticated = True
with (
mock.patch("flask.has_request_context", return_value=True),
mock.patch("models.Account", MockAccount),
mock.patch("flask_login.current_user", mock_user),
):
filter = IdentityContextFilter()
filter.filter(log_record)
assert log_record.tenant_id == ""
assert log_record.user_id == "account_id"
assert log_record.user_type == "account"
def test_sets_identity_for_end_user(self, log_record):
from core.logging.filters import IdentityContextFilter
class MockEndUser:
pass
class AnotherClass:
pass
mock_user = MockEndUser()
mock_user.id = "end_user_id"
mock_user.tenant_id = "tenant_id"
mock_user.type = "custom_type"
mock_user.is_authenticated = True
with (
mock.patch("flask.has_request_context", return_value=True),
mock.patch("models.model.EndUser", MockEndUser),
mock.patch("models.Account", AnotherClass),
mock.patch("flask_login.current_user", mock_user),
):
filter = IdentityContextFilter()
filter.filter(log_record)
assert log_record.tenant_id == "tenant_id"
assert log_record.user_id == "end_user_id"
assert log_record.user_type == "custom_type"
def test_sets_identity_for_end_user_default_type(self, log_record):
from core.logging.filters import IdentityContextFilter
class MockEndUser:
pass
class AnotherClass:
pass
mock_user = MockEndUser()
mock_user.id = "end_user_id"
mock_user.tenant_id = "tenant_id"
mock_user.type = None
mock_user.is_authenticated = True
with (
mock.patch("flask.has_request_context", return_value=True),
mock.patch("models.model.EndUser", MockEndUser),
mock.patch("models.Account", AnotherClass),
mock.patch("flask_login.current_user", mock_user),
):
filter = IdentityContextFilter()
filter.filter(log_record)
assert log_record.tenant_id == "tenant_id"
assert log_record.user_id == "end_user_id"
assert log_record.user_type == "end_user"

View File

@ -1,27 +1,39 @@
"""Unit tests for MCP OAuth authentication flow."""
import json
from unittest.mock import Mock, patch
import httpx
import pytest
from pydantic import ValidationError
from core.entities.mcp_provider import MCPProviderEntity
from core.helper import ssrf_proxy
from core.mcp.auth.auth_flow import (
OAUTH_STATE_EXPIRY_SECONDS,
OAUTH_STATE_REDIS_KEY_PREFIX,
OAuthCallbackState,
_create_secure_redis_state,
_parse_token_response,
_retrieve_redis_state,
auth,
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
check_support_resource_discovery,
client_credentials_flow,
discover_oauth_authorization_server_metadata,
discover_oauth_metadata,
discover_protected_resource_metadata,
exchange_authorization,
generate_pkce_challenge,
get_effective_scope,
handle_callback,
refresh_authorization,
register_client,
start_authorization,
)
from core.mcp.entities import AuthActionType, AuthResult
from core.mcp.error import MCPRefreshTokenError
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
@ -764,3 +776,555 @@ class TestAuthOrchestration:
auth(mock_provider, authorization_code="auth-code")
assert "Existing OAuth client information is required" in str(exc_info.value)
def test_generate_pkce_challenge(self):
verifier, challenge = generate_pkce_challenge()
assert verifier
assert challenge
assert "=" not in verifier
assert "=" not in challenge
def test_build_protected_resource_metadata_discovery_urls(self):
# Case 1: WWW-Auth URL provided
urls = build_protected_resource_metadata_discovery_urls(
"https://auth.example.com/prm", "https://api.example.com"
)
assert "https://auth.example.com/prm" in urls
assert "https://api.example.com/.well-known/oauth-protected-resource" in urls
# Case 2: No WWW-Auth URL, with path
urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com/v1")
assert "https://api.example.com/.well-known/oauth-protected-resource/v1" in urls
assert "https://api.example.com/.well-known/oauth-protected-resource" in urls
# Case 3: No path
urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com")
assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"]
def test_build_oauth_authorization_server_metadata_discovery_urls(self):
# Case 1: with auth_server_url
urls = build_oauth_authorization_server_metadata_discovery_urls(
"https://auth.example.com", "https://api.example.com"
)
assert "https://auth.example.com/.well-known/oauth-authorization-server" in urls
assert "https://auth.example.com/.well-known/openid-configuration" in urls
# Case 2: with path
urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://api.example.com/tenant")
assert "https://api.example.com/.well-known/oauth-authorization-server/tenant" in urls
assert "https://api.example.com/tenant/.well-known/openid-configuration" in urls
@patch("core.helper.ssrf_proxy.get")
def test_discover_protected_resource_metadata(self, mock_get):
# Success
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"resource": "https://api.example.com",
"authorization_servers": ["https://auth"],
}
mock_get.return_value = mock_response
result = discover_protected_resource_metadata(None, "https://api.example.com")
assert result is not None
assert result.resource == "https://api.example.com"
# 404 then Success
res404 = Mock()
res404.status_code = 404
mock_get.side_effect = [res404, mock_response]
result = discover_protected_resource_metadata(None, "https://api.example.com/path")
assert result is not None
assert result.resource == "https://api.example.com"
# Error handling
mock_get.side_effect = httpx.RequestError("Error")
result = discover_protected_resource_metadata(None, "https://api.example.com")
assert result is None
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_authorization_server_metadata(self, mock_get):
# Success
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"authorization_endpoint": "https://auth.example.com/auth",
"token_endpoint": "https://auth.example.com/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
assert result is not None
assert result.authorization_endpoint == "https://auth.example.com/auth"
# 404
res404 = Mock()
res404.status_code = 404
mock_get.side_effect = [res404, mock_response]
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com/tenant")
assert result is not None
assert result.authorization_endpoint == "https://auth.example.com/auth"
# ValidationError
mock_response.json.return_value = {"invalid": "data"}
mock_get.side_effect = None
mock_get.return_value = mock_response
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
assert result is None
def test_get_effective_scope(self):
prm = ProtectedResourceMetadata(
resource="https://api.example.com",
authorization_servers=["https://auth"],
scopes_supported=["read", "write"],
)
asm = OAuthMetadata(
authorization_endpoint="https://auth.example.com/auth",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
scopes_supported=["openid", "profile"],
)
# 1. WWW-Auth priority
assert get_effective_scope("scope1", prm, asm, "client") == "scope1"
# 2. PRM priority
assert get_effective_scope(None, prm, asm, "client") == "read write"
# 3. ASM priority
assert get_effective_scope(None, None, asm, "client") == "openid profile"
# 4. Client configured
assert get_effective_scope(None, None, None, "client") == "client"
@patch("core.mcp.auth.auth_flow.redis_client")
def test_redis_state_management(self, mock_redis):
state_data = OAuthCallbackState(
provider_id="p1",
tenant_id="t1",
server_url="https://api",
metadata=None,
client_information=OAuthClientInformation(client_id="c1"),
code_verifier="cv",
redirect_uri="https://re",
)
# Create
state_key = _create_secure_redis_state(state_data)
assert state_key
mock_redis.setex.assert_called_once()
# Retrieve Success
mock_redis.get.return_value = state_data.model_dump_json()
retrieved = _retrieve_redis_state(state_key)
assert retrieved.provider_id == "p1"
mock_redis.delete.assert_called_once()
# Retrieve Failure - Not found
mock_redis.get.return_value = None
with pytest.raises(ValueError, match="expired or does not exist"):
_retrieve_redis_state("absent")
# Retrieve Failure - Invalid JSON
mock_redis.get.return_value = "invalid"
with pytest.raises(ValueError, match="Invalid state parameter"):
_retrieve_redis_state("invalid")
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_handle_callback(self, mock_exchange, mock_retrieve):
state = Mock(spec=OAuthCallbackState)
state.server_url = "https://api"
state.metadata = None
state.client_information = Mock()
state.code_verifier = "cv"
state.redirect_uri = "https://re"
mock_retrieve.return_value = state
tokens = Mock(spec=OAuthTokens)
mock_exchange.return_value = tokens
s, t = handle_callback("key", "code")
assert s == state
assert t == tokens
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery(self, mock_get):
# Case 1: authorization_servers (plural)
res = Mock()
res.status_code = 200
res.json.return_value = {"authorization_servers": ["https://auth1"]}
mock_get.return_value = res
supported, url = check_support_resource_discovery("https://api")
assert supported is True
assert url == "https://auth1"
# Case 2: authorization_server_url (singular alias)
res.json.return_value = {"authorization_server_url": ["https://auth2"]}
supported, url = check_support_resource_discovery("https://api")
assert supported is True
assert url == "https://auth2"
# Case 3: Missing fields
res.json.return_value = {"nothing": []}
supported, url = check_support_resource_discovery("https://api")
assert supported is False
# Case 4: 404
res.status_code = 404
supported, url = check_support_resource_discovery("https://api")
assert supported is False
# Case 5: RequestError
mock_get.side_effect = httpx.RequestError("Error")
supported, url = check_support_resource_discovery("https://api")
assert supported is False
def test_discover_oauth_metadata(self):
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
mock_prm.return_value = ProtectedResourceMetadata(
resource="https://api", authorization_servers=["https://auth"]
)
mock_asm.return_value = Mock(spec=OAuthMetadata)
asm, prm, hint = discover_oauth_metadata("https://api")
assert asm == mock_asm.return_value
assert prm == mock_prm.return_value
mock_asm.assert_called_with("https://auth", "https://api", None)
def test_start_authorization(self):
metadata = OAuthMetadata(
authorization_endpoint="https://auth/authorize",
token_endpoint="https://auth/token",
response_types_supported=["code"],
)
client_info = OAuthClientInformation(client_id="c1")
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create:
mock_create.return_value = "state-key"
# Success with scope
url, verifier = start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1", "read")
assert "scope=read" in url
assert "state=state-key" in url
# Success without metadata
url, verifier = start_authorization("https://api", None, client_info, "https://re", "p1", "t1")
assert "https://api/authorize" in url
# Failure: incompatible auth server
metadata.response_types_supported = ["implicit"]
with pytest.raises(ValueError, match="Incompatible auth server"):
start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1")
def test_parse_token_response(self):
# Case 1: JSON
res = Mock()
res.headers = {"content-type": "application/json"}
res.json.return_value = {"access_token": "at", "token_type": "Bearer"}
tokens = _parse_token_response(res)
assert tokens.access_token == "at"
# Case 2: Form-urlencoded
res.headers = {"content-type": "application/x-www-form-urlencoded"}
res.text = "access_token=at2&token_type=Bearer"
tokens = _parse_token_response(res)
assert tokens.access_token == "at2"
# Case 3: No content-type, but JSON
res.headers = {}
res.json.return_value = {"access_token": "at3", "token_type": "Bearer"}
tokens = _parse_token_response(res)
assert tokens.access_token == "at3"
# Case 4: No content-type, not JSON, but Form
res.json.side_effect = json.JSONDecodeError("msg", "doc", 0)
res.text = "access_token=at4&token_type=Bearer"
tokens = _parse_token_response(res)
assert tokens.access_token == "at4"
# Case 5: Validation Error fallback
res.json.side_effect = ValidationError.from_exception_data("error", [])
res.text = "access_token=at5&token_type=Bearer"
tokens = _parse_token_response(res)
assert tokens.access_token == "at5"
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization(self, mock_post):
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
metadata = OAuthMetadata(
authorization_endpoint="https://auth/authorize",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
# Success
res = Mock()
res.is_success = True
res.headers = {"content-type": "application/json"}
res.json.return_value = {"access_token": "at", "token_type": "Bearer"}
mock_post.return_value = res
tokens = exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
assert tokens.access_token == "at"
# Failure: Unsupported grant type
metadata.grant_types_supported = ["client_credentials"]
with pytest.raises(ValueError, match="Incompatible auth server"):
exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
# Failure: HTTP error
metadata.grant_types_supported = ["authorization_code"]
res.is_success = False
res.status_code = 400
with pytest.raises(ValueError, match="Token exchange failed"):
exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
@patch("core.helper.ssrf_proxy.post")
def test_refresh_authorization(self, mock_post):
# Case 1: with client_secret
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
# Success
res = Mock()
res.is_success = True
res.headers = {"content-type": "application/json"}
res.json.return_value = {"access_token": "at_new", "token_type": "Bearer"}
mock_post.return_value = res
tokens = refresh_authorization("https://api", None, client_info, "rt")
assert tokens.access_token == "at_new"
assert mock_post.call_args[1]["data"]["client_secret"] == "s1"
# Failure: MaxRetriesExceededError
mock_post.side_effect = ssrf_proxy.MaxRetriesExceededError("Too many retries")
with pytest.raises(MCPRefreshTokenError):
refresh_authorization("https://api", None, client_info, "rt")
# Failure: HTTP error
mock_post.side_effect = None
res.is_success = False
res.text = "error_msg"
with pytest.raises(MCPRefreshTokenError, match="error_msg"):
refresh_authorization("https://api", None, client_info, "rt")
# Failure: Incompatible metadata
metadata = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
with pytest.raises(ValueError, match="Incompatible auth server"):
refresh_authorization("https://api", metadata, client_info, "rt")
@patch("core.helper.ssrf_proxy.post")
def test_client_credentials_flow(self, mock_post):
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
# Success with secret
res = Mock()
res.is_success = True
res.headers = {"content-type": "application/json"}
res.json.return_value = {"access_token": "at_cc", "token_type": "Bearer"}
mock_post.return_value = res
tokens = client_credentials_flow("https://api", None, client_info, "read")
assert tokens.access_token == "at_cc"
args, kwargs = mock_post.call_args
assert "Authorization" in kwargs["headers"]
# Success without secret
client_info_no_secret = OAuthClientInformation(client_id="c2")
tokens = client_credentials_flow("https://api", None, client_info_no_secret)
args, kwargs = mock_post.call_args
assert kwargs["data"]["client_id"] == "c2"
# Failure: Incompatible metadata
metadata = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
with pytest.raises(ValueError, match="Incompatible auth server"):
client_credentials_flow("https://api", metadata, client_info)
# Failure: HTTP error
res.is_success = False
res.status_code = 401
res.text = "Unauthorized"
with pytest.raises(ValueError, match="Client credentials token request failed"):
client_credentials_flow("https://api", None, client_info)
@patch("core.helper.ssrf_proxy.post")
def test_register_client(self, mock_post):
# Case 1: Success with metadata
metadata = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
registration_endpoint="https://auth/register",
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://re"])
res = Mock()
res.is_success = True
res.json.return_value = {
"client_id": "c_new",
"client_secret": "s_new",
"client_name": "Dify",
"redirect_uris": ["https://re"],
}
mock_post.return_value = res
info = register_client("https://api", metadata, client_metadata)
assert info.client_id == "c_new"
# Case 2: Success without metadata
info = register_client("https://api", None, client_metadata)
assert mock_post.call_args[0][0] == "https://api/register"
# Case 3: Metadata provided but no endpoint
metadata.registration_endpoint = None
with pytest.raises(ValueError, match="does not support dynamic client registration"):
register_client("https://api", metadata, client_metadata)
# Failure: HTTP
res.is_success = False
res.raise_for_status = Mock()
res.status_code = 400
# If is_success is false, it should call raise_for_status
register_client("https://api", None, client_metadata)
res.raise_for_status.assert_called_once()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_orchestration_failures(self, mock_discover):
provider = Mock(spec=MCPProviderEntity)
provider.decrypt_server_url.return_value = "https://api"
provider.id = "p1"
provider.tenant_id = "t1"
# Case 1: No server metadata
mock_discover.return_value = (None, None, None)
with pytest.raises(ValueError, match="Failed to discover OAuth metadata"):
auth(provider)
# Case 2: No client info, exchange code provided
asm = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
)
mock_discover.return_value = (asm, None, None)
provider.retrieve_client_information.return_value = None
with pytest.raises(ValueError, match="Existing OAuth client information is required"):
auth(provider, authorization_code="code")
# Case 3: CLIENT_CREDENTIALS but client must provide info
asm.grant_types_supported = ["client_credentials"]
with pytest.raises(ValueError, match="requires client_id and client_secret"):
auth(provider)
# Case 4: Client registration fails
asm.grant_types_supported = ["authorization_code"]
with patch("core.mcp.auth.auth_flow.register_client") as mock_reg:
mock_reg.side_effect = httpx.RequestError("Reg failed")
with pytest.raises(ValueError, match="Could not register OAuth client"):
auth(provider)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_orchestration_client_credentials(self, mock_discover):
provider = Mock(spec=MCPProviderEntity)
provider.decrypt_server_url.return_value = "https://api"
provider.id = "p1"
provider.tenant_id = "t1"
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1", client_secret="s1")
provider.decrypt_credentials.return_value = {"scope": "read"}
asm = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["client_credentials"],
)
mock_discover.return_value = (asm, None, None)
with patch("core.mcp.auth.auth_flow.client_credentials_flow") as mock_cc:
mock_cc.return_value = OAuthTokens(access_token="at_cc", token_type="Bearer")
result = auth(provider)
assert result.response == {"result": "success"}
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
assert result.actions[0].data["grant_type"] == "client_credentials"
# Failure in CC flow
mock_cc.side_effect = ValueError("CC Failed")
with pytest.raises(ValueError, match="Client credentials flow failed"):
auth(provider)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_orchestration_authorization_code(self, mock_discover):
provider = Mock(spec=MCPProviderEntity)
provider.decrypt_server_url.return_value = "https://api"
provider.id = "p1"
provider.tenant_id = "t1"
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1")
provider.decrypt_credentials.return_value = {}
asm = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_discover.return_value = (asm, None, None)
# Case 1: Exchange code
with patch("core.mcp.auth.auth_flow._retrieve_redis_state") as mock_retrieve:
state = Mock(spec=OAuthCallbackState)
state.code_verifier = "cv"
state.redirect_uri = "https://re"
mock_retrieve.return_value = state
with patch("core.mcp.auth.auth_flow.exchange_authorization") as mock_exchange:
mock_exchange.return_value = OAuthTokens(access_token="at_code", token_type="Bearer")
# Success
result = auth(provider, authorization_code="code", state_param="sp")
assert result.response == {"result": "success"}
# Missing state_param
with pytest.raises(ValueError, match="State parameter is required"):
auth(provider, authorization_code="code")
# Missing verifier in state
state.code_verifier = None
with pytest.raises(ValueError, match="Missing code_verifier"):
auth(provider, authorization_code="code", state_param="sp")
# Invalid state
mock_retrieve.side_effect = ValueError("Invalid")
with pytest.raises(ValueError, match="Invalid state parameter"):
auth(provider, authorization_code="code", state_param="sp")
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_orchestration_refresh_failure(self, mock_discover):
provider = Mock(spec=MCPProviderEntity)
provider.decrypt_server_url.return_value = "https://api"
provider.id = "p1"
provider.tenant_id = "t1"
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1")
provider.decrypt_credentials.return_value = {}
provider.retrieve_tokens.return_value = OAuthTokens(access_token="at", token_type="Bearer", refresh_token="rt")
asm = OAuthMetadata(
authorization_endpoint="https://auth/auth",
token_endpoint="https://auth/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_discover.return_value = (asm, None, None)
with patch("core.mcp.auth.auth_flow.refresh_authorization") as mock_refresh:
mock_refresh.side_effect = ValueError("Refresh Failed")
with pytest.raises(ValueError, match="Could not refresh OAuth tokens"):
auth(provider)

View File

@ -322,3 +322,475 @@ def test_sse_client_concurrent_access():
assert len(received_messages) == 10
for i in range(10):
assert f"message_{i}" in received_messages
class TestStatusClasses:
"""Tests for _StatusReady and _StatusError data containers."""
def test_status_ready_stores_endpoint(self):
from core.mcp.client.sse_client import _StatusReady
status = _StatusReady("http://example.com/messages/")
assert status.endpoint_url == "http://example.com/messages/"
def test_status_error_stores_exception(self):
from core.mcp.client.sse_client import _StatusError
exc = ValueError("bad endpoint")
status = _StatusError(exc)
assert status.exc is exc
class TestSSETransportInit:
"""Tests for SSETransport default and explicit init values."""
def test_defaults(self):
from core.mcp.client.sse_client import SSETransport
t = SSETransport("http://example.com/sse")
assert t.url == "http://example.com/sse"
assert t.headers == {}
assert t.timeout == 5.0
assert t.sse_read_timeout == 60.0
assert t.endpoint_url is None
assert t.event_source is None
def test_explicit_headers_not_mutated(self):
from core.mcp.client.sse_client import SSETransport
hdrs = {"X-Foo": "bar"}
t = SSETransport("http://example.com/sse", headers=hdrs)
assert t.headers is hdrs
class TestHandleEndpointEvent:
"""Tests for SSETransport._handle_endpoint_event covering the invalid-origin branch."""
def test_invalid_origin_puts_status_error(self):
from core.mcp.client.sse_client import SSETransport, _StatusError
transport = SSETransport("http://example.com/sse")
status_queue: queue.Queue = queue.Queue()
# Provide a full URL with a different origin so urljoin keeps it as-is
transport._handle_endpoint_event("http://evil.com/messages/", status_queue)
result = status_queue.get_nowait()
assert isinstance(result, _StatusError)
assert "does not match" in str(result.exc)
def test_valid_origin_puts_status_ready(self):
from core.mcp.client.sse_client import SSETransport, _StatusReady
transport = SSETransport("http://example.com/sse")
status_queue: queue.Queue = queue.Queue()
transport._handle_endpoint_event("/messages/?session_id=abc", status_queue)
result = status_queue.get_nowait()
assert isinstance(result, _StatusReady)
assert "example.com" in result.endpoint_url
class TestHandleSSEEvent:
"""Tests for SSETransport._handle_sse_event covering all match branches."""
def _make_sse(self, event_type: str, data: str):
sse = Mock()
sse.event = event_type
sse.data = data
return sse
def test_message_event_dispatched(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
read_queue: queue.Queue = queue.Queue()
status_queue: queue.Queue = queue.Queue()
valid_msg = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
transport._handle_sse_event(self._make_sse("message", valid_msg), read_queue, status_queue)
item = read_queue.get_nowait()
assert hasattr(item, "message")
def test_unknown_event_logs_warning_and_does_nothing(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
read_queue: queue.Queue = queue.Queue()
status_queue: queue.Queue = queue.Queue()
transport._handle_sse_event(self._make_sse("ping", "{}"), read_queue, status_queue)
assert read_queue.empty()
assert status_queue.empty()
class TestSSEReader:
"""Tests for SSETransport.sse_reader exception branches."""
def test_read_error_closes_cleanly(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
read_queue: queue.Queue = queue.Queue()
status_queue: queue.Queue = queue.Queue()
event_source = Mock()
event_source.iter_sse.side_effect = httpx.ReadError("connection reset")
transport.sse_reader(event_source, read_queue, status_queue)
# Finally block always puts None as sentinel
sentinel = read_queue.get_nowait()
assert sentinel is None
def test_generic_exception_puts_exc_then_none(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
read_queue: queue.Queue = queue.Queue()
status_queue: queue.Queue = queue.Queue()
boom = RuntimeError("unexpected!")
event_source = Mock()
event_source.iter_sse.side_effect = boom
transport.sse_reader(event_source, read_queue, status_queue)
exc_item = read_queue.get_nowait()
assert exc_item is boom
sentinel = read_queue.get_nowait()
assert sentinel is None
class TestSendMessage:
"""Tests for SSETransport._send_message."""
def _make_session_message(self):
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
msg = types.JSONRPCMessage.model_validate_json(msg_json)
return types.SessionMessage(msg)
def test_sends_post_and_raises_for_status(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
mock_response = Mock()
mock_response.status_code = 200
mock_client = Mock()
mock_client.post.return_value = mock_response
session_msg = self._make_session_message()
transport._send_message(mock_client, "http://example.com/messages/", session_msg)
mock_client.post.assert_called_once()
mock_response.raise_for_status.assert_called_once()
class TestPostWriter:
"""Tests for SSETransport.post_writer exception branches."""
def _make_session_message(self):
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
msg = types.JSONRPCMessage.model_validate_json(msg_json)
return types.SessionMessage(msg)
def test_none_message_exits_loop(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
write_queue: queue.Queue = queue.Queue()
write_queue.put(None) # Signal shutdown immediately
mock_client = Mock()
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
# Should put final None sentinel
sentinel = write_queue.get_nowait()
assert sentinel is None
def test_exception_in_message_put_back_to_queue(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
write_queue: queue.Queue = queue.Queue()
exc = ValueError("some error")
write_queue.put(exc) # Exception goes in first
write_queue.put(None) # Then shutdown signal
mock_client = Mock()
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
# The exception should be re-queued, then None from loop exit, then None from finally
item1 = write_queue.get_nowait()
assert isinstance(item1, Exception)
def test_read_error_shuts_down_cleanly(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
write_queue: queue.Queue = queue.Queue()
session_msg = self._make_session_message()
write_queue.put(session_msg)
mock_response = Mock()
mock_response.status_code = 200
mock_client = Mock()
mock_client.post.side_effect = httpx.ReadError("connection dropped")
# post_writer calls _send_message which calls client.post → ReadError propagates
# The ReadError is raised inside _send_message → propagates out of the while loop
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
# finally always puts None
sentinel = write_queue.get_nowait()
assert sentinel is None
def test_generic_exception_puts_exc_in_queue(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
write_queue: queue.Queue = queue.Queue()
session_msg = self._make_session_message()
write_queue.put(session_msg)
mock_client = Mock()
boom = RuntimeError("boom")
mock_client.post.side_effect = boom
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
exc_item = write_queue.get_nowait()
assert isinstance(exc_item, Exception)
sentinel = write_queue.get_nowait()
assert sentinel is None
def test_queue_empty_timeout_continues_loop(self):
"""Cover the 'except queue.Empty: continue' branch (line 188) in post_writer."""
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
write_queue: queue.Queue = queue.Queue()
mock_client = Mock()
# Patch queue.Queue.get so it raises Empty first, then returns None (shutdown)
call_count = {"n": 0}
original_get = write_queue.get
def patched_get(*args, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
raise queue.Empty
write_queue.get = patched_get # type: ignore[method-assign]
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
# finally always puts None sentinel
sentinel = write_queue.get_nowait()
assert sentinel is None
assert call_count["n"] >= 2 # Empty on first, None on second (and possibly more retries)
class TestWaitForEndpoint:
"""Tests for SSETransport._wait_for_endpoint edge cases."""
def test_raises_on_empty_queue(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
status_queue: queue.Queue = queue.Queue() # empty
with pytest.raises(ValueError, match="failed to get endpoint URL"):
transport._wait_for_endpoint(status_queue)
def test_raises_status_error_exception(self):
from core.mcp.client.sse_client import SSETransport, _StatusError
transport = SSETransport("http://example.com/sse")
status_queue: queue.Queue = queue.Queue()
exc = ValueError("malicious endpoint")
status_queue.put(_StatusError(exc))
with pytest.raises(ValueError, match="malicious endpoint"):
transport._wait_for_endpoint(status_queue)
def test_raises_on_unknown_status_type(self):
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
status_queue: queue.Queue = queue.Queue()
# Put an object that is neither _StatusReady nor _StatusError
status_queue.put("unexpected_value")
with pytest.raises(ValueError, match="failed to get endpoint URL"):
transport._wait_for_endpoint(status_queue)
class TestSSEClientRuntimeError:
"""Test sse_client context manager handles RuntimeError on close()."""
def test_runtime_error_on_close_is_suppressed(self):
"""Ensure RuntimeError raised by event_source.response.close() is caught."""
test_url = "http://test.example/sse"
class MockSSEEvent:
def __init__(self, event_type: str, data: str):
self.event = event_type
self.data = data
endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123")
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_cf:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sc:
mock_client = Mock()
mock_cf.return_value.__enter__.return_value = mock_client
mock_es = Mock()
mock_es.response.raise_for_status.return_value = None
mock_es.iter_sse.return_value = [endpoint_event]
# Make close() raise RuntimeError to exercise line 307-308
mock_es.response.close.side_effect = RuntimeError("already closed")
mock_sc.return_value.__enter__.return_value = mock_es
# Should NOT raise even though close() raises RuntimeError
with contextlib.suppress(Exception):
with sse_client(test_url) as (rq, wq):
pass
class TestStandaloneSendMessage:
"""Tests for the module-level send_message() function."""
def _make_session_message(self):
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
msg = types.JSONRPCMessage.model_validate_json(msg_json)
return types.SessionMessage(msg)
def test_send_message_success(self):
from core.mcp.client.sse_client import send_message
mock_response = Mock()
mock_response.status_code = 200
mock_http_client = Mock()
mock_http_client.post.return_value = mock_response
session_msg = self._make_session_message()
send_message(mock_http_client, "http://example.com/messages/", session_msg)
mock_http_client.post.assert_called_once()
mock_response.raise_for_status.assert_called_once()
def test_send_message_raises_on_http_error(self):
from core.mcp.client.sse_client import send_message
mock_http_client = Mock()
mock_http_client.post.side_effect = httpx.ConnectError("refused")
session_msg = self._make_session_message()
with pytest.raises(httpx.ConnectError):
send_message(mock_http_client, "http://example.com/messages/", session_msg)
def test_send_message_raises_for_status_failure(self):
from core.mcp.client.sse_client import send_message
mock_response = Mock()
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Not Found", request=Mock(), response=Mock(status_code=404)
)
mock_http_client = Mock()
mock_http_client.post.return_value = mock_response
session_msg = self._make_session_message()
with pytest.raises(httpx.HTTPStatusError):
send_message(mock_http_client, "http://example.com/messages/", session_msg)
class TestReadMessages:
"""Tests for the module-level read_messages() generator."""
def _make_mock_sse_event(self, event_type: str, data: str):
ev = Mock()
ev.event = event_type
ev.data = data
return ev
def test_valid_message_event_yields_session_message(self):
from core.mcp.client.sse_client import read_messages
valid_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
mock_sse_event = self._make_mock_sse_event("message", valid_json)
mock_client = Mock()
mock_client.events.return_value = [mock_sse_event]
results = list(read_messages(mock_client))
assert len(results) == 1
assert hasattr(results[0], "message")
def test_invalid_json_yields_exception(self):
from core.mcp.client.sse_client import read_messages
mock_sse_event = self._make_mock_sse_event("message", "{not valid json}")
mock_client = Mock()
mock_client.events.return_value = [mock_sse_event]
results = list(read_messages(mock_client))
assert len(results) == 1
assert isinstance(results[0], Exception)
def test_non_message_event_is_skipped(self):
from core.mcp.client.sse_client import read_messages
mock_sse_event = self._make_mock_sse_event("endpoint", "/messages/")
mock_client = Mock()
mock_client.events.return_value = [mock_sse_event]
results = list(read_messages(mock_client))
# Non-message events produce no output
assert results == []
def test_outer_exception_yields_exc(self):
from core.mcp.client.sse_client import read_messages
boom = RuntimeError("stream broken")
mock_client = Mock()
mock_client.events.side_effect = boom
results = list(read_messages(mock_client))
assert len(results) == 1
assert results[0] is boom
def test_multiple_events_mixed(self):
from core.mcp.client.sse_client import read_messages
valid_json = '{"jsonrpc": "2.0", "id": 2, "result": {}}'
events = [
self._make_mock_sse_event("endpoint", "/messages/"),
self._make_mock_sse_event("message", valid_json),
self._make_mock_sse_event("message", "{bad json}"),
]
mock_client = Mock()
mock_client.events.return_value = events
results = list(read_messages(mock_client))
# endpoint is skipped; 1 valid SessionMessage + 1 Exception
assert len(results) == 2
assert hasattr(results[0], "message")
assert isinstance(results[1], Exception)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,617 @@
import queue
import time
from concurrent.futures import Future, ThreadPoolExecutor
from datetime import timedelta
from typing import Union
from unittest.mock import MagicMock, patch
import pytest
from httpx import HTTPStatusError, Request, Response
from pydantic import BaseModel, ConfigDict, RootModel
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.session.base_session import BaseSession, RequestResponder
from core.mcp.types import (
CancelledNotification,
ClientNotification,
ClientRequest,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCResponse,
Notification,
RequestParams,
SessionMessage,
)
from core.mcp.types import (
Request as MCPRequest,
)
class MockRequestParams(RequestParams):
name: str = "default"
model_config = ConfigDict(extra="allow")
class MockRequest(MCPRequest[MockRequestParams, str]):
method: str = "test/request"
params: MockRequestParams = MockRequestParams()
class MockResult(BaseModel):
result: str
class MockNotificationParams(BaseModel):
message: str
class MockNotification(Notification[MockNotificationParams, str]):
method: str = "test/notification"
params: MockNotificationParams
class ReceiveRequest(RootModel[Union[MockRequest, ClientRequest]]):
pass
class ReceiveNotification(RootModel[Union[CancelledNotification, MockNotification, JSONRPCNotification]]):
pass
class MockSession(BaseSession[MockRequest, MockNotification, MockResult, ReceiveRequest, ReceiveNotification]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.received_requests = []
self.received_notifications = []
self.handled_incoming = []
def _received_request(self, responder):
self.received_requests.append(responder)
def _received_notification(self, notification):
self.received_notifications.append(notification)
def _handle_incoming(self, item):
self.handled_incoming.append(item)
@pytest.fixture
def streams():
return queue.Queue(), queue.Queue()
@pytest.mark.timeout(5)
def test_request_responder_respond(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
on_complete = MagicMock()
request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test")))
responder = RequestResponder(
request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete
)
with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"):
responder.respond(MockResult(result="ok"))
with responder as r:
r.respond(MockResult(result="ok"))
with pytest.raises(AssertionError, match="Request already responded to"):
r.respond(MockResult(result="error"))
assert responder.completed is True
on_complete.assert_called_once_with(responder)
msg = write_stream.get_nowait()
assert isinstance(msg.message.root, JSONRPCResponse)
assert msg.message.root.result == {"result": "ok"}
@pytest.mark.timeout(5)
def test_request_responder_cancel(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
on_complete = MagicMock()
request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test")))
responder = RequestResponder(
request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete
)
with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"):
responder.cancel()
with responder as r:
r.cancel()
assert responder.completed is True
on_complete.assert_called_once_with(responder)
msg = write_stream.get_nowait()
assert isinstance(msg.message.root, JSONRPCError)
assert msg.message.root.error.message == "Request cancelled"
@pytest.mark.timeout(10)
def test_base_session_lifecycle(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session as s:
assert isinstance(s, MockSession)
assert s._executor is not None
assert s._receiver_future is not None
session._receiver_future.result(timeout=5.0)
assert session._receiver_future.done()
@pytest.mark.timeout(5)
def test_send_request_success(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_response():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "hello world"})
read_stream.put(SessionMessage(message=JSONRPCMessage(response)))
except Exception:
pass
import threading
t = threading.Thread(target=mock_response, daemon=True)
t.start()
with session:
result = session.send_request(request, MockResult)
assert result.result == "hello world"
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_request_retry_loop_coverage(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_delayed_response():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
time.sleep(0.2)
response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "slow"})
read_stream.put(SessionMessage(message=JSONRPCMessage(response)))
except:
pass
import threading
t = threading.Thread(target=mock_delayed_response, daemon=True)
t.start()
with session:
result = session.send_request(request, MockResult, request_read_timeout_seconds=timedelta(seconds=0.1))
assert result.result == "slow"
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_request_jsonrpc_error(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_error():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=-32000, message="Error"))
read_stream.put(SessionMessage(message=JSONRPCMessage(error)))
except:
pass
import threading
t = threading.Thread(target=mock_error, daemon=True)
t.start()
with session:
with pytest.raises(MCPConnectionError) as exc:
session.send_request(request, MockResult)
assert exc.value.args[0].message == "Error"
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_request_auth_error(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_error():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=401, message="Unauthorized"))
read_stream.put(SessionMessage(message=JSONRPCMessage(error)))
except:
pass
import threading
t = threading.Thread(target=mock_error, daemon=True)
t.start()
with session:
with pytest.raises(MCPAuthError):
session.send_request(request, MockResult)
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_request_http_status_error_coverage(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_direct_http_error():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
# To cover line 263 in base_session.py, we MUST put non-401 HTTPStatusError
# DIRECTLY into response_streams, as _receive_loop would convert it to JSONRPCError.
response = Response(status_code=403, request=Request("GET", "http://test"))
error = HTTPStatusError("Forbidden", request=response.request, response=response)
session._response_streams[req_id].put(error)
except:
pass
import threading
t = threading.Thread(target=mock_direct_http_error, daemon=True)
t.start()
# We still need the session for request ID generation and queue setup
with session:
with pytest.raises(MCPConnectionError) as exc:
session.send_request(request, MockResult)
assert exc.value.args[0].code == 403
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_request_http_status_auth_error(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_error():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
response = Response(status_code=401, request=Request("GET", "http://test"))
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
read_stream.put(error)
except:
pass
import threading
t = threading.Thread(target=mock_error, daemon=True)
t.start()
with session:
with pytest.raises(MCPAuthError):
session.send_request(request, MockResult)
t.join(timeout=1)
@pytest.mark.timeout(5)
def test_send_notification(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
notification = MockNotification(method="notify", params=MockNotificationParams(message="hi"))
session.send_notification(notification, related_request_id="rel-1")
msg = write_stream.get_nowait()
assert isinstance(msg.message.root, JSONRPCNotification)
assert msg.message.root.method == "notify"
assert msg.message.root.params == {"message": "hi"}
assert msg.metadata.related_request_id == "rel-1"
@pytest.mark.timeout(10)
def test_receive_loop_request(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
req_payload = {"jsonrpc": "2.0", "id": 1, "method": "test/request", "params": {"name": "test"}}
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload)))
for _ in range(30):
if session.received_requests:
break
time.sleep(0.1)
assert len(session.received_requests) == 1
responder = session.received_requests[0]
assert responder.request_id == 1
assert responder.request.root.method == "test/request"
@pytest.mark.timeout(10)
def test_receive_loop_notification(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
notif_payload = {"jsonrpc": "2.0", "method": "test/notification", "params": {"message": "hello"}}
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload)))
for _ in range(30):
if session.received_notifications:
break
time.sleep(0.1)
assert len(session.received_notifications) == 1
assert isinstance(session.received_notifications[0].root, MockNotification)
assert session.received_notifications[0].root.method == "test/notification"
@pytest.mark.timeout(15)
def test_receive_loop_cancel_notification(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ClientNotification)
with session:
req_payload = {"jsonrpc": "2.0", "id": "req-1", "method": "test/request", "params": {"name": "test"}}
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload)))
for _ in range(30):
if "req-1" in session._in_flight:
break
time.sleep(0.1)
assert "req-1" in session._in_flight
responder = session._in_flight["req-1"]
with responder:
cancel_payload = {"jsonrpc": "2.0", "method": "notifications/cancelled", "params": {"requestId": "req-1"}}
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(cancel_payload)))
for _ in range(30):
if responder.completed:
break
time.sleep(0.1)
assert responder.completed is True
msg = write_stream.get(timeout=2)
assert isinstance(msg.message.root, JSONRPCError)
assert msg.message.root.id == "req-1"
@pytest.mark.timeout(10)
def test_receive_loop_exception(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
read_stream.put(Exception("Unexpected error"))
for _ in range(30):
if any(isinstance(x, Exception) for x in session.handled_incoming):
break
time.sleep(0.1)
assert any(isinstance(x, Exception) and str(x) == "Unexpected error" for x in session.handled_incoming)
@pytest.mark.timeout(10)
def test_receive_loop_http_status_error(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
session._request_id = 1
resp_queue = queue.Queue()
session._response_streams[0] = resp_queue
response = Response(status_code=401, request=Request("GET", "http://test"))
# Using 401 specifically as _receive_loop preserves it
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
read_stream.put(error)
got = resp_queue.get(timeout=2)
assert isinstance(got, HTTPStatusError)
@pytest.mark.timeout(10)
def test_receive_loop_http_status_error_non_401(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
session._request_id = 1
resp_queue = queue.Queue()
session._response_streams[0] = resp_queue
response = Response(status_code=500, request=Request("GET", "http://test"))
error = HTTPStatusError("Server Error", request=response.request, response=response)
read_stream.put(error)
got = resp_queue.get(timeout=2)
assert isinstance(got, JSONRPCError)
assert got.error.code == 500
@pytest.mark.timeout(5)
def test_check_receiver_status_fail(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
executor = ThreadPoolExecutor(max_workers=1)
def raise_err():
raise RuntimeError("Receiver failed")
future = executor.submit(raise_err)
session._receiver_future = future
try:
future.result()
except:
pass
with pytest.raises(RuntimeError, match="Receiver failed"):
session.check_receiver_status()
executor.shutdown()
@pytest.mark.timeout(10)
def test_receive_loop_unknown_request_id(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
resp = JSONRPCResponse(jsonrpc="2.0", id=999, result={"ok": True})
read_stream.put(SessionMessage(message=JSONRPCMessage(resp)))
for _ in range(30):
if any(isinstance(x, RuntimeError) and "Server Error" in str(x) for x in session.handled_incoming):
break
time.sleep(0.1)
assert any("Server Error" in str(x) for x in session.handled_incoming)
@pytest.mark.timeout(10)
def test_receive_loop_http_error_unknown_id(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
response = Response(status_code=401, request=Request("GET", "http://test"))
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
read_stream.put(error)
for _ in range(30):
if any(isinstance(x, RuntimeError) and "unknown request ID" in str(x) for x in session.handled_incoming):
break
time.sleep(0.1)
assert any("unknown request ID" in str(x) for x in session.handled_incoming)
@pytest.mark.timeout(10)
def test_receive_loop_validation_error_notification(streams):
from core.mcp.session.base_session import logger
with patch.object(logger, "warning") as mock_warning:
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, RootModel[MockNotification])
with session:
notif_payload = {"jsonrpc": "2.0", "method": "bad", "params": {"some": "data"}}
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload)))
time.sleep(1.0)
assert mock_warning.called
@pytest.mark.timeout(5)
def test_send_request_none_response(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
def mock_none():
try:
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
session._response_streams[req_id].put(None)
except:
pass
import threading
t = threading.Thread(target=mock_none, daemon=True)
t.start()
with session:
with pytest.raises(MCPConnectionError) as exc:
session.send_request(request, MockResult)
assert exc.value.args[0].message == "No response received"
t.join(timeout=1)
@pytest.mark.timeout(15)
def test_session_exit_timeout(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
mock_future = MagicMock(spec=Future)
mock_future.result.side_effect = TimeoutError()
mock_future.done.return_value = False
session._receiver_future = mock_future
session._executor = MagicMock(spec=ThreadPoolExecutor)
session.__exit__(None, None, None)
mock_future.cancel.assert_called_once()
session._executor.shutdown.assert_called_once_with(wait=False)
@pytest.mark.timeout(10)
def test_receive_loop_fatal_exception(streams):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with patch.object(read_stream, "get", side_effect=RuntimeError("Fatal loop error")):
with patch("core.mcp.session.base_session.logger") as mock_logger:
with pytest.raises(RuntimeError, match="Fatal loop error"):
with session:
pass
mock_logger.exception.assert_called_with("Error in message processing loop")
@pytest.mark.timeout(5)
def test_receive_loop_empty_coverage(streams):
with patch("core.mcp.session.base_session.DEFAULT_RESPONSE_READ_TIMEOUT", 0.1):
read_stream, write_stream = streams
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
with session:
time.sleep(0.3)
@pytest.mark.timeout(2)
def test_base_methods_noop(streams):
read_stream, write_stream = streams
session = BaseSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
session._received_request(MagicMock())
session._received_notification(MagicMock())
session.send_progress_notification("token", 0.5)
session._handle_incoming(MagicMock())
@pytest.mark.timeout(5)
def test_send_request_session_timeout_retry_6(streams):
read_stream, write_stream = streams
session = MockSession(
read_stream, write_stream, ReceiveRequest, ReceiveNotification, read_timeout_seconds=timedelta(seconds=0.1)
)
request = MockRequest(method="test", params=MockRequestParams(name="world"))
with patch.object(session, "check_receiver_status", side_effect=[None, RuntimeError("timeout_broken")]):
with pytest.raises(RuntimeError, match="timeout_broken"):
session.send_request(request, MockResult)

View File

@ -0,0 +1,576 @@
import queue
from unittest.mock import MagicMock
import pytest
from pydantic import AnyUrl
from core.mcp import types
from core.mcp.session.base_session import RequestResponder, SessionMessage
from core.mcp.session.client_session import (
ClientSession,
_default_list_roots_callback,
_default_logging_callback,
_default_message_handler,
_default_sampling_callback,
)
@pytest.fixture
def streams():
return queue.Queue(), queue.Queue()
def test_client_session_init(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
assert session._client_info.name == "Dify"
assert session._sampling_callback == _default_sampling_callback
assert session._list_roots_callback == _default_list_roots_callback
assert session._logging_callback == _default_logging_callback
assert session._message_handler == _default_message_handler
def test_client_session_init_custom(streams):
read_stream, write_stream = streams
sampling_cb = MagicMock()
list_roots_cb = MagicMock()
logging_cb = MagicMock()
msg_handler = MagicMock()
client_info = types.Implementation(name="Custom", version="1.0")
session = ClientSession(
read_stream,
write_stream,
sampling_callback=sampling_cb,
list_roots_callback=list_roots_cb,
logging_callback=logging_cb,
message_handler=msg_handler,
client_info=client_info,
)
assert session._client_info == client_info
assert session._sampling_callback == sampling_cb
assert session._list_roots_callback == list_roots_cb
assert session._logging_callback == logging_cb
assert session._message_handler == msg_handler
def test_initialize_success(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
expected_result = types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ServerCapabilities(),
serverInfo=types.Implementation(name="test-server", version="1.0"),
)
def mock_server():
# Handle initialize request
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result=expected_result.model_dump())
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
# Expect initialized notification
notif = write_stream.get(timeout=2)
assert notif.message.root.method == "notifications/initialized"
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
result = session.initialize()
assert result.protocolVersion == types.LATEST_PROTOCOL_VERSION
assert result.serverInfo.name == "test-server"
t.join(timeout=1)
def test_initialize_custom_capabilities(streams):
read_stream, write_stream = streams
session = ClientSession(
read_stream, write_stream, sampling_callback=lambda c, p: None, list_roots_callback=lambda c: None
)
def mock_server():
msg = write_stream.get(timeout=2)
params = msg.message.root.params
# Check that capabilities are set because we provided custom callbacks
assert params["capabilities"]["sampling"] is not None
assert params["capabilities"]["roots"]["listChanged"] is True
req_id = msg.message.root.id
resp = types.JSONRPCResponse(
jsonrpc="2.0",
id=req_id,
result={
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
"capabilities": {},
"serverInfo": {"name": "test", "version": "1.0"},
},
)
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
write_stream.get(timeout=2) # initialized notif
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
session.initialize()
t.join(timeout=1)
def test_initialize_unsupported_version(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
def mock_server():
msg = write_stream.get(timeout=2)
req_id = msg.message.root.id
resp = types.JSONRPCResponse(
jsonrpc="2.0",
id=req_id,
result={
"protocolVersion": "0.0.1", # Unsupported
"capabilities": {},
"serverInfo": {"name": "test", "version": "1.0"},
},
)
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
session.initialize()
t.join(timeout=1)
def test_send_ping(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "ping"
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
session.send_ping()
t.join(timeout=1)
def test_send_progress_notification(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
session.send_progress_notification(progress_token="token", progress=50.0, total=100.0)
msg = write_stream.get_nowait()
assert msg.message.root.method == "notifications/progress"
assert msg.message.root.params["progressToken"] == "token"
assert msg.message.root.params["progress"] == 50.0
assert msg.message.root.params["total"] == 100.0
def test_set_logging_level(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "logging/setLevel"
assert msg.message.root.params["level"] == "debug"
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
session.set_logging_level("debug")
t.join(timeout=1)
def test_list_resources(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "resources/list"
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"resources": []})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
result = session.list_resources()
assert result.resources == []
t.join(timeout=1)
def test_list_resource_templates(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "resources/templates/list"
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"resourceTemplates": []})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
result = session.list_resource_templates()
assert result.resourceTemplates == []
t.join(timeout=1)
def test_read_resource(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
uri = AnyUrl("file:///test")
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "resources/read"
assert msg.message.root.params["uri"] == str(uri)
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"contents": []})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
result = session.read_resource(uri)
assert result.contents == []
t.join(timeout=1)
def test_subscribe_resource(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
uri = AnyUrl("file:///test")
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "resources/subscribe"
assert msg.message.root.params["uri"] == str(uri)
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
session.subscribe_resource(uri)
t.join(timeout=1)
def test_unsubscribe_resource(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
uri = AnyUrl("file:///test")
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "resources/unsubscribe"
assert msg.message.root.params["uri"] == str(uri)
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
session.unsubscribe_resource(uri)
t.join(timeout=1)
def test_call_tool(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "tools/call"
assert msg.message.root.params["name"] == "test-tool"
assert msg.message.root.params["arguments"] == {"arg": 1}
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"content": [], "isError": False})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
result = session.call_tool("test-tool", arguments={"arg": 1})
assert result.isError is False
t.join(timeout=1)
def test_list_prompts(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "prompts/list"
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"prompts": []})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
result = session.list_prompts()
assert result.prompts == []
t.join(timeout=1)
def test_get_prompt(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "prompts/get"
assert msg.message.root.params["name"] == "test-prompt"
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"messages": []})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
result = session.get_prompt("test-prompt")
assert result.messages == []
t.join(timeout=1)
def test_complete(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
ref = types.PromptReference(type="ref/prompt", name="test")
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "completion/complete"
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"completion": {"values": [], "hasMore": False}})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
result = session.complete(ref, argument={"name": "val", "value": "x"})
assert result.completion.hasMore is False
t.join(timeout=1)
def test_list_tools(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
def mock_server():
msg = write_stream.get(timeout=2)
assert msg.message.root.method == "tools/list"
req_id = msg.message.root.id
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"tools": []})
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
import threading
t = threading.Thread(target=mock_server, daemon=True)
t.start()
with session:
result = session.list_tools()
assert result.tools == []
t.join(timeout=1)
def test_send_roots_list_changed(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
session.send_roots_list_changed()
msg = write_stream.get_nowait()
assert msg.message.root.method == "notifications/roots/list_changed"
def test_received_request_sampling(streams):
read_stream, write_stream = streams
sampling_cb = MagicMock(
return_value=types.CreateMessageResult(
role="assistant", content=types.TextContent(type="text", text="hello"), model="gpt-4"
)
)
session = ClientSession(read_stream, write_stream, sampling_callback=sampling_cb)
req = types.ServerRequest(
root=types.CreateMessageRequest(
method="sampling/createMessage", params=types.CreateMessageRequestParams(messages=[], maxTokens=100)
)
)
responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock())
session._received_request(responder)
msg = write_stream.get_nowait()
assert msg.message.root.result["model"] == "gpt-4"
sampling_cb.assert_called_once()
def test_received_request_list_roots(streams):
read_stream, write_stream = streams
list_roots_cb = MagicMock(return_value=types.ListRootsResult(roots=[]))
session = ClientSession(read_stream, write_stream, list_roots_callback=list_roots_cb)
req = types.ServerRequest(root=types.ListRootsRequest(method="roots/list"))
responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock())
session._received_request(responder)
msg = write_stream.get_nowait()
assert msg.message.root.result["roots"] == []
list_roots_cb.assert_called_once()
def test_received_request_ping(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
req = types.ServerRequest(root=types.PingRequest(method="ping"))
responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock())
session._received_request(responder)
msg = write_stream.get_nowait()
assert msg.message.root.result == {}
def test_handle_incoming(streams):
read_stream, write_stream = streams
msg_handler = MagicMock()
session = ClientSession(read_stream, write_stream, message_handler=msg_handler)
item = MagicMock()
session._handle_incoming(item)
msg_handler.assert_called_once_with(item)
def test_received_notification_logging(streams):
read_stream, write_stream = streams
logging_cb = MagicMock()
session = ClientSession(read_stream, write_stream, logging_callback=logging_cb)
notif = types.ServerNotification(
root=types.LoggingMessageNotification(
method="notifications/message",
params=types.LoggingMessageNotificationParams(level="info", data={"msg": "test"}),
)
)
session._received_notification(notif)
logging_cb.assert_called_once()
assert logging_cb.call_args[0][0].level == "info"
def test_default_message_handler():
# Exception case
with pytest.raises(ValueError, match="test error"):
_default_message_handler(Exception("test error"))
# Notification case - should do nothing
_default_message_handler(MagicMock(spec=types.ServerNotification))
# RequestResponder case - should do nothing
_default_message_handler(MagicMock(spec=RequestResponder))
def test_default_sampling_callback():
ctx = MagicMock()
params = MagicMock()
res = _default_sampling_callback(ctx, params)
assert res.code == types.INVALID_REQUEST
assert "not supported" in res.message
def test_default_list_roots_callback():
ctx = MagicMock()
res = _default_list_roots_callback(ctx)
assert res.code == types.INVALID_REQUEST
assert "not supported" in res.message
def test_default_logging_callback():
params = MagicMock()
_default_logging_callback(params) # Should do nothing
def test_received_notification_unknown(streams):
read_stream, write_stream = streams
session = ClientSession(read_stream, write_stream)
# Use a notification type that is NOT LoggingMessageNotification
notif = types.ServerNotification(
root=types.ResourceListChangedNotification(method="notifications/resources/list_changed")
)
session._received_notification(notif)
# Should just pass (case _:)

View File

@ -2,13 +2,16 @@
from contextlib import ExitStack
from types import TracebackType
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.mcp.error import MCPConnectionError
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations
from core.mcp.types import CallToolResult, ListToolsResult, OAuthTokens, TextContent, Tool, ToolAnnotations
class TestMCPClient:
@ -380,3 +383,256 @@ class TestMCPClient:
timeout=30.0,
sse_read_timeout=60.0,
)
class TestMCPClientWithAuthRetry:
"""Test suite for MCPClientWithAuthRetry."""
@pytest.fixture
def mock_provider(self):
provider = MagicMock(spec=MCPProviderEntity)
provider.id = "test-provider-id"
provider.tenant_id = "test-tenant-id"
provider.retrieve_tokens.return_value = OAuthTokens(
access_token="new-token",
token_type="Bearer",
expires_in=3600,
refresh_token="refresh-token",
)
return provider
@pytest.fixture
def auth_client(self, mock_provider):
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
headers={"Authorization": "Bearer old-token"},
provider_entity=mock_provider,
authorization_code="test-code",
by_server_id=True,
)
return client
def test_init(self, mock_provider):
"""Test initialization."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
headers={"Authorization": "Bearer test"},
timeout=30.0,
provider_entity=mock_provider,
authorization_code="initial-code",
by_server_id=True,
)
assert client.server_url == "http://test.example.com"
assert client.headers == {"Authorization": "Bearer test"}
assert client.timeout == 30.0
assert client.provider_entity == mock_provider
assert client.authorization_code == "initial-code"
assert client.by_server_id is True
assert client._has_retried is False
@patch("core.mcp.auth_client.db")
@patch("core.mcp.auth_client.Session")
@patch("services.tools.mcp_tools_manage_service.MCPToolManageService")
def test_handle_auth_error_success(
self, mock_service_class, mock_session_class, mock_db, auth_client, mock_provider
):
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_service = mock_service_class.return_value
new_provider = MagicMock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = OAuthTokens(
access_token="new-access-token",
token_type="Bearer",
expires_in=3600,
refresh_token="new-refresh-token",
)
mock_service.get_provider_entity.return_value = new_provider
# MCPAuthError parses resource_metadata and scope from www_authenticate_header
www_auth = 'Bearer resource_metadata="http://meta", scope="read"'
error = MCPAuthError("Auth failed", www_authenticate_header=www_auth)
auth_client._handle_auth_error(error)
# Verify service calls - error.resource_metadata_url and error.scope_hint are parsed from header
mock_service.auth_with_actions.assert_called_once_with(
mock_provider,
"test-code",
resource_metadata_url="http://meta",
scope_hint="read",
)
mock_service.get_provider_entity.assert_called_once_with(
mock_provider.id, mock_provider.tenant_id, by_server_id=True
)
# Verify client updates
assert auth_client.headers["Authorization"] == "Bearer new-access-token"
assert auth_client.authorization_code is None
assert auth_client._has_retried is True
assert auth_client.provider_entity == new_provider
def test_handle_auth_error_no_provider(self, auth_client):
"""Test auth error handling when no provider entity is set."""
auth_client.provider_entity = None
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
auth_client._handle_auth_error(error)
assert exc_info.value == error
def test_handle_auth_error_already_retried(self, auth_client):
"""Test auth error handling when already retried."""
auth_client._has_retried = True
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
auth_client._handle_auth_error(error)
assert exc_info.value == error
@patch("core.mcp.auth_client.db")
@patch("core.mcp.auth_client.Session")
@patch("services.tools.mcp_tools_manage_service.MCPToolManageService")
def test_handle_auth_error_no_token(
self, mock_service_class, mock_session_class, mock_db, auth_client, mock_provider
):
"""Test auth error handling when no token is received."""
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_service = mock_service_class.return_value
new_provider = MagicMock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = None
mock_service.get_provider_entity.return_value = new_provider
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
auth_client._handle_auth_error(error)
assert "Authentication failed - no token received" in str(exc_info.value)
@patch("core.mcp.auth_client.db")
@patch("core.mcp.auth_client.Session")
@patch("services.tools.mcp_tools_manage_service.MCPToolManageService")
def test_handle_auth_error_generic_exception(self, mock_service_class, mock_session_class, mock_db, auth_client):
"""Test auth error handling when a generic exception occurs."""
mock_session_class.side_effect = Exception("DB error")
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
auth_client._handle_auth_error(error)
assert "Authentication retry failed: DB error" in str(exc_info.value)
@patch("core.mcp.auth_client.db")
@patch("core.mcp.auth_client.Session")
@patch("services.tools.mcp_tools_manage_service.MCPToolManageService")
def test_handle_auth_error_mcp_auth_error_propagation(
self, mock_service_class, mock_session_class, mock_db, auth_client
):
"""Test that MCPAuthError during refresh is propagated as is."""
mock_session_class.return_value.__enter__.return_value = MagicMock()
mock_service = mock_service_class.return_value
mock_service.auth_with_actions.side_effect = MCPAuthError("Refresh failed")
error = MCPAuthError("Initial auth failed")
with pytest.raises(MCPAuthError) as exc_info:
auth_client._handle_auth_error(error)
assert "Refresh failed" in str(exc_info.value)
def test_execute_with_retry_success_first_try(self, auth_client):
"""Test execution success on first try."""
mock_func = MagicMock(return_value="success")
result = auth_client._execute_with_retry(mock_func, "arg1", kwarg1="val1")
assert result == "success"
mock_func.assert_called_once_with("arg1", kwarg1="val1")
assert auth_client._has_retried is False
@patch.object(MCPClientWithAuthRetry, "_handle_auth_error")
@patch.object(MCPClientWithAuthRetry, "_initialize")
def test_execute_with_retry_success_on_retry_initialized(self, mock_initialize, mock_handle_auth, auth_client):
"""Test execution success on retry after auth error when client was already initialized."""
mock_func = MagicMock()
mock_func.side_effect = [MCPAuthError("Auth failed"), "success"]
auth_client._initialized = True
auth_client._exit_stack = MagicMock()
result = auth_client._execute_with_retry(mock_func, "arg")
assert result == "success"
assert mock_func.call_count == 2
mock_handle_auth.assert_called_once()
mock_initialize.assert_called_once()
auth_client._exit_stack.close.assert_called_once()
assert auth_client._has_retried is False
@patch.object(MCPClientWithAuthRetry, "_handle_auth_error")
@patch.object(MCPClientWithAuthRetry, "_initialize")
def test_execute_with_retry_success_on_retry_not_initialized(self, mock_initialize, mock_handle_auth, auth_client):
"""Test retry when client was NOT initialized (skips cleanup/re-init)."""
mock_func = MagicMock()
mock_func.side_effect = [MCPAuthError("Auth failed"), "result"]
auth_client._initialized = False
result = auth_client._execute_with_retry(mock_func, "arg")
assert result == "result"
assert mock_func.call_count == 2
mock_handle_auth.assert_called_once()
mock_initialize.assert_not_called()
assert auth_client._has_retried is False
@patch.object(MCPClientWithAuthRetry, "_handle_auth_error")
def test_execute_with_retry_failure_on_retry(self, mock_handle_auth, auth_client):
"""Test execution failure even after retry."""
mock_func = MagicMock()
mock_func.side_effect = [MCPAuthError("First fail"), MCPAuthError("Second fail")]
with pytest.raises(MCPAuthError) as exc_info:
auth_client._execute_with_retry(mock_func, "arg")
assert "Second fail" in str(exc_info.value)
assert mock_func.call_count == 2
mock_handle_auth.assert_called_once()
assert auth_client._has_retried is False
@patch.object(MCPClientWithAuthRetry, "_execute_with_retry")
def test_auth_client_context_manager_enter(self, mock_execute_retry, auth_client):
"""Test context manager __enter__."""
auth_client.__enter__()
mock_execute_retry.assert_called_once()
func = mock_execute_retry.call_args[0][0]
with patch("core.mcp.mcp_client.MCPClient.__enter__") as mock_base_enter:
result = func()
assert result == auth_client
mock_base_enter.assert_called_once()
@patch.object(MCPClientWithAuthRetry, "_execute_with_retry")
def test_auth_client_list_tools(self, mock_execute_retry, auth_client):
"""Test list_tools with retry."""
auth_client.list_tools()
mock_execute_retry.assert_called_once()
assert mock_execute_retry.call_args[0][0].__name__ == "list_tools"
@patch.object(MCPClientWithAuthRetry, "_execute_with_retry")
def test_auth_client_invoke_tool(self, mock_execute_retry, auth_client):
"""Test invoke_tool with retry."""
auth_client.invoke_tool("test-tool", {"arg": "val"})
mock_execute_retry.assert_called_once()
assert mock_execute_retry.call_args[0][0].__name__ == "invoke_tool"
assert mock_execute_retry.call_args[0][1] == "test-tool"
assert mock_execute_retry.call_args[0][2] == {"arg": "val"}