Preserve auth service interface compatibility

This commit is contained in:
Yanli 盐粒
2026-03-18 23:34:24 +08:00
parent a4dbb76d3a
commit 7ff470d8a0
3 changed files with 81 additions and 79 deletions

View File

@ -1,19 +1,15 @@
from abc import ABC, abstractmethod
from typing import Annotated, NotRequired
from pydantic import StringConstraints
from typing_extensions import TypedDict
NonEmptyString = Annotated[str, StringConstraints(min_length=1)]
class ApiKeyAuthConfig(TypedDict):
api_key: NotRequired[str]
base_url: NotRequired[str]
class ApiKeyAuthConfig(TypedDict, total=False):
api_key: str
base_url: str
class ApiKeyAuthCredentials(TypedDict):
auth_type: NonEmptyString
auth_type: object
config: ApiKeyAuthConfig

View File

@ -1,8 +1,6 @@
import json
from collections.abc import Mapping
from typing import Annotated, TypeVar
from pydantic import StringConstraints, TypeAdapter
from pydantic import TypeAdapter
from sqlalchemy import select
from typing_extensions import TypedDict
@ -11,20 +9,16 @@ from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_base import ApiKeyAuthCredentials
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.auth_type import AuthProvider
NonEmptyString = Annotated[str, StringConstraints(min_length=1)]
ValidatedPayload = TypeVar("ValidatedPayload")
class ApiKeyAuthCreateArgs(TypedDict):
category: NonEmptyString
provider: NonEmptyString
category: str
provider: str
credentials: ApiKeyAuthCredentials
AUTH_CREDENTIALS_ADAPTER = TypeAdapter(ApiKeyAuthCredentials)
AUTH_CREATE_ARGS_ADAPTER = TypeAdapter(ApiKeyAuthCreateArgs)
AUTH_CREDENTIALS_ADAPTER = TypeAdapter(dict[str, object])
class ApiKeyAuthService:
@ -37,30 +31,32 @@ class ApiKeyAuthService:
).all()
return list(data_source_api_key_bindings)
@classmethod
def create_provider_auth(cls, tenant_id: str, args: Mapping[str, object] | ApiKeyAuthCreateArgs) -> None:
validated_args = cls.validate_api_key_auth_args(args)
@staticmethod
def create_provider_auth(tenant_id: str, args: dict[str, object]) -> None:
validated_args = ApiKeyAuthService.validate_api_key_auth_args(args)
raw_credentials = ApiKeyAuthService._get_credentials_dict(args)
auth_result = ApiKeyAuthFactory(
validated_args["provider"], validated_args["credentials"]
).validate_credentials()
if auth_result:
api_key_value = validated_args["credentials"]["config"].get("api_key")
if api_key_value is None:
raise ValueError("credentials config api_key is required")
encrypted_api_key = encrypter.encrypt_token(tenant_id, api_key_value)
validated_args["credentials"]["config"]["api_key"] = encrypted_api_key
raise KeyError("api_key")
api_key = encrypter.encrypt_token(tenant_id, api_key_value)
raw_config = ApiKeyAuthService._get_config_dict(raw_credentials)
raw_config["api_key"] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant_id,
category=validated_args["category"],
provider=validated_args["provider"],
)
data_source_api_key_binding.credentials = json.dumps(validated_args["credentials"], ensure_ascii=False)
data_source_api_key_binding.credentials = json.dumps(raw_credentials, ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: AuthProvider) -> ApiKeyAuthCredentials | None:
def get_auth_credentials(tenant_id: str, category: str, provider: str) -> dict[str, object] | None:
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(
@ -75,8 +71,8 @@ class ApiKeyAuthService:
return None
if not data_source_api_key_bindings.credentials:
return None
raw_credentials = json.loads(data_source_api_key_bindings.credentials)
return ApiKeyAuthService._validate_credentials_payload(raw_credentials)
credentials = json.loads(data_source_api_key_bindings.credentials)
return AUTH_CREDENTIALS_ADAPTER.validate_python(credentials)
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str) -> None:
@ -89,14 +85,32 @@ class ApiKeyAuthService:
db.session.delete(data_source_api_key_binding)
db.session.commit()
@classmethod
def validate_api_key_auth_args(cls, args: Mapping[str, object] | None) -> ApiKeyAuthCreateArgs:
return cls._validate_payload(AUTH_CREATE_ARGS_ADAPTER, args)
@staticmethod
def validate_api_key_auth_args(args: dict[str, object] | None) -> ApiKeyAuthCreateArgs:
if args is None:
raise TypeError("argument of type 'NoneType' is not iterable")
if "category" not in args or not args["category"]:
raise ValueError("category is required")
if "provider" not in args or not args["provider"]:
raise ValueError("provider is required")
if "credentials" not in args or not args["credentials"]:
raise ValueError("credentials is required")
if not isinstance(args["credentials"], dict):
raise ValueError("credentials must be a dictionary")
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
raise ValueError("auth_type is required")
return AUTH_CREATE_ARGS_ADAPTER.validate_python(args)
@staticmethod
def _validate_credentials_payload(raw_credentials: object) -> ApiKeyAuthCredentials:
return ApiKeyAuthService._validate_payload(AUTH_CREDENTIALS_ADAPTER, raw_credentials)
def _get_credentials_dict(args: dict[str, object]) -> dict[str, object]:
credentials = args["credentials"]
if not isinstance(credentials, dict):
raise ValueError("credentials must be a dictionary")
return credentials
@staticmethod
def _validate_payload(adapter: TypeAdapter[ValidatedPayload], payload: object) -> ValidatedPayload:
return adapter.validate_python(payload)
def _get_config_dict(credentials: dict[str, object]) -> dict[str, object]:
config = credentials["config"]
if not isinstance(config, dict):
raise TypeError("string indices must be integers")
return config

View File

@ -1,8 +1,8 @@
import json
from copy import deepcopy
from unittest.mock import Mock, patch
import pytest
from pydantic import ValidationError
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_service import ApiKeyAuthService
@ -11,14 +11,6 @@ from services.auth.api_key_auth_service import ApiKeyAuthService
class TestApiKeyAuthService:
"""API key authentication service security tests"""
@staticmethod
def _assert_validation_error_loc(
exc_info: pytest.ExceptionInfo[ValidationError], expected_loc: tuple[str, ...]
) -> None:
errors = exc_info.value.errors()
assert len(errors) == 1
assert errors[0]["loc"] == expected_loc
def setup_method(self):
"""Setup test fixtures"""
self.tenant_id = "test_tenant_123"
@ -77,7 +69,16 @@ class TestApiKeyAuthService:
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
captured_provider = None
captured_credentials = None
def factory_side_effect(provider, credentials):
nonlocal captured_provider, captured_credentials
captured_provider = provider
captured_credentials = deepcopy(credentials)
return mock_auth_instance
mock_factory.side_effect = factory_side_effect
# Mock encryption
encrypted_key = "encrypted_test_key_123"
@ -86,12 +87,14 @@ class TestApiKeyAuthService:
# Mock database operations
mock_session.add = Mock()
mock_session.commit = Mock()
expected_credentials = deepcopy(self.mock_credentials)
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
# Verify factory class calls
assert mock_factory.call_count == 1
assert mock_factory.call_args.args[0] == self.provider
assert captured_provider == self.provider
assert captured_credentials == expected_credentials
mock_auth_instance.validate_credentials.assert_called_once()
# Verify encryption calls
@ -120,7 +123,7 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_encrypts_api_key(self, mock_encrypter, mock_factory, mock_session):
"""Test create provider auth - stores encrypted API key"""
"""Test create provider auth - ensures API key is encrypted"""
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
@ -134,18 +137,17 @@ class TestApiKeyAuthService:
mock_session.add = Mock()
mock_session.commit = Mock()
args_copy = {
"category": self.category,
"provider": self.provider,
"credentials": {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}},
}
args_copy = self.mock_args.copy()
original_key = args_copy["credentials"]["config"]["api_key"]
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
stored_binding = mock_session.add.call_args.args[0]
stored_credentials = json.loads(stored_binding.credentials)
assert stored_credentials["config"]["api_key"] == encrypted_key
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, "test_secret_key_123")
# Verify original key is replaced with encrypted key
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
assert args_copy["credentials"]["config"]["api_key"] != original_key
# Verify encryption function is called correctly
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_success(self, mock_session):
@ -241,81 +243,72 @@ class TestApiKeyAuthService:
args = self.mock_args.copy()
del args["category"]
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(ValueError, match="category is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
self._assert_validation_error_loc(exc_info, ("category",))
def test_validate_api_key_auth_args_empty_category(self):
"""Test API key auth args validation - empty category"""
args = self.mock_args.copy()
args["category"] = ""
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(ValueError, match="category is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
self._assert_validation_error_loc(exc_info, ("category",))
def test_validate_api_key_auth_args_missing_provider(self):
"""Test API key auth args validation - missing provider"""
args = self.mock_args.copy()
del args["provider"]
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(ValueError, match="provider is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
self._assert_validation_error_loc(exc_info, ("provider",))
def test_validate_api_key_auth_args_empty_provider(self):
"""Test API key auth args validation - empty provider"""
args = self.mock_args.copy()
args["provider"] = ""
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(ValueError, match="provider is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
self._assert_validation_error_loc(exc_info, ("provider",))
def test_validate_api_key_auth_args_missing_credentials(self):
"""Test API key auth args validation - missing credentials"""
args = self.mock_args.copy()
del args["credentials"]
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
self._assert_validation_error_loc(exc_info, ("credentials",))
def test_validate_api_key_auth_args_empty_credentials(self):
"""Test API key auth args validation - empty credentials"""
args = self.mock_args.copy()
args["credentials"] = None
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
self._assert_validation_error_loc(exc_info, ("credentials",))
def test_validate_api_key_auth_args_invalid_credentials_type(self):
"""Test API key auth args validation - invalid credentials type"""
args = self.mock_args.copy()
args["credentials"] = "not_a_dict"
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(ValueError, match="credentials must be a dictionary"):
ApiKeyAuthService.validate_api_key_auth_args(args)
self._assert_validation_error_loc(exc_info, ("credentials",))
def test_validate_api_key_auth_args_missing_auth_type(self):
"""Test API key auth args validation - missing auth_type"""
args = self.mock_args.copy()
del args["credentials"]["auth_type"]
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
self._assert_validation_error_loc(exc_info, ("credentials", "auth_type"))
def test_validate_api_key_auth_args_empty_auth_type(self):
"""Test API key auth args validation - empty auth_type"""
args = self.mock_args.copy()
args["credentials"]["auth_type"] = ""
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
self._assert_validation_error_loc(exc_info, ("credentials", "auth_type"))
@pytest.mark.parametrize(
"malicious_input",
@ -394,15 +387,14 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_none_input(self):
"""Test API key auth args validation - None input"""
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(TypeError):
ApiKeyAuthService.validate_api_key_auth_args(None)
self._assert_validation_error_loc(exc_info, ())
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
"""Test API key auth args validation - dict credentials with list auth_type"""
args = self.mock_args.copy()
args["credentials"]["auth_type"] = ["api_key"]
with pytest.raises(ValidationError) as exc_info:
ApiKeyAuthService.validate_api_key_auth_args(args)
self._assert_validation_error_loc(exc_info, ("credentials", "auth_type"))
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
# So this should not raise exception, this test should pass
ApiKeyAuthService.validate_api_key_auth_args(args)