mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 18:27:27 +08:00
Preserve auth service interface compatibility
This commit is contained in:
@ -1,19 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Annotated, NotRequired
|
||||
|
||||
from pydantic import StringConstraints
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
NonEmptyString = Annotated[str, StringConstraints(min_length=1)]
|
||||
|
||||
|
||||
class ApiKeyAuthConfig(TypedDict):
|
||||
api_key: NotRequired[str]
|
||||
base_url: NotRequired[str]
|
||||
class ApiKeyAuthConfig(TypedDict, total=False):
|
||||
api_key: str
|
||||
base_url: str
|
||||
|
||||
|
||||
class ApiKeyAuthCredentials(TypedDict):
|
||||
auth_type: NonEmptyString
|
||||
auth_type: object
|
||||
config: ApiKeyAuthConfig
|
||||
|
||||
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Annotated, TypeVar
|
||||
|
||||
from pydantic import StringConstraints, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@ -11,20 +9,16 @@ from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthCredentials
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
from services.auth.auth_type import AuthProvider
|
||||
|
||||
NonEmptyString = Annotated[str, StringConstraints(min_length=1)]
|
||||
ValidatedPayload = TypeVar("ValidatedPayload")
|
||||
|
||||
|
||||
class ApiKeyAuthCreateArgs(TypedDict):
|
||||
category: NonEmptyString
|
||||
provider: NonEmptyString
|
||||
category: str
|
||||
provider: str
|
||||
credentials: ApiKeyAuthCredentials
|
||||
|
||||
|
||||
AUTH_CREDENTIALS_ADAPTER = TypeAdapter(ApiKeyAuthCredentials)
|
||||
AUTH_CREATE_ARGS_ADAPTER = TypeAdapter(ApiKeyAuthCreateArgs)
|
||||
AUTH_CREDENTIALS_ADAPTER = TypeAdapter(dict[str, object])
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
@ -37,30 +31,32 @@ class ApiKeyAuthService:
|
||||
).all()
|
||||
return list(data_source_api_key_bindings)
|
||||
|
||||
@classmethod
|
||||
def create_provider_auth(cls, tenant_id: str, args: Mapping[str, object] | ApiKeyAuthCreateArgs) -> None:
|
||||
validated_args = cls.validate_api_key_auth_args(args)
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict[str, object]) -> None:
|
||||
validated_args = ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
raw_credentials = ApiKeyAuthService._get_credentials_dict(args)
|
||||
auth_result = ApiKeyAuthFactory(
|
||||
validated_args["provider"], validated_args["credentials"]
|
||||
).validate_credentials()
|
||||
if auth_result:
|
||||
api_key_value = validated_args["credentials"]["config"].get("api_key")
|
||||
if api_key_value is None:
|
||||
raise ValueError("credentials config api_key is required")
|
||||
encrypted_api_key = encrypter.encrypt_token(tenant_id, api_key_value)
|
||||
validated_args["credentials"]["config"]["api_key"] = encrypted_api_key
|
||||
raise KeyError("api_key")
|
||||
api_key = encrypter.encrypt_token(tenant_id, api_key_value)
|
||||
raw_config = ApiKeyAuthService._get_config_dict(raw_credentials)
|
||||
raw_config["api_key"] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant_id,
|
||||
category=validated_args["category"],
|
||||
provider=validated_args["provider"],
|
||||
)
|
||||
data_source_api_key_binding.credentials = json.dumps(validated_args["credentials"], ensure_ascii=False)
|
||||
data_source_api_key_binding.credentials = json.dumps(raw_credentials, ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: AuthProvider) -> ApiKeyAuthCredentials | None:
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str) -> dict[str, object] | None:
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(
|
||||
@ -75,8 +71,8 @@ class ApiKeyAuthService:
|
||||
return None
|
||||
if not data_source_api_key_bindings.credentials:
|
||||
return None
|
||||
raw_credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return ApiKeyAuthService._validate_credentials_payload(raw_credentials)
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return AUTH_CREDENTIALS_ADAPTER.validate_python(credentials)
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str) -> None:
|
||||
@ -89,14 +85,32 @@ class ApiKeyAuthService:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args: Mapping[str, object] | None) -> ApiKeyAuthCreateArgs:
|
||||
return cls._validate_payload(AUTH_CREATE_ARGS_ADAPTER, args)
|
||||
@staticmethod
|
||||
def validate_api_key_auth_args(args: dict[str, object] | None) -> ApiKeyAuthCreateArgs:
|
||||
if args is None:
|
||||
raise TypeError("argument of type 'NoneType' is not iterable")
|
||||
if "category" not in args or not args["category"]:
|
||||
raise ValueError("category is required")
|
||||
if "provider" not in args or not args["provider"]:
|
||||
raise ValueError("provider is required")
|
||||
if "credentials" not in args or not args["credentials"]:
|
||||
raise ValueError("credentials is required")
|
||||
if not isinstance(args["credentials"], dict):
|
||||
raise ValueError("credentials must be a dictionary")
|
||||
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
|
||||
raise ValueError("auth_type is required")
|
||||
return AUTH_CREATE_ARGS_ADAPTER.validate_python(args)
|
||||
|
||||
@staticmethod
|
||||
def _validate_credentials_payload(raw_credentials: object) -> ApiKeyAuthCredentials:
|
||||
return ApiKeyAuthService._validate_payload(AUTH_CREDENTIALS_ADAPTER, raw_credentials)
|
||||
def _get_credentials_dict(args: dict[str, object]) -> dict[str, object]:
|
||||
credentials = args["credentials"]
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError("credentials must be a dictionary")
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def _validate_payload(adapter: TypeAdapter[ValidatedPayload], payload: object) -> ValidatedPayload:
|
||||
return adapter.validate_python(payload)
|
||||
def _get_config_dict(credentials: dict[str, object]) -> dict[str, object]:
|
||||
config = credentials["config"]
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError("string indices must be integers")
|
||||
return config
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@ -11,14 +11,6 @@ from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
class TestApiKeyAuthService:
|
||||
"""API key authentication service security tests"""
|
||||
|
||||
@staticmethod
|
||||
def _assert_validation_error_loc(
|
||||
exc_info: pytest.ExceptionInfo[ValidationError], expected_loc: tuple[str, ...]
|
||||
) -> None:
|
||||
errors = exc_info.value.errors()
|
||||
assert len(errors) == 1
|
||||
assert errors[0]["loc"] == expected_loc
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures"""
|
||||
self.tenant_id = "test_tenant_123"
|
||||
@ -77,7 +69,16 @@ class TestApiKeyAuthService:
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
captured_provider = None
|
||||
captured_credentials = None
|
||||
|
||||
def factory_side_effect(provider, credentials):
|
||||
nonlocal captured_provider, captured_credentials
|
||||
captured_provider = provider
|
||||
captured_credentials = deepcopy(credentials)
|
||||
return mock_auth_instance
|
||||
|
||||
mock_factory.side_effect = factory_side_effect
|
||||
|
||||
# Mock encryption
|
||||
encrypted_key = "encrypted_test_key_123"
|
||||
@ -86,12 +87,14 @@ class TestApiKeyAuthService:
|
||||
# Mock database operations
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
expected_credentials = deepcopy(self.mock_credentials)
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
# Verify factory class calls
|
||||
assert mock_factory.call_count == 1
|
||||
assert mock_factory.call_args.args[0] == self.provider
|
||||
assert captured_provider == self.provider
|
||||
assert captured_credentials == expected_credentials
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
|
||||
# Verify encryption calls
|
||||
@ -120,7 +123,7 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_encrypts_api_key(self, mock_encrypter, mock_factory, mock_session):
|
||||
"""Test create provider auth - stores encrypted API key"""
|
||||
"""Test create provider auth - ensures API key is encrypted"""
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
@ -134,18 +137,17 @@ class TestApiKeyAuthService:
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args_copy = {
|
||||
"category": self.category,
|
||||
"provider": self.provider,
|
||||
"credentials": {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}},
|
||||
}
|
||||
args_copy = self.mock_args.copy()
|
||||
original_key = args_copy["credentials"]["config"]["api_key"]
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
|
||||
|
||||
stored_binding = mock_session.add.call_args.args[0]
|
||||
stored_credentials = json.loads(stored_binding.credentials)
|
||||
assert stored_credentials["config"]["api_key"] == encrypted_key
|
||||
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, "test_secret_key_123")
|
||||
# Verify original key is replaced with encrypted key
|
||||
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
|
||||
assert args_copy["credentials"]["config"]["api_key"] != original_key
|
||||
|
||||
# Verify encryption function is called correctly
|
||||
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_success(self, mock_session):
|
||||
@ -241,81 +243,72 @@ class TestApiKeyAuthService:
|
||||
args = self.mock_args.copy()
|
||||
del args["category"]
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
with pytest.raises(ValueError, match="category is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
self._assert_validation_error_loc(exc_info, ("category",))
|
||||
|
||||
def test_validate_api_key_auth_args_empty_category(self):
|
||||
"""Test API key auth args validation - empty category"""
|
||||
args = self.mock_args.copy()
|
||||
args["category"] = ""
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
with pytest.raises(ValueError, match="category is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
self._assert_validation_error_loc(exc_info, ("category",))
|
||||
|
||||
def test_validate_api_key_auth_args_missing_provider(self):
|
||||
"""Test API key auth args validation - missing provider"""
|
||||
args = self.mock_args.copy()
|
||||
del args["provider"]
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
with pytest.raises(ValueError, match="provider is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
self._assert_validation_error_loc(exc_info, ("provider",))
|
||||
|
||||
def test_validate_api_key_auth_args_empty_provider(self):
|
||||
"""Test API key auth args validation - empty provider"""
|
||||
args = self.mock_args.copy()
|
||||
args["provider"] = ""
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
with pytest.raises(ValueError, match="provider is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
self._assert_validation_error_loc(exc_info, ("provider",))
|
||||
|
||||
def test_validate_api_key_auth_args_missing_credentials(self):
|
||||
"""Test API key auth args validation - missing credentials"""
|
||||
args = self.mock_args.copy()
|
||||
del args["credentials"]
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
with pytest.raises(ValueError, match="credentials is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
self._assert_validation_error_loc(exc_info, ("credentials",))
|
||||
|
||||
def test_validate_api_key_auth_args_empty_credentials(self):
|
||||
"""Test API key auth args validation - empty credentials"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"] = None
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
with pytest.raises(ValueError, match="credentials is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
self._assert_validation_error_loc(exc_info, ("credentials",))
|
||||
|
||||
def test_validate_api_key_auth_args_invalid_credentials_type(self):
|
||||
"""Test API key auth args validation - invalid credentials type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"] = "not_a_dict"
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
with pytest.raises(ValueError, match="credentials must be a dictionary"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
self._assert_validation_error_loc(exc_info, ("credentials",))
|
||||
|
||||
def test_validate_api_key_auth_args_missing_auth_type(self):
|
||||
"""Test API key auth args validation - missing auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
del args["credentials"]["auth_type"]
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
self._assert_validation_error_loc(exc_info, ("credentials", "auth_type"))
|
||||
|
||||
def test_validate_api_key_auth_args_empty_auth_type(self):
|
||||
"""Test API key auth args validation - empty auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = ""
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
self._assert_validation_error_loc(exc_info, ("credentials", "auth_type"))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"malicious_input",
|
||||
@ -394,15 +387,14 @@ class TestApiKeyAuthService:
|
||||
|
||||
def test_validate_api_key_auth_args_none_input(self):
|
||||
"""Test API key auth args validation - None input"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
with pytest.raises(TypeError):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(None)
|
||||
self._assert_validation_error_loc(exc_info, ())
|
||||
|
||||
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
|
||||
"""Test API key auth args validation - dict credentials with list auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = ["api_key"]
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
self._assert_validation_error_loc(exc_info, ("credentials", "auth_type"))
|
||||
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
|
||||
# So this should not raise exception, this test should pass
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
Reference in New Issue
Block a user