mirror of
https://github.com/langgenius/dify.git
synced 2026-03-06 16:16:38 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -28,18 +28,20 @@ class TestApiKeyAuthService:
|
||||
mock_binding.provider = self.provider
|
||||
mock_binding.disabled = False
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [mock_binding]
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].tenant_id == self.tenant_id
|
||||
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
|
||||
assert mock_session.scalars.call_count == 1
|
||||
select_arg = mock_session.scalars.call_args[0][0]
|
||||
assert "data_source_api_key_auth_binding" in str(select_arg).lower()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_empty(self, mock_session):
|
||||
"""Test get provider auth list - empty result"""
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
@ -48,13 +50,15 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_filters_disabled(self, mock_session):
|
||||
"""Test get provider auth list - filters disabled items"""
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
# Verify where conditions include disabled.is_(False)
|
||||
where_call = mock_session.query.return_value.where.call_args[0]
|
||||
assert len(where_call) == 2 # tenant_id and disabled filter conditions
|
||||
select_stmt = mock_session.scalars.call_args[0][0]
|
||||
where_clauses = list(getattr(select_stmt, "_where_criteria", []) or [])
|
||||
# Ensure both tenant filter and disabled filter exist
|
||||
where_strs = [str(c).lower() for c in where_clauses]
|
||||
assert any("tenant_id" in s for s in where_strs)
|
||||
assert any("disabled" in s for s in where_strs)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
|
||||
@ -6,8 +6,8 @@ import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@ -26,7 +26,7 @@ class TestAuthIntegration:
|
||||
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
|
||||
"""Test complete authentication flow: request → validation → encryption → storage"""
|
||||
@ -47,7 +47,7 @@ class TestAuthIntegration:
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_cross_component_integration(self, mock_http):
|
||||
"""Test factory → provider → HTTP call integration"""
|
||||
mock_http.return_value = self._create_success_response()
|
||||
@ -63,10 +63,10 @@ class TestAuthIntegration:
|
||||
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [tenant1_binding]
|
||||
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [tenant2_binding]
|
||||
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
|
||||
|
||||
assert len(result1) == 1
|
||||
@ -97,7 +97,7 @@ class TestAuthIntegration:
|
||||
assert "another_secret" not in factory_str
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
|
||||
"""Test concurrent authentication creation safety"""
|
||||
@ -142,31 +142,31 @@ class TestAuthIntegration:
|
||||
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
|
||||
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_http_error_handling(self, mock_http):
|
||||
"""Test proper HTTP error handling"""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = '{"error": "Unauthorized"}'
|
||||
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized")
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
|
||||
mock_http.return_value = mock_response
|
||||
|
||||
# PT012: Split into single statement for pytest.raises
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
with pytest.raises((requests.exceptions.HTTPError, Exception)):
|
||||
with pytest.raises((httpx.HTTPError, Exception)):
|
||||
factory.validate_credentials()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_network_failure_recovery(self, mock_http, mock_session):
|
||||
"""Test system recovery from network failures"""
|
||||
mock_http.side_effect = requests.exceptions.RequestException("Network timeout")
|
||||
mock_http.side_effect = httpx.RequestError("Network timeout")
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
|
||||
|
||||
with pytest.raises(requests.exceptions.RequestException):
|
||||
with pytest.raises(httpx.RequestError):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
@ -64,7 +64,7 @@ class TestFirecrawlAuth:
|
||||
FirecrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -95,7 +95,7 @@ class TestFirecrawlAuth:
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
@ -115,7 +115,7 @@ class TestFirecrawlAuth:
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
@ -134,13 +134,13 @@ class TestFirecrawlAuth:
|
||||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
[
|
||||
(requests.ConnectionError, "Network error"),
|
||||
(requests.Timeout, "Request timeout"),
|
||||
(requests.ReadTimeout, "Read timeout"),
|
||||
(requests.ConnectTimeout, "Connection timeout"),
|
||||
(httpx.ConnectError, "Network error"),
|
||||
(httpx.TimeoutException, "Request timeout"),
|
||||
(httpx.ReadTimeout, "Read timeout"),
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_post.side_effect = exception_type(exception_message)
|
||||
@ -162,7 +162,7 @@ class TestFirecrawlAuth:
|
||||
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -179,12 +179,12 @@ class TestFirecrawlAuth:
|
||||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.requests.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds")
|
||||
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
with pytest.raises(requests.Timeout) as exc_info:
|
||||
with pytest.raises(httpx.TimeoutException) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
|
||||
# Verify the timeout exception is raised with original message
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.jina.jina import JinaAuth
|
||||
|
||||
@ -35,7 +35,7 @@ class TestJinaAuth:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -53,7 +53,7 @@ class TestJinaAuth:
|
||||
json={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_http_402_error(self, mock_post):
|
||||
"""Test handling of 402 Payment Required error"""
|
||||
mock_response = MagicMock()
|
||||
@ -68,7 +68,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_http_409_error(self, mock_post):
|
||||
"""Test handling of 409 Conflict error"""
|
||||
mock_response = MagicMock()
|
||||
@ -83,7 +83,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_http_500_error(self, mock_post):
|
||||
"""Test handling of 500 Internal Server Error"""
|
||||
mock_response = MagicMock()
|
||||
@ -98,7 +98,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
|
||||
"""Test handling of unexpected errors with text response"""
|
||||
mock_response = MagicMock()
|
||||
@ -114,7 +114,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_unexpected_error_without_text(self, mock_post):
|
||||
"""Test handling of unexpected errors without text response"""
|
||||
mock_response = MagicMock()
|
||||
@ -130,15 +130,15 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
|
||||
|
||||
@patch("services.auth.jina.jina.requests.post")
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_network_errors(self, mock_post):
|
||||
"""Test handling of network connection errors"""
|
||||
mock_post.side_effect = requests.ConnectionError("Network error")
|
||||
mock_post.side_effect = httpx.ConnectError("Network error")
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(requests.ConnectionError):
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
auth.validate_credentials()
|
||||
|
||||
def test_should_not_expose_api_key_in_error_messages(self):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from services.auth.watercrawl.watercrawl import WatercrawlAuth
|
||||
|
||||
@ -64,7 +64,7 @@ class TestWatercrawlAuth:
|
||||
WatercrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -87,7 +87,7 @@ class TestWatercrawlAuth:
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
@ -107,7 +107,7 @@ class TestWatercrawlAuth:
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
@ -126,13 +126,13 @@ class TestWatercrawlAuth:
|
||||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
[
|
||||
(requests.ConnectionError, "Network error"),
|
||||
(requests.Timeout, "Request timeout"),
|
||||
(requests.ReadTimeout, "Read timeout"),
|
||||
(requests.ConnectTimeout, "Connection timeout"),
|
||||
(httpx.ConnectError, "Network error"),
|
||||
(httpx.TimeoutException, "Request timeout"),
|
||||
(httpx.ReadTimeout, "Read timeout"),
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_get.side_effect = exception_type(exception_message)
|
||||
@ -154,7 +154,7 @@ class TestWatercrawlAuth:
|
||||
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_get):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -179,7 +179,7 @@ class TestWatercrawlAuth:
|
||||
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
|
||||
"""Test that urljoin is used correctly for URL construction with various base URLs"""
|
||||
mock_response = MagicMock()
|
||||
@ -193,12 +193,12 @@ class TestWatercrawlAuth:
|
||||
# Verify the correct URL was called
|
||||
assert mock_get.call_args[0][0] == expected_url
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.requests.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds")
|
||||
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
with pytest.raises(requests.Timeout) as exc_info:
|
||||
with pytest.raises(httpx.TimeoutException) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
|
||||
# Verify the timeout exception is raised with original message
|
||||
|
||||
@ -10,7 +10,6 @@ from services.account_service import AccountService, RegisterService, TenantServ
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
AccountLoginError,
|
||||
AccountNotFoundError,
|
||||
AccountPasswordError,
|
||||
AccountRegisterError,
|
||||
CurrentPasswordIncorrectError,
|
||||
@ -195,7 +194,7 @@ class TestAccountService:
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
AccountNotFoundError, AccountService.authenticate, "notfound@example.com", "password"
|
||||
AccountPasswordError, AccountService.authenticate, "notfound@example.com", "password"
|
||||
)
|
||||
|
||||
def test_authenticate_account_banned(self, mock_db_dependencies):
|
||||
@ -1370,8 +1369,8 @@ class TestRegisterService:
|
||||
account_id="user-123", email="test@example.com"
|
||||
)
|
||||
|
||||
with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token:
|
||||
# Mock the invitation data returned by _get_invitation_by_token
|
||||
with patch("services.account_service.RegisterService.get_invitation_by_token") as mock_get_invitation_by_token:
|
||||
# Mock the invitation data returned by get_invitation_by_token
|
||||
invitation_data = {
|
||||
"account_id": "user-123",
|
||||
"email": "test@example.com",
|
||||
@ -1503,12 +1502,12 @@ class TestRegisterService:
|
||||
assert result == "member_invite:token:test-token"
|
||||
|
||||
def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies):
|
||||
"""Test _get_invitation_by_token with workspace ID and email."""
|
||||
"""Test get_invitation_by_token with workspace ID and email."""
|
||||
# Setup mock
|
||||
mock_redis_dependencies.get.return_value = b"user-123"
|
||||
|
||||
# Execute test
|
||||
result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com")
|
||||
result = RegisterService.get_invitation_by_token("token-123", "workspace-456", "test@example.com")
|
||||
|
||||
# Verify results
|
||||
assert result is not None
|
||||
@ -1517,7 +1516,7 @@ class TestRegisterService:
|
||||
assert result["workspace_id"] == "workspace-456"
|
||||
|
||||
def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies):
|
||||
"""Test _get_invitation_by_token without workspace ID and email."""
|
||||
"""Test get_invitation_by_token without workspace ID and email."""
|
||||
# Setup mock
|
||||
invitation_data = {
|
||||
"account_id": "user-123",
|
||||
@ -1527,19 +1526,19 @@ class TestRegisterService:
|
||||
mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
|
||||
|
||||
# Execute test
|
||||
result = RegisterService._get_invitation_by_token("token-123")
|
||||
result = RegisterService.get_invitation_by_token("token-123")
|
||||
|
||||
# Verify results
|
||||
assert result is not None
|
||||
assert result == invitation_data
|
||||
|
||||
def test_get_invitation_by_token_no_data(self, mock_redis_dependencies):
|
||||
"""Test _get_invitation_by_token with no data."""
|
||||
"""Test get_invitation_by_token with no data."""
|
||||
# Setup mock
|
||||
mock_redis_dependencies.get.return_value = None
|
||||
|
||||
# Execute test
|
||||
result = RegisterService._get_invitation_by_token("token-123")
|
||||
result = RegisterService.get_invitation_by_token("token-123")
|
||||
|
||||
# Verify results
|
||||
assert result is None
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
# Mock redis_client before importing dataset_service
|
||||
from unittest.mock import Mock, call, patch
|
||||
@ -37,7 +36,7 @@ class DocumentBatchUpdateTestDataFactory:
|
||||
enabled: bool = True,
|
||||
archived: bool = False,
|
||||
indexing_status: str = "completed",
|
||||
completed_at: Optional[datetime.datetime] = None,
|
||||
completed_at: datetime.datetime | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock document with specified attributes."""
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
# Mock redis_client before importing dataset_service
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
@ -23,9 +24,9 @@ class DatasetUpdateTestDataFactory:
|
||||
description: str = "old_description",
|
||||
indexing_technique: str = "high_quality",
|
||||
retrieval_model: str = "old_model",
|
||||
embedding_model_provider: Optional[str] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
collection_binding_id: Optional[str] = None,
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
collection_binding_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
@ -78,7 +79,7 @@ class DatasetUpdateTestDataFactory:
|
||||
@staticmethod
|
||||
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
|
||||
"""Create a mock current user."""
|
||||
current_user = Mock()
|
||||
current_user = create_autospec(Account, instance=True)
|
||||
current_user.current_tenant_id = tenant_id
|
||||
return current_user
|
||||
|
||||
@ -103,6 +104,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
@ -113,6 +115,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"db_session": mock_db,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
"has_dataset_same_name": has_dataset_same_name,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
@ -135,7 +138,9 @@ class TestDatasetServiceUpdateDataset:
|
||||
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||
) as mock_get_binding,
|
||||
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
|
||||
patch("services.dataset_service.current_user") as mock_current_user,
|
||||
patch(
|
||||
"services.dataset_service.current_user", create_autospec(Account, instance=True)
|
||||
) as mock_current_user,
|
||||
):
|
||||
mock_current_user.current_tenant_id = "tenant-123"
|
||||
yield {
|
||||
@ -187,9 +192,9 @@ class TestDatasetServiceUpdateDataset:
|
||||
"external_knowledge_api_id": "new_api_id",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify permission check was called
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify dataset and binding updates
|
||||
@ -211,6 +216,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
@ -224,6 +230,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
@ -247,6 +254,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"external_knowledge_id": "knowledge_id",
|
||||
"external_knowledge_api_id": "api_id",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
@ -277,6 +285,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify permission check was called
|
||||
@ -317,6 +326,8 @@ class TestDatasetServiceUpdateDataset:
|
||||
"embedding_model": None, # Should be filtered out
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with filtered data
|
||||
@ -353,6 +364,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@ -399,6 +411,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@ -450,6 +463,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@ -502,6 +516,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@ -555,6 +570,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"indexing_technique": "high_quality", # Same as current
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@ -585,6 +601,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
@ -601,6 +618,8 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
update_data = {"name": "new_name"}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
@ -625,6 +644,8 @@ class TestDatasetServiceUpdateDataset:
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(Exception) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from unittest.mock import Mock, patch
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from models.account import Account
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
@ -35,19 +37,21 @@ class TestMetadataBugCompleteValidation:
|
||||
mock_metadata_args.name = None
|
||||
mock_metadata_args.type = "string"
|
||||
|
||||
with patch("services.metadata_service.current_user") as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch("services.metadata_service.current_user", mock_user):
|
||||
# Should crash with TypeError
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
# Test update method as well
|
||||
with patch("services.metadata_service.current_user") as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch("services.metadata_service.current_user", mock_user):
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
|
||||
|
||||
@ -143,19 +147,17 @@ class TestMetadataBugCompleteValidation:
|
||||
# Console API create
|
||||
console_create_file = "api/controllers/console/datasets/metadata.py"
|
||||
if os.path.exists(console_create_file):
|
||||
with open(console_create_file) as f:
|
||||
content = f.read()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
|
||||
content = Path(console_create_file).read_text()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
|
||||
|
||||
# Service API create
|
||||
service_create_file = "api/controllers/service_api/dataset/metadata.py"
|
||||
if os.path.exists(service_create_file):
|
||||
with open(service_create_file) as f:
|
||||
content = f.read()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
|
||||
assert "nullable=True" not in create_api_section
|
||||
content = Path(service_create_file).read_text()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
|
||||
assert "nullable=True" not in create_api_section
|
||||
|
||||
|
||||
class TestMetadataValidationSummary:
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import reqparse
|
||||
|
||||
from models.account import Account
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
@ -24,20 +25,22 @@ class TestMetadataNullableBug:
|
||||
mock_metadata_args.name = None # This will cause len() to crash
|
||||
mock_metadata_args.type = "string"
|
||||
|
||||
with patch("services.metadata_service.current_user") as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch("services.metadata_service.current_user", mock_user):
|
||||
# This should crash with TypeError when calling len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
def test_metadata_service_update_with_none_name_crashes(self):
|
||||
"""Test that MetadataService.update_metadata_name crashes when name is None."""
|
||||
with patch("services.metadata_service.current_user") as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch("services.metadata_service.current_user", mock_user):
|
||||
# This should crash with TypeError when calling len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
|
||||
@ -81,10 +84,11 @@ class TestMetadataNullableBug:
|
||||
mock_metadata_args.name = None # From args["name"]
|
||||
mock_metadata_args.type = None # From args["type"]
|
||||
|
||||
with patch("services.metadata_service.current_user") as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch("services.metadata_service.current_user", mock_user):
|
||||
# Step 4: Service layer crashes on len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
590
api/tests/unit_tests/services/test_variable_truncator.py
Normal file
590
api/tests/unit_tests/services/test_variable_truncator.py
Normal file
@ -0,0 +1,590 @@
|
||||
"""
|
||||
Comprehensive unit tests for VariableTruncator class based on current implementation.
|
||||
|
||||
This test suite covers all functionality of the current VariableTruncator including:
|
||||
- JSON size calculation for different data types
|
||||
- String, array, and object truncation logic
|
||||
- Segment-based truncation interface
|
||||
- Helper methods for budget-based truncation
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from services.variable_truncator import (
|
||||
MaxDepthExceededError,
|
||||
TruncationResult,
|
||||
UnknownTypeError,
|
||||
VariableTruncator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file() -> File:
|
||||
return File(
|
||||
id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id=str(uuid.uuid4()),
|
||||
filename="test_file.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="initial_key",
|
||||
)
|
||||
|
||||
|
||||
_compact_json_dumps = functools.partial(json.dumps, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestCalculateJsonSize:
|
||||
"""Test calculate_json_size method with different data types."""
|
||||
|
||||
@pytest.fixture
|
||||
def truncator(self):
|
||||
return VariableTruncator()
|
||||
|
||||
def test_string_size_calculation(self):
|
||||
"""Test JSON size calculation for strings."""
|
||||
# Simple ASCII string
|
||||
assert VariableTruncator.calculate_json_size("hello") == 7 # "hello" + 2 quotes
|
||||
|
||||
# Empty string
|
||||
assert VariableTruncator.calculate_json_size("") == 2 # Just quotes
|
||||
|
||||
# Unicode string
|
||||
assert VariableTruncator.calculate_json_size("你好") == 4
|
||||
|
||||
def test_number_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for numbers."""
|
||||
assert truncator.calculate_json_size(123) == 3
|
||||
assert truncator.calculate_json_size(12.34) == 5
|
||||
assert truncator.calculate_json_size(-456) == 4
|
||||
assert truncator.calculate_json_size(0) == 1
|
||||
|
||||
def test_boolean_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for booleans."""
|
||||
assert truncator.calculate_json_size(True) == 4 # "true"
|
||||
assert truncator.calculate_json_size(False) == 5 # "false"
|
||||
|
||||
def test_null_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for None/null."""
|
||||
assert truncator.calculate_json_size(None) == 4 # "null"
|
||||
|
||||
def test_array_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for arrays."""
|
||||
# Empty array
|
||||
assert truncator.calculate_json_size([]) == 2 # "[]"
|
||||
|
||||
# Simple array
|
||||
simple_array = [1, 2, 3]
|
||||
# [1,2,3] = 1 + 1 + 1 + 1 + 1 + 2 = 7 (numbers + commas + brackets)
|
||||
assert truncator.calculate_json_size(simple_array) == 7
|
||||
|
||||
# Array with strings
|
||||
string_array = ["a", "b"]
|
||||
# ["a","b"] = 3 + 3 + 1 + 2 = 9 (quoted strings + comma + brackets)
|
||||
assert truncator.calculate_json_size(string_array) == 9
|
||||
|
||||
def test_object_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for objects."""
|
||||
# Empty object
|
||||
assert truncator.calculate_json_size({}) == 2 # "{}"
|
||||
|
||||
# Simple object
|
||||
simple_obj = {"a": 1}
|
||||
# {"a":1} = 3 + 1 + 1 + 2 = 7 (key + colon + value + brackets)
|
||||
assert truncator.calculate_json_size(simple_obj) == 7
|
||||
|
||||
# Multiple keys
|
||||
multi_obj = {"a": 1, "b": 2}
|
||||
# {"a":1,"b":2} = 3 + 1 + 1 + 1 + 3 + 1 + 1 + 2 = 13
|
||||
assert truncator.calculate_json_size(multi_obj) == 13
|
||||
|
||||
def test_nested_structure_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for nested structures."""
|
||||
nested = {"items": [1, 2, {"nested": "value"}]}
|
||||
size = truncator.calculate_json_size(nested)
|
||||
assert size > 0 # Should calculate without error
|
||||
|
||||
# Verify it matches actual JSON length roughly
|
||||
|
||||
actual_json = _compact_json_dumps(nested)
|
||||
# Should be close but not exact due to UTF-8 encoding considerations
|
||||
assert abs(size - len(actual_json.encode())) <= 5
|
||||
|
||||
def test_calculate_json_size_max_depth_exceeded(self, truncator):
|
||||
"""Test that calculate_json_size handles deep nesting gracefully."""
|
||||
# Create deeply nested structure
|
||||
nested: dict[str, Any] = {"level": 0}
|
||||
current = nested
|
||||
for i in range(105): # Create deep nesting
|
||||
current["next"] = {"level": i + 1}
|
||||
current = current["next"]
|
||||
|
||||
# Should either raise an error or handle gracefully
|
||||
with pytest.raises(MaxDepthExceededError):
|
||||
truncator.calculate_json_size(nested)
|
||||
|
||||
def test_calculate_json_size_unknown_type(self, truncator):
|
||||
"""Test that calculate_json_size raises error for unknown types."""
|
||||
|
||||
class CustomType:
|
||||
pass
|
||||
|
||||
with pytest.raises(UnknownTypeError):
|
||||
truncator.calculate_json_size(CustomType())
|
||||
|
||||
|
||||
class TestStringTruncation:
|
||||
LENGTH_LIMIT = 10
|
||||
"""Test string truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(string_length_limit=10)
|
||||
|
||||
def test_short_string_no_truncation(self, small_truncator):
|
||||
"""Test that short strings are not truncated."""
|
||||
short_str = "hello"
|
||||
result = small_truncator._truncate_string(short_str, self.LENGTH_LIMIT)
|
||||
assert result.value == short_str
|
||||
assert result.truncated is False
|
||||
assert result.value_size == VariableTruncator.calculate_json_size(short_str)
|
||||
|
||||
def test_long_string_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that long strings are truncated with ellipsis."""
|
||||
long_str = "this is a very long string that exceeds the limit"
|
||||
result = small_truncator._truncate_string(long_str, self.LENGTH_LIMIT)
|
||||
|
||||
assert result.truncated is True
|
||||
assert result.value == long_str[:5] + "..."
|
||||
assert result.value_size == 10 # 10 chars + "..."
|
||||
|
||||
def test_exact_limit_string(self, small_truncator: VariableTruncator):
|
||||
"""Test string exactly at limit."""
|
||||
exact_str = "1234567890" # Exactly 10 chars
|
||||
result = small_truncator._truncate_string(exact_str, self.LENGTH_LIMIT)
|
||||
assert result.value == "12345..."
|
||||
assert result.truncated is True
|
||||
assert result.value_size == 10
|
||||
|
||||
|
||||
class TestArrayTruncation:
|
||||
"""Test array truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(array_element_limit=3, max_size_bytes=100)
|
||||
|
||||
def test_small_array_no_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that small arrays are not truncated."""
|
||||
small_array = [1, 2]
|
||||
result = small_truncator._truncate_array(small_array, 1000)
|
||||
assert result.value == small_array
|
||||
assert result.truncated is False
|
||||
|
||||
def test_array_element_limit_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that arrays over element limit are truncated."""
|
||||
large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
|
||||
result = small_truncator._truncate_array(large_array, 1000)
|
||||
|
||||
assert result.truncated is True
|
||||
assert result.value == [1, 2, 3]
|
||||
|
||||
def test_array_size_budget_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test array truncation due to size budget constraints."""
|
||||
# Create array with strings that will exceed size budget
|
||||
large_strings = ["very long string " * 5, "another long string " * 5]
|
||||
result = small_truncator._truncate_array(large_strings, 50)
|
||||
|
||||
assert result.truncated is True
|
||||
# Should have truncated the strings within the array
|
||||
for item in result.value:
|
||||
assert isinstance(item, str)
|
||||
assert VariableTruncator.calculate_json_size(result.value) <= 50
|
||||
|
||||
def test_array_with_nested_objects(self, small_truncator):
|
||||
"""Test array truncation with nested objects."""
|
||||
nested_array = [
|
||||
{"name": "item1", "data": "some data"},
|
||||
{"name": "item2", "data": "more data"},
|
||||
{"name": "item3", "data": "even more data"},
|
||||
]
|
||||
result = small_truncator._truncate_array(nested_array, 30)
|
||||
|
||||
assert isinstance(result.value, list)
|
||||
assert len(result.value) <= 3
|
||||
for item in result.value:
|
||||
assert isinstance(item, dict)
|
||||
|
||||
|
||||
class TestObjectTruncation:
|
||||
"""Test object truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(max_size_bytes=100)
|
||||
|
||||
def test_small_object_no_truncation(self, small_truncator):
|
||||
"""Test that small objects are not truncated."""
|
||||
small_obj = {"a": 1, "b": 2}
|
||||
result = small_truncator._truncate_object(small_obj, 1000)
|
||||
assert result.value == small_obj
|
||||
assert result.truncated is False
|
||||
|
||||
def test_empty_object_no_truncation(self, small_truncator):
|
||||
"""Test that empty objects are not truncated."""
|
||||
empty_obj = {}
|
||||
result = small_truncator._truncate_object(empty_obj, 100)
|
||||
assert result.value == empty_obj
|
||||
assert result.truncated is False
|
||||
|
||||
def test_object_value_truncation(self, small_truncator):
|
||||
"""Test object truncation where values are truncated to fit budget."""
|
||||
obj_with_long_values = {
|
||||
"key1": "very long string " * 10,
|
||||
"key2": "another long string " * 10,
|
||||
"key3": "third long string " * 10,
|
||||
}
|
||||
result = small_truncator._truncate_object(obj_with_long_values, 80)
|
||||
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
assert set(result.value.keys()).issubset(obj_with_long_values.keys())
|
||||
|
||||
# Values should be truncated if they exist
|
||||
for key, value in result.value.items():
|
||||
if isinstance(value, str):
|
||||
original_value = obj_with_long_values[key]
|
||||
# Value should be same or smaller
|
||||
assert len(value) <= len(original_value)
|
||||
|
||||
def test_object_key_dropping(self, small_truncator):
|
||||
"""Test object truncation where keys are dropped due to size constraints."""
|
||||
large_obj = {f"key{i:02d}": f"value{i}" for i in range(20)}
|
||||
result = small_truncator._truncate_object(large_obj, 50)
|
||||
|
||||
assert result.truncated is True
|
||||
assert len(result.value) < len(large_obj)
|
||||
|
||||
# Should maintain sorted key order
|
||||
result_keys = list(result.value.keys())
|
||||
assert result_keys == sorted(result_keys)
|
||||
|
||||
def test_object_with_nested_structures(self, small_truncator):
|
||||
"""Test object truncation with nested arrays and objects."""
|
||||
nested_obj = {"simple": "value", "array": [1, 2, 3, 4, 5], "nested": {"inner": "data", "more": ["a", "b", "c"]}}
|
||||
result = small_truncator._truncate_object(nested_obj, 60)
|
||||
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
|
||||
class TestSegmentBasedTruncation:
|
||||
"""Test the main truncate method that works with Segments."""
|
||||
|
||||
@pytest.fixture
|
||||
def truncator(self):
|
||||
return VariableTruncator()
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(string_length_limit=20, array_element_limit=3, max_size_bytes=200)
|
||||
|
||||
def test_integer_segment_no_truncation(self, truncator):
|
||||
"""Test that integer segments are never truncated."""
|
||||
segment = IntegerSegment(value=12345)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_boolean_as_integer_segment(self, truncator):
|
||||
"""Test boolean values in IntegerSegment are converted to int."""
|
||||
segment = IntegerSegment(value=True)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert isinstance(result.result, IntegerSegment)
|
||||
assert result.result.value == 1 # True converted to 1
|
||||
|
||||
def test_float_segment_no_truncation(self, truncator):
|
||||
"""Test that float segments are never truncated."""
|
||||
segment = FloatSegment(value=123.456)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_none_segment_no_truncation(self, truncator):
|
||||
"""Test that None segments are never truncated."""
|
||||
segment = NoneSegment()
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_file_segment_no_truncation(self, truncator, file):
|
||||
"""Test that file segments are never truncated."""
|
||||
file_segment = FileSegment(value=file)
|
||||
result = truncator.truncate(file_segment)
|
||||
assert result.result == file_segment
|
||||
assert result.truncated is False
|
||||
|
||||
def test_array_file_segment_no_truncation(self, truncator, file):
|
||||
"""Test that array file segments are never truncated."""
|
||||
|
||||
array_file_segment = ArrayFileSegment(value=[file] * 20)
|
||||
result = truncator.truncate(array_file_segment)
|
||||
assert result.result == array_file_segment
|
||||
assert result.truncated is False
|
||||
|
||||
def test_string_segment_small_no_truncation(self, truncator):
|
||||
"""Test small string segments are not truncated."""
|
||||
segment = StringSegment(value="hello world")
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_string_segment_large_truncation(self, small_truncator):
|
||||
"""Test large string segments are truncated."""
|
||||
long_text = "this is a very long string that will definitely exceed the limit"
|
||||
segment = StringSegment(value=long_text)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert len(result.result.value) < len(long_text)
|
||||
assert result.result.value.endswith("...")
|
||||
|
||||
def test_array_segment_small_no_truncation(self, truncator):
|
||||
"""Test small array segments are not truncated."""
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
segment = build_segment([1, 2, 3])
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_array_segment_large_truncation(self, small_truncator):
|
||||
"""Test large array segments are truncated."""
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
large_array = list(range(10)) # Exceeds element limit of 3
|
||||
segment = build_segment(large_array)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, ArraySegment)
|
||||
assert len(result.result.value) <= 3
|
||||
|
||||
def test_object_segment_small_no_truncation(self, truncator):
|
||||
"""Test small object segments are not truncated."""
|
||||
segment = ObjectSegment(value={"key": "value"})
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_object_segment_large_truncation(self, small_truncator):
|
||||
"""Test large object segments are truncated."""
|
||||
large_obj = {f"key{i}": f"very long value {i}" * 5 for i in range(5)}
|
||||
segment = ObjectSegment(value=large_obj)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, ObjectSegment)
|
||||
# Object should be smaller or equal than original
|
||||
original_size = small_truncator.calculate_json_size(large_obj)
|
||||
result_size = small_truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
|
||||
def test_final_size_fallback_to_json_string(self, small_truncator):
|
||||
"""Test final fallback when truncated result still exceeds size limit."""
|
||||
# Create data that will still be large after initial truncation
|
||||
large_nested_data = {"data": ["very long string " * 5] * 5, "more": {"nested": "content " * 20}}
|
||||
segment = ObjectSegment(value=large_nested_data)
|
||||
|
||||
# Use very small limit to force JSON string fallback
|
||||
tiny_truncator = VariableTruncator(max_size_bytes=50)
|
||||
result = tiny_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
# Should be JSON string with possible truncation
|
||||
assert len(result.result.value) <= 53 # 50 + "..." = 53
|
||||
|
||||
def test_final_size_fallback_string_truncation(self, small_truncator):
|
||||
"""Test final fallback for string that still exceeds limit."""
|
||||
# Create very long string that exceeds string length limit
|
||||
very_long_string = "x" * 6000 # Exceeds default string_length_limit of 5000
|
||||
segment = StringSegment(value=very_long_string)
|
||||
|
||||
# Use small limit to test string fallback path
|
||||
tiny_truncator = VariableTruncator(string_length_limit=100, max_size_bytes=50)
|
||||
result = tiny_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
# Should be truncated due to string limit or final size limit
|
||||
assert len(result.result.value) <= 1000 # Much smaller than original
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""Test truncator with empty inputs."""
|
||||
truncator = VariableTruncator()
|
||||
|
||||
# Empty string
|
||||
result = truncator.truncate(StringSegment(value=""))
|
||||
assert not result.truncated
|
||||
assert result.result.value == ""
|
||||
|
||||
# Empty array
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
result = truncator.truncate(build_segment([]))
|
||||
assert not result.truncated
|
||||
assert result.result.value == []
|
||||
|
||||
# Empty object
|
||||
result = truncator.truncate(ObjectSegment(value={}))
|
||||
assert not result.truncated
|
||||
assert result.result.value == {}
|
||||
|
||||
def test_zero_and_negative_limits(self):
|
||||
"""Test truncator behavior with zero or very small limits."""
|
||||
# Zero string limit
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(string_length_limit=3)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(array_element_limit=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(max_size_bytes=0)
|
||||
|
||||
def test_unicode_and_special_characters(self):
|
||||
"""Test truncator with unicode and special characters."""
|
||||
truncator = VariableTruncator(string_length_limit=10)
|
||||
|
||||
# Unicode characters
|
||||
unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character
|
||||
result = truncator.truncate(StringSegment(value=unicode_text))
|
||||
if len(unicode_text) > 10:
|
||||
assert result.truncated is True
|
||||
|
||||
# Special JSON characters
|
||||
special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}'
|
||||
result = truncator.truncate(StringSegment(value=special_chars))
|
||||
assert isinstance(result.result, StringSegment)
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test realistic integration scenarios."""
|
||||
|
||||
def test_workflow_output_scenario(self):
|
||||
"""Test truncation of typical workflow output data."""
|
||||
truncator = VariableTruncator()
|
||||
|
||||
workflow_data = {
|
||||
"result": "success",
|
||||
"data": {
|
||||
"users": [
|
||||
{"id": 1, "name": "Alice", "email": "alice@example.com"},
|
||||
{"id": 2, "name": "Bob", "email": "bob@example.com"},
|
||||
]
|
||||
* 3, # Multiply to make it larger
|
||||
"metadata": {
|
||||
"count": 6,
|
||||
"processing_time": "1.23s",
|
||||
"details": "x" * 200, # Long string but not too long
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
segment = ObjectSegment(value=workflow_data)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert isinstance(result.result, (ObjectSegment, StringSegment))
|
||||
# Should handle complex nested structure appropriately
|
||||
|
||||
def test_large_text_processing_scenario(self):
|
||||
"""Test truncation of large text data."""
|
||||
truncator = VariableTruncator(string_length_limit=100)
|
||||
|
||||
large_text = "This is a very long text document. " * 20 # Make it larger than limit
|
||||
|
||||
segment = StringSegment(value=large_text)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert len(result.result.value) <= 103 # 100 + "..."
|
||||
assert result.result.value.endswith("...")
|
||||
|
||||
def test_mixed_data_types_scenario(self):
|
||||
"""Test truncation with mixed data types in complex structure."""
|
||||
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
|
||||
|
||||
mixed_data = {
|
||||
"strings": ["short", "medium length", "very long string " * 3],
|
||||
"numbers": [1, 2.5, 999999],
|
||||
"booleans": [True, False, True],
|
||||
"nested": {
|
||||
"more_strings": ["nested string " * 2],
|
||||
"more_numbers": list(range(5)),
|
||||
"deep": {"level": 3, "content": "deep content " * 3},
|
||||
},
|
||||
"nulls": [None, None],
|
||||
}
|
||||
|
||||
segment = ObjectSegment(value=mixed_data)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
# Should handle all data types appropriately
|
||||
if result.truncated:
|
||||
# Verify the result is smaller or equal than original
|
||||
original_size = truncator.calculate_json_size(mixed_data)
|
||||
if isinstance(result.result, ObjectSegment):
|
||||
result_size = truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
212
api/tests/unit_tests/services/tools/test_mcp_tools_transform.py
Normal file
212
api/tests/unit_tests/services/tools/test_mcp_tools_transform.py
Normal file
@ -0,0 +1,212 @@
|
||||
"""Test cases for MCP tool transformation functionality."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Provides a mock user object."""
|
||||
user = Mock()
|
||||
user.name = "Test User"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(mock_user):
|
||||
"""Provides a mock MCPToolProvider with a loaded user."""
|
||||
provider = Mock(spec=MCPToolProvider)
|
||||
provider.load_user.return_value = mock_user
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_no_user():
|
||||
"""Provides a mock MCPToolProvider with no user."""
|
||||
provider = Mock(spec=MCPToolProvider)
|
||||
provider.load_user.return_value = None
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_full(mock_user):
|
||||
"""Provides a fully configured mock MCPToolProvider for detailed tests."""
|
||||
provider = Mock(spec=MCPToolProvider)
|
||||
provider.id = "provider-id-123"
|
||||
provider.server_identifier = "server-identifier-456"
|
||||
provider.name = "Test MCP Provider"
|
||||
provider.provider_icon = "icon.png"
|
||||
provider.authed = True
|
||||
provider.masked_server_url = "https://*****.com/mcp"
|
||||
provider.timeout = 30
|
||||
provider.sse_read_timeout = 300
|
||||
provider.masked_headers = {"Authorization": "Bearer *****"}
|
||||
provider.decrypted_headers = {"Authorization": "Bearer secret-token"}
|
||||
|
||||
# Mock timestamp
|
||||
mock_updated_at = Mock()
|
||||
mock_updated_at.timestamp.return_value = 1234567890
|
||||
provider.updated_at = mock_updated_at
|
||||
|
||||
provider.load_user.return_value = mock_user
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_mcp_tools():
|
||||
"""Provides sample MCP tools for testing."""
|
||||
return {
|
||||
"simple": MCPTool(
|
||||
name="simple_tool", description="A simple test tool", inputSchema={"type": "object", "properties": {}}
|
||||
),
|
||||
"none_desc": MCPTool(name="tool_none_desc", description=None, inputSchema={"type": "object", "properties": {}}),
|
||||
"complex": MCPTool(
|
||||
name="complex_tool",
|
||||
description="A tool with complex parameters",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "Input text"},
|
||||
"count": {"type": "integer", "description": "Number of items", "minimum": 1, "maximum": 100},
|
||||
"options": {"type": "array", "items": {"type": "string"}, "description": "List of options"},
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestMCPToolTransform:
|
||||
"""Test cases for MCP tool transformation methods."""
|
||||
|
||||
def test_mcp_tool_to_user_tool_with_none_description(self, mock_provider):
|
||||
"""Test that mcp_tool_to_user_tool handles None description correctly."""
|
||||
# Create MCP tools with None description
|
||||
tools = [
|
||||
MCPTool(
|
||||
name="tool1",
|
||||
description=None, # This is the case that caused the error
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
MCPTool(
|
||||
name="tool2",
|
||||
description=None,
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"param1": {"type": "string", "description": "A parameter"}},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider, tools)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(tool, ToolApiEntity) for tool in result)
|
||||
|
||||
# Check first tool
|
||||
assert result[0].name == "tool1"
|
||||
assert result[0].author == "Test User"
|
||||
assert isinstance(result[0].label, I18nObject)
|
||||
assert result[0].label.en_US == "tool1"
|
||||
assert isinstance(result[0].description, I18nObject)
|
||||
assert result[0].description.en_US == "" # Should be empty string, not None
|
||||
assert result[0].description.zh_Hans == ""
|
||||
|
||||
# Check second tool
|
||||
assert result[1].name == "tool2"
|
||||
assert result[1].description.en_US == ""
|
||||
assert result[1].description.zh_Hans == ""
|
||||
|
||||
def test_mcp_tool_to_user_tool_with_description(self, mock_provider):
|
||||
"""Test that mcp_tool_to_user_tool handles normal description correctly."""
|
||||
# Create MCP tools with description
|
||||
tools = [
|
||||
MCPTool(
|
||||
name="tool_with_desc",
|
||||
description="This is a test tool that does something useful",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
]
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider, tools)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ToolApiEntity)
|
||||
assert result[0].name == "tool_with_desc"
|
||||
assert result[0].description.en_US == "This is a test tool that does something useful"
|
||||
assert result[0].description.zh_Hans == "This is a test tool that does something useful"
|
||||
|
||||
def test_mcp_tool_to_user_tool_with_no_user(self, mock_provider_no_user):
|
||||
"""Test that mcp_tool_to_user_tool handles None user correctly."""
|
||||
# Create MCP tool
|
||||
tools = [MCPTool(name="tool1", description="Test tool", inputSchema={"type": "object", "properties": {}})]
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider_no_user, tools)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 1
|
||||
assert result[0].author == "Anonymous"
|
||||
|
||||
def test_mcp_tool_to_user_tool_with_complex_schema(self, mock_provider, sample_mcp_tools):
|
||||
"""Test that mcp_tool_to_user_tool correctly converts complex input schemas."""
|
||||
# Use complex tool from fixtures
|
||||
tools = [sample_mcp_tools["complex"]]
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider, tools)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "complex_tool"
|
||||
assert result[0].parameters is not None
|
||||
# The actual parameter conversion is handled by convert_mcp_schema_to_parameter
|
||||
# which should be tested separately
|
||||
|
||||
def test_mcp_provider_to_user_provider_for_list(self, mock_provider_full):
|
||||
"""Test mcp_provider_to_user_provider with for_list=True."""
|
||||
# Set tools data with null description
|
||||
mock_provider_full.tools = '[{"name": "tool1", "description": null, "inputSchema": {}}]'
|
||||
|
||||
# Call the method with for_list=True
|
||||
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=True)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == "provider-id-123" # Should use provider.id when for_list=True
|
||||
assert result.name == "Test MCP Provider"
|
||||
assert result.type == ToolProviderType.MCP
|
||||
assert result.is_team_authorization is True
|
||||
assert result.server_url == "https://*****.com/mcp"
|
||||
assert len(result.tools) == 1
|
||||
assert result.tools[0].description.en_US == "" # Should handle None description
|
||||
|
||||
def test_mcp_provider_to_user_provider_not_for_list(self, mock_provider_full):
|
||||
"""Test mcp_provider_to_user_provider with for_list=False."""
|
||||
# Set tools data with description
|
||||
mock_provider_full.tools = '[{"name": "tool1", "description": "Tool description", "inputSchema": {}}]'
|
||||
|
||||
# Call the method with for_list=False
|
||||
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=False)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False
|
||||
assert result.server_identifier == "server-identifier-456"
|
||||
assert result.timeout == 30
|
||||
assert result.sse_read_timeout == 300
|
||||
assert result.original_headers == {"Authorization": "Bearer secret-token"}
|
||||
assert len(result.tools) == 1
|
||||
assert result.tools[0].description.en_US == "Tool description"
|
||||
@ -0,0 +1,377 @@
|
||||
"""Simplified unit tests for DraftVarLoader focusing on core functionality."""
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.variables.segments import ObjectSegment, StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from services.workflow_draft_variable_service import DraftVarLoader
|
||||
|
||||
|
||||
class TestDraftVarLoaderSimple:
|
||||
"""Simplified unit tests for DraftVarLoader core methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine(self) -> Engine:
|
||||
return Mock(spec=Engine)
|
||||
|
||||
@pytest.fixture
|
||||
def draft_var_loader(self, mock_engine):
|
||||
"""Create DraftVarLoader instance for testing."""
|
||||
return DraftVarLoader(
|
||||
engine=mock_engine, app_id="test-app-id", tenant_id="test-tenant-id", fallback_variables=[]
|
||||
)
|
||||
|
||||
def test_load_offloaded_variable_string_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with string type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test.txt"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.STRING
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_variable"
|
||||
draft_var.description = "test description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_variable"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_content = "This is the full string content"
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_content.encode()
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_variable"
|
||||
mock_variable.value = StringSegment(value=test_content)
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_variable")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_variable"
|
||||
assert variable.description == "test description"
|
||||
assert variable.value == test_content
|
||||
|
||||
# Verify storage was called correctly
|
||||
mock_storage.load.assert_called_once_with("storage/key/test.txt")
|
||||
|
||||
def test_load_offloaded_variable_object_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with object type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.OBJECT
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_object"
|
||||
draft_var.description = "test description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_object"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_object = {"key1": "value1", "key2": 42}
|
||||
test_json_content = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
mock_segment = ObjectSegment(value=test_object)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_object"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_object")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_object"
|
||||
assert variable.description == "test description"
|
||||
assert variable.value == test_object
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object)
|
||||
|
||||
def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader):
|
||||
"""Test that assertion error is raised when variable_file is None."""
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.variable_file = None
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
def test_load_offloaded_variable_missing_upload_file_unit(self, draft_var_loader):
|
||||
"""Test that assertion error is raised when upload_file is None."""
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.upload_file = None
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
def test_load_variables_empty_selectors_unit(self, draft_var_loader):
|
||||
"""Test load_variables returns empty list for empty selectors."""
|
||||
result = draft_var_loader.load_variables([])
|
||||
assert result == []
|
||||
|
||||
def test_selector_to_tuple_unit(self, draft_var_loader):
|
||||
"""Test _selector_to_tuple method."""
|
||||
selector = ["node_id", "var_name", "extra_field"]
|
||||
result = draft_var_loader._selector_to_tuple(selector)
|
||||
assert result == ("node_id", "var_name")
|
||||
|
||||
def test_load_offloaded_variable_number_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with number type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test_number.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.NUMBER
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_number"
|
||||
draft_var.description = "test number description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_number"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_number = 123.45
|
||||
test_json_content = json.dumps(test_number)
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import FloatSegment
|
||||
|
||||
mock_segment = FloatSegment(value=test_number)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_number"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_number")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_number"
|
||||
assert variable.description == "test number description"
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test_number.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number)
|
||||
|
||||
def test_load_offloaded_variable_array_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with array type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test_array.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.ARRAY_ANY
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_array"
|
||||
draft_var.description = "test array description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_array"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_array = ["item1", "item2", "item3"]
|
||||
test_json_content = json.dumps(test_array)
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
|
||||
mock_segment = ArrayAnySegment(value=test_array)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_array"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_array")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_array"
|
||||
assert variable.description == "test array description"
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test_array.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array)
|
||||
|
||||
def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader):
|
||||
"""Test load_variables method with mix of regular and offloaded variables."""
|
||||
selectors = [["node1", "regular_var"], ["node2", "offloaded_var"]]
|
||||
|
||||
# Mock regular variable
|
||||
regular_draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
regular_draft_var.is_truncated.return_value = False
|
||||
regular_draft_var.node_id = "node1"
|
||||
regular_draft_var.name = "regular_var"
|
||||
regular_draft_var.get_value.return_value = StringSegment(value="regular_value")
|
||||
regular_draft_var.get_selector.return_value = ["node1", "regular_var"]
|
||||
regular_draft_var.id = "regular-var-id"
|
||||
regular_draft_var.description = "regular description"
|
||||
|
||||
# Mock offloaded variable
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/offloaded.txt"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.STRING
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
offloaded_draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_draft_var.is_truncated.return_value = True
|
||||
offloaded_draft_var.node_id = "node2"
|
||||
offloaded_draft_var.name = "offloaded_var"
|
||||
offloaded_draft_var.get_selector.return_value = ["node2", "offloaded_var"]
|
||||
offloaded_draft_var.variable_file = variable_file
|
||||
offloaded_draft_var.id = "offloaded-var-id"
|
||||
offloaded_draft_var.description = "offloaded description"
|
||||
|
||||
draft_vars = [regular_draft_var, offloaded_draft_var]
|
||||
|
||||
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
|
||||
mock_session = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
|
||||
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
|
||||
):
|
||||
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
# Mock regular variable creation
|
||||
regular_variable = Mock()
|
||||
regular_variable.selector = ["node1", "regular_var"]
|
||||
|
||||
# Mock offloaded variable creation
|
||||
offloaded_variable = Mock()
|
||||
offloaded_variable.selector = ["node2", "offloaded_var"]
|
||||
|
||||
mock_segment_to_variable.return_value = regular_variable
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"offloaded_content"
|
||||
|
||||
with patch.object(draft_var_loader, "_load_offloaded_variable") as mock_load_offloaded:
|
||||
mock_load_offloaded.return_value = (("node2", "offloaded_var"), offloaded_variable)
|
||||
|
||||
with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_cls:
|
||||
mock_executor = Mock()
|
||||
mock_executor_cls.return_value.__enter__.return_value = mock_executor
|
||||
mock_executor.map.return_value = [(("node2", "offloaded_var"), offloaded_variable)]
|
||||
|
||||
# Execute the method
|
||||
result = draft_var_loader.load_variables(selectors)
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify service method was called
|
||||
mock_service.get_draft_variables_by_selectors.assert_called_once_with(
|
||||
draft_var_loader._app_id, selectors
|
||||
)
|
||||
|
||||
# Verify offloaded variable loading was called
|
||||
mock_load_offloaded.assert_called_once_with(offloaded_draft_var)
|
||||
|
||||
def test_load_variables_all_offloaded_variables_unit(self, draft_var_loader):
|
||||
"""Test load_variables method with only offloaded variables."""
|
||||
selectors = [["node1", "offloaded_var1"], ["node2", "offloaded_var2"]]
|
||||
|
||||
# Mock first offloaded variable
|
||||
offloaded_var1 = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_var1.is_truncated.return_value = True
|
||||
offloaded_var1.node_id = "node1"
|
||||
offloaded_var1.name = "offloaded_var1"
|
||||
|
||||
# Mock second offloaded variable
|
||||
offloaded_var2 = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_var2.is_truncated.return_value = True
|
||||
offloaded_var2.node_id = "node2"
|
||||
offloaded_var2.name = "offloaded_var2"
|
||||
|
||||
draft_vars = [offloaded_var1, offloaded_var2]
|
||||
|
||||
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
|
||||
mock_session = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
|
||||
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
|
||||
):
|
||||
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
|
||||
with patch("services.workflow_draft_variable_service.ThreadPoolExecutor") as mock_executor_cls:
|
||||
mock_executor = Mock()
|
||||
mock_executor_cls.return_value.__enter__.return_value = mock_executor
|
||||
mock_executor.map.return_value = [
|
||||
(("node1", "offloaded_var1"), Mock()),
|
||||
(("node2", "offloaded_var2"), Mock()),
|
||||
]
|
||||
|
||||
# Execute the method
|
||||
result = draft_var_loader.load_variables(selectors)
|
||||
|
||||
# Verify results - since we have only offloaded variables, should have 2 results
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify ThreadPoolExecutor was used
|
||||
mock_executor_cls.assert_called_once_with(max_workers=10)
|
||||
mock_executor.map.assert_called_once()
|
||||
@ -66,7 +66,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.CHAT.value
|
||||
app_model.mode = AppMode.CHAT
|
||||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
@ -127,7 +127,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.WORKFLOW.value
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
|
||||
@ -1,16 +1,26 @@
|
||||
import dataclasses
|
||||
import secrets
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import StringSegment
|
||||
from core.variables.segments import StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDraftVariable,
|
||||
WorkflowDraftVariableFile,
|
||||
WorkflowNodeExecutionModel,
|
||||
is_system_variable_editable,
|
||||
)
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
VariableResetError,
|
||||
@ -37,6 +47,7 @@ class TestDraftVariableSaver:
|
||||
|
||||
def test__should_variable_be_visible(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = Account(id=str(uuid.uuid4()))
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
@ -44,6 +55,7 @@ class TestDraftVariableSaver:
|
||||
node_id="test_node_id",
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="test_execution_id",
|
||||
user=mock_user,
|
||||
)
|
||||
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
|
||||
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
|
||||
@ -83,6 +95,7 @@ class TestDraftVariableSaver:
|
||||
]
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = MagicMock()
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
@ -90,6 +103,7 @@ class TestDraftVariableSaver:
|
||||
node_id=_NODE_ID,
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="test_execution_id",
|
||||
user=mock_user,
|
||||
)
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||
@ -97,6 +111,76 @@ class TestDraftVariableSaver:
|
||||
assert node_id == c.expected_node_id, fail_msg
|
||||
assert name == c.expected_name, fail_msg
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Mock SQLAlchemy session."""
|
||||
from sqlalchemy import Engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_engine = MagicMock(spec=Engine)
|
||||
mock_session.get_bind.return_value = mock_engine
|
||||
return mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def draft_saver(self, mock_session):
|
||||
"""Create DraftVariableSaver instance with user context."""
|
||||
# Create a mock user
|
||||
mock_user = MagicMock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.tenant_id = "test-tenant-id"
|
||||
|
||||
return DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id="test-app-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
node_execution_id="test-execution-id",
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
) as _mock_try_offload:
|
||||
_mock_try_offload.return_value = None
|
||||
mock_segment = StringSegment(value="small value")
|
||||
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
|
||||
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id is None
|
||||
_mock_try_offload.return_value = None
|
||||
|
||||
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
) as _mock_try_offload:
|
||||
mock_segment = StringSegment(value="small value")
|
||||
mock_draft_var_file = WorkflowDraftVariableFile(
|
||||
id=str(uuidv7()),
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.ARRAY_STRING,
|
||||
upload_file_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
_mock_try_offload.return_value = mock_segment, mock_draft_var_file
|
||||
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
|
||||
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id == mock_draft_var_file.id
|
||||
|
||||
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
|
||||
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
|
||||
"""Test complete save workflow."""
|
||||
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
|
||||
|
||||
draft_saver.save(outputs=outputs)
|
||||
|
||||
# Should batch upsert draft variables
|
||||
mock_batch_upsert.assert_called_once()
|
||||
draft_vars = mock_batch_upsert.call_args[0][1]
|
||||
assert len(draft_vars) == 2
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
def _get_test_app_id(self):
|
||||
@ -115,6 +199,7 @@ class TestWorkflowDraftVariableService:
|
||||
created_by="test_user_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
|
||||
def test_reset_conversation_variable(self, mock_session):
|
||||
@ -225,7 +310,7 @@ class TestWorkflowDraftVariableService:
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"test_var": "output_value"}
|
||||
mock_execution.load_full_outputs.return_value = {"test_var": "output_value"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
@ -298,7 +383,7 @@ class TestWorkflowDraftVariableService:
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.files": "[]"}
|
||||
mock_execution.load_full_outputs.return_value = {"sys.files": "[]"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
@ -330,7 +415,7 @@ class TestWorkflowDraftVariableService:
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.query": "reset query"}
|
||||
mock_execution.load_full_outputs.return_value = {"sys.query": "reset query"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
|
||||
Reference in New Issue
Block a user