Merge branch 'feat/mcp-06-18' into deploy/dev

This commit is contained in:
Novice
2025-10-14 21:40:29 +08:00
187 changed files with 9526 additions and 3521 deletions

View File

@ -25,9 +25,7 @@ class TestAnnotationService:
patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task,
patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task,
patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task,
patch(
"services.annotation_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
patch("services.annotation_service.current_account_with_tenant") as mock_current_account_with_tenant,
):
# Setup default mock returns
mock_account_feature_service.get_features.return_value.billing.enabled = False
@ -38,6 +36,9 @@ class TestAnnotationService:
mock_disable_task.delay.return_value = None
mock_batch_import_task.delay.return_value = None
# Create mock user that will be returned by current_account_with_tenant
mock_user = create_autospec(Account, instance=True)
yield {
"account_feature_service": mock_account_feature_service,
"feature_service": mock_feature_service,
@ -47,7 +48,8 @@ class TestAnnotationService:
"enable_task": mock_enable_task,
"disable_task": mock_disable_task,
"batch_import_task": mock_batch_import_task,
"current_user": mock_current_user,
"current_account_with_tenant": mock_current_account_with_tenant,
"current_user": mock_user,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
@ -107,6 +109,11 @@ class TestAnnotationService:
"""
mock_external_service_dependencies["current_user"].id = account_id
mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
# Configure current_account_with_tenant to return (user, tenant_id)
mock_external_service_dependencies["current_account_with_tenant"].return_value = (
mock_external_service_dependencies["current_user"],
tenant_id,
)
def _create_test_conversation(self, app, account, fake):
"""

View File

@ -794,16 +794,12 @@ class TestWorkflowAppService:
new_email = "changed@example.com"
account.email = new_email
db_session_with_containers.commit()
assert account.email == new_email
# Results for new email, is expected to be the same as the original email
result_with_new_email = service.get_paginate_workflow_app_logs(
session=db_session_with_containers,
app_model=app,
created_by_account=new_email,
page=1,
limit=20
session=db_session_with_containers, app_model=app, created_by_account=new_email, page=1, limit=20
)
assert result_with_new_email["total"] == 3
assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_with_new_email["data"])
@ -1087,15 +1083,15 @@ class TestWorkflowAppService:
assert len(result_no_session["data"]) == 0
# Test with account email that doesn't exist
result_no_account = service.get_paginate_workflow_app_logs(
session=db_session_with_containers,
app_model=app,
created_by_account="nonexistent@example.com",
page=1,
limit=20,
)
assert result_no_account["total"] == 0
assert len(result_no_account["data"]) == 0
with pytest.raises(ValueError) as exc_info:
service.get_paginate_workflow_app_logs(
session=db_session_with_containers,
app_model=app,
created_by_account="nonexistent@example.com",
page=1,
limit=20,
)
assert "Account not found" in str(exc_info.value)
def test_get_paginate_workflow_app_logs_with_complex_query_combinations(
self, db_session_with_containers, mock_external_service_dependencies

View File

@ -20,12 +20,21 @@ class TestMCPToolManageService:
patch("services.tools.mcp_tools_manage_service.ToolTransformService") as mock_tool_transform_service,
):
# Setup default mock returns
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_encrypter.encrypt_token.return_value = "encrypted_server_url"
mock_tool_transform_service.mcp_provider_to_user_provider.return_value = {
"id": "test_id",
"name": "test_name",
"type": ToolProviderType.MCP,
}
mock_tool_transform_service.mcp_provider_to_user_provider.return_value = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
type=ToolProviderType.MCP,
description=I18nObject(en_US="Test Description", zh_Hans="测试描述"),
icon={"type": "emoji", "content": "🤖"},
label=I18nObject(en_US="Test Label", zh_Hans="测试标签"),
labels=[],
tools=[],
)
yield {
"encrypter": mock_encrypter,
@ -104,9 +113,9 @@ class TestMCPToolManageService:
mcp_provider = MCPToolProvider(
tenant_id=tenant_id,
name=fake.company(),
server_identifier=fake.uuid4(),
server_identifier=str(fake.uuid4()),
server_url="encrypted_server_url",
server_url_hash=fake.sha256(),
server_url_hash=str(fake.sha256()),
user_id=user_id,
authed=False,
tools="[]",
@ -144,7 +153,10 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
result = MCPToolManageService.get_mcp_provider_by_provider_id(mcp_provider.id, tenant.id)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id)
# Assert: Verify the expected outcomes
assert result is not None
@ -154,8 +166,6 @@ class TestMCPToolManageService:
assert result.user_id == account.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
assert result.id is not None
assert result.server_identifier == mcp_provider.server_identifier
@ -177,11 +187,14 @@ class TestMCPToolManageService:
db_session_with_containers, mock_external_service_dependencies
)
non_existent_id = fake.uuid4()
non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_provider_id(non_existent_id, tenant.id)
service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id)
def test_get_mcp_provider_by_provider_id_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
@ -210,8 +223,11 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_provider_id(mcp_provider1.id, tenant2.id)
service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id)
def test_get_mcp_provider_by_server_identifier_success(
self, db_session_with_containers, mock_external_service_dependencies
@ -235,7 +251,10 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
result = MCPToolManageService.get_mcp_provider_by_server_identifier(mcp_provider.server_identifier, tenant.id)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id)
# Assert: Verify the expected outcomes
assert result is not None
@ -245,8 +264,6 @@ class TestMCPToolManageService:
assert result.user_id == account.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
assert result.id is not None
assert result.name == mcp_provider.name
@ -268,11 +285,14 @@ class TestMCPToolManageService:
db_session_with_containers, mock_external_service_dependencies
)
non_existent_identifier = fake.uuid4()
non_existent_identifier = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_server_identifier(non_existent_identifier, tenant.id)
service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id)
def test_get_mcp_provider_by_server_identifier_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
@ -301,8 +321,11 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.get_mcp_provider_by_server_identifier(mcp_provider1.server_identifier, tenant2.id)
service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id)
def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -322,15 +345,30 @@ class TestMCPToolManageService:
)
# Setup mocks for provider creation
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_external_service_dependencies["encrypter"].encrypt_token.return_value = "encrypted_server_url"
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.return_value = {
"id": "new_provider_id",
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
}
mock_external_service_dependencies[
"tool_transform_service"
].mcp_provider_to_user_provider.return_value = ToolProviderApiEntity(
id="new_provider_id",
author=account.name,
name="Test MCP Provider",
type=ToolProviderType.MCP,
description=I18nObject(en_US="Test MCP Provider Description", zh_Hans="测试MCP提供者描述"),
icon={"type": "emoji", "content": "🤖"},
label=I18nObject(en_US="Test MCP Provider", zh_Hans="测试MCP提供者"),
labels=[],
tools=[],
)
# Act: Execute the method under test
result = MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider",
server_url="https://example.com/mcp",
@ -339,14 +377,16 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_123",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Assert: Verify the expected outcomes
assert result is not None
assert result["name"] == "Test MCP Provider"
assert result["type"] == ToolProviderType.MCP
assert result.name == "Test MCP Provider"
assert result.type == ToolProviderType.MCP
# Verify database state
from extensions.ext_database import db
@ -386,7 +426,11 @@ class TestMCPToolManageService:
)
# Create first provider
MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider",
server_url="https://example1.com/mcp",
@ -395,13 +439,15 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_1",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate name
with pytest.raises(ValueError, match="MCP tool Test MCP Provider already exists"):
MCPToolManageService.create_mcp_provider(
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider", # Duplicate name
server_url="https://example2.com/mcp",
@ -410,8 +456,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_2",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_create_mcp_provider_duplicate_server_url(
@ -432,7 +480,11 @@ class TestMCPToolManageService:
)
# Create first provider
MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 1",
server_url="https://example.com/mcp",
@ -441,13 +493,15 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_1",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate server URL
with pytest.raises(ValueError, match="MCP tool https://example.com/mcp already exists"):
MCPToolManageService.create_mcp_provider(
with pytest.raises(ValueError, match="MCP tool with this server URL already exists"):
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 2",
server_url="https://example.com/mcp", # Duplicate URL
@ -456,8 +510,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_2",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_create_mcp_provider_duplicate_server_identifier(
@ -478,7 +534,11 @@ class TestMCPToolManageService:
)
# Create first provider
MCPToolManageService.create_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 1",
server_url="https://example1.com/mcp",
@ -487,13 +547,15 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_123",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate server identifier
with pytest.raises(ValueError, match="MCP tool test_identifier_123 already exists"):
MCPToolManageService.create_mcp_provider(
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 2",
server_url="https://example2.com/mcp",
@ -502,8 +564,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_123", # Duplicate identifier
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies):
@ -543,23 +607,59 @@ class TestMCPToolManageService:
db.session.commit()
# Setup mock for transformation service
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [
{"id": provider1.id, "name": provider1.name, "type": ToolProviderType.MCP},
{"id": provider2.id, "name": provider2.name, "type": ToolProviderType.MCP},
{"id": provider3.id, "name": provider3.name, "type": ToolProviderType.MCP},
ToolProviderApiEntity(
id=provider1.id,
author=account.name,
name=provider1.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Alpha Provider Description", zh_Hans="Alpha提供者描述"),
icon={"type": "emoji", "content": "🅰️"},
label=I18nObject(en_US=provider1.name, zh_Hans=provider1.name),
labels=[],
tools=[],
),
ToolProviderApiEntity(
id=provider2.id,
author=account.name,
name=provider2.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Beta Provider Description", zh_Hans="Beta提供者描述"),
icon={"type": "emoji", "content": "🅱️"},
label=I18nObject(en_US=provider2.name, zh_Hans=provider2.name),
labels=[],
tools=[],
),
ToolProviderApiEntity(
id=provider3.id,
author=account.name,
name=provider3.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Gamma Provider Description", zh_Hans="Gamma提供者描述"),
icon={"type": "emoji", "content": "Γ"},
label=I18nObject(en_US=provider3.name, zh_Hans=provider3.name),
labels=[],
tools=[],
),
]
# Act: Execute the method under test
result = MCPToolManageService.retrieve_mcp_tools(tenant.id, for_list=True)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.list_providers(tenant_id=tenant.id, for_list=True)
# Assert: Verify the expected outcomes
assert result is not None
assert len(result) == 3
# Verify correct ordering by name
assert result[0]["name"] == "Alpha Provider"
assert result[1]["name"] == "Beta Provider"
assert result[2]["name"] == "Gamma Provider"
assert result[0].name == "Alpha Provider"
assert result[1].name == "Beta Provider"
assert result[2].name == "Gamma Provider"
# Verify mock interactions
assert (
@ -584,7 +684,10 @@ class TestMCPToolManageService:
# No MCP providers created for this tenant
# Act: Execute the method under test
result = MCPToolManageService.retrieve_mcp_tools(tenant.id, for_list=False)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.list_providers(tenant_id=tenant.id, for_list=False)
# Assert: Verify the expected outcomes
assert result is not None
@ -624,20 +727,46 @@ class TestMCPToolManageService:
)
# Setup mock for transformation service
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.side_effect = [
{"id": provider1.id, "name": provider1.name, "type": ToolProviderType.MCP},
{"id": provider2.id, "name": provider2.name, "type": ToolProviderType.MCP},
ToolProviderApiEntity(
id=provider1.id,
author=account1.name,
name=provider1.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Provider 1 Description", zh_Hans="提供者1描述"),
icon={"type": "emoji", "content": "1"},
label=I18nObject(en_US=provider1.name, zh_Hans=provider1.name),
labels=[],
tools=[],
),
ToolProviderApiEntity(
id=provider2.id,
author=account2.name,
name=provider2.name,
type=ToolProviderType.MCP,
description=I18nObject(en_US="Provider 2 Description", zh_Hans="提供者2描述"),
icon={"type": "emoji", "content": "2"},
label=I18nObject(en_US=provider2.name, zh_Hans=provider2.name),
labels=[],
tools=[],
),
]
# Act: Execute the method under test for both tenants
result1 = MCPToolManageService.retrieve_mcp_tools(tenant1.id, for_list=True)
result2 = MCPToolManageService.retrieve_mcp_tools(tenant2.id, for_list=True)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result1 = service.list_providers(tenant_id=tenant1.id, for_list=True)
result2 = service.list_providers(tenant_id=tenant2.id, for_list=True)
# Assert: Verify tenant isolation
assert len(result1) == 1
assert len(result2) == 1
assert result1[0]["id"] == provider1.id
assert result2[0]["id"] == provider2.id
assert result1[0].id == provider1.id
assert result2[0].id == provider2.id
def test_list_mcp_tool_from_remote_server_success(
self, db_session_with_containers, mock_external_service_dependencies
@ -661,17 +790,20 @@ class TestMCPToolManageService:
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
mcp_provider.server_url = "encrypted_server_url"
mcp_provider.authed = False
# Use a valid base64 encoded string to avoid decryption errors
import base64
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
mcp_provider.authed = True # Provider must be authenticated to list tools
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
# Mock the decrypted_server_url property to avoid encryption issues
with patch("models.tools.encrypter") as mock_encrypter:
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.return_value = "https://example.com/mcp"
# Mock MCPClient and its context manager
mock_tools = [
@ -683,13 +815,16 @@ class TestMCPToolManageService:
)(),
]
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
# Setup mock client
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
result = MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes
assert result is not None
@ -705,16 +840,8 @@ class TestMCPToolManageService:
assert mcp_provider.updated_at is not None
# Verify mock interactions
mock_mcp_client.assert_called_once_with(
"https://example.com/mcp",
mcp_provider.id,
tenant.id,
authed=False,
for_list=True,
headers={},
timeout=30.0,
sse_read_timeout=300.0,
)
# MCPClientWithAuthRetry is called with different parameters
mock_mcp_client.assert_called_once()
def test_list_mcp_tool_from_remote_server_auth_error(
self, db_session_with_containers, mock_external_service_dependencies
@ -737,7 +864,10 @@ class TestMCPToolManageService:
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
mcp_provider.server_url = "encrypted_server_url"
# Use a valid base64 encoded string to avoid decryption errors
import base64
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
mcp_provider.authed = False
mcp_provider.tools = "[]"
@ -745,20 +875,23 @@ class TestMCPToolManageService:
db.session.commit()
# Mock the decrypted_server_url property to avoid encryption issues
with patch("models.tools.encrypter") as mock_encrypter:
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.return_value = "https://example.com/mcp"
# Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPAuthError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Please auth the tool first"):
MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed
db.session.refresh(mcp_provider)
@ -786,32 +919,38 @@ class TestMCPToolManageService:
mcp_provider = self._create_test_mcp_provider(
db_session_with_containers, mock_external_service_dependencies, tenant.id, account.id
)
mcp_provider.server_url = "encrypted_server_url"
mcp_provider.authed = False
# Use a valid base64 encoded string to avoid decryption errors
import base64
mcp_provider.server_url = base64.b64encode(b"encrypted_server_url").decode()
mcp_provider.authed = True # Provider must be authenticated to test connection errors
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
# Mock the decrypted_server_url property to avoid encryption issues
with patch("models.tools.encrypter") as mock_encrypter:
mock_encrypter.decrypt_token.return_value = "https://example.com/mcp"
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.return_value = "https://example.com/mcp"
# Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"):
MCPToolManageService.list_mcp_tool_from_remote_server(tenant.id, mcp_provider.id)
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed
db.session.refresh(mcp_provider)
assert mcp_provider.authed is False
assert mcp_provider.authed is True # Provider remains authenticated
assert mcp_provider.tools == "[]"
def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
@ -840,7 +979,8 @@ class TestMCPToolManageService:
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
# Act: Execute the method under test
MCPToolManageService.delete_mcp_tool(tenant.id, mcp_provider.id)
service = MCPToolManageService(db.session())
service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes
# Provider should be deleted from database
@ -862,11 +1002,14 @@ class TestMCPToolManageService:
db_session_with_containers, mock_external_service_dependencies
)
non_existent_id = fake.uuid4()
non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.delete_mcp_tool(tenant.id, non_existent_id)
service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id)
def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -893,8 +1036,11 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool not found"):
MCPToolManageService.delete_mcp_tool(tenant2.id, mcp_provider1.id)
service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id)
# Verify provider still exists in tenant1
from extensions.ext_database import db
@ -929,7 +1075,10 @@ class TestMCPToolManageService:
db.session.commit()
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
service = MCPToolManageService(db.session())
service.update_provider(
tenant_id=tenant.id,
provider_id=mcp_provider.id,
name="Updated MCP Provider",
@ -938,8 +1087,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="updated_identifier_123",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
# Assert: Verify the expected outcomes
@ -953,7 +1104,7 @@ class TestMCPToolManageService:
# Verify icon was updated
import json
icon_data = json.loads(mcp_provider.icon)
icon_data = json.loads(mcp_provider.icon or "{}")
assert icon_data["content"] == "🚀"
assert icon_data["background"] == "#4ECDC4"
@ -985,7 +1136,7 @@ class TestMCPToolManageService:
db.session.commit()
# Mock the reconnection method
with patch.object(MCPToolManageService, "_re_connect_mcp_provider") as mock_reconnect:
with patch.object(MCPToolManageService, "_reconnect_provider") as mock_reconnect:
mock_reconnect.return_value = {
"authed": True,
"tools": '[{"name": "test_tool"}]',
@ -993,7 +1144,11 @@ class TestMCPToolManageService:
}
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider(
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.update_provider(
tenant_id=tenant.id,
provider_id=mcp_provider.id,
name="Updated MCP Provider",
@ -1002,8 +1157,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="updated_identifier_123",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
# Assert: Verify the expected outcomes
@ -1015,7 +1172,10 @@ class TestMCPToolManageService:
assert mcp_provider.updated_at is not None
# Verify reconnection was called
mock_reconnect.assert_called_once_with("https://new-example.com/mcp", mcp_provider.id, tenant.id)
mock_reconnect.assert_called_once_with(
server_url="https://new-example.com/mcp",
provider=mcp_provider,
)
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -1048,8 +1208,12 @@ class TestMCPToolManageService:
db.session.commit()
# Act & Assert: Verify proper error handling for duplicate name
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="MCP tool First Provider already exists"):
MCPToolManageService.update_mcp_provider(
service.update_provider(
tenant_id=tenant.id,
provider_id=provider2.id,
name="First Provider", # Duplicate name
@ -1058,8 +1222,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="unique_identifier",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_update_mcp_provider_credentials_success(
@ -1094,19 +1260,22 @@ class TestMCPToolManageService:
# Mock the provider controller and encryption
with (
patch("services.tools.mcp_tools_manage_service.MCPToolProviderController") as mock_controller,
patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter") as mock_encrypter,
patch("core.tools.mcp_tool.provider.MCPToolProviderController") as mock_controller,
patch("core.tools.utils.encryption.ProviderConfigEncrypter") as mock_encrypter,
):
# Setup mocks
mock_controller_instance = mock_controller._from_db.return_value
mock_controller_instance = mock_controller.from_db.return_value
mock_controller_instance.get_credentials_schema.return_value = []
mock_encrypter_instance = mock_encrypter.return_value
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.update_provider_credentials(
provider=mcp_provider, credentials={"new_key": "new_value"}, authed=True
)
# Assert: Verify the expected outcomes
@ -1117,7 +1286,7 @@ class TestMCPToolManageService:
# Verify credentials were encrypted and merged
import json
credentials = json.loads(mcp_provider.encrypted_credentials)
credentials = json.loads(mcp_provider.encrypted_credentials or "{}")
assert "existing_key" in credentials
assert "new_key" in credentials
@ -1152,19 +1321,22 @@ class TestMCPToolManageService:
# Mock the provider controller and encryption
with (
patch("services.tools.mcp_tools_manage_service.MCPToolProviderController") as mock_controller,
patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter") as mock_encrypter,
patch("core.tools.mcp_tool.provider.MCPToolProviderController") as mock_controller,
patch("core.tools.utils.encryption.ProviderConfigEncrypter") as mock_encrypter,
):
# Setup mocks
mock_controller_instance = mock_controller._from_db.return_value
mock_controller_instance = mock_controller.from_db.return_value
mock_controller_instance.get_credentials_schema.return_value = []
mock_encrypter_instance = mock_encrypter.return_value
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service.update_provider_credentials(
provider=mcp_provider, credentials={"new_key": "new_value"}, authed=False
)
# Assert: Verify the expected outcomes
@ -1199,14 +1371,18 @@ class TestMCPToolManageService:
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
]
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
# Setup mock client
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
result = MCPToolManageService._re_connect_mcp_provider(
"https://example.com/mcp", mcp_provider.id, tenant.id
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
server_url="https://example.com/mcp",
provider=mcp_provider,
)
# Assert: Verify the expected outcomes
@ -1224,16 +1400,8 @@ class TestMCPToolManageService:
assert tools_data[1]["name"] == "test_tool_2"
# Verify mock interactions
mock_mcp_client.assert_called_once_with(
"https://example.com/mcp",
mcp_provider.id,
tenant.id,
authed=False,
for_list=True,
headers={},
timeout=30.0,
sse_read_timeout=300.0,
)
provider_entity = mcp_provider.to_entity()
mock_mcp_client.assert_called_once()
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -1256,15 +1424,19 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPAuthError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act: Execute the method under test
result = MCPToolManageService._re_connect_mcp_provider(
"https://example.com/mcp", mcp_provider.id, tenant.id
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
server_url="https://example.com/mcp",
provider=mcp_provider,
)
# Assert: Verify the expected outcomes
@ -1295,12 +1467,18 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClient") as mock_mcp_client:
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
from core.mcp.error import MCPError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
MCPToolManageService._re_connect_mcp_provider("https://example.com/mcp", mcp_provider.id, tenant.id)
service._reconnect_provider(
server_url="https://example.com/mcp",
provider=mcp_provider,
)

View File

@ -0,0 +1,401 @@
"""
TestContainers-based integration tests for mail_owner_transfer_task.
This module provides comprehensive integration tests for the mail owner transfer tasks
using TestContainers to ensure real email service integration and proper functionality
testing with actual database and service dependencies.
"""
import logging
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from tasks.mail_owner_transfer_task import (
send_new_owner_transfer_notify_email_task,
send_old_owner_transfer_notify_email_task,
send_owner_transfer_confirm_task,
)
logger = logging.getLogger(__name__)
class TestMailOwnerTransferTask:
"""Integration tests for mail owner transfer tasks using testcontainers."""
@pytest.fixture
def mock_mail_dependencies(self):
"""Mock setup for mail service dependencies."""
with (
patch("tasks.mail_owner_transfer_task.mail") as mock_mail,
patch("tasks.mail_owner_transfer_task.get_email_i18n_service") as mock_get_email_service,
):
# Setup mock mail service
mock_mail.is_inited.return_value = True
# Setup mock email service
mock_email_service = MagicMock()
mock_get_email_service.return_value = mock_email_service
yield {
"mail": mock_mail,
"email_service": mock_email_service,
"get_email_service": mock_get_email_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers):
"""
Helper method to create test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return account, tenant
def test_send_owner_transfer_confirm_task_success(self, db_session_with_containers, mock_mail_dependencies):
"""
Test successful owner transfer confirmation email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls with right parameters
- Email template context is properly constructed
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
test_language = "en-US"
test_email = account.email
test_code = "123456"
test_workspace = tenant.name
# Act: Execute the task
send_owner_transfer_confirm_task(
language=test_language,
to=test_email,
code=test_code,
workspace=test_workspace,
)
# Assert: Verify the expected outcomes
mock_mail_dependencies["mail"].is_inited.assert_called_once()
mock_mail_dependencies["get_email_service"].assert_called_once()
# Verify email service was called with correct parameters
mock_mail_dependencies["email_service"].send_email.assert_called_once()
call_args = mock_mail_dependencies["email_service"].send_email.call_args
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_CONFIRM
assert call_args[1]["language_code"] == test_language
assert call_args[1]["to"] == test_email
assert call_args[1]["template_context"]["to"] == test_email
assert call_args[1]["template_context"]["code"] == test_code
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
def test_send_owner_transfer_confirm_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test owner transfer confirmation email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Set mail service as not initialized
mock_mail_dependencies["mail"].is_inited.return_value = False
test_language = "en-US"
test_email = "test@example.com"
test_code = "123456"
test_workspace = "Test Workspace"
# Act: Execute the task
send_owner_transfer_confirm_task(
language=test_language,
to=test_email,
code=test_code,
workspace=test_workspace,
)
# Assert: Verify no email service calls were made
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_owner_transfer_confirm_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test exception handling in owner transfer confirmation email.
This test verifies:
- Exceptions are properly caught and logged
- No exceptions are propagated to caller
- Email service calls are attempted
- Error logging works correctly
"""
# Arrange: Setup email service to raise exception
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
test_language = "en-US"
test_email = "test@example.com"
test_code = "123456"
test_workspace = "Test Workspace"
# Act & Assert: Verify no exception is raised
try:
send_owner_transfer_confirm_task(
language=test_language,
to=test_email,
code=test_code,
workspace=test_workspace,
)
except Exception as e:
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
# Verify email service was called despite the exception
mock_mail_dependencies["email_service"].send_email.assert_called_once()
def test_send_old_owner_transfer_notify_email_task_success(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test successful old owner transfer notification email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls with right parameters
- Email template context includes new owner email
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
test_language = "en-US"
test_email = account.email
test_workspace = tenant.name
test_new_owner_email = "newowner@example.com"
# Act: Execute the task
send_old_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
new_owner_email=test_new_owner_email,
)
# Assert: Verify the expected outcomes
mock_mail_dependencies["mail"].is_inited.assert_called_once()
mock_mail_dependencies["get_email_service"].assert_called_once()
# Verify email service was called with correct parameters
mock_mail_dependencies["email_service"].send_email.assert_called_once()
call_args = mock_mail_dependencies["email_service"].send_email.call_args
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_OLD_NOTIFY
assert call_args[1]["language_code"] == test_language
assert call_args[1]["to"] == test_email
assert call_args[1]["template_context"]["to"] == test_email
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
assert call_args[1]["template_context"]["NewOwnerEmail"] == test_new_owner_email
def test_send_old_owner_transfer_notify_email_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test old owner transfer notification email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Set mail service as not initialized
mock_mail_dependencies["mail"].is_inited.return_value = False
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
test_new_owner_email = "newowner@example.com"
# Act: Execute the task
send_old_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
new_owner_email=test_new_owner_email,
)
# Assert: Verify no email service calls were made
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_old_owner_transfer_notify_email_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test exception handling in old owner transfer notification email.
This test verifies:
- Exceptions are properly caught and logged
- No exceptions are propagated to caller
- Email service calls are attempted
- Error logging works correctly
"""
# Arrange: Setup email service to raise exception
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
test_new_owner_email = "newowner@example.com"
# Act & Assert: Verify no exception is raised
try:
send_old_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
new_owner_email=test_new_owner_email,
)
except Exception as e:
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
# Verify email service was called despite the exception
mock_mail_dependencies["email_service"].send_email.assert_called_once()
def test_send_new_owner_transfer_notify_email_task_success(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test successful new owner transfer notification email sending.
This test verifies:
- Proper email service initialization check
- Correct email service method calls with right parameters
- Email template context is properly constructed
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
test_language = "en-US"
test_email = account.email
test_workspace = tenant.name
# Act: Execute the task
send_new_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
)
# Assert: Verify the expected outcomes
mock_mail_dependencies["mail"].is_inited.assert_called_once()
mock_mail_dependencies["get_email_service"].assert_called_once()
# Verify email service was called with correct parameters
mock_mail_dependencies["email_service"].send_email.assert_called_once()
call_args = mock_mail_dependencies["email_service"].send_email.call_args
assert call_args[1]["email_type"] == EmailType.OWNER_TRANSFER_NEW_NOTIFY
assert call_args[1]["language_code"] == test_language
assert call_args[1]["to"] == test_email
assert call_args[1]["template_context"]["to"] == test_email
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
def test_send_new_owner_transfer_notify_email_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test new owner transfer notification email when mail service is not initialized.
This test verifies:
- Early return when mail service is not initialized
- No email service calls are made
- No exceptions are raised
"""
# Arrange: Set mail service as not initialized
mock_mail_dependencies["mail"].is_inited.return_value = False
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
# Act: Execute the task
send_new_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
)
# Assert: Verify no email service calls were made
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_new_owner_transfer_notify_email_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
):
"""
Test exception handling in new owner transfer notification email.
This test verifies:
- Exceptions are properly caught and logged
- No exceptions are propagated to caller
- Email service calls are attempted
- Error logging works correctly
"""
# Arrange: Setup email service to raise exception
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
test_language = "en-US"
test_email = "test@example.com"
test_workspace = "Test Workspace"
# Act & Assert: Verify no exception is raised
try:
send_new_owner_transfer_notify_email_task(
language=test_language,
to=test_email,
workspace=test_workspace,
)
except Exception as e:
pytest.fail(f"Task should not raise exceptions, but raised: {e}")
# Verify email service was called despite the exception
mock_mail_dependencies["email_service"].send_email.assert_called_once()

View File

@ -60,7 +60,7 @@ class TestAccountInitialization:
return "success"
# Act
with patch("controllers.console.wraps._current_account", return_value=mock_user):
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
result = protected_view()
# Assert
@ -77,7 +77,7 @@ class TestAccountInitialization:
return "success"
# Act & Assert
with patch("controllers.console.wraps._current_account", return_value=mock_user):
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
with pytest.raises(AccountNotInitializedError):
protected_view()
@ -163,7 +163,9 @@ class TestBillingResourceLimits:
return "member_added"
# Act
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = add_member()
@ -185,7 +187,10 @@ class TestBillingResourceLimits:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
add_member()
@ -207,7 +212,10 @@ class TestBillingResourceLimits:
# Test 1: Should reject when source is datasets
with app.test_request_context("/?source=datasets"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
upload_document()
@ -215,7 +223,10 @@ class TestBillingResourceLimits:
# Test 2: Should allow when source is not datasets
with app.test_request_context("/?source=other"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = upload_document()
assert result == "document_uploaded"
@ -239,7 +250,9 @@ class TestRateLimiting:
return "knowledge_success"
# Act
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
@ -271,7 +284,10 @@ class TestRateLimiting:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):

View File

@ -0,0 +1,720 @@
"""Unit tests for MCP OAuth authentication flow."""
from unittest.mock import Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth.auth_flow import (
OAUTH_STATE_EXPIRY_SECONDS,
OAUTH_STATE_REDIS_KEY_PREFIX,
OAuthCallbackState,
_create_secure_redis_state,
_retrieve_redis_state,
auth,
check_support_resource_discovery,
discover_oauth_metadata,
exchange_authorization,
generate_pkce_challenge,
handle_callback,
refresh_authorization,
register_client,
start_authorization,
)
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
)
class TestPKCEGeneration:
"""Test PKCE challenge generation."""
def test_generate_pkce_challenge(self):
"""Test PKCE challenge and verifier generation."""
code_verifier, code_challenge = generate_pkce_challenge()
# Verify format - should be URL-safe base64 without padding
assert "=" not in code_verifier
assert "+" not in code_verifier
assert "/" not in code_verifier
assert "=" not in code_challenge
assert "+" not in code_challenge
assert "/" not in code_challenge
# Verify length
assert len(code_verifier) > 40 # Should be around 54 characters
assert len(code_challenge) > 40 # Should be around 43 characters
def test_generate_pkce_challenge_uniqueness(self):
"""Test that PKCE generation produces unique values."""
results = set()
for _ in range(10):
code_verifier, code_challenge = generate_pkce_challenge()
results.add((code_verifier, code_challenge))
# All should be unique
assert len(results) == 10
class TestRedisStateManagement:
"""Test Redis state management functions."""
@patch("core.mcp.auth.auth_flow.redis_client")
def test_create_secure_redis_state(self, mock_redis):
"""Test creating secure Redis state."""
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
state_key = _create_secure_redis_state(state_data)
# Verify state key format
assert len(state_key) > 20 # Should be a secure random token
# Verify Redis call
mock_redis.setex.assert_called_once()
call_args = mock_redis.setex.call_args
assert call_args[0][0].startswith(OAUTH_STATE_REDIS_KEY_PREFIX)
assert call_args[0][1] == OAUTH_STATE_EXPIRY_SECONDS
assert state_data.model_dump_json() in call_args[0][2]
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_success(self, mock_redis):
"""Test retrieving state from Redis."""
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_redis.get.return_value = state_data.model_dump_json()
result = _retrieve_redis_state("test-state-key")
# Verify result
assert result.provider_id == "test-provider"
assert result.tenant_id == "test-tenant"
assert result.server_url == "https://example.com"
# Verify Redis calls
mock_redis.get.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
mock_redis.delete.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_not_found(self, mock_redis):
"""Test retrieving non-existent state from Redis."""
mock_redis.get.return_value = None
with pytest.raises(ValueError) as exc_info:
_retrieve_redis_state("nonexistent-key")
assert "State parameter has expired or does not exist" in str(exc_info.value)
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_invalid_json(self, mock_redis):
"""Test retrieving invalid JSON state from Redis."""
mock_redis.get.return_value = '{"invalid": json}'
with pytest.raises(ValueError) as exc_info:
_retrieve_redis_state("test-key")
assert "Invalid state parameter" in str(exc_info.value)
# State should still be deleted
mock_redis.delete.assert_called_once()
class TestOAuthDiscovery:
"""Test OAuth discovery functions."""
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_success(self, mock_get):
"""Test successful resource discovery check."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com/endpoint")
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_not_supported(self, mock_get):
"""Test resource discovery not supported."""
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com")
assert supported is False
assert auth_url == ""
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_with_query_fragment(self, mock_get):
"""Test resource discovery with query and fragment."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com/path?query=1#fragment")
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
"""Test OAuth metadata discovery with resource discovery support."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (True, "https://auth.example.com")
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert metadata.token_endpoint == "https://auth.example.com/token"
mock_get.assert_called_once_with(
"https://auth.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
"""Test OAuth metadata discovery without resource discovery."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://api.example.com/oauth/authorize",
"token_endpoint": "https://api.example.com/oauth/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_not_found(self, mock_get):
"""Test OAuth metadata discovery when not found."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
assert metadata is None
class TestAuthorizationFlow:
"""Test authorization flow functions."""
@patch("core.mcp.auth.auth_flow._create_secure_redis_state")
def test_start_authorization_with_metadata(self, mock_create_state):
"""Test starting authorization with metadata."""
mock_create_state.return_value = "secure-state-key"
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
code_challenge_methods_supported=["S256"],
)
client_info = OAuthClientInformation(client_id="test-client-id")
auth_url, code_verifier = start_authorization(
"https://api.example.com",
metadata,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
# Verify URL format
assert auth_url.startswith("https://auth.example.com/authorize?")
assert "response_type=code" in auth_url
assert "client_id=test-client-id" in auth_url
assert "code_challenge=" in auth_url
assert "code_challenge_method=S256" in auth_url
assert "redirect_uri=https%3A%2F%2Fredirect.example.com" in auth_url
assert "state=secure-state-key" in auth_url
# Verify code verifier
assert len(code_verifier) > 40
# Verify state was stored
mock_create_state.assert_called_once()
state_data = mock_create_state.call_args[0][0]
assert state_data.provider_id == "provider-id"
assert state_data.tenant_id == "tenant-id"
assert state_data.code_verifier == code_verifier
def test_start_authorization_without_metadata(self):
"""Test starting authorization without metadata."""
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create_state:
mock_create_state.return_value = "secure-state-key"
client_info = OAuthClientInformation(client_id="test-client-id")
auth_url, code_verifier = start_authorization(
"https://api.example.com",
None,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
# Should use default authorization endpoint
assert auth_url.startswith("https://api.example.com/authorize?")
def test_start_authorization_invalid_metadata(self):
"""Test starting authorization with invalid metadata."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["token"], # No "code" support
code_challenge_methods_supported=["plain"], # No "S256" support
)
client_info = OAuthClientInformation(client_id="test-client-id")
with pytest.raises(ValueError) as exc_info:
start_authorization(
"https://api.example.com",
metadata,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
assert "does not support response type code" in str(exc_info.value)
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization_success(self, mock_post):
"""Test successful authorization code exchange."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"access_token": "new-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
client_info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
tokens = exchange_authorization(
"https://api.example.com",
metadata,
client_info,
"auth-code-123",
"code-verifier-xyz",
"https://redirect.example.com",
)
assert tokens.access_token == "new-access-token"
assert tokens.token_type == "Bearer"
assert tokens.expires_in == 3600
assert tokens.refresh_token == "new-refresh-token"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "authorization_code",
"client_id": "test-client-id",
"client_secret": "test-secret",
"code": "auth-code-123",
"code_verifier": "code-verifier-xyz",
"redirect_uri": "https://redirect.example.com",
},
)
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization_failure(self, mock_post):
"""Test failed authorization code exchange."""
mock_response = Mock()
mock_response.is_success = False
mock_response.status_code = 400
mock_post.return_value = mock_response
client_info = OAuthClientInformation(client_id="test-client-id")
with pytest.raises(ValueError) as exc_info:
exchange_authorization(
"https://api.example.com",
None,
client_info,
"invalid-code",
"code-verifier",
"https://redirect.example.com",
)
assert "Token exchange failed: HTTP 400" in str(exc_info.value)
@patch("core.helper.ssrf_proxy.post")
def test_refresh_authorization_success(self, mock_post):
"""Test successful token refresh."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"access_token": "refreshed-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["refresh_token"],
)
client_info = OAuthClientInformation(client_id="test-client-id")
tokens = refresh_authorization("https://api.example.com", metadata, client_info, "old-refresh-token")
assert tokens.access_token == "refreshed-access-token"
assert tokens.refresh_token == "new-refresh-token"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "refresh_token",
"client_id": "test-client-id",
"refresh_token": "old-refresh-token",
},
)
@patch("core.helper.ssrf_proxy.post")
def test_register_client_success(self, mock_post):
"""Test successful client registration."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"client_id": "new-client-id",
"client_secret": "new-client-secret",
"client_name": "Dify",
"redirect_uris": ["https://redirect.example.com"],
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
grant_types=["authorization_code"],
response_types=["code"],
)
client_info = register_client("https://api.example.com", metadata, client_metadata)
assert isinstance(client_info, OAuthClientInformationFull)
assert client_info.client_id == "new-client-id"
assert client_info.client_secret == "new-client-secret"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/register",
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
def test_register_client_no_endpoint(self):
"""Test client registration when no endpoint available."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint=None,
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://redirect.example.com"])
with pytest.raises(ValueError) as exc_info:
register_client("https://api.example.com", metadata, client_metadata)
assert "does not support dynamic client registration" in str(exc_info.value)
class TestCallbackHandling:
"""Test OAuth callback handling."""
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_handle_callback_success(self, mock_exchange, mock_retrieve_state):
"""Test successful callback handling."""
# Setup state
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://api.example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_retrieve_state.return_value = state_data
# Setup token exchange
tokens = OAuthTokens(
access_token="new-token",
token_type="Bearer",
expires_in=3600,
)
mock_exchange.return_value = tokens
# Setup service
mock_service = Mock()
result = handle_callback("state-key", "auth-code", mock_service)
assert result == state_data
# Verify calls
mock_retrieve_state.assert_called_once_with("state-key")
mock_exchange.assert_called_once_with(
"https://api.example.com",
None,
state_data.client_information,
"auth-code",
"test-verifier",
"https://redirect.example.com",
)
mock_service.save_oauth_data.assert_called_once_with(
"test-provider", "test-tenant", tokens.model_dump(), "tokens"
)
class TestAuthOrchestration:
"""Test the main auth orchestration function."""
@pytest.fixture
def mock_provider(self):
"""Create a mock provider entity."""
provider = Mock(spec=MCPProviderEntity)
provider.id = "provider-id"
provider.tenant_id = "tenant-id"
provider.decrypt_server_url.return_value = "https://api.example.com"
provider.client_metadata = OAuthClientMetadata(
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
)
provider.redirect_url = "https://redirect.example.com"
provider.retrieve_client_information.return_value = None
provider.retrieve_tokens.return_value = None
return provider
@pytest.fixture
def mock_service(self):
"""Create a mock MCP service."""
return Mock()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow.register_client")
@patch("core.mcp.auth.auth_flow.start_authorization")
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
"""Test auth flow for new client registration."""
# Setup
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_register.return_value = OAuthClientInformationFull(
client_id="new-client-id",
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
)
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
result = auth(mock_provider, mock_service)
assert result == {"authorization_url": "https://auth.example.com/authorize?..."}
# Verify calls
mock_register.assert_called_once()
mock_service.save_oauth_data.assert_any_call(
"provider-id",
"tenant-id",
{"client_information": mock_register.return_value.model_dump()},
"client_info",
)
mock_service.save_oauth_data.assert_any_call(
"provider-id", "tenant-id", {"code_verifier": "code-verifier"}, "code_verifier"
)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
"""Test auth flow for exchanging authorization code."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
# Setup existing client
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
# Setup state retrieval
state_data = OAuthCallbackState(
provider_id="provider-id",
tenant_id="tenant-id",
server_url="https://api.example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="existing-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_retrieve_state.return_value = state_data
# Setup token exchange
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
mock_exchange.return_value = tokens
result = auth(mock_provider, mock_service, authorization_code="auth-code", state_param="state-key")
assert result == {"result": "success"}
# Verify token save
mock_service.save_oauth_data.assert_called_with("provider-id", "tenant-id", tokens.model_dump(), "tokens")
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
"""Test auth flow fails when exchanging code without state."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, mock_service, authorization_code="auth-code")
assert "State parameter is required" in str(exc_info.value)
@patch("core.mcp.auth.auth_flow.refresh_authorization")
def test_auth_refresh_token(self, mock_refresh, mock_provider, mock_service):
"""Test auth flow for refreshing tokens."""
# Setup existing client and tokens
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
mock_provider.retrieve_tokens.return_value = OAuthTokens(
access_token="old-token",
token_type="Bearer",
expires_in=0,
refresh_token="refresh-token",
)
# Setup refresh
new_tokens = OAuthTokens(
access_token="refreshed-token",
token_type="Bearer",
expires_in=3600,
refresh_token="new-refresh-token",
)
mock_refresh.return_value = new_tokens
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
result = auth(mock_provider, mock_service)
assert result == {"result": "success"}
# Verify refresh was called
mock_refresh.assert_called_once()
mock_service.save_oauth_data.assert_called_with(
"provider-id", "tenant-id", new_tokens.model_dump(), "tokens"
)
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
"""Test auth fails when no client info exists but code is provided."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
mock_provider.retrieve_client_information.return_value = None
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, mock_service, authorization_code="auth-code")
assert "Existing OAuth client information is required" in str(exc_info.value)

View File

@ -0,0 +1,420 @@
"""Unit tests for MCP auth client with retry logic."""
from types import TracebackType
from unittest.mock import Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, TextContent, Tool, ToolAnnotations
class TestMCPClientWithAuthRetry:
"""Test suite for MCPClientWithAuthRetry."""
@pytest.fixture
def mock_provider_entity(self):
"""Create a mock provider entity."""
provider = Mock(spec=MCPProviderEntity)
provider.id = "test-provider-id"
provider.tenant_id = "test-tenant-id"
provider.retrieve_tokens.return_value = Mock(
access_token="test-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
return provider
@pytest.fixture
def mock_mcp_service(self):
"""Create a mock MCP service."""
service = Mock()
service.get_provider_entity.return_value = Mock(
retrieve_tokens=lambda: Mock(
access_token="new-test-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
)
return service
@pytest.fixture
def auth_callback(self):
"""Create a mock auth callback."""
return Mock()
def test_init(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test client initialization."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
headers={"Authorization": "Bearer test"},
timeout=30.0,
sse_read_timeout=60.0,
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
authorization_code="test-auth-code",
by_server_id=True,
mcp_service=mock_mcp_service,
)
assert client.server_url == "http://test.example.com"
assert client.headers == {"Authorization": "Bearer test"}
assert client.timeout == 30.0
assert client.sse_read_timeout == 60.0
assert client.provider_entity == mock_provider_entity
assert client.auth_callback == auth_callback
assert client.authorization_code == "test-auth-code"
assert client.by_server_id is True
assert client.mcp_service == mock_mcp_service
assert client._has_retried is False
# In inheritance design, we don't have _client attribute
assert hasattr(client, "_session") # Inherited from MCPClient
def test_inheritance_structure(self):
"""Test that MCPClientWithAuthRetry properly inherits from MCPClient."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
headers={"Authorization": "Bearer test"},
)
# Verify inheritance
assert isinstance(client, MCPClient)
# Verify inherited attributes are accessible
assert hasattr(client, "server_url")
assert hasattr(client, "headers")
assert hasattr(client, "_session")
assert hasattr(client, "_exit_stack")
assert hasattr(client, "_initialized")
def test_handle_auth_error_no_retry_components(self):
"""Test auth error handling when retry components are missing."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert exc_info.value == error
def test_handle_auth_error_already_retried(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth error handling when already retried."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
client._has_retried = True
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert exc_info.value == error
auth_callback.assert_not_called()
def test_handle_auth_error_successful_refresh(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test successful auth refresh on error."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
authorization_code="test-code",
by_server_id=True,
mcp_service=mock_mcp_service,
)
# Configure mocks
new_provider = Mock(spec=MCPProviderEntity)
new_provider.id = "test-provider-id"
new_provider.tenant_id = "test-tenant-id"
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
error = MCPAuthError("Auth failed")
client._handle_auth_error(error)
# Verify auth flow
auth_callback.assert_called_once_with(mock_provider_entity, mock_mcp_service, "test-code")
mock_mcp_service.get_provider_entity.assert_called_once_with(
"test-provider-id", "test-tenant-id", by_server_id=True
)
assert client.headers["Authorization"] == "Bearer new-token"
assert client.authorization_code is None # Should be cleared after use
assert client._has_retried is True
def test_handle_auth_error_refresh_fails(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth refresh failure."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
auth_callback.side_effect = Exception("Auth callback failed")
error = MCPAuthError("Original auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert "Authentication retry failed" in str(exc_info.value)
def test_handle_auth_error_no_token_received(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test auth refresh when no token is received."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure mock to return no token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = None
mock_mcp_service.get_provider_entity.return_value = new_provider
error = MCPAuthError("Auth failed")
with pytest.raises(MCPAuthError) as exc_info:
client._handle_auth_error(error)
assert "no token received" in str(exc_info.value)
def test_execute_with_retry_success(self):
"""Test successful execution without retry."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_func = Mock(return_value="success")
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
assert result == "success"
mock_func.assert_called_once_with("arg1", kwarg1="value1")
assert client._has_retried is False
def test_execute_with_retry_auth_error_then_success(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test execution with auth error followed by successful retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
# Mock function that fails first, then succeeds
mock_func = Mock(side_effect=[MCPAuthError("Auth failed"), "success"])
# Mock the exit stack and session cleanup
with (
patch.object(client, "_exit_stack") as mock_exit_stack,
patch.object(client, "_session") as mock_session,
patch.object(client, "_initialize") as mock_initialize,
):
client._initialized = True
result = client._execute_with_retry(mock_func, "arg1", kwarg1="value1")
assert result == "success"
assert mock_func.call_count == 2
mock_func.assert_called_with("arg1", kwarg1="value1")
auth_callback.assert_called_once()
mock_exit_stack.close.assert_called_once()
mock_initialize.assert_called_once()
assert client._has_retried is False # Reset after completion
def test_execute_with_retry_non_auth_error(self):
"""Test execution with non-auth error (no retry)."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
mock_func = Mock(side_effect=ValueError("Some other error"))
with pytest.raises(ValueError) as exc_info:
client._execute_with_retry(mock_func)
assert str(exc_info.value) == "Some other error"
mock_func.assert_called_once()
def test_context_manager_enter(self):
"""Test context manager enter."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with patch.object(client, "_initialize") as mock_initialize:
result = client.__enter__()
assert result == client
assert client._initialized is True
mock_initialize.assert_called_once()
def test_context_manager_enter_with_auth_error(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test context manager enter with auth error and retry."""
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Mock parent class __enter__ to raise auth error first, then succeed
with patch.object(MCPClient, "__enter__") as mock_parent_enter:
mock_parent_enter.side_effect = [MCPAuthError("Auth failed"), client]
result = client.__enter__()
assert result == client
assert mock_parent_enter.call_count == 2
auth_callback.assert_called_once()
def test_context_manager_exit(self):
"""Test context manager exit."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with patch.object(client, "cleanup") as mock_cleanup:
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
mock_cleanup.assert_called_once()
def test_list_tools_not_initialized(self):
"""Test list_tools when client not initialized."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.list_tools()
assert "Session not initialized" in str(exc_info.value)
def test_list_tools_success(self):
"""Test successful list_tools call."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
expected_tools = [
Tool(
name="test-tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
annotations=ToolAnnotations(title="Test Tool"),
)
]
# Mock the parent class list_tools method
with patch.object(MCPClient, "list_tools", return_value=expected_tools):
result = client.list_tools()
assert result == expected_tools
def test_list_tools_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test list_tools with auth retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
expected_tools = [Tool(name="test-tool", description="A test tool", inputSchema={})]
# Mock parent class list_tools to raise auth error first, then succeed
with patch.object(MCPClient, "list_tools") as mock_list_tools:
mock_list_tools.side_effect = [MCPAuthError("Auth failed"), expected_tools]
result = client.list_tools()
assert result == expected_tools
assert mock_list_tools.call_count == 2
auth_callback.assert_called_once()
def test_invoke_tool_not_initialized(self):
"""Test invoke_tool when client not initialized."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.invoke_tool("test-tool", {"arg": "value"})
assert "Session not initialized" in str(exc_info.value)
def test_invoke_tool_success(self):
"""Test successful invoke_tool call."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
expected_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")], isError=False
)
# Mock the parent class invoke_tool method
with patch.object(MCPClient, "invoke_tool", return_value=expected_result) as mock_invoke:
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_invoke.assert_called_once_with("test-tool", {"arg": "value"})
def test_invoke_tool_with_auth_retry(self, mock_provider_entity, mock_mcp_service, auth_callback):
"""Test invoke_tool with auth retry."""
client = MCPClientWithAuthRetry(
server_url="http://test.example.com",
provider_entity=mock_provider_entity,
auth_callback=auth_callback,
mcp_service=mock_mcp_service,
)
# Configure new provider with token
new_provider = Mock(spec=MCPProviderEntity)
new_provider.retrieve_tokens.return_value = Mock(
access_token="new-token", token_type="Bearer", expires_in=3600, refresh_token=None
)
mock_mcp_service.get_provider_entity.return_value = new_provider
expected_result = CallToolResult(content=[TextContent(type="text", text="Success")], isError=False)
# Mock parent class invoke_tool to raise auth error first, then succeed
with patch.object(MCPClient, "invoke_tool") as mock_invoke_tool:
mock_invoke_tool.side_effect = [MCPAuthError("Auth failed"), expected_result]
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
assert mock_invoke_tool.call_count == 2
mock_invoke_tool.assert_called_with("test-tool", {"arg": "value"})
auth_callback.assert_called_once()
def test_cleanup(self):
"""Test cleanup method."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
# Mock the parent class cleanup method
with patch.object(MCPClient, "cleanup") as mock_cleanup:
client.cleanup()
mock_cleanup.assert_called_once()
def test_cleanup_no_client(self):
"""Test cleanup when no client exists."""
client = MCPClientWithAuthRetry(server_url="http://test.example.com")
# Should not raise
client.cleanup()
# Since MCPClientWithAuthRetry inherits from MCPClient,
# it doesn't have a _client attribute. The test should just
# verify that cleanup can be called without error.
assert not hasattr(client, "_client")

View File

@ -0,0 +1,239 @@
"""Unit tests for MCP entities module."""
from unittest.mock import Mock
from core.mcp.entities import (
SUPPORTED_PROTOCOL_VERSIONS,
LifespanContextT,
RequestContext,
SessionT,
)
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
class TestProtocolVersions:
"""Test protocol version constants."""
def test_supported_protocol_versions(self):
"""Test supported protocol versions list."""
assert isinstance(SUPPORTED_PROTOCOL_VERSIONS, list)
assert len(SUPPORTED_PROTOCOL_VERSIONS) >= 3
assert "2024-11-05" in SUPPORTED_PROTOCOL_VERSIONS
assert "2025-03-26" in SUPPORTED_PROTOCOL_VERSIONS
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
def test_latest_protocol_version_is_supported(self):
"""Test that latest protocol version is in supported versions."""
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
class TestRequestContext:
"""Test RequestContext dataclass."""
def test_request_context_creation(self):
"""Test creating a RequestContext instance."""
mock_session = Mock(spec=BaseSession)
mock_lifespan = {"key": "value"}
mock_meta = RequestParams.Meta(progressToken="test-token")
context = RequestContext(
request_id="test-request-123",
meta=mock_meta,
session=mock_session,
lifespan_context=mock_lifespan,
)
assert context.request_id == "test-request-123"
assert context.meta == mock_meta
assert context.session == mock_session
assert context.lifespan_context == mock_lifespan
def test_request_context_with_none_meta(self):
"""Test creating RequestContext with None meta."""
mock_session = Mock(spec=BaseSession)
context = RequestContext(
request_id=42, # Can be int or string
meta=None,
session=mock_session,
lifespan_context=None,
)
assert context.request_id == 42
assert context.meta is None
assert context.session == mock_session
assert context.lifespan_context is None
def test_request_context_attributes(self):
"""Test RequestContext attributes are accessible."""
mock_session = Mock(spec=BaseSession)
context = RequestContext(
request_id="test-123",
meta=None,
session=mock_session,
lifespan_context=None,
)
# Verify attributes are accessible
assert hasattr(context, "request_id")
assert hasattr(context, "meta")
assert hasattr(context, "session")
assert hasattr(context, "lifespan_context")
# Verify values
assert context.request_id == "test-123"
assert context.meta is None
assert context.session == mock_session
assert context.lifespan_context is None
def test_request_context_generic_typing(self):
"""Test RequestContext with different generic types."""
# Create a mock session with specific type
mock_session = Mock(spec=BaseSession)
# Create context with string lifespan context
context_str = RequestContext[BaseSession, str](
request_id="test-1",
meta=None,
session=mock_session,
lifespan_context="string-context",
)
assert isinstance(context_str.lifespan_context, str)
# Create context with dict lifespan context
context_dict = RequestContext[BaseSession, dict](
request_id="test-2",
meta=None,
session=mock_session,
lifespan_context={"key": "value"},
)
assert isinstance(context_dict.lifespan_context, dict)
# Create context with custom object lifespan context
class CustomLifespan:
def __init__(self, data):
self.data = data
custom_lifespan = CustomLifespan("test-data")
context_custom = RequestContext[BaseSession, CustomLifespan](
request_id="test-3",
meta=None,
session=mock_session,
lifespan_context=custom_lifespan,
)
assert isinstance(context_custom.lifespan_context, CustomLifespan)
assert context_custom.lifespan_context.data == "test-data"
def test_request_context_with_progress_meta(self):
"""Test RequestContext with progress metadata."""
mock_session = Mock(spec=BaseSession)
progress_meta = RequestParams.Meta(progressToken="progress-123")
context = RequestContext(
request_id="req-456",
meta=progress_meta,
session=mock_session,
lifespan_context=None,
)
assert context.meta is not None
assert context.meta.progressToken == "progress-123"
def test_request_context_equality(self):
"""Test RequestContext equality comparison."""
mock_session1 = Mock(spec=BaseSession)
mock_session2 = Mock(spec=BaseSession)
context1 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session1,
lifespan_context="context",
)
context2 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session1,
lifespan_context="context",
)
context3 = RequestContext(
request_id="test-456",
meta=None,
session=mock_session1,
lifespan_context="context",
)
# Same values should be equal
assert context1 == context2
# Different request_id should not be equal
assert context1 != context3
# Different session should not be equal
context4 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session2,
lifespan_context="context",
)
assert context1 != context4
def test_request_context_repr(self):
"""Test RequestContext string representation."""
mock_session = Mock(spec=BaseSession)
mock_session.__repr__ = Mock(return_value="<MockSession>")
context = RequestContext(
request_id="test-123",
meta=None,
session=mock_session,
lifespan_context={"data": "test"},
)
repr_str = repr(context)
assert "RequestContext" in repr_str
assert "test-123" in repr_str
assert "MockSession" in repr_str
class TestTypeVariables:
"""Test type variables defined in the module."""
def test_session_type_var(self):
"""Test SessionT type variable."""
# Create a custom session class
class CustomSession(BaseSession):
pass
# Use in generic context
def process_session(session: SessionT) -> SessionT:
return session
mock_session = Mock(spec=CustomSession)
result = process_session(mock_session)
assert result == mock_session
def test_lifespan_context_type_var(self):
"""Test LifespanContextT type variable."""
# Use in generic context
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
return context
# Test with different types
str_context = "string-context"
assert process_lifespan(str_context) == str_context
dict_context = {"key": "value"}
assert process_lifespan(dict_context) == dict_context
class CustomContext:
pass
custom_context = CustomContext()
assert process_lifespan(custom_context) == custom_context

View File

@ -0,0 +1,205 @@
"""Unit tests for MCP error classes."""
import pytest
from core.mcp.error import MCPAuthError, MCPConnectionError, MCPError
class TestMCPError:
"""Test MCPError base exception class."""
def test_mcp_error_creation(self):
"""Test creating MCPError instance."""
error = MCPError("Test error message")
assert str(error) == "Test error message"
assert isinstance(error, Exception)
def test_mcp_error_inheritance(self):
"""Test MCPError inherits from Exception."""
error = MCPError()
assert isinstance(error, Exception)
assert type(error).__name__ == "MCPError"
def test_mcp_error_with_empty_message(self):
"""Test MCPError with empty message."""
error = MCPError()
assert str(error) == ""
def test_mcp_error_raise(self):
"""Test raising MCPError."""
with pytest.raises(MCPError) as exc_info:
raise MCPError("Something went wrong")
assert str(exc_info.value) == "Something went wrong"
class TestMCPConnectionError:
"""Test MCPConnectionError exception class."""
def test_mcp_connection_error_creation(self):
"""Test creating MCPConnectionError instance."""
error = MCPConnectionError("Connection failed")
assert str(error) == "Connection failed"
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_connection_error_inheritance(self):
"""Test MCPConnectionError inheritance chain."""
error = MCPConnectionError()
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_connection_error_raise(self):
"""Test raising MCPConnectionError."""
with pytest.raises(MCPConnectionError) as exc_info:
raise MCPConnectionError("Unable to connect to server")
assert str(exc_info.value) == "Unable to connect to server"
def test_mcp_connection_error_catch_as_mcp_error(self):
"""Test catching MCPConnectionError as MCPError."""
with pytest.raises(MCPError) as exc_info:
raise MCPConnectionError("Connection issue")
assert isinstance(exc_info.value, MCPConnectionError)
assert str(exc_info.value) == "Connection issue"
class TestMCPAuthError:
"""Test MCPAuthError exception class."""
def test_mcp_auth_error_creation(self):
"""Test creating MCPAuthError instance."""
error = MCPAuthError("Authentication failed")
assert str(error) == "Authentication failed"
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_auth_error_inheritance(self):
"""Test MCPAuthError inheritance chain."""
error = MCPAuthError()
assert isinstance(error, MCPAuthError)
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_auth_error_raise(self):
"""Test raising MCPAuthError."""
with pytest.raises(MCPAuthError) as exc_info:
raise MCPAuthError("Invalid credentials")
assert str(exc_info.value) == "Invalid credentials"
def test_mcp_auth_error_catch_hierarchy(self):
"""Test catching MCPAuthError at different levels."""
# Catch as MCPAuthError
with pytest.raises(MCPAuthError) as exc_info:
raise MCPAuthError("Auth specific error")
assert str(exc_info.value) == "Auth specific error"
# Catch as MCPConnectionError
with pytest.raises(MCPConnectionError) as exc_info:
raise MCPAuthError("Auth connection error")
assert isinstance(exc_info.value, MCPAuthError)
assert str(exc_info.value) == "Auth connection error"
# Catch as MCPError
with pytest.raises(MCPError) as exc_info:
raise MCPAuthError("Auth base error")
assert isinstance(exc_info.value, MCPAuthError)
assert str(exc_info.value) == "Auth base error"
class TestErrorHierarchy:
"""Test the complete error hierarchy."""
def test_exception_hierarchy(self):
"""Test the complete exception hierarchy."""
# Create instances
base_error = MCPError("base")
connection_error = MCPConnectionError("connection")
auth_error = MCPAuthError("auth")
# Test type relationships
assert not isinstance(base_error, MCPConnectionError)
assert not isinstance(base_error, MCPAuthError)
assert isinstance(connection_error, MCPError)
assert not isinstance(connection_error, MCPAuthError)
assert isinstance(auth_error, MCPError)
assert isinstance(auth_error, MCPConnectionError)
def test_error_handling_patterns(self):
"""Test common error handling patterns."""
def raise_auth_error():
raise MCPAuthError("401 Unauthorized")
def raise_connection_error():
raise MCPConnectionError("Connection timeout")
def raise_base_error():
raise MCPError("Generic error")
# Pattern 1: Catch specific errors first
errors_caught = []
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
try:
error_func()
except MCPAuthError:
errors_caught.append("auth")
except MCPConnectionError:
errors_caught.append("connection")
except MCPError:
errors_caught.append("base")
assert errors_caught == ["auth", "connection", "base"]
# Pattern 2: Catch all as base error
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
with pytest.raises(MCPError) as exc_info:
error_func()
assert isinstance(exc_info.value, MCPError)
def test_error_with_cause(self):
"""Test errors with cause (chained exceptions)."""
original_error = ValueError("Original error")
def raise_chained_error():
try:
raise original_error
except ValueError as e:
raise MCPConnectionError("Connection failed") from e
with pytest.raises(MCPConnectionError) as exc_info:
raise_chained_error()
assert str(exc_info.value) == "Connection failed"
assert exc_info.value.__cause__ == original_error
def test_error_comparison(self):
"""Test error instance comparison."""
error1 = MCPError("Test message")
error2 = MCPError("Test message")
error3 = MCPError("Different message")
# Errors are not equal even with same message (different instances)
assert error1 != error2
assert error1 != error3
# But they have the same type
assert type(error1) == type(error2) == type(error3)
def test_error_representation(self):
"""Test error string representation."""
base_error = MCPError("Base error message")
connection_error = MCPConnectionError("Connection error message")
auth_error = MCPAuthError("Auth error message")
assert repr(base_error) == "MCPError('Base error message')"
assert repr(connection_error) == "MCPConnectionError('Connection error message')"
assert repr(auth_error) == "MCPAuthError('Auth error message')"

View File

@ -0,0 +1,382 @@
"""Unit tests for MCP client."""
from contextlib import ExitStack
from types import TracebackType
from unittest.mock import Mock, patch
import pytest
from core.mcp.error import MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations
class TestMCPClient:
"""Test suite for MCPClient."""
def test_init(self):
"""Test client initialization."""
client = MCPClient(
server_url="http://test.example.com/mcp",
headers={"Authorization": "Bearer test"},
timeout=30.0,
sse_read_timeout=60.0,
)
assert client.server_url == "http://test.example.com/mcp"
assert client.headers == {"Authorization": "Bearer test"}
assert client.timeout == 30.0
assert client.sse_read_timeout == 60.0
assert client._session is None
assert isinstance(client._exit_stack, ExitStack)
assert client._initialized is False
def test_init_defaults(self):
"""Test client initialization with defaults."""
client = MCPClient(server_url="http://test.example.com")
assert client.server_url == "http://test.example.com"
assert client.headers == {}
assert client.timeout is None
assert client.sse_read_timeout is None
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_mcp_url(self, mock_client_session, mock_streamable_client):
"""Test initialization with MCP URL."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/mcp")
client._initialize()
# Verify streamable client was called
mock_streamable_client.assert_called_once_with(
url="http://test.example.com/mcp",
headers={},
timeout=None,
sse_read_timeout=None,
)
# Verify session was created
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_sse_url(self, mock_client_session, mock_sse_client):
"""Test initialization with SSE URL."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/sse")
client._initialize()
# Verify SSE client was called
mock_sse_client.assert_called_once_with(
url="http://test.example.com/sse",
headers={},
timeout=None,
sse_read_timeout=None,
)
# Verify session was created
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_unknown_method_fallback_to_sse(
self, mock_client_session, mock_streamable_client, mock_sse_client
):
"""Test initialization with unknown method falls back to SSE."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/unknown")
client._initialize()
# Verify SSE client was tried
mock_sse_client.assert_called_once()
mock_streamable_client.assert_not_called()
# Verify session was created
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_fallback_from_sse_to_mcp(self, mock_client_session, mock_streamable_client, mock_sse_client):
"""Test initialization falls back from SSE to MCP on connection error."""
# Setup SSE to fail
mock_sse_client.side_effect = MCPConnectionError("SSE connection failed")
# Setup MCP to succeed
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/unknown")
client._initialize()
# Verify both were tried
mock_sse_client.assert_called_once()
mock_streamable_client.assert_called_once()
# Verify session was created with MCP
assert client._session == mock_session
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_connect_server_mcp(self, mock_client_session, mock_streamable_client):
"""Test connect_server with MCP method."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com")
client.connect_server(mock_streamable_client, "mcp")
# Verify correct streams were passed
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_connect_server_sse(self, mock_client_session, mock_sse_client):
"""Test connect_server with SSE method."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com")
client.connect_server(mock_sse_client, "sse")
# Verify correct streams were passed
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
def test_context_manager_enter(self):
"""Test context manager enter."""
client = MCPClient(server_url="http://test.example.com")
with patch.object(client, "_initialize") as mock_initialize:
result = client.__enter__()
assert result == client
assert client._initialized is True
mock_initialize.assert_called_once()
def test_context_manager_exit(self):
"""Test context manager exit."""
client = MCPClient(server_url="http://test.example.com")
with patch.object(client, "cleanup") as mock_cleanup:
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
mock_cleanup.assert_called_once()
def test_list_tools_not_initialized(self):
"""Test list_tools when session not initialized."""
client = MCPClient(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.list_tools()
assert "Session not initialized" in str(exc_info.value)
def test_list_tools_success(self):
"""Test successful list_tools call."""
client = MCPClient(server_url="http://test.example.com")
# Setup mock session
mock_session = Mock()
expected_tools = [
Tool(
name="test-tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
annotations=ToolAnnotations(title="Test Tool"),
)
]
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
client._session = mock_session
result = client.list_tools()
assert result == expected_tools
mock_session.list_tools.assert_called_once()
def test_invoke_tool_not_initialized(self):
"""Test invoke_tool when session not initialized."""
client = MCPClient(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.invoke_tool("test-tool", {"arg": "value"})
assert "Session not initialized" in str(exc_info.value)
def test_invoke_tool_success(self):
"""Test successful invoke_tool call."""
client = MCPClient(server_url="http://test.example.com")
# Setup mock session
mock_session = Mock()
expected_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
isError=False,
)
mock_session.call_tool.return_value = expected_result
client._session = mock_session
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_session.call_tool.assert_called_once_with("test-tool", {"arg": "value"})
def test_cleanup(self):
"""Test cleanup method."""
client = MCPClient(server_url="http://test.example.com")
mock_exit_stack = Mock(spec=ExitStack)
client._exit_stack = mock_exit_stack
client._session = Mock()
client._initialized = True
client.cleanup()
mock_exit_stack.close.assert_called_once()
assert client._session is None
assert client._initialized is False
def test_cleanup_with_error(self):
"""Test cleanup method with error."""
client = MCPClient(server_url="http://test.example.com")
mock_exit_stack = Mock(spec=ExitStack)
mock_exit_stack.close.side_effect = Exception("Cleanup error")
client._exit_stack = mock_exit_stack
client._session = Mock()
client._initialized = True
with pytest.raises(ValueError) as exc_info:
client.cleanup()
assert "Error during cleanup: Cleanup error" in str(exc_info.value)
assert client._session is None
assert client._initialized is False
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_full_context_manager_flow(self, mock_client_session, mock_streamable_client):
"""Test full context manager flow."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
expected_tools = [Tool(name="test-tool", description="Test", inputSchema={})]
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
with MCPClient(server_url="http://test.example.com/mcp") as client:
assert client._initialized is True
assert client._session == mock_session
# Test tool operations
tools = client.list_tools()
assert tools == expected_tools
# After exit, should be cleaned up
assert client._initialized is False
assert client._session is None
def test_headers_passed_to_clients(self):
"""Test that headers are properly passed to underlying clients."""
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
}
with patch("core.mcp.mcp_client.streamablehttp_client") as mock_streamable_client:
with patch("core.mcp.mcp_client.ClientSession") as mock_client_session:
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(
server_url="http://test.example.com/mcp",
headers=custom_headers,
timeout=30.0,
sse_read_timeout=60.0,
)
client._initialize()
# Verify headers were passed
mock_streamable_client.assert_called_once_with(
url="http://test.example.com/mcp",
headers=custom_headers,
timeout=30.0,
sse_read_timeout=60.0,
)

View File

@ -0,0 +1,492 @@
"""Unit tests for MCP types module."""
import pytest
from pydantic import ValidationError
from core.mcp.types import (
INTERNAL_ERROR,
INVALID_PARAMS,
INVALID_REQUEST,
LATEST_PROTOCOL_VERSION,
METHOD_NOT_FOUND,
PARSE_ERROR,
SERVER_LATEST_PROTOCOL_VERSION,
Annotations,
CallToolRequest,
CallToolRequestParams,
CallToolResult,
ClientCapabilities,
CompleteRequest,
CompleteRequestParams,
CompleteResult,
Completion,
CompletionArgument,
CompletionContext,
ErrorData,
ImageContent,
Implementation,
InitializeRequest,
InitializeRequestParams,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ListToolsRequest,
ListToolsResult,
OAuthClientInformation,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
PingRequest,
ProgressNotification,
ProgressNotificationParams,
PromptReference,
RequestParams,
ResourceTemplateReference,
Result,
ServerCapabilities,
TextContent,
Tool,
ToolAnnotations,
)
class TestConstants:
"""Test module constants."""
def test_protocol_versions(self):
"""Test protocol version constants."""
assert LATEST_PROTOCOL_VERSION == "2025-03-26"
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
def test_error_codes(self):
"""Test JSON-RPC error code constants."""
assert PARSE_ERROR == -32700
assert INVALID_REQUEST == -32600
assert METHOD_NOT_FOUND == -32601
assert INVALID_PARAMS == -32602
assert INTERNAL_ERROR == -32603
class TestRequestParams:
"""Test RequestParams and related classes."""
def test_request_params_basic(self):
"""Test basic RequestParams creation."""
params = RequestParams()
assert params.meta is None
def test_request_params_with_meta(self):
"""Test RequestParams with meta."""
meta = RequestParams.Meta(progressToken="test-token")
params = RequestParams(_meta=meta)
assert params.meta is not None
assert params.meta.progressToken == "test-token"
def test_request_params_meta_extra_fields(self):
"""Test RequestParams.Meta allows extra fields."""
meta = RequestParams.Meta(progressToken="token", customField="value")
assert meta.progressToken == "token"
assert meta.customField == "value" # type: ignore
def test_request_params_serialization(self):
"""Test RequestParams serialization with _meta alias."""
meta = RequestParams.Meta(progressToken="test")
params = RequestParams(_meta=meta)
# Model dump should use the alias
dumped = params.model_dump(by_alias=True)
assert "_meta" in dumped
assert dumped["_meta"] is not None
assert dumped["_meta"]["progressToken"] == "test"
class TestJSONRPCMessages:
"""Test JSON-RPC message types."""
def test_jsonrpc_request(self):
"""Test JSONRPCRequest creation and validation."""
request = JSONRPCRequest(jsonrpc="2.0", id="test-123", method="test_method", params={"key": "value"})
assert request.jsonrpc == "2.0"
assert request.id == "test-123"
assert request.method == "test_method"
assert request.params == {"key": "value"}
def test_jsonrpc_request_numeric_id(self):
"""Test JSONRPCRequest with numeric ID."""
request = JSONRPCRequest(jsonrpc="2.0", id=123, method="test", params=None)
assert request.id == 123
def test_jsonrpc_notification(self):
"""Test JSONRPCNotification creation."""
notification = JSONRPCNotification(jsonrpc="2.0", method="notification_method", params={"data": "test"})
assert notification.jsonrpc == "2.0"
assert notification.method == "notification_method"
assert not hasattr(notification, "id") # Notifications don't have ID
def test_jsonrpc_response(self):
"""Test JSONRPCResponse creation."""
response = JSONRPCResponse(jsonrpc="2.0", id="req-123", result={"success": True})
assert response.jsonrpc == "2.0"
assert response.id == "req-123"
assert response.result == {"success": True}
def test_jsonrpc_error(self):
"""Test JSONRPCError creation."""
error_data = ErrorData(code=INVALID_PARAMS, message="Invalid parameters", data={"field": "missing"})
error = JSONRPCError(jsonrpc="2.0", id="req-123", error=error_data)
assert error.jsonrpc == "2.0"
assert error.id == "req-123"
assert error.error.code == INVALID_PARAMS
assert error.error.message == "Invalid parameters"
assert error.error.data == {"field": "missing"}
def test_jsonrpc_message_parsing(self):
"""Test JSONRPCMessage parsing different message types."""
# Parse request
request_json = '{"jsonrpc": "2.0", "id": 1, "method": "test", "params": null}'
msg = JSONRPCMessage.model_validate_json(request_json)
assert isinstance(msg.root, JSONRPCRequest)
# Parse response
response_json = '{"jsonrpc": "2.0", "id": 1, "result": {"data": "test"}}'
msg = JSONRPCMessage.model_validate_json(response_json)
assert isinstance(msg.root, JSONRPCResponse)
# Parse error
error_json = '{"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "Invalid Request"}}'
msg = JSONRPCMessage.model_validate_json(error_json)
assert isinstance(msg.root, JSONRPCError)
class TestCapabilities:
"""Test capability classes."""
def test_client_capabilities(self):
"""Test ClientCapabilities creation."""
caps = ClientCapabilities(
experimental={"feature": {"enabled": True}},
sampling={"model_config": {"extra": "allow"}},
roots={"listChanged": True},
)
assert caps.experimental == {"feature": {"enabled": True}}
assert caps.sampling is not None
assert caps.roots.listChanged is True # type: ignore
def test_server_capabilities(self):
"""Test ServerCapabilities creation."""
caps = ServerCapabilities(
tools={"listChanged": True},
resources={"subscribe": True, "listChanged": False},
prompts={"listChanged": True},
logging={},
completions={},
)
assert caps.tools.listChanged is True # type: ignore
assert caps.resources.subscribe is True # type: ignore
assert caps.resources.listChanged is False # type: ignore
class TestInitialization:
"""Test initialization request/response types."""
def test_initialize_request(self):
"""Test InitializeRequest creation."""
client_info = Implementation(name="test-client", version="1.0.0")
capabilities = ClientCapabilities()
params = InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=capabilities, clientInfo=client_info
)
request = InitializeRequest(params=params)
assert request.method == "initialize"
assert request.params.protocolVersion == LATEST_PROTOCOL_VERSION
assert request.params.clientInfo.name == "test-client"
def test_initialize_result(self):
"""Test InitializeResult creation."""
server_info = Implementation(name="test-server", version="1.0.0")
capabilities = ServerCapabilities()
result = InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
serverInfo=server_info,
instructions="Welcome to test server",
)
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
assert result.serverInfo.name == "test-server"
assert result.instructions == "Welcome to test server"
class TestTools:
"""Test tool-related types."""
def test_tool_creation(self):
"""Test Tool creation with all fields."""
tool = Tool(
name="test_tool",
title="Test Tool",
description="A tool for testing",
inputSchema={"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]},
outputSchema={"type": "object", "properties": {"result": {"type": "string"}}},
annotations=ToolAnnotations(
title="Test Tool", readOnlyHint=False, destructiveHint=False, idempotentHint=True
),
)
assert tool.name == "test_tool"
assert tool.title == "Test Tool"
assert tool.description == "A tool for testing"
assert tool.inputSchema["properties"]["input"]["type"] == "string"
assert tool.annotations.idempotentHint is True
def test_call_tool_request(self):
"""Test CallToolRequest creation."""
params = CallToolRequestParams(name="test_tool", arguments={"input": "test value"})
request = CallToolRequest(params=params)
assert request.method == "tools/call"
assert request.params.name == "test_tool"
assert request.params.arguments == {"input": "test value"}
def test_call_tool_result(self):
"""Test CallToolResult creation."""
result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
structuredContent={"status": "success", "data": "test"},
isError=False,
)
assert len(result.content) == 1
assert result.content[0].text == "Tool executed successfully" # type: ignore
assert result.structuredContent == {"status": "success", "data": "test"}
assert result.isError is False
def test_list_tools_request(self):
"""Test ListToolsRequest creation."""
request = ListToolsRequest()
assert request.method == "tools/list"
def test_list_tools_result(self):
"""Test ListToolsResult creation."""
tool1 = Tool(name="tool1", inputSchema={})
tool2 = Tool(name="tool2", inputSchema={})
result = ListToolsResult(tools=[tool1, tool2])
assert len(result.tools) == 2
assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool2"
class TestContent:
"""Test content types."""
def test_text_content(self):
"""Test TextContent creation."""
annotations = Annotations(audience=["user"], priority=0.8)
content = TextContent(type="text", text="Hello, world!", annotations=annotations)
assert content.type == "text"
assert content.text == "Hello, world!"
assert content.annotations is not None
assert content.annotations.priority == 0.8
def test_image_content(self):
"""Test ImageContent creation."""
content = ImageContent(type="image", data="base64encodeddata", mimeType="image/png")
assert content.type == "image"
assert content.data == "base64encodeddata"
assert content.mimeType == "image/png"
class TestOAuth:
"""Test OAuth-related types."""
def test_oauth_client_metadata(self):
"""Test OAuthClientMetadata creation."""
metadata = OAuthClientMetadata(
client_name="Test Client",
redirect_uris=["https://example.com/callback"],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
client_uri="https://example.com",
scope="read write",
)
assert metadata.client_name == "Test Client"
assert len(metadata.redirect_uris) == 1
assert "authorization_code" in metadata.grant_types
def test_oauth_client_information(self):
"""Test OAuthClientInformation creation."""
info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
assert info.client_id == "test-client-id"
assert info.client_secret == "test-secret"
def test_oauth_client_information_without_secret(self):
"""Test OAuthClientInformation without secret."""
info = OAuthClientInformation(client_id="public-client")
assert info.client_id == "public-client"
assert info.client_secret is None
def test_oauth_tokens(self):
"""Test OAuthTokens creation."""
tokens = OAuthTokens(
access_token="access-token-123",
token_type="Bearer",
expires_in=3600,
refresh_token="refresh-token-456",
scope="read write",
)
assert tokens.access_token == "access-token-123"
assert tokens.token_type == "Bearer"
assert tokens.expires_in == 3600
assert tokens.refresh_token == "refresh-token-456"
assert tokens.scope == "read write"
def test_oauth_metadata(self):
"""Test OAuthMetadata creation."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
response_types_supported=["code", "token"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["plain", "S256"],
)
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert "code" in metadata.response_types_supported
assert "S256" in metadata.code_challenge_methods_supported
class TestNotifications:
"""Test notification types."""
def test_progress_notification(self):
"""Test ProgressNotification creation."""
params = ProgressNotificationParams(
progressToken="progress-123", progress=50.0, total=100.0, message="Processing... 50%"
)
notification = ProgressNotification(params=params)
assert notification.method == "notifications/progress"
assert notification.params.progressToken == "progress-123"
assert notification.params.progress == 50.0
assert notification.params.total == 100.0
assert notification.params.message == "Processing... 50%"
def test_ping_request(self):
"""Test PingRequest creation."""
request = PingRequest()
assert request.method == "ping"
assert request.params is None
class TestCompletion:
"""Test completion-related types."""
def test_completion_context(self):
"""Test CompletionContext creation."""
context = CompletionContext(arguments={"template_var": "value"})
assert context.arguments == {"template_var": "value"}
def test_resource_template_reference(self):
"""Test ResourceTemplateReference creation."""
ref = ResourceTemplateReference(type="ref/resource", uri="file:///path/to/{filename}")
assert ref.type == "ref/resource"
assert ref.uri == "file:///path/to/{filename}"
def test_prompt_reference(self):
"""Test PromptReference creation."""
ref = PromptReference(type="ref/prompt", name="test_prompt")
assert ref.type == "ref/prompt"
assert ref.name == "test_prompt"
def test_complete_request(self):
"""Test CompleteRequest creation."""
ref = PromptReference(type="ref/prompt", name="test_prompt")
arg = CompletionArgument(name="arg1", value="val")
params = CompleteRequestParams(ref=ref, argument=arg, context=CompletionContext(arguments={"key": "value"}))
request = CompleteRequest(params=params)
assert request.method == "completion/complete"
assert request.params.ref.name == "test_prompt" # type: ignore
assert request.params.argument.name == "arg1"
def test_complete_result(self):
"""Test CompleteResult creation."""
completion = Completion(values=["option1", "option2", "option3"], total=10, hasMore=True)
result = CompleteResult(completion=completion)
assert len(result.completion.values) == 3
assert result.completion.total == 10
assert result.completion.hasMore is True
class TestValidation:
"""Test validation of various types."""
def test_invalid_jsonrpc_version(self):
"""Test invalid JSON-RPC version validation."""
with pytest.raises(ValidationError):
JSONRPCRequest(
jsonrpc="1.0", # Invalid version
id=1,
method="test",
)
def test_tool_annotations_validation(self):
"""Test ToolAnnotations with invalid values."""
# Valid annotations
annotations = ToolAnnotations(
title="Test", readOnlyHint=True, destructiveHint=False, idempotentHint=True, openWorldHint=False
)
assert annotations.title == "Test"
def test_extra_fields_allowed(self):
"""Test that extra fields are allowed in models."""
# Most models should allow extra fields
tool = Tool(
name="test",
inputSchema={},
customField="allowed", # type: ignore
)
assert tool.customField == "allowed" # type: ignore
def test_result_meta_alias(self):
"""Test Result model with _meta alias."""
# Create with the field name (not alias)
result = Result(_meta={"key": "value"})
# Verify the field is set correctly
assert result.meta == {"key": "value"}
# Dump with alias
dumped = result.model_dump(by_alias=True)
assert "_meta" in dumped
assert dumped["_meta"] == {"key": "value"}

View File

@ -0,0 +1,355 @@
"""Unit tests for MCP utils module."""
import json
from collections.abc import Generator
from unittest.mock import MagicMock, Mock, patch
import httpx
import httpx_sse
import pytest
from core.mcp.utils import (
STATUS_FORCELIST,
create_mcp_error_response,
create_ssrf_proxy_mcp_http_client,
ssrf_proxy_sse_connect,
)
class TestConstants:
"""Test module constants."""
def test_status_forcelist(self):
"""Test STATUS_FORCELIST contains expected HTTP status codes."""
assert STATUS_FORCELIST == [429, 500, 502, 503, 504]
assert 429 in STATUS_FORCELIST # Too Many Requests
assert 500 in STATUS_FORCELIST # Internal Server Error
assert 502 in STATUS_FORCELIST # Bad Gateway
assert 503 in STATUS_FORCELIST # Service Unavailable
assert 504 in STATUS_FORCELIST # Gateway Timeout
class TestCreateSSRFProxyMCPHTTPClient:
"""Test create_ssrf_proxy_mcp_http_client function."""
@patch("core.mcp.utils.dify_config")
def test_create_client_with_all_url_proxy(self, mock_config):
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
client = create_ssrf_proxy_mcp_http_client(
headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30.0)
)
assert isinstance(client, httpx.Client)
assert client.headers["Authorization"] == "Bearer token"
assert client.timeout.connect == 30.0
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_with_http_https_proxies(self, mock_config):
"""Test client creation with separate HTTP/HTTPS proxies."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = "http://http-proxy.example.com:8080"
mock_config.SSRF_PROXY_HTTPS_URL = "http://https-proxy.example.com:8443"
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = False
client = create_ssrf_proxy_mcp_http_client()
assert isinstance(client, httpx.Client)
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_without_proxy(self, mock_config):
"""Test client creation without proxy configuration."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = None
mock_config.SSRF_PROXY_HTTPS_URL = None
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
headers = {"X-Custom-Header": "value"}
timeout = httpx.Timeout(timeout=30.0, connect=5.0, read=10.0, write=30.0)
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
assert isinstance(client, httpx.Client)
assert client.headers["X-Custom-Header"] == "value"
assert client.timeout.connect == 5.0
assert client.timeout.read == 10.0
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_default_params(self, mock_config):
"""Test client creation with default parameters."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = None
mock_config.SSRF_PROXY_HTTPS_URL = None
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
client = create_ssrf_proxy_mcp_http_client()
assert isinstance(client, httpx.Client)
# httpx.Client adds default headers, so we just check it's a Headers object
assert isinstance(client.headers, httpx.Headers)
# When no timeout is provided, httpx uses its default timeout
assert client.timeout is not None
# Clean up
client.close()
class TestSSRFProxySSEConnect:
"""Test ssrf_proxy_sse_connect function."""
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with pre-configured client."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
# Call with provided client
result = ssrf_proxy_sse_connect(
"http://example.com/sse", client=mock_client, method="POST", headers={"Authorization": "Bearer token"}
)
# Verify client creation was not called
mock_create_client.assert_not_called()
# Verify connect_sse was called correctly
mock_connect_sse.assert_called_once_with(
mock_client, "POST", "http://example.com/sse", headers={"Authorization": "Bearer token"}
)
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
@patch("core.mcp.utils.dify_config")
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
"""Test SSE connection without pre-configured client."""
# Setup config
mock_config.SSRF_DEFAULT_TIME_OUT = 30.0
mock_config.SSRF_DEFAULT_CONNECT_TIME_OUT = 10.0
mock_config.SSRF_DEFAULT_READ_TIME_OUT = 60.0
mock_config.SSRF_DEFAULT_WRITE_TIME_OUT = 30.0
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
# Call without client
result = ssrf_proxy_sse_connect("http://example.com/sse", headers={"X-Custom": "value"})
# Verify client was created
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args
assert call_args[1]["headers"] == {"X-Custom": "value"}
timeout = call_args[1]["timeout"]
# httpx.Timeout object has these attributes
assert isinstance(timeout, httpx.Timeout)
assert timeout.connect == 10.0
assert timeout.read == 60.0
assert timeout.write == 30.0
# Verify connect_sse was called
mock_connect_sse.assert_called_once_with(
mock_client,
"GET", # Default method
"http://example.com/sse",
)
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with custom timeout."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
custom_timeout = httpx.Timeout(timeout=60.0, read=120.0)
# Call with custom timeout
result = ssrf_proxy_sse_connect("http://example.com/sse", timeout=custom_timeout)
# Verify client was created with custom timeout
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args
assert call_args[1]["timeout"] == custom_timeout
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse):
"""Test SSE connection cleans up client on error."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
# Make connect_sse raise an exception
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
# Call should raise the exception
with pytest.raises(httpx.ConnectError):
ssrf_proxy_sse_connect("http://example.com/sse")
# Verify client was cleaned up
mock_client.close.assert_called_once()
@patch("core.mcp.utils.connect_sse")
def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse):
"""Test SSE connection doesn't clean up provided client on error."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
# Make connect_sse raise an exception
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
# Call should raise the exception
with pytest.raises(httpx.ConnectError):
ssrf_proxy_sse_connect("http://example.com/sse", client=mock_client)
# Verify client was NOT cleaned up (because it was provided)
mock_client.close.assert_not_called()
class TestCreateMCPErrorResponse:
"""Test create_mcp_error_response function."""
def test_create_error_response_basic(self):
"""Test creating basic error response."""
generator = create_mcp_error_response(request_id="req-123", code=-32600, message="Invalid Request")
# Generator should yield bytes
assert isinstance(generator, Generator)
# Get the response
response_bytes = next(generator)
assert isinstance(response_bytes, bytes)
# Parse the response
response_str = response_bytes.decode("utf-8")
response_json = json.loads(response_str)
assert response_json["jsonrpc"] == "2.0"
assert response_json["id"] == "req-123"
assert response_json["error"]["code"] == -32600
assert response_json["error"]["message"] == "Invalid Request"
assert response_json["error"]["data"] is None
# Generator should be exhausted
with pytest.raises(StopIteration):
next(generator)
def test_create_error_response_with_data(self):
"""Test creating error response with additional data."""
error_data = {"field": "username", "reason": "required"}
generator = create_mcp_error_response(
request_id=456, # Numeric ID
code=-32602,
message="Invalid params",
data=error_data,
)
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
assert response_json["id"] == 456
assert response_json["error"]["code"] == -32602
assert response_json["error"]["message"] == "Invalid params"
assert response_json["error"]["data"] == error_data
def test_create_error_response_without_request_id(self):
"""Test creating error response without request ID."""
generator = create_mcp_error_response(request_id=None, code=-32700, message="Parse error")
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
# Should default to ID 1
assert response_json["id"] == 1
assert response_json["error"]["code"] == -32700
assert response_json["error"]["message"] == "Parse error"
def test_create_error_response_with_complex_data(self):
"""Test creating error response with complex error data."""
complex_data = {
"errors": [{"field": "name", "message": "Too short"}, {"field": "email", "message": "Invalid format"}],
"timestamp": "2024-01-01T00:00:00Z",
}
generator = create_mcp_error_response(
request_id="complex-req", code=-32602, message="Validation failed", data=complex_data
)
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
assert response_json["error"]["data"] == complex_data
assert len(response_json["error"]["data"]["errors"]) == 2
def test_create_error_response_encoding(self):
"""Test error response with non-ASCII characters."""
generator = create_mcp_error_response(
request_id="unicode-req",
code=-32603,
message="内部错误", # Chinese characters
data={"details": "エラー詳細"}, # Japanese characters
)
response_bytes = next(generator)
# Should be valid UTF-8
response_str = response_bytes.decode("utf-8")
response_json = json.loads(response_str)
assert response_json["error"]["message"] == "内部错误"
assert response_json["error"]["data"]["details"] == "エラー詳細"
def test_create_error_response_yields_once(self):
"""Test that error response generator yields exactly once."""
generator = create_mcp_error_response(request_id="test", code=-32600, message="Test")
# First yield should work
first_yield = next(generator)
assert isinstance(first_yield, bytes)
# Second yield should raise StopIteration
with pytest.raises(StopIteration):
next(generator)
# Subsequent calls should also raise
with pytest.raises(StopIteration):
next(generator)

View File

@ -180,6 +180,25 @@ class TestMCPToolTransform:
# Set tools data with null description
mock_provider_full.tools = '[{"name": "tool1", "description": null, "inputSchema": {}}]'
# Mock the to_entity and to_api_response methods
mock_entity = Mock()
mock_entity.to_api_response.return_value = {
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
"is_team_authorization": True,
"server_url": "https://*****.com/mcp",
"provider_icon": "icon.png",
"masked_headers": {"Authorization": "Bearer *****"},
"updated_at": 1234567890,
"labels": [],
"author": "Test User",
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
"icon": "icon.png",
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
"masked_credentials": {},
}
mock_provider_full.to_entity.return_value = mock_entity
# Call the method with for_list=True
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=True)
@ -198,6 +217,27 @@ class TestMCPToolTransform:
# Set tools data with description
mock_provider_full.tools = '[{"name": "tool1", "description": "Tool description", "inputSchema": {}}]'
# Mock the to_entity and to_api_response methods
mock_entity = Mock()
mock_entity.to_api_response.return_value = {
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
"is_team_authorization": True,
"server_url": "https://*****.com/mcp",
"provider_icon": "icon.png",
"masked_headers": {"Authorization": "Bearer *****"},
"updated_at": 1234567890,
"labels": [],
"configuration": {"timeout": "30", "sse_read_timeout": "300"},
"original_headers": {"Authorization": "Bearer secret-token"},
"author": "Test User",
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
"icon": "icon.png",
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
"masked_credentials": {},
}
mock_provider_full.to_entity.return_value = mock_entity
# Call the method with for_list=False
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=False)
@ -205,8 +245,9 @@ class TestMCPToolTransform:
assert isinstance(result, ToolProviderApiEntity)
assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False
assert result.server_identifier == "server-identifier-456"
assert result.timeout == 30
assert result.sse_read_timeout == 300
assert result.configuration is not None
assert result.configuration.timeout == 30
assert result.configuration.sse_read_timeout == 300
assert result.original_headers == {"Authorization": "Bearer secret-token"}
assert len(result.tools) == 1
assert result.tools[0].description.en_US == "Tool description"