mirror of
https://github.com/langgenius/dify.git
synced 2026-03-14 03:18:36 +08:00
test: added for core logging and core mcp (#32478)
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
This commit is contained in:
@ -82,6 +82,68 @@ class TestTraceContextFilter:
|
||||
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert log_record.span_id == "051581bf3bb55c45"
|
||||
|
||||
def test_otel_context_invalid_trace_id(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.trace_id = 0
|
||||
mock_context.is_valid = True
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
# Use mocks for base context to ensure we can test the fallback
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("core.logging.filters.get_trace_id", return_value=""),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
assert log_record.trace_id == ""
|
||||
|
||||
def test_otel_context_invalid_span_id(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
|
||||
mock_context.span_id = 0
|
||||
mock_context.is_valid = True
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert log_record.span_id == ""
|
||||
|
||||
def test_otel_context_span_none(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=None),
|
||||
mock.patch("core.logging.filters.get_trace_id", return_value=""),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
assert log_record.trace_id == ""
|
||||
|
||||
def test_otel_context_exception(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
# Trigger exception in OTEL block
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception),
|
||||
mock.patch("core.logging.filters.get_trace_id", return_value=""),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
assert log_record.trace_id == ""
|
||||
|
||||
|
||||
class TestIdentityContextFilter:
|
||||
def test_sets_empty_identity_without_request_context(self, log_record):
|
||||
@ -114,3 +176,119 @@ class TestIdentityContextFilter:
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
assert log_record.tenant_id == ""
|
||||
|
||||
def test_sets_empty_identity_unauthenticated(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
mock_user = mock.MagicMock()
|
||||
mock_user.is_authenticated = False
|
||||
|
||||
with (
|
||||
mock.patch("flask.has_request_context", return_value=True),
|
||||
mock.patch("flask_login.current_user", mock_user),
|
||||
):
|
||||
filter = IdentityContextFilter()
|
||||
filter.filter(log_record)
|
||||
assert log_record.user_id == ""
|
||||
|
||||
def test_sets_identity_for_account(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
class MockAccount:
|
||||
pass
|
||||
|
||||
mock_user = MockAccount()
|
||||
mock_user.id = "account_id"
|
||||
mock_user.current_tenant_id = "tenant_id"
|
||||
mock_user.is_authenticated = True
|
||||
|
||||
with (
|
||||
mock.patch("flask.has_request_context", return_value=True),
|
||||
mock.patch("models.Account", MockAccount),
|
||||
mock.patch("flask_login.current_user", mock_user),
|
||||
):
|
||||
filter = IdentityContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
assert log_record.tenant_id == "tenant_id"
|
||||
assert log_record.user_id == "account_id"
|
||||
assert log_record.user_type == "account"
|
||||
|
||||
def test_sets_identity_for_account_no_tenant(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
class MockAccount:
|
||||
pass
|
||||
|
||||
mock_user = MockAccount()
|
||||
mock_user.id = "account_id"
|
||||
mock_user.current_tenant_id = None
|
||||
mock_user.is_authenticated = True
|
||||
|
||||
with (
|
||||
mock.patch("flask.has_request_context", return_value=True),
|
||||
mock.patch("models.Account", MockAccount),
|
||||
mock.patch("flask_login.current_user", mock_user),
|
||||
):
|
||||
filter = IdentityContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
assert log_record.tenant_id == ""
|
||||
assert log_record.user_id == "account_id"
|
||||
assert log_record.user_type == "account"
|
||||
|
||||
def test_sets_identity_for_end_user(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
class MockEndUser:
|
||||
pass
|
||||
|
||||
class AnotherClass:
|
||||
pass
|
||||
|
||||
mock_user = MockEndUser()
|
||||
mock_user.id = "end_user_id"
|
||||
mock_user.tenant_id = "tenant_id"
|
||||
mock_user.type = "custom_type"
|
||||
mock_user.is_authenticated = True
|
||||
|
||||
with (
|
||||
mock.patch("flask.has_request_context", return_value=True),
|
||||
mock.patch("models.model.EndUser", MockEndUser),
|
||||
mock.patch("models.Account", AnotherClass),
|
||||
mock.patch("flask_login.current_user", mock_user),
|
||||
):
|
||||
filter = IdentityContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
assert log_record.tenant_id == "tenant_id"
|
||||
assert log_record.user_id == "end_user_id"
|
||||
assert log_record.user_type == "custom_type"
|
||||
|
||||
def test_sets_identity_for_end_user_default_type(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
class MockEndUser:
|
||||
pass
|
||||
|
||||
class AnotherClass:
|
||||
pass
|
||||
|
||||
mock_user = MockEndUser()
|
||||
mock_user.id = "end_user_id"
|
||||
mock_user.tenant_id = "tenant_id"
|
||||
mock_user.type = None
|
||||
mock_user.is_authenticated = True
|
||||
|
||||
with (
|
||||
mock.patch("flask.has_request_context", return_value=True),
|
||||
mock.patch("models.model.EndUser", MockEndUser),
|
||||
mock.patch("models.Account", AnotherClass),
|
||||
mock.patch("flask_login.current_user", mock_user),
|
||||
):
|
||||
filter = IdentityContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
assert log_record.tenant_id == "tenant_id"
|
||||
assert log_record.user_id == "end_user_id"
|
||||
assert log_record.user_type == "end_user"
|
||||
|
||||
@ -1,27 +1,39 @@
|
||||
"""Unit tests for MCP OAuth authentication flow."""
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.helper import ssrf_proxy
|
||||
from core.mcp.auth.auth_flow import (
|
||||
OAUTH_STATE_EXPIRY_SECONDS,
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX,
|
||||
OAuthCallbackState,
|
||||
_create_secure_redis_state,
|
||||
_parse_token_response,
|
||||
_retrieve_redis_state,
|
||||
auth,
|
||||
build_oauth_authorization_server_metadata_discovery_urls,
|
||||
build_protected_resource_metadata_discovery_urls,
|
||||
check_support_resource_discovery,
|
||||
client_credentials_flow,
|
||||
discover_oauth_authorization_server_metadata,
|
||||
discover_oauth_metadata,
|
||||
discover_protected_resource_metadata,
|
||||
exchange_authorization,
|
||||
generate_pkce_challenge,
|
||||
get_effective_scope,
|
||||
handle_callback,
|
||||
refresh_authorization,
|
||||
register_client,
|
||||
start_authorization,
|
||||
)
|
||||
from core.mcp.entities import AuthActionType, AuthResult
|
||||
from core.mcp.error import MCPRefreshTokenError
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
@ -764,3 +776,555 @@ class TestAuthOrchestration:
|
||||
auth(mock_provider, authorization_code="auth-code")
|
||||
|
||||
assert "Existing OAuth client information is required" in str(exc_info.value)
|
||||
|
||||
def test_generate_pkce_challenge(self):
|
||||
verifier, challenge = generate_pkce_challenge()
|
||||
assert verifier
|
||||
assert challenge
|
||||
assert "=" not in verifier
|
||||
assert "=" not in challenge
|
||||
|
||||
def test_build_protected_resource_metadata_discovery_urls(self):
|
||||
# Case 1: WWW-Auth URL provided
|
||||
urls = build_protected_resource_metadata_discovery_urls(
|
||||
"https://auth.example.com/prm", "https://api.example.com"
|
||||
)
|
||||
assert "https://auth.example.com/prm" in urls
|
||||
assert "https://api.example.com/.well-known/oauth-protected-resource" in urls
|
||||
|
||||
# Case 2: No WWW-Auth URL, with path
|
||||
urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com/v1")
|
||||
assert "https://api.example.com/.well-known/oauth-protected-resource/v1" in urls
|
||||
assert "https://api.example.com/.well-known/oauth-protected-resource" in urls
|
||||
|
||||
# Case 3: No path
|
||||
urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com")
|
||||
assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"]
|
||||
|
||||
def test_build_oauth_authorization_server_metadata_discovery_urls(self):
|
||||
# Case 1: with auth_server_url
|
||||
urls = build_oauth_authorization_server_metadata_discovery_urls(
|
||||
"https://auth.example.com", "https://api.example.com"
|
||||
)
|
||||
assert "https://auth.example.com/.well-known/oauth-authorization-server" in urls
|
||||
assert "https://auth.example.com/.well-known/openid-configuration" in urls
|
||||
|
||||
# Case 2: with path
|
||||
urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://api.example.com/tenant")
|
||||
assert "https://api.example.com/.well-known/oauth-authorization-server/tenant" in urls
|
||||
assert "https://api.example.com/tenant/.well-known/openid-configuration" in urls
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_protected_resource_metadata(self, mock_get):
|
||||
# Success
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"resource": "https://api.example.com",
|
||||
"authorization_servers": ["https://auth"],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
result = discover_protected_resource_metadata(None, "https://api.example.com")
|
||||
assert result is not None
|
||||
assert result.resource == "https://api.example.com"
|
||||
|
||||
# 404 then Success
|
||||
res404 = Mock()
|
||||
res404.status_code = 404
|
||||
mock_get.side_effect = [res404, mock_response]
|
||||
result = discover_protected_resource_metadata(None, "https://api.example.com/path")
|
||||
assert result is not None
|
||||
assert result.resource == "https://api.example.com"
|
||||
|
||||
# Error handling
|
||||
mock_get.side_effect = httpx.RequestError("Error")
|
||||
result = discover_protected_resource_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_authorization_server_metadata(self, mock_get):
|
||||
# Success
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://auth.example.com/auth",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"response_types_supported": ["code"],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
|
||||
assert result is not None
|
||||
assert result.authorization_endpoint == "https://auth.example.com/auth"
|
||||
|
||||
# 404
|
||||
res404 = Mock()
|
||||
res404.status_code = 404
|
||||
mock_get.side_effect = [res404, mock_response]
|
||||
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com/tenant")
|
||||
assert result is not None
|
||||
assert result.authorization_endpoint == "https://auth.example.com/auth"
|
||||
|
||||
# ValidationError
|
||||
mock_response.json.return_value = {"invalid": "data"}
|
||||
mock_get.side_effect = None
|
||||
mock_get.return_value = mock_response
|
||||
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
def test_get_effective_scope(self):
|
||||
prm = ProtectedResourceMetadata(
|
||||
resource="https://api.example.com",
|
||||
authorization_servers=["https://auth"],
|
||||
scopes_supported=["read", "write"],
|
||||
)
|
||||
asm = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/auth",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
scopes_supported=["openid", "profile"],
|
||||
)
|
||||
|
||||
# 1. WWW-Auth priority
|
||||
assert get_effective_scope("scope1", prm, asm, "client") == "scope1"
|
||||
# 2. PRM priority
|
||||
assert get_effective_scope(None, prm, asm, "client") == "read write"
|
||||
# 3. ASM priority
|
||||
assert get_effective_scope(None, None, asm, "client") == "openid profile"
|
||||
# 4. Client configured
|
||||
assert get_effective_scope(None, None, None, "client") == "client"
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.redis_client")
|
||||
def test_redis_state_management(self, mock_redis):
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id="p1",
|
||||
tenant_id="t1",
|
||||
server_url="https://api",
|
||||
metadata=None,
|
||||
client_information=OAuthClientInformation(client_id="c1"),
|
||||
code_verifier="cv",
|
||||
redirect_uri="https://re",
|
||||
)
|
||||
|
||||
# Create
|
||||
state_key = _create_secure_redis_state(state_data)
|
||||
assert state_key
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
# Retrieve Success
|
||||
mock_redis.get.return_value = state_data.model_dump_json()
|
||||
retrieved = _retrieve_redis_state(state_key)
|
||||
assert retrieved.provider_id == "p1"
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
# Retrieve Failure - Not found
|
||||
mock_redis.get.return_value = None
|
||||
with pytest.raises(ValueError, match="expired or does not exist"):
|
||||
_retrieve_redis_state("absent")
|
||||
|
||||
# Retrieve Failure - Invalid JSON
|
||||
mock_redis.get.return_value = "invalid"
|
||||
with pytest.raises(ValueError, match="Invalid state parameter"):
|
||||
_retrieve_redis_state("invalid")
|
||||
|
||||
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
|
||||
@patch("core.mcp.auth.auth_flow.exchange_authorization")
|
||||
def test_handle_callback(self, mock_exchange, mock_retrieve):
|
||||
state = Mock(spec=OAuthCallbackState)
|
||||
state.server_url = "https://api"
|
||||
state.metadata = None
|
||||
state.client_information = Mock()
|
||||
state.code_verifier = "cv"
|
||||
state.redirect_uri = "https://re"
|
||||
mock_retrieve.return_value = state
|
||||
|
||||
tokens = Mock(spec=OAuthTokens)
|
||||
mock_exchange.return_value = tokens
|
||||
|
||||
s, t = handle_callback("key", "code")
|
||||
assert s == state
|
||||
assert t == tokens
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_check_support_resource_discovery(self, mock_get):
|
||||
# Case 1: authorization_servers (plural)
|
||||
res = Mock()
|
||||
res.status_code = 200
|
||||
res.json.return_value = {"authorization_servers": ["https://auth1"]}
|
||||
mock_get.return_value = res
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is True
|
||||
assert url == "https://auth1"
|
||||
|
||||
# Case 2: authorization_server_url (singular alias)
|
||||
res.json.return_value = {"authorization_server_url": ["https://auth2"]}
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is True
|
||||
assert url == "https://auth2"
|
||||
|
||||
# Case 3: Missing fields
|
||||
res.json.return_value = {"nothing": []}
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
|
||||
# Case 4: 404
|
||||
res.status_code = 404
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
|
||||
# Case 5: RequestError
|
||||
mock_get.side_effect = httpx.RequestError("Error")
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
|
||||
def test_discover_oauth_metadata(self):
|
||||
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
|
||||
mock_prm.return_value = ProtectedResourceMetadata(
|
||||
resource="https://api", authorization_servers=["https://auth"]
|
||||
)
|
||||
mock_asm.return_value = Mock(spec=OAuthMetadata)
|
||||
|
||||
asm, prm, hint = discover_oauth_metadata("https://api")
|
||||
assert asm == mock_asm.return_value
|
||||
assert prm == mock_prm.return_value
|
||||
mock_asm.assert_called_with("https://auth", "https://api", None)
|
||||
|
||||
def test_start_authorization(self):
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/authorize",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
client_info = OAuthClientInformation(client_id="c1")
|
||||
|
||||
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create:
|
||||
mock_create.return_value = "state-key"
|
||||
|
||||
# Success with scope
|
||||
url, verifier = start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1", "read")
|
||||
assert "scope=read" in url
|
||||
assert "state=state-key" in url
|
||||
|
||||
# Success without metadata
|
||||
url, verifier = start_authorization("https://api", None, client_info, "https://re", "p1", "t1")
|
||||
assert "https://api/authorize" in url
|
||||
|
||||
# Failure: incompatible auth server
|
||||
metadata.response_types_supported = ["implicit"]
|
||||
with pytest.raises(ValueError, match="Incompatible auth server"):
|
||||
start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1")
|
||||
|
||||
def test_parse_token_response(self):
|
||||
# Case 1: JSON
|
||||
res = Mock()
|
||||
res.headers = {"content-type": "application/json"}
|
||||
res.json.return_value = {"access_token": "at", "token_type": "Bearer"}
|
||||
tokens = _parse_token_response(res)
|
||||
assert tokens.access_token == "at"
|
||||
|
||||
# Case 2: Form-urlencoded
|
||||
res.headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
res.text = "access_token=at2&token_type=Bearer"
|
||||
tokens = _parse_token_response(res)
|
||||
assert tokens.access_token == "at2"
|
||||
|
||||
# Case 3: No content-type, but JSON
|
||||
res.headers = {}
|
||||
res.json.return_value = {"access_token": "at3", "token_type": "Bearer"}
|
||||
tokens = _parse_token_response(res)
|
||||
assert tokens.access_token == "at3"
|
||||
|
||||
# Case 4: No content-type, not JSON, but Form
|
||||
res.json.side_effect = json.JSONDecodeError("msg", "doc", 0)
|
||||
res.text = "access_token=at4&token_type=Bearer"
|
||||
tokens = _parse_token_response(res)
|
||||
assert tokens.access_token == "at4"
|
||||
|
||||
# Case 5: Validation Error fallback
|
||||
res.json.side_effect = ValidationError.from_exception_data("error", [])
|
||||
res.text = "access_token=at5&token_type=Bearer"
|
||||
tokens = _parse_token_response(res)
|
||||
assert tokens.access_token == "at5"
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_exchange_authorization(self, mock_post):
|
||||
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/authorize",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
|
||||
# Success
|
||||
res = Mock()
|
||||
res.is_success = True
|
||||
res.headers = {"content-type": "application/json"}
|
||||
res.json.return_value = {"access_token": "at", "token_type": "Bearer"}
|
||||
mock_post.return_value = res
|
||||
|
||||
tokens = exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
|
||||
assert tokens.access_token == "at"
|
||||
|
||||
# Failure: Unsupported grant type
|
||||
metadata.grant_types_supported = ["client_credentials"]
|
||||
with pytest.raises(ValueError, match="Incompatible auth server"):
|
||||
exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
|
||||
|
||||
# Failure: HTTP error
|
||||
metadata.grant_types_supported = ["authorization_code"]
|
||||
res.is_success = False
|
||||
res.status_code = 400
|
||||
with pytest.raises(ValueError, match="Token exchange failed"):
|
||||
exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re")
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_refresh_authorization(self, mock_post):
|
||||
# Case 1: with client_secret
|
||||
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
|
||||
|
||||
# Success
|
||||
res = Mock()
|
||||
res.is_success = True
|
||||
res.headers = {"content-type": "application/json"}
|
||||
res.json.return_value = {"access_token": "at_new", "token_type": "Bearer"}
|
||||
mock_post.return_value = res
|
||||
|
||||
tokens = refresh_authorization("https://api", None, client_info, "rt")
|
||||
assert tokens.access_token == "at_new"
|
||||
assert mock_post.call_args[1]["data"]["client_secret"] == "s1"
|
||||
|
||||
# Failure: MaxRetriesExceededError
|
||||
mock_post.side_effect = ssrf_proxy.MaxRetriesExceededError("Too many retries")
|
||||
with pytest.raises(MCPRefreshTokenError):
|
||||
refresh_authorization("https://api", None, client_info, "rt")
|
||||
|
||||
# Failure: HTTP error
|
||||
mock_post.side_effect = None
|
||||
res.is_success = False
|
||||
res.text = "error_msg"
|
||||
with pytest.raises(MCPRefreshTokenError, match="error_msg"):
|
||||
refresh_authorization("https://api", None, client_info, "rt")
|
||||
|
||||
# Failure: Incompatible metadata
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
with pytest.raises(ValueError, match="Incompatible auth server"):
|
||||
refresh_authorization("https://api", metadata, client_info, "rt")
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_client_credentials_flow(self, mock_post):
|
||||
client_info = OAuthClientInformation(client_id="c1", client_secret="s1")
|
||||
|
||||
# Success with secret
|
||||
res = Mock()
|
||||
res.is_success = True
|
||||
res.headers = {"content-type": "application/json"}
|
||||
res.json.return_value = {"access_token": "at_cc", "token_type": "Bearer"}
|
||||
mock_post.return_value = res
|
||||
|
||||
tokens = client_credentials_flow("https://api", None, client_info, "read")
|
||||
assert tokens.access_token == "at_cc"
|
||||
args, kwargs = mock_post.call_args
|
||||
assert "Authorization" in kwargs["headers"]
|
||||
|
||||
# Success without secret
|
||||
client_info_no_secret = OAuthClientInformation(client_id="c2")
|
||||
tokens = client_credentials_flow("https://api", None, client_info_no_secret)
|
||||
args, kwargs = mock_post.call_args
|
||||
assert kwargs["data"]["client_id"] == "c2"
|
||||
|
||||
# Failure: Incompatible metadata
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
with pytest.raises(ValueError, match="Incompatible auth server"):
|
||||
client_credentials_flow("https://api", metadata, client_info)
|
||||
|
||||
# Failure: HTTP error
|
||||
res.is_success = False
|
||||
res.status_code = 401
|
||||
res.text = "Unauthorized"
|
||||
with pytest.raises(ValueError, match="Client credentials token request failed"):
|
||||
client_credentials_flow("https://api", None, client_info)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_register_client(self, mock_post):
|
||||
# Case 1: Success with metadata
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
registration_endpoint="https://auth/register",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://re"])
|
||||
|
||||
res = Mock()
|
||||
res.is_success = True
|
||||
res.json.return_value = {
|
||||
"client_id": "c_new",
|
||||
"client_secret": "s_new",
|
||||
"client_name": "Dify",
|
||||
"redirect_uris": ["https://re"],
|
||||
}
|
||||
mock_post.return_value = res
|
||||
|
||||
info = register_client("https://api", metadata, client_metadata)
|
||||
assert info.client_id == "c_new"
|
||||
|
||||
# Case 2: Success without metadata
|
||||
info = register_client("https://api", None, client_metadata)
|
||||
assert mock_post.call_args[0][0] == "https://api/register"
|
||||
|
||||
# Case 3: Metadata provided but no endpoint
|
||||
metadata.registration_endpoint = None
|
||||
with pytest.raises(ValueError, match="does not support dynamic client registration"):
|
||||
register_client("https://api", metadata, client_metadata)
|
||||
|
||||
# Failure: HTTP
|
||||
res.is_success = False
|
||||
res.raise_for_status = Mock()
|
||||
res.status_code = 400
|
||||
# If is_success is false, it should call raise_for_status
|
||||
register_client("https://api", None, client_metadata)
|
||||
res.raise_for_status.assert_called_once()
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_orchestration_failures(self, mock_discover):
|
||||
provider = Mock(spec=MCPProviderEntity)
|
||||
provider.decrypt_server_url.return_value = "https://api"
|
||||
provider.id = "p1"
|
||||
provider.tenant_id = "t1"
|
||||
|
||||
# Case 1: No server metadata
|
||||
mock_discover.return_value = (None, None, None)
|
||||
with pytest.raises(ValueError, match="Failed to discover OAuth metadata"):
|
||||
auth(provider)
|
||||
|
||||
# Case 2: No client info, exchange code provided
|
||||
asm = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
mock_discover.return_value = (asm, None, None)
|
||||
provider.retrieve_client_information.return_value = None
|
||||
with pytest.raises(ValueError, match="Existing OAuth client information is required"):
|
||||
auth(provider, authorization_code="code")
|
||||
|
||||
# Case 3: CLIENT_CREDENTIALS but client must provide info
|
||||
asm.grant_types_supported = ["client_credentials"]
|
||||
with pytest.raises(ValueError, match="requires client_id and client_secret"):
|
||||
auth(provider)
|
||||
|
||||
# Case 4: Client registration fails
|
||||
asm.grant_types_supported = ["authorization_code"]
|
||||
with patch("core.mcp.auth.auth_flow.register_client") as mock_reg:
|
||||
mock_reg.side_effect = httpx.RequestError("Reg failed")
|
||||
with pytest.raises(ValueError, match="Could not register OAuth client"):
|
||||
auth(provider)
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_orchestration_client_credentials(self, mock_discover):
|
||||
provider = Mock(spec=MCPProviderEntity)
|
||||
provider.decrypt_server_url.return_value = "https://api"
|
||||
provider.id = "p1"
|
||||
provider.tenant_id = "t1"
|
||||
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1", client_secret="s1")
|
||||
provider.decrypt_credentials.return_value = {"scope": "read"}
|
||||
|
||||
asm = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["client_credentials"],
|
||||
)
|
||||
mock_discover.return_value = (asm, None, None)
|
||||
|
||||
with patch("core.mcp.auth.auth_flow.client_credentials_flow") as mock_cc:
|
||||
mock_cc.return_value = OAuthTokens(access_token="at_cc", token_type="Bearer")
|
||||
|
||||
result = auth(provider)
|
||||
assert result.response == {"result": "success"}
|
||||
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
|
||||
assert result.actions[0].data["grant_type"] == "client_credentials"
|
||||
|
||||
# Failure in CC flow
|
||||
mock_cc.side_effect = ValueError("CC Failed")
|
||||
with pytest.raises(ValueError, match="Client credentials flow failed"):
|
||||
auth(provider)
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_orchestration_authorization_code(self, mock_discover):
|
||||
provider = Mock(spec=MCPProviderEntity)
|
||||
provider.decrypt_server_url.return_value = "https://api"
|
||||
provider.id = "p1"
|
||||
provider.tenant_id = "t1"
|
||||
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1")
|
||||
provider.decrypt_credentials.return_value = {}
|
||||
|
||||
asm = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
mock_discover.return_value = (asm, None, None)
|
||||
|
||||
# Case 1: Exchange code
|
||||
with patch("core.mcp.auth.auth_flow._retrieve_redis_state") as mock_retrieve:
|
||||
state = Mock(spec=OAuthCallbackState)
|
||||
state.code_verifier = "cv"
|
||||
state.redirect_uri = "https://re"
|
||||
mock_retrieve.return_value = state
|
||||
|
||||
with patch("core.mcp.auth.auth_flow.exchange_authorization") as mock_exchange:
|
||||
mock_exchange.return_value = OAuthTokens(access_token="at_code", token_type="Bearer")
|
||||
|
||||
# Success
|
||||
result = auth(provider, authorization_code="code", state_param="sp")
|
||||
assert result.response == {"result": "success"}
|
||||
|
||||
# Missing state_param
|
||||
with pytest.raises(ValueError, match="State parameter is required"):
|
||||
auth(provider, authorization_code="code")
|
||||
|
||||
# Missing verifier in state
|
||||
state.code_verifier = None
|
||||
with pytest.raises(ValueError, match="Missing code_verifier"):
|
||||
auth(provider, authorization_code="code", state_param="sp")
|
||||
|
||||
# Invalid state
|
||||
mock_retrieve.side_effect = ValueError("Invalid")
|
||||
with pytest.raises(ValueError, match="Invalid state parameter"):
|
||||
auth(provider, authorization_code="code", state_param="sp")
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_orchestration_refresh_failure(self, mock_discover):
|
||||
provider = Mock(spec=MCPProviderEntity)
|
||||
provider.decrypt_server_url.return_value = "https://api"
|
||||
provider.id = "p1"
|
||||
provider.tenant_id = "t1"
|
||||
provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1")
|
||||
provider.decrypt_credentials.return_value = {}
|
||||
provider.retrieve_tokens.return_value = OAuthTokens(access_token="at", token_type="Bearer", refresh_token="rt")
|
||||
|
||||
asm = OAuthMetadata(
|
||||
authorization_endpoint="https://auth/auth",
|
||||
token_endpoint="https://auth/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
mock_discover.return_value = (asm, None, None)
|
||||
|
||||
with patch("core.mcp.auth.auth_flow.refresh_authorization") as mock_refresh:
|
||||
mock_refresh.side_effect = ValueError("Refresh Failed")
|
||||
with pytest.raises(ValueError, match="Could not refresh OAuth tokens"):
|
||||
auth(provider)
|
||||
|
||||
@ -322,3 +322,475 @@ def test_sse_client_concurrent_access():
|
||||
assert len(received_messages) == 10
|
||||
for i in range(10):
|
||||
assert f"message_{i}" in received_messages
|
||||
|
||||
|
||||
class TestStatusClasses:
|
||||
"""Tests for _StatusReady and _StatusError data containers."""
|
||||
|
||||
def test_status_ready_stores_endpoint(self):
|
||||
from core.mcp.client.sse_client import _StatusReady
|
||||
|
||||
status = _StatusReady("http://example.com/messages/")
|
||||
assert status.endpoint_url == "http://example.com/messages/"
|
||||
|
||||
def test_status_error_stores_exception(self):
|
||||
from core.mcp.client.sse_client import _StatusError
|
||||
|
||||
exc = ValueError("bad endpoint")
|
||||
status = _StatusError(exc)
|
||||
assert status.exc is exc
|
||||
|
||||
|
||||
class TestSSETransportInit:
|
||||
"""Tests for SSETransport default and explicit init values."""
|
||||
|
||||
def test_defaults(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
t = SSETransport("http://example.com/sse")
|
||||
assert t.url == "http://example.com/sse"
|
||||
assert t.headers == {}
|
||||
assert t.timeout == 5.0
|
||||
assert t.sse_read_timeout == 60.0
|
||||
assert t.endpoint_url is None
|
||||
assert t.event_source is None
|
||||
|
||||
def test_explicit_headers_not_mutated(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
hdrs = {"X-Foo": "bar"}
|
||||
t = SSETransport("http://example.com/sse", headers=hdrs)
|
||||
assert t.headers is hdrs
|
||||
|
||||
|
||||
class TestHandleEndpointEvent:
|
||||
"""Tests for SSETransport._handle_endpoint_event covering the invalid-origin branch."""
|
||||
|
||||
def test_invalid_origin_puts_status_error(self):
|
||||
from core.mcp.client.sse_client import SSETransport, _StatusError
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# Provide a full URL with a different origin so urljoin keeps it as-is
|
||||
transport._handle_endpoint_event("http://evil.com/messages/", status_queue)
|
||||
|
||||
result = status_queue.get_nowait()
|
||||
assert isinstance(result, _StatusError)
|
||||
assert "does not match" in str(result.exc)
|
||||
|
||||
def test_valid_origin_puts_status_ready(self):
|
||||
from core.mcp.client.sse_client import SSETransport, _StatusReady
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
transport._handle_endpoint_event("/messages/?session_id=abc", status_queue)
|
||||
|
||||
result = status_queue.get_nowait()
|
||||
assert isinstance(result, _StatusReady)
|
||||
assert "example.com" in result.endpoint_url
|
||||
|
||||
|
||||
class TestHandleSSEEvent:
|
||||
"""Tests for SSETransport._handle_sse_event covering all match branches."""
|
||||
|
||||
def _make_sse(self, event_type: str, data: str):
|
||||
sse = Mock()
|
||||
sse.event = event_type
|
||||
sse.data = data
|
||||
return sse
|
||||
|
||||
def test_message_event_dispatched(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
read_queue: queue.Queue = queue.Queue()
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
valid_msg = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
transport._handle_sse_event(self._make_sse("message", valid_msg), read_queue, status_queue)
|
||||
|
||||
item = read_queue.get_nowait()
|
||||
assert hasattr(item, "message")
|
||||
|
||||
def test_unknown_event_logs_warning_and_does_nothing(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
read_queue: queue.Queue = queue.Queue()
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
transport._handle_sse_event(self._make_sse("ping", "{}"), read_queue, status_queue)
|
||||
|
||||
assert read_queue.empty()
|
||||
assert status_queue.empty()
|
||||
|
||||
|
||||
class TestSSEReader:
|
||||
"""Tests for SSETransport.sse_reader exception branches."""
|
||||
|
||||
def test_read_error_closes_cleanly(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
read_queue: queue.Queue = queue.Queue()
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
event_source = Mock()
|
||||
event_source.iter_sse.side_effect = httpx.ReadError("connection reset")
|
||||
|
||||
transport.sse_reader(event_source, read_queue, status_queue)
|
||||
|
||||
# Finally block always puts None as sentinel
|
||||
sentinel = read_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
|
||||
def test_generic_exception_puts_exc_then_none(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
read_queue: queue.Queue = queue.Queue()
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
boom = RuntimeError("unexpected!")
|
||||
event_source = Mock()
|
||||
event_source.iter_sse.side_effect = boom
|
||||
|
||||
transport.sse_reader(event_source, read_queue, status_queue)
|
||||
|
||||
exc_item = read_queue.get_nowait()
|
||||
assert exc_item is boom
|
||||
|
||||
sentinel = read_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
|
||||
|
||||
class TestSendMessage:
|
||||
"""Tests for SSETransport._send_message."""
|
||||
|
||||
def _make_session_message(self):
|
||||
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
msg = types.JSONRPCMessage.model_validate_json(msg_json)
|
||||
return types.SessionMessage(msg)
|
||||
|
||||
def test_sends_post_and_raises_for_status(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_client = Mock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
transport._send_message(mock_client, "http://example.com/messages/", session_msg)
|
||||
|
||||
mock_client.post.assert_called_once()
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
class TestPostWriter:
|
||||
"""Tests for SSETransport.post_writer exception branches."""
|
||||
|
||||
def _make_session_message(self):
|
||||
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
msg = types.JSONRPCMessage.model_validate_json(msg_json)
|
||||
return types.SessionMessage(msg)
|
||||
|
||||
def test_none_message_exits_loop(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
write_queue: queue.Queue = queue.Queue()
|
||||
write_queue.put(None) # Signal shutdown immediately
|
||||
|
||||
mock_client = Mock()
|
||||
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
|
||||
|
||||
# Should put final None sentinel
|
||||
sentinel = write_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
|
||||
def test_exception_in_message_put_back_to_queue(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
exc = ValueError("some error")
|
||||
write_queue.put(exc) # Exception goes in first
|
||||
write_queue.put(None) # Then shutdown signal
|
||||
|
||||
mock_client = Mock()
|
||||
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
|
||||
|
||||
# The exception should be re-queued, then None from loop exit, then None from finally
|
||||
item1 = write_queue.get_nowait()
|
||||
assert isinstance(item1, Exception)
|
||||
|
||||
def test_read_error_shuts_down_cleanly(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
write_queue.put(session_msg)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_client = Mock()
|
||||
mock_client.post.side_effect = httpx.ReadError("connection dropped")
|
||||
|
||||
# post_writer calls _send_message which calls client.post → ReadError propagates
|
||||
# The ReadError is raised inside _send_message → propagates out of the while loop
|
||||
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
|
||||
|
||||
# finally always puts None
|
||||
sentinel = write_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
|
||||
def test_generic_exception_puts_exc_in_queue(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
write_queue.put(session_msg)
|
||||
|
||||
mock_client = Mock()
|
||||
boom = RuntimeError("boom")
|
||||
mock_client.post.side_effect = boom
|
||||
|
||||
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
|
||||
|
||||
exc_item = write_queue.get_nowait()
|
||||
assert isinstance(exc_item, Exception)
|
||||
|
||||
sentinel = write_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
|
||||
def test_queue_empty_timeout_continues_loop(self):
|
||||
"""Cover the 'except queue.Empty: continue' branch (line 188) in post_writer."""
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
mock_client = Mock()
|
||||
|
||||
# Patch queue.Queue.get so it raises Empty first, then returns None (shutdown)
|
||||
call_count = {"n": 0}
|
||||
original_get = write_queue.get
|
||||
|
||||
def patched_get(*args, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
raise queue.Empty
|
||||
|
||||
write_queue.get = patched_get # type: ignore[method-assign]
|
||||
|
||||
transport.post_writer(mock_client, "http://example.com/messages/", write_queue)
|
||||
|
||||
# finally always puts None sentinel
|
||||
sentinel = write_queue.get_nowait()
|
||||
assert sentinel is None
|
||||
assert call_count["n"] >= 2 # Empty on first, None on second (and possibly more retries)
|
||||
|
||||
|
||||
class TestWaitForEndpoint:
|
||||
"""Tests for SSETransport._wait_for_endpoint edge cases."""
|
||||
|
||||
def test_raises_on_empty_queue(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
status_queue: queue.Queue = queue.Queue() # empty
|
||||
|
||||
with pytest.raises(ValueError, match="failed to get endpoint URL"):
|
||||
transport._wait_for_endpoint(status_queue)
|
||||
|
||||
def test_raises_status_error_exception(self):
|
||||
from core.mcp.client.sse_client import SSETransport, _StatusError
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
exc = ValueError("malicious endpoint")
|
||||
status_queue.put(_StatusError(exc))
|
||||
|
||||
with pytest.raises(ValueError, match="malicious endpoint"):
|
||||
transport._wait_for_endpoint(status_queue)
|
||||
|
||||
def test_raises_on_unknown_status_type(self):
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
status_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# Put an object that is neither _StatusReady nor _StatusError
|
||||
status_queue.put("unexpected_value")
|
||||
|
||||
with pytest.raises(ValueError, match="failed to get endpoint URL"):
|
||||
transport._wait_for_endpoint(status_queue)
|
||||
|
||||
|
||||
class TestSSEClientRuntimeError:
|
||||
"""Test sse_client context manager handles RuntimeError on close()."""
|
||||
|
||||
def test_runtime_error_on_close_is_suppressed(self):
|
||||
"""Ensure RuntimeError raised by event_source.response.close() is caught."""
|
||||
test_url = "http://test.example/sse"
|
||||
|
||||
class MockSSEEvent:
|
||||
def __init__(self, event_type: str, data: str):
|
||||
self.event = event_type
|
||||
self.data = data
|
||||
|
||||
endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123")
|
||||
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_cf:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sc:
|
||||
mock_client = Mock()
|
||||
mock_cf.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_es = Mock()
|
||||
mock_es.response.raise_for_status.return_value = None
|
||||
mock_es.iter_sse.return_value = [endpoint_event]
|
||||
# Make close() raise RuntimeError to exercise line 307-308
|
||||
mock_es.response.close.side_effect = RuntimeError("already closed")
|
||||
mock_sc.return_value.__enter__.return_value = mock_es
|
||||
|
||||
# Should NOT raise even though close() raises RuntimeError
|
||||
with contextlib.suppress(Exception):
|
||||
with sse_client(test_url) as (rq, wq):
|
||||
pass
|
||||
|
||||
|
||||
class TestStandaloneSendMessage:
|
||||
"""Tests for the module-level send_message() function."""
|
||||
|
||||
def _make_session_message(self):
|
||||
msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
msg = types.JSONRPCMessage.model_validate_json(msg_json)
|
||||
return types.SessionMessage(msg)
|
||||
|
||||
def test_send_message_success(self):
|
||||
from core.mcp.client.sse_client import send_message
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_http_client = Mock()
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
send_message(mock_http_client, "http://example.com/messages/", session_msg)
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
|
||||
def test_send_message_raises_on_http_error(self):
|
||||
from core.mcp.client.sse_client import send_message
|
||||
|
||||
mock_http_client = Mock()
|
||||
mock_http_client.post.side_effect = httpx.ConnectError("refused")
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
send_message(mock_http_client, "http://example.com/messages/", session_msg)
|
||||
|
||||
def test_send_message_raises_for_status_failure(self):
|
||||
from core.mcp.client.sse_client import send_message
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"Not Found", request=Mock(), response=Mock(status_code=404)
|
||||
)
|
||||
mock_http_client = Mock()
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
session_msg = self._make_session_message()
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
send_message(mock_http_client, "http://example.com/messages/", session_msg)
|
||||
|
||||
|
||||
class TestReadMessages:
|
||||
"""Tests for the module-level read_messages() generator."""
|
||||
|
||||
def _make_mock_sse_event(self, event_type: str, data: str):
|
||||
ev = Mock()
|
||||
ev.event = event_type
|
||||
ev.data = data
|
||||
return ev
|
||||
|
||||
def test_valid_message_event_yields_session_message(self):
|
||||
from core.mcp.client.sse_client import read_messages
|
||||
|
||||
valid_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
mock_sse_event = self._make_mock_sse_event("message", valid_json)
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.events.return_value = [mock_sse_event]
|
||||
|
||||
results = list(read_messages(mock_client))
|
||||
assert len(results) == 1
|
||||
assert hasattr(results[0], "message")
|
||||
|
||||
def test_invalid_json_yields_exception(self):
|
||||
from core.mcp.client.sse_client import read_messages
|
||||
|
||||
mock_sse_event = self._make_mock_sse_event("message", "{not valid json}")
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.events.return_value = [mock_sse_event]
|
||||
|
||||
results = list(read_messages(mock_client))
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], Exception)
|
||||
|
||||
def test_non_message_event_is_skipped(self):
|
||||
from core.mcp.client.sse_client import read_messages
|
||||
|
||||
mock_sse_event = self._make_mock_sse_event("endpoint", "/messages/")
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.events.return_value = [mock_sse_event]
|
||||
|
||||
results = list(read_messages(mock_client))
|
||||
# Non-message events produce no output
|
||||
assert results == []
|
||||
|
||||
def test_outer_exception_yields_exc(self):
|
||||
from core.mcp.client.sse_client import read_messages
|
||||
|
||||
boom = RuntimeError("stream broken")
|
||||
mock_client = Mock()
|
||||
mock_client.events.side_effect = boom
|
||||
|
||||
results = list(read_messages(mock_client))
|
||||
assert len(results) == 1
|
||||
assert results[0] is boom
|
||||
|
||||
def test_multiple_events_mixed(self):
|
||||
from core.mcp.client.sse_client import read_messages
|
||||
|
||||
valid_json = '{"jsonrpc": "2.0", "id": 2, "result": {}}'
|
||||
events = [
|
||||
self._make_mock_sse_event("endpoint", "/messages/"),
|
||||
self._make_mock_sse_event("message", valid_json),
|
||||
self._make_mock_sse_event("message", "{bad json}"),
|
||||
]
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.events.return_value = events
|
||||
|
||||
results = list(read_messages(mock_client))
|
||||
# endpoint is skipped; 1 valid SessionMessage + 1 Exception
|
||||
assert len(results) == 2
|
||||
assert hasattr(results[0], "message")
|
||||
assert isinstance(results[1], Exception)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
617
api/tests/unit_tests/core/mcp/session/test_base_session.py
Normal file
617
api/tests/unit_tests/core/mcp/session/test_base_session.py
Normal file
@ -0,0 +1,617 @@
|
||||
import queue
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import Union
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import HTTPStatusError, Request, Response
|
||||
from pydantic import BaseModel, ConfigDict, RootModel
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.session.base_session import BaseSession, RequestResponder
|
||||
from core.mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCResponse,
|
||||
Notification,
|
||||
RequestParams,
|
||||
SessionMessage,
|
||||
)
|
||||
from core.mcp.types import (
|
||||
Request as MCPRequest,
|
||||
)
|
||||
|
||||
|
||||
class MockRequestParams(RequestParams):
|
||||
name: str = "default"
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class MockRequest(MCPRequest[MockRequestParams, str]):
|
||||
method: str = "test/request"
|
||||
params: MockRequestParams = MockRequestParams()
|
||||
|
||||
|
||||
class MockResult(BaseModel):
|
||||
result: str
|
||||
|
||||
|
||||
class MockNotificationParams(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class MockNotification(Notification[MockNotificationParams, str]):
|
||||
method: str = "test/notification"
|
||||
params: MockNotificationParams
|
||||
|
||||
|
||||
class ReceiveRequest(RootModel[Union[MockRequest, ClientRequest]]):
|
||||
pass
|
||||
|
||||
|
||||
class ReceiveNotification(RootModel[Union[CancelledNotification, MockNotification, JSONRPCNotification]]):
|
||||
pass
|
||||
|
||||
|
||||
class MockSession(BaseSession[MockRequest, MockNotification, MockResult, ReceiveRequest, ReceiveNotification]):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.received_requests = []
|
||||
self.received_notifications = []
|
||||
self.handled_incoming = []
|
||||
|
||||
def _received_request(self, responder):
|
||||
self.received_requests.append(responder)
|
||||
|
||||
def _received_notification(self, notification):
|
||||
self.received_notifications.append(notification)
|
||||
|
||||
def _handle_incoming(self, item):
|
||||
self.handled_incoming.append(item)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streams():
|
||||
return queue.Queue(), queue.Queue()
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_request_responder_respond(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
on_complete = MagicMock()
|
||||
request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test")))
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"):
|
||||
responder.respond(MockResult(result="ok"))
|
||||
|
||||
with responder as r:
|
||||
r.respond(MockResult(result="ok"))
|
||||
with pytest.raises(AssertionError, match="Request already responded to"):
|
||||
r.respond(MockResult(result="error"))
|
||||
|
||||
assert responder.completed is True
|
||||
on_complete.assert_called_once_with(responder)
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert isinstance(msg.message.root, JSONRPCResponse)
|
||||
assert msg.message.root.result == {"result": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_request_responder_cancel(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
on_complete = MagicMock()
|
||||
request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test")))
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"):
|
||||
responder.cancel()
|
||||
|
||||
with responder as r:
|
||||
r.cancel()
|
||||
|
||||
assert responder.completed is True
|
||||
on_complete.assert_called_once_with(responder)
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert isinstance(msg.message.root, JSONRPCError)
|
||||
assert msg.message.root.error.message == "Request cancelled"
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_base_session_lifecycle(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session as s:
|
||||
assert isinstance(s, MockSession)
|
||||
assert s._executor is not None
|
||||
assert s._receiver_future is not None
|
||||
|
||||
session._receiver_future.result(timeout=5.0)
|
||||
assert session._receiver_future.done()
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_success(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_response():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "hello world"})
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage(response)))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_response, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.send_request(request, MockResult)
|
||||
assert result.result == "hello world"
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_retry_loop_coverage(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_delayed_response():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
time.sleep(0.2)
|
||||
response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "slow"})
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage(response)))
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_delayed_response, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.send_request(request, MockResult, request_read_timeout_seconds=timedelta(seconds=0.1))
|
||||
assert result.result == "slow"
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_jsonrpc_error(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_error():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=-32000, message="Error"))
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage(error)))
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_error, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
with pytest.raises(MCPConnectionError) as exc:
|
||||
session.send_request(request, MockResult)
|
||||
assert exc.value.args[0].message == "Error"
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_auth_error(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_error():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=401, message="Unauthorized"))
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage(error)))
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_error, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
with pytest.raises(MCPAuthError):
|
||||
session.send_request(request, MockResult)
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_http_status_error_coverage(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_direct_http_error():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
# To cover line 263 in base_session.py, we MUST put non-401 HTTPStatusError
|
||||
# DIRECTLY into response_streams, as _receive_loop would convert it to JSONRPCError.
|
||||
response = Response(status_code=403, request=Request("GET", "http://test"))
|
||||
error = HTTPStatusError("Forbidden", request=response.request, response=response)
|
||||
session._response_streams[req_id].put(error)
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_direct_http_error, daemon=True)
|
||||
t.start()
|
||||
|
||||
# We still need the session for request ID generation and queue setup
|
||||
with session:
|
||||
with pytest.raises(MCPConnectionError) as exc:
|
||||
session.send_request(request, MockResult)
|
||||
assert exc.value.args[0].code == 403
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_http_status_auth_error(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_error():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
response = Response(status_code=401, request=Request("GET", "http://test"))
|
||||
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
|
||||
read_stream.put(error)
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_error, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
with pytest.raises(MCPAuthError):
|
||||
session.send_request(request, MockResult)
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_notification(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
notification = MockNotification(method="notify", params=MockNotificationParams(message="hi"))
|
||||
|
||||
session.send_notification(notification, related_request_id="rel-1")
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert isinstance(msg.message.root, JSONRPCNotification)
|
||||
assert msg.message.root.method == "notify"
|
||||
assert msg.message.root.params == {"message": "hi"}
|
||||
assert msg.metadata.related_request_id == "rel-1"
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_request(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
req_payload = {"jsonrpc": "2.0", "id": 1, "method": "test/request", "params": {"name": "test"}}
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload)))
|
||||
|
||||
for _ in range(30):
|
||||
if session.received_requests:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(session.received_requests) == 1
|
||||
responder = session.received_requests[0]
|
||||
assert responder.request_id == 1
|
||||
assert responder.request.root.method == "test/request"
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_notification(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
notif_payload = {"jsonrpc": "2.0", "method": "test/notification", "params": {"message": "hello"}}
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload)))
|
||||
|
||||
for _ in range(30):
|
||||
if session.received_notifications:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(session.received_notifications) == 1
|
||||
assert isinstance(session.received_notifications[0].root, MockNotification)
|
||||
assert session.received_notifications[0].root.method == "test/notification"
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_receive_loop_cancel_notification(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ClientNotification)
|
||||
|
||||
with session:
|
||||
req_payload = {"jsonrpc": "2.0", "id": "req-1", "method": "test/request", "params": {"name": "test"}}
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload)))
|
||||
|
||||
for _ in range(30):
|
||||
if "req-1" in session._in_flight:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert "req-1" in session._in_flight
|
||||
responder = session._in_flight["req-1"]
|
||||
|
||||
with responder:
|
||||
cancel_payload = {"jsonrpc": "2.0", "method": "notifications/cancelled", "params": {"requestId": "req-1"}}
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(cancel_payload)))
|
||||
|
||||
for _ in range(30):
|
||||
if responder.completed:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert responder.completed is True
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert isinstance(msg.message.root, JSONRPCError)
|
||||
assert msg.message.root.id == "req-1"
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_exception(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
read_stream.put(Exception("Unexpected error"))
|
||||
for _ in range(30):
|
||||
if any(isinstance(x, Exception) for x in session.handled_incoming):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert any(isinstance(x, Exception) and str(x) == "Unexpected error" for x in session.handled_incoming)
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_http_status_error(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
session._request_id = 1
|
||||
resp_queue = queue.Queue()
|
||||
session._response_streams[0] = resp_queue
|
||||
|
||||
response = Response(status_code=401, request=Request("GET", "http://test"))
|
||||
# Using 401 specifically as _receive_loop preserves it
|
||||
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
|
||||
read_stream.put(error)
|
||||
|
||||
got = resp_queue.get(timeout=2)
|
||||
assert isinstance(got, HTTPStatusError)
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_http_status_error_non_401(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
session._request_id = 1
|
||||
resp_queue = queue.Queue()
|
||||
session._response_streams[0] = resp_queue
|
||||
|
||||
response = Response(status_code=500, request=Request("GET", "http://test"))
|
||||
error = HTTPStatusError("Server Error", request=response.request, response=response)
|
||||
read_stream.put(error)
|
||||
|
||||
got = resp_queue.get(timeout=2)
|
||||
assert isinstance(got, JSONRPCError)
|
||||
assert got.error.code == 500
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_check_receiver_status_fail(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def raise_err():
|
||||
raise RuntimeError("Receiver failed")
|
||||
|
||||
future = executor.submit(raise_err)
|
||||
session._receiver_future = future
|
||||
|
||||
try:
|
||||
future.result()
|
||||
except:
|
||||
pass
|
||||
|
||||
with pytest.raises(RuntimeError, match="Receiver failed"):
|
||||
session.check_receiver_status()
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_unknown_request_id(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
resp = JSONRPCResponse(jsonrpc="2.0", id=999, result={"ok": True})
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage(resp)))
|
||||
|
||||
for _ in range(30):
|
||||
if any(isinstance(x, RuntimeError) and "Server Error" in str(x) for x in session.handled_incoming):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert any("Server Error" in str(x) for x in session.handled_incoming)
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_http_error_unknown_id(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with session:
|
||||
response = Response(status_code=401, request=Request("GET", "http://test"))
|
||||
error = HTTPStatusError("Unauthorized", request=response.request, response=response)
|
||||
read_stream.put(error)
|
||||
|
||||
for _ in range(30):
|
||||
if any(isinstance(x, RuntimeError) and "unknown request ID" in str(x) for x in session.handled_incoming):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert any("unknown request ID" in str(x) for x in session.handled_incoming)
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_validation_error_notification(streams):
|
||||
from core.mcp.session.base_session import logger
|
||||
|
||||
with patch.object(logger, "warning") as mock_warning:
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, RootModel[MockNotification])
|
||||
|
||||
with session:
|
||||
notif_payload = {"jsonrpc": "2.0", "method": "bad", "params": {"some": "data"}}
|
||||
read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload)))
|
||||
time.sleep(1.0)
|
||||
|
||||
assert mock_warning.called
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_none_response(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
def mock_none():
|
||||
try:
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
session._response_streams[req_id].put(None)
|
||||
except:
|
||||
pass
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_none, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
with pytest.raises(MCPConnectionError) as exc:
|
||||
session.send_request(request, MockResult)
|
||||
assert exc.value.args[0].message == "No response received"
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_session_exit_timeout(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
mock_future = MagicMock(spec=Future)
|
||||
mock_future.result.side_effect = TimeoutError()
|
||||
mock_future.done.return_value = False
|
||||
|
||||
session._receiver_future = mock_future
|
||||
session._executor = MagicMock(spec=ThreadPoolExecutor)
|
||||
|
||||
session.__exit__(None, None, None)
|
||||
|
||||
mock_future.cancel.assert_called_once()
|
||||
session._executor.shutdown.assert_called_once_with(wait=False)
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_receive_loop_fatal_exception(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
with patch.object(read_stream, "get", side_effect=RuntimeError("Fatal loop error")):
|
||||
with patch("core.mcp.session.base_session.logger") as mock_logger:
|
||||
with pytest.raises(RuntimeError, match="Fatal loop error"):
|
||||
with session:
|
||||
pass
|
||||
mock_logger.exception.assert_called_with("Error in message processing loop")
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_receive_loop_empty_coverage(streams):
|
||||
with patch("core.mcp.session.base_session.DEFAULT_RESPONSE_READ_TIMEOUT", 0.1):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
with session:
|
||||
time.sleep(0.3)
|
||||
|
||||
|
||||
@pytest.mark.timeout(2)
|
||||
def test_base_methods_noop(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = BaseSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification)
|
||||
|
||||
session._received_request(MagicMock())
|
||||
session._received_notification(MagicMock())
|
||||
session.send_progress_notification("token", 0.5)
|
||||
session._handle_incoming(MagicMock())
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
def test_send_request_session_timeout_retry_6(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = MockSession(
|
||||
read_stream, write_stream, ReceiveRequest, ReceiveNotification, read_timeout_seconds=timedelta(seconds=0.1)
|
||||
)
|
||||
|
||||
request = MockRequest(method="test", params=MockRequestParams(name="world"))
|
||||
|
||||
with patch.object(session, "check_receiver_status", side_effect=[None, RuntimeError("timeout_broken")]):
|
||||
with pytest.raises(RuntimeError, match="timeout_broken"):
|
||||
session.send_request(request, MockResult)
|
||||
576
api/tests/unit_tests/core/mcp/session/test_client_session.py
Normal file
576
api/tests/unit_tests/core/mcp/session/test_client_session.py
Normal file
@ -0,0 +1,576 @@
|
||||
import queue
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.session.base_session import RequestResponder, SessionMessage
|
||||
from core.mcp.session.client_session import (
|
||||
ClientSession,
|
||||
_default_list_roots_callback,
|
||||
_default_logging_callback,
|
||||
_default_message_handler,
|
||||
_default_sampling_callback,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streams():
|
||||
return queue.Queue(), queue.Queue()
|
||||
|
||||
|
||||
def test_client_session_init(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
assert session._client_info.name == "Dify"
|
||||
assert session._sampling_callback == _default_sampling_callback
|
||||
assert session._list_roots_callback == _default_list_roots_callback
|
||||
assert session._logging_callback == _default_logging_callback
|
||||
assert session._message_handler == _default_message_handler
|
||||
|
||||
|
||||
def test_client_session_init_custom(streams):
|
||||
read_stream, write_stream = streams
|
||||
sampling_cb = MagicMock()
|
||||
list_roots_cb = MagicMock()
|
||||
logging_cb = MagicMock()
|
||||
msg_handler = MagicMock()
|
||||
client_info = types.Implementation(name="Custom", version="1.0")
|
||||
|
||||
session = ClientSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
sampling_callback=sampling_cb,
|
||||
list_roots_callback=list_roots_cb,
|
||||
logging_callback=logging_cb,
|
||||
message_handler=msg_handler,
|
||||
client_info=client_info,
|
||||
)
|
||||
|
||||
assert session._client_info == client_info
|
||||
assert session._sampling_callback == sampling_cb
|
||||
assert session._list_roots_callback == list_roots_cb
|
||||
assert session._logging_callback == logging_cb
|
||||
assert session._message_handler == msg_handler
|
||||
|
||||
|
||||
def test_initialize_success(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
expected_result = types.InitializeResult(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ServerCapabilities(),
|
||||
serverInfo=types.Implementation(name="test-server", version="1.0"),
|
||||
)
|
||||
|
||||
def mock_server():
|
||||
# Handle initialize request
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result=expected_result.model_dump())
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
# Expect initialized notification
|
||||
notif = write_stream.get(timeout=2)
|
||||
assert notif.message.root.method == "notifications/initialized"
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.initialize()
|
||||
assert result.protocolVersion == types.LATEST_PROTOCOL_VERSION
|
||||
assert result.serverInfo.name == "test-server"
|
||||
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_initialize_custom_capabilities(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(
|
||||
read_stream, write_stream, sampling_callback=lambda c, p: None, list_roots_callback=lambda c: None
|
||||
)
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
params = msg.message.root.params
|
||||
# Check that capabilities are set because we provided custom callbacks
|
||||
assert params["capabilities"]["sampling"] is not None
|
||||
assert params["capabilities"]["roots"]["listChanged"] is True
|
||||
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=req_id,
|
||||
result={
|
||||
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
|
||||
"capabilities": {},
|
||||
"serverInfo": {"name": "test", "version": "1.0"},
|
||||
},
|
||||
)
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
write_stream.get(timeout=2) # initialized notif
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
session.initialize()
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_initialize_unsupported_version(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=req_id,
|
||||
result={
|
||||
"protocolVersion": "0.0.1", # Unsupported
|
||||
"capabilities": {},
|
||||
"serverInfo": {"name": "test", "version": "1.0"},
|
||||
},
|
||||
)
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
|
||||
session.initialize()
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_send_ping(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "ping"
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
session.send_ping()
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_send_progress_notification(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
session.send_progress_notification(progress_token="token", progress=50.0, total=100.0)
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert msg.message.root.method == "notifications/progress"
|
||||
assert msg.message.root.params["progressToken"] == "token"
|
||||
assert msg.message.root.params["progress"] == 50.0
|
||||
assert msg.message.root.params["total"] == 100.0
|
||||
|
||||
|
||||
def test_set_logging_level(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "logging/setLevel"
|
||||
assert msg.message.root.params["level"] == "debug"
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
session.set_logging_level("debug")
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_list_resources(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "resources/list"
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"resources": []})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.list_resources()
|
||||
assert result.resources == []
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_list_resource_templates(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "resources/templates/list"
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"resourceTemplates": []})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.list_resource_templates()
|
||||
assert result.resourceTemplates == []
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_read_resource(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
uri = AnyUrl("file:///test")
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "resources/read"
|
||||
assert msg.message.root.params["uri"] == str(uri)
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"contents": []})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.read_resource(uri)
|
||||
assert result.contents == []
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_subscribe_resource(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
uri = AnyUrl("file:///test")
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "resources/subscribe"
|
||||
assert msg.message.root.params["uri"] == str(uri)
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
session.subscribe_resource(uri)
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_unsubscribe_resource(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
uri = AnyUrl("file:///test")
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "resources/unsubscribe"
|
||||
assert msg.message.root.params["uri"] == str(uri)
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
session.unsubscribe_resource(uri)
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_call_tool(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "tools/call"
|
||||
assert msg.message.root.params["name"] == "test-tool"
|
||||
assert msg.message.root.params["arguments"] == {"arg": 1}
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"content": [], "isError": False})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.call_tool("test-tool", arguments={"arg": 1})
|
||||
assert result.isError is False
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_list_prompts(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "prompts/list"
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"prompts": []})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.list_prompts()
|
||||
assert result.prompts == []
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_get_prompt(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "prompts/get"
|
||||
assert msg.message.root.params["name"] == "test-prompt"
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"messages": []})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.get_prompt("test-prompt")
|
||||
assert result.messages == []
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_complete(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
ref = types.PromptReference(type="ref/prompt", name="test")
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "completion/complete"
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"completion": {"values": [], "hasMore": False}})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.complete(ref, argument={"name": "val", "value": "x"})
|
||||
assert result.completion.hasMore is False
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_list_tools(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
def mock_server():
|
||||
msg = write_stream.get(timeout=2)
|
||||
assert msg.message.root.method == "tools/list"
|
||||
req_id = msg.message.root.id
|
||||
resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"tools": []})
|
||||
read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp)))
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(target=mock_server, daemon=True)
|
||||
t.start()
|
||||
|
||||
with session:
|
||||
result = session.list_tools()
|
||||
assert result.tools == []
|
||||
t.join(timeout=1)
|
||||
|
||||
|
||||
def test_send_roots_list_changed(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
session.send_roots_list_changed()
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert msg.message.root.method == "notifications/roots/list_changed"
|
||||
|
||||
|
||||
def test_received_request_sampling(streams):
|
||||
read_stream, write_stream = streams
|
||||
sampling_cb = MagicMock(
|
||||
return_value=types.CreateMessageResult(
|
||||
role="assistant", content=types.TextContent(type="text", text="hello"), model="gpt-4"
|
||||
)
|
||||
)
|
||||
session = ClientSession(read_stream, write_stream, sampling_callback=sampling_cb)
|
||||
|
||||
req = types.ServerRequest(
|
||||
root=types.CreateMessageRequest(
|
||||
method="sampling/createMessage", params=types.CreateMessageRequestParams(messages=[], maxTokens=100)
|
||||
)
|
||||
)
|
||||
|
||||
responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock())
|
||||
|
||||
session._received_request(responder)
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert msg.message.root.result["model"] == "gpt-4"
|
||||
sampling_cb.assert_called_once()
|
||||
|
||||
|
||||
def test_received_request_list_roots(streams):
|
||||
read_stream, write_stream = streams
|
||||
list_roots_cb = MagicMock(return_value=types.ListRootsResult(roots=[]))
|
||||
session = ClientSession(read_stream, write_stream, list_roots_callback=list_roots_cb)
|
||||
|
||||
req = types.ServerRequest(root=types.ListRootsRequest(method="roots/list"))
|
||||
|
||||
responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock())
|
||||
|
||||
session._received_request(responder)
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert msg.message.root.result["roots"] == []
|
||||
list_roots_cb.assert_called_once()
|
||||
|
||||
|
||||
def test_received_request_ping(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
req = types.ServerRequest(root=types.PingRequest(method="ping"))
|
||||
|
||||
responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock())
|
||||
|
||||
session._received_request(responder)
|
||||
|
||||
msg = write_stream.get_nowait()
|
||||
assert msg.message.root.result == {}
|
||||
|
||||
|
||||
def test_handle_incoming(streams):
|
||||
read_stream, write_stream = streams
|
||||
msg_handler = MagicMock()
|
||||
session = ClientSession(read_stream, write_stream, message_handler=msg_handler)
|
||||
|
||||
item = MagicMock()
|
||||
session._handle_incoming(item)
|
||||
msg_handler.assert_called_once_with(item)
|
||||
|
||||
|
||||
def test_received_notification_logging(streams):
|
||||
read_stream, write_stream = streams
|
||||
logging_cb = MagicMock()
|
||||
session = ClientSession(read_stream, write_stream, logging_callback=logging_cb)
|
||||
|
||||
notif = types.ServerNotification(
|
||||
root=types.LoggingMessageNotification(
|
||||
method="notifications/message",
|
||||
params=types.LoggingMessageNotificationParams(level="info", data={"msg": "test"}),
|
||||
)
|
||||
)
|
||||
|
||||
session._received_notification(notif)
|
||||
logging_cb.assert_called_once()
|
||||
assert logging_cb.call_args[0][0].level == "info"
|
||||
|
||||
|
||||
def test_default_message_handler():
|
||||
# Exception case
|
||||
with pytest.raises(ValueError, match="test error"):
|
||||
_default_message_handler(Exception("test error"))
|
||||
|
||||
# Notification case - should do nothing
|
||||
_default_message_handler(MagicMock(spec=types.ServerNotification))
|
||||
|
||||
# RequestResponder case - should do nothing
|
||||
_default_message_handler(MagicMock(spec=RequestResponder))
|
||||
|
||||
|
||||
def test_default_sampling_callback():
|
||||
ctx = MagicMock()
|
||||
params = MagicMock()
|
||||
res = _default_sampling_callback(ctx, params)
|
||||
assert res.code == types.INVALID_REQUEST
|
||||
assert "not supported" in res.message
|
||||
|
||||
|
||||
def test_default_list_roots_callback():
|
||||
ctx = MagicMock()
|
||||
res = _default_list_roots_callback(ctx)
|
||||
assert res.code == types.INVALID_REQUEST
|
||||
assert "not supported" in res.message
|
||||
|
||||
|
||||
def test_default_logging_callback():
|
||||
params = MagicMock()
|
||||
_default_logging_callback(params) # Should do nothing
|
||||
|
||||
|
||||
def test_received_notification_unknown(streams):
|
||||
read_stream, write_stream = streams
|
||||
session = ClientSession(read_stream, write_stream)
|
||||
|
||||
# Use a notification type that is NOT LoggingMessageNotification
|
||||
notif = types.ServerNotification(
|
||||
root=types.ResourceListChangedNotification(method="notifications/resources/list_changed")
|
||||
)
|
||||
|
||||
session._received_notification(notif)
|
||||
# Should just pass (case _:)
|
||||
@ -2,13 +2,16 @@
|
||||
|
||||
from contextlib import ExitStack
|
||||
from types import TracebackType
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations
|
||||
from core.mcp.types import CallToolResult, ListToolsResult, OAuthTokens, TextContent, Tool, ToolAnnotations
|
||||
|
||||
|
||||
class TestMCPClient:
|
||||
@ -380,3 +383,256 @@ class TestMCPClient:
|
||||
timeout=30.0,
|
||||
sse_read_timeout=60.0,
|
||||
)
|
||||
|
||||
|
||||
class TestMCPClientWithAuthRetry:
|
||||
"""Test suite for MCPClientWithAuthRetry."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(self):
|
||||
provider = MagicMock(spec=MCPProviderEntity)
|
||||
provider.id = "test-provider-id"
|
||||
provider.tenant_id = "test-tenant-id"
|
||||
provider.retrieve_tokens.return_value = OAuthTokens(
|
||||
access_token="new-token",
|
||||
token_type="Bearer",
|
||||
expires_in=3600,
|
||||
refresh_token="refresh-token",
|
||||
)
|
||||
return provider
|
||||
|
||||
@pytest.fixture
|
||||
def auth_client(self, mock_provider):
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
headers={"Authorization": "Bearer old-token"},
|
||||
provider_entity=mock_provider,
|
||||
authorization_code="test-code",
|
||||
by_server_id=True,
|
||||
)
|
||||
return client
|
||||
|
||||
def test_init(self, mock_provider):
|
||||
"""Test initialization."""
|
||||
client = MCPClientWithAuthRetry(
|
||||
server_url="http://test.example.com",
|
||||
headers={"Authorization": "Bearer test"},
|
||||
timeout=30.0,
|
||||
provider_entity=mock_provider,
|
||||
authorization_code="initial-code",
|
||||
by_server_id=True,
|
||||
)
|
||||
|
||||
assert client.server_url == "http://test.example.com"
|
||||
assert client.headers == {"Authorization": "Bearer test"}
|
||||
assert client.timeout == 30.0
|
||||
assert client.provider_entity == mock_provider
|
||||
assert client.authorization_code == "initial-code"
|
||||
assert client.by_server_id is True
|
||||
assert client._has_retried is False
|
||||
|
||||
@patch("core.mcp.auth_client.db")
|
||||
@patch("core.mcp.auth_client.Session")
|
||||
@patch("services.tools.mcp_tools_manage_service.MCPToolManageService")
|
||||
def test_handle_auth_error_success(
|
||||
self, mock_service_class, mock_session_class, mock_db, auth_client, mock_provider
|
||||
):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = mock_service_class.return_value
|
||||
new_provider = MagicMock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = OAuthTokens(
|
||||
access_token="new-access-token",
|
||||
token_type="Bearer",
|
||||
expires_in=3600,
|
||||
refresh_token="new-refresh-token",
|
||||
)
|
||||
mock_service.get_provider_entity.return_value = new_provider
|
||||
|
||||
# MCPAuthError parses resource_metadata and scope from www_authenticate_header
|
||||
www_auth = 'Bearer resource_metadata="http://meta", scope="read"'
|
||||
error = MCPAuthError("Auth failed", www_authenticate_header=www_auth)
|
||||
|
||||
auth_client._handle_auth_error(error)
|
||||
|
||||
# Verify service calls - error.resource_metadata_url and error.scope_hint are parsed from header
|
||||
mock_service.auth_with_actions.assert_called_once_with(
|
||||
mock_provider,
|
||||
"test-code",
|
||||
resource_metadata_url="http://meta",
|
||||
scope_hint="read",
|
||||
)
|
||||
mock_service.get_provider_entity.assert_called_once_with(
|
||||
mock_provider.id, mock_provider.tenant_id, by_server_id=True
|
||||
)
|
||||
|
||||
# Verify client updates
|
||||
assert auth_client.headers["Authorization"] == "Bearer new-access-token"
|
||||
assert auth_client.authorization_code is None
|
||||
assert auth_client._has_retried is True
|
||||
assert auth_client.provider_entity == new_provider
|
||||
|
||||
def test_handle_auth_error_no_provider(self, auth_client):
|
||||
"""Test auth error handling when no provider entity is set."""
|
||||
auth_client.provider_entity = None
|
||||
error = MCPAuthError("Auth failed")
|
||||
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
auth_client._handle_auth_error(error)
|
||||
|
||||
assert exc_info.value == error
|
||||
|
||||
def test_handle_auth_error_already_retried(self, auth_client):
|
||||
"""Test auth error handling when already retried."""
|
||||
auth_client._has_retried = True
|
||||
error = MCPAuthError("Auth failed")
|
||||
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
auth_client._handle_auth_error(error)
|
||||
|
||||
assert exc_info.value == error
|
||||
|
||||
@patch("core.mcp.auth_client.db")
|
||||
@patch("core.mcp.auth_client.Session")
|
||||
@patch("services.tools.mcp_tools_manage_service.MCPToolManageService")
|
||||
def test_handle_auth_error_no_token(
|
||||
self, mock_service_class, mock_session_class, mock_db, auth_client, mock_provider
|
||||
):
|
||||
"""Test auth error handling when no token is received."""
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_service = mock_service_class.return_value
|
||||
|
||||
new_provider = MagicMock(spec=MCPProviderEntity)
|
||||
new_provider.retrieve_tokens.return_value = None
|
||||
mock_service.get_provider_entity.return_value = new_provider
|
||||
|
||||
error = MCPAuthError("Auth failed")
|
||||
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
auth_client._handle_auth_error(error)
|
||||
|
||||
assert "Authentication failed - no token received" in str(exc_info.value)
|
||||
|
||||
@patch("core.mcp.auth_client.db")
|
||||
@patch("core.mcp.auth_client.Session")
|
||||
@patch("services.tools.mcp_tools_manage_service.MCPToolManageService")
|
||||
def test_handle_auth_error_generic_exception(self, mock_service_class, mock_session_class, mock_db, auth_client):
|
||||
"""Test auth error handling when a generic exception occurs."""
|
||||
mock_session_class.side_effect = Exception("DB error")
|
||||
|
||||
error = MCPAuthError("Auth failed")
|
||||
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
auth_client._handle_auth_error(error)
|
||||
|
||||
assert "Authentication retry failed: DB error" in str(exc_info.value)
|
||||
|
||||
@patch("core.mcp.auth_client.db")
|
||||
@patch("core.mcp.auth_client.Session")
|
||||
@patch("services.tools.mcp_tools_manage_service.MCPToolManageService")
|
||||
def test_handle_auth_error_mcp_auth_error_propagation(
|
||||
self, mock_service_class, mock_session_class, mock_db, auth_client
|
||||
):
|
||||
"""Test that MCPAuthError during refresh is propagated as is."""
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_service = mock_service_class.return_value
|
||||
mock_service.auth_with_actions.side_effect = MCPAuthError("Refresh failed")
|
||||
|
||||
error = MCPAuthError("Initial auth failed")
|
||||
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
auth_client._handle_auth_error(error)
|
||||
|
||||
assert "Refresh failed" in str(exc_info.value)
|
||||
|
||||
def test_execute_with_retry_success_first_try(self, auth_client):
|
||||
"""Test execution success on first try."""
|
||||
mock_func = MagicMock(return_value="success")
|
||||
|
||||
result = auth_client._execute_with_retry(mock_func, "arg1", kwarg1="val1")
|
||||
|
||||
assert result == "success"
|
||||
mock_func.assert_called_once_with("arg1", kwarg1="val1")
|
||||
assert auth_client._has_retried is False
|
||||
|
||||
@patch.object(MCPClientWithAuthRetry, "_handle_auth_error")
|
||||
@patch.object(MCPClientWithAuthRetry, "_initialize")
|
||||
def test_execute_with_retry_success_on_retry_initialized(self, mock_initialize, mock_handle_auth, auth_client):
|
||||
"""Test execution success on retry after auth error when client was already initialized."""
|
||||
mock_func = MagicMock()
|
||||
mock_func.side_effect = [MCPAuthError("Auth failed"), "success"]
|
||||
|
||||
auth_client._initialized = True
|
||||
auth_client._exit_stack = MagicMock()
|
||||
|
||||
result = auth_client._execute_with_retry(mock_func, "arg")
|
||||
|
||||
assert result == "success"
|
||||
assert mock_func.call_count == 2
|
||||
mock_handle_auth.assert_called_once()
|
||||
mock_initialize.assert_called_once()
|
||||
auth_client._exit_stack.close.assert_called_once()
|
||||
assert auth_client._has_retried is False
|
||||
|
||||
@patch.object(MCPClientWithAuthRetry, "_handle_auth_error")
|
||||
@patch.object(MCPClientWithAuthRetry, "_initialize")
|
||||
def test_execute_with_retry_success_on_retry_not_initialized(self, mock_initialize, mock_handle_auth, auth_client):
|
||||
"""Test retry when client was NOT initialized (skips cleanup/re-init)."""
|
||||
mock_func = MagicMock()
|
||||
mock_func.side_effect = [MCPAuthError("Auth failed"), "result"]
|
||||
|
||||
auth_client._initialized = False
|
||||
|
||||
result = auth_client._execute_with_retry(mock_func, "arg")
|
||||
|
||||
assert result == "result"
|
||||
assert mock_func.call_count == 2
|
||||
mock_handle_auth.assert_called_once()
|
||||
mock_initialize.assert_not_called()
|
||||
assert auth_client._has_retried is False
|
||||
|
||||
@patch.object(MCPClientWithAuthRetry, "_handle_auth_error")
|
||||
def test_execute_with_retry_failure_on_retry(self, mock_handle_auth, auth_client):
|
||||
"""Test execution failure even after retry."""
|
||||
mock_func = MagicMock()
|
||||
mock_func.side_effect = [MCPAuthError("First fail"), MCPAuthError("Second fail")]
|
||||
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
auth_client._execute_with_retry(mock_func, "arg")
|
||||
|
||||
assert "Second fail" in str(exc_info.value)
|
||||
assert mock_func.call_count == 2
|
||||
mock_handle_auth.assert_called_once()
|
||||
assert auth_client._has_retried is False
|
||||
|
||||
@patch.object(MCPClientWithAuthRetry, "_execute_with_retry")
|
||||
def test_auth_client_context_manager_enter(self, mock_execute_retry, auth_client):
|
||||
"""Test context manager __enter__."""
|
||||
auth_client.__enter__()
|
||||
|
||||
mock_execute_retry.assert_called_once()
|
||||
func = mock_execute_retry.call_args[0][0]
|
||||
|
||||
with patch("core.mcp.mcp_client.MCPClient.__enter__") as mock_base_enter:
|
||||
result = func()
|
||||
assert result == auth_client
|
||||
mock_base_enter.assert_called_once()
|
||||
|
||||
@patch.object(MCPClientWithAuthRetry, "_execute_with_retry")
|
||||
def test_auth_client_list_tools(self, mock_execute_retry, auth_client):
|
||||
"""Test list_tools with retry."""
|
||||
auth_client.list_tools()
|
||||
|
||||
mock_execute_retry.assert_called_once()
|
||||
assert mock_execute_retry.call_args[0][0].__name__ == "list_tools"
|
||||
|
||||
@patch.object(MCPClientWithAuthRetry, "_execute_with_retry")
|
||||
def test_auth_client_invoke_tool(self, mock_execute_retry, auth_client):
|
||||
"""Test invoke_tool with retry."""
|
||||
auth_client.invoke_tool("test-tool", {"arg": "val"})
|
||||
|
||||
mock_execute_retry.assert_called_once()
|
||||
assert mock_execute_retry.call_args[0][0].__name__ == "invoke_tool"
|
||||
assert mock_execute_retry.call_args[0][1] == "test-tool"
|
||||
assert mock_execute_retry.call_args[0][2] == {"arg": "val"}
|
||||
|
||||
Reference in New Issue
Block a user