mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
refactor(services): migrate trigger_provider_service to SQLAlchemy 2.0 select() API (#34972)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy import delete, desc, func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -73,27 +73,28 @@ class TriggerProviderService:
|
||||
workflows_in_use_map: dict[str, int] = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get all subscriptions
|
||||
subscriptions_db = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
subscriptions_db = session.scalars(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.provider_id == str(provider_id),
|
||||
)
|
||||
.order_by(desc(TriggerSubscription.created_at))
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
|
||||
if not subscriptions:
|
||||
return []
|
||||
usage_counts = (
|
||||
session.query(
|
||||
usage_counts = session.execute(
|
||||
select(
|
||||
WorkflowPluginTrigger.subscription_id,
|
||||
func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"),
|
||||
)
|
||||
.filter(
|
||||
.where(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]),
|
||||
)
|
||||
.group_by(WorkflowPluginTrigger.subscription_id)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts}
|
||||
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
@ -156,9 +157,13 @@ class TriggerProviderService:
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
# Check provider count limit
|
||||
provider_count = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
.count()
|
||||
session.scalar(
|
||||
select(func.count(TriggerSubscription.id)).where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.provider_id == str(provider_id),
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
|
||||
@ -168,10 +173,14 @@ class TriggerProviderService:
|
||||
)
|
||||
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
existing = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.provider_id == str(provider_id),
|
||||
TriggerSubscription.name == name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
@ -248,8 +257,13 @@ class TriggerProviderService:
|
||||
# Use distributed lock to prevent race conditions on the same subscription
|
||||
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.id == subscription_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger subscription {subscription_id} not found")
|
||||
@ -259,10 +273,14 @@ class TriggerProviderService:
|
||||
|
||||
# Check for name uniqueness if name is being updated
|
||||
if name is not None and name != subscription.name:
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
existing = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.provider_id == str(provider_id),
|
||||
TriggerSubscription.name == name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||
@ -320,11 +338,18 @@ class TriggerProviderService:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
subscription: TriggerSubscription | None = None
|
||||
if subscription_id:
|
||||
subscription = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.id == subscription_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
if subscription:
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(subscription.provider_id)
|
||||
@ -353,8 +378,13 @@ class TriggerProviderService:
|
||||
:param subscription_id: Subscription instance ID
|
||||
:return: Success response
|
||||
"""
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.id == subscription_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
@ -406,7 +436,14 @@ class TriggerProviderService:
|
||||
:return: New token info
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.id == subscription_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
@ -479,8 +516,13 @@ class TriggerProviderService:
|
||||
now_ts: int = int(now if now is not None else _time.time())
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription)
|
||||
.where(
|
||||
TriggerSubscription.tenant_id == tenant_id,
|
||||
TriggerSubscription.id == subscription_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if subscription is None:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
@ -556,15 +598,15 @@ class TriggerProviderService:
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
tenant_client: TriggerOAuthTenantClient | None = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
enabled=True,
|
||||
tenant_client = session.scalar(
|
||||
select(TriggerOAuthTenantClient)
|
||||
.where(
|
||||
TriggerOAuthTenantClient.tenant_id == tenant_id,
|
||||
TriggerOAuthTenantClient.provider == provider_id.provider_name,
|
||||
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthTenantClient.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
oauth_params: Mapping[str, Any] | None = None
|
||||
@ -582,10 +624,13 @@ class TriggerProviderService:
|
||||
return None
|
||||
|
||||
# Check for system-level OAuth client
|
||||
system_client: TriggerOAuthSystemClient | None = (
|
||||
session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
||||
.first()
|
||||
system_client = session.scalar(
|
||||
select(TriggerOAuthSystemClient)
|
||||
.where(
|
||||
TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthSystemClient.provider == provider_id.provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if system_client:
|
||||
@ -606,10 +651,13 @@ class TriggerProviderService:
|
||||
if not is_verified:
|
||||
return False
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
system_client: TriggerOAuthSystemClient | None = (
|
||||
session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
||||
.first()
|
||||
system_client = session.scalar(
|
||||
select(TriggerOAuthSystemClient)
|
||||
.where(
|
||||
TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthSystemClient.provider == provider_id.provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return system_client is not None
|
||||
|
||||
@ -640,14 +688,14 @@ class TriggerProviderService:
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# Find existing custom client params
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
custom_client = session.scalar(
|
||||
select(TriggerOAuthTenantClient)
|
||||
.where(
|
||||
TriggerOAuthTenantClient.tenant_id == tenant_id,
|
||||
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthTenantClient.provider == provider_id.provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# Create new record if doesn't exist
|
||||
@ -694,14 +742,14 @@ class TriggerProviderService:
|
||||
:return: Masked OAuth client parameters
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
custom_client = session.scalar(
|
||||
select(TriggerOAuthTenantClient)
|
||||
.where(
|
||||
TriggerOAuthTenantClient.tenant_id == tenant_id,
|
||||
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthTenantClient.provider == provider_id.provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if custom_client is None:
|
||||
@ -731,11 +779,15 @@ class TriggerProviderService:
|
||||
:return: Success response
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(TriggerOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
).delete()
|
||||
session.execute(
|
||||
delete(TriggerOAuthTenantClient)
|
||||
.where(
|
||||
TriggerOAuthTenantClient.tenant_id == tenant_id,
|
||||
TriggerOAuthTenantClient.provider == provider_id.provider_name,
|
||||
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -749,15 +801,15 @@ class TriggerProviderService:
|
||||
:return: True if enabled, False otherwise
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
enabled=True,
|
||||
custom_client = session.scalar(
|
||||
select(TriggerOAuthTenantClient)
|
||||
.where(
|
||||
TriggerOAuthTenantClient.tenant_id == tenant_id,
|
||||
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
|
||||
TriggerOAuthTenantClient.provider == provider_id.provider_name,
|
||||
TriggerOAuthTenantClient.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return custom_client is not None
|
||||
|
||||
@ -767,7 +819,9 @@ class TriggerProviderService:
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
|
||||
subscription = session.scalar(
|
||||
select(TriggerSubscription).where(TriggerSubscription.endpoint_id == endpoint_id).limit(1)
|
||||
)
|
||||
if not subscription:
|
||||
return None
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
|
||||
@ -124,9 +124,7 @@ def test_list_trigger_provider_subscriptions_should_return_empty_list_when_no_su
|
||||
provider_id: TriggerProviderID,
|
||||
) -> None:
|
||||
# Arrange
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value.order_by.return_value.all.return_value = []
|
||||
mock_session.query.return_value = query
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id)
|
||||
@ -152,11 +150,8 @@ def test_list_trigger_provider_subscriptions_should_mask_fields_and_attach_workf
|
||||
db_sub = SimpleNamespace(to_api_entity=lambda: api_sub)
|
||||
usage_row = SimpleNamespace(subscription_id="sub-1", app_count=2)
|
||||
|
||||
query_subs = MagicMock()
|
||||
query_subs.filter_by.return_value.order_by.return_value.all.return_value = [db_sub]
|
||||
query_usage = MagicMock()
|
||||
query_usage.filter.return_value.group_by.return_value.all.return_value = [usage_row]
|
||||
mock_session.query.side_effect = [query_subs, query_usage]
|
||||
mock_session.scalars.return_value.all.return_value = [db_sub]
|
||||
mock_session.execute.return_value.all.return_value = [usage_row]
|
||||
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cred_enc = _encrypter_mock(decrypted={"token": "plain"}, masked={"token": "****"})
|
||||
@ -188,11 +183,7 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap
|
||||
) -> None:
|
||||
# Arrange
|
||||
_patch_redis_lock(mocker)
|
||||
query_count = MagicMock()
|
||||
query_count.filter_by.return_value.count.return_value = 0
|
||||
query_existing = MagicMock()
|
||||
query_existing.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.side_effect = [query_count, query_existing]
|
||||
mock_session.scalar.side_effect = [0, None] # count=0, no existing name
|
||||
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cred_enc = _encrypter_mock(encrypted={"api_key": "enc"})
|
||||
@ -228,11 +219,7 @@ def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorize
|
||||
) -> None:
|
||||
# Arrange
|
||||
_patch_redis_lock(mocker)
|
||||
query_count = MagicMock()
|
||||
query_count.filter_by.return_value.count.return_value = 0
|
||||
query_existing = MagicMock()
|
||||
query_existing.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.side_effect = [query_count, query_existing]
|
||||
mock_session.scalar.side_effect = [0, None] # count=0, no existing name
|
||||
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
prop_enc = _encrypter_mock(encrypted={"p": "enc"})
|
||||
@ -267,9 +254,7 @@ def test_add_trigger_subscription_should_raise_error_when_provider_limit_reached
|
||||
) -> None:
|
||||
# Arrange
|
||||
_patch_redis_lock(mocker)
|
||||
query_count = MagicMock()
|
||||
query_count.filter_by.return_value.count.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__
|
||||
mock_session.query.return_value = query_count
|
||||
mock_session.scalar.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mock_logger = mocker.patch("services.trigger.trigger_provider_service.logger")
|
||||
|
||||
@ -297,11 +282,7 @@ def test_add_trigger_subscription_should_raise_error_when_name_exists(
|
||||
) -> None:
|
||||
# Arrange
|
||||
_patch_redis_lock(mocker)
|
||||
query_count = MagicMock()
|
||||
query_count.filter_by.return_value.count.return_value = 0
|
||||
query_existing = MagicMock()
|
||||
query_existing.filter_by.return_value.first.return_value = object()
|
||||
mock_session.query.side_effect = [query_count, query_existing]
|
||||
mock_session.scalar.side_effect = [0, object()] # count=0, existing name conflict
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
|
||||
# Act + Assert
|
||||
@ -325,9 +306,7 @@ def test_update_trigger_subscription_should_raise_error_when_subscription_not_fo
|
||||
) -> None:
|
||||
# Arrange
|
||||
_patch_redis_lock(mocker)
|
||||
query_sub = MagicMock()
|
||||
query_sub.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.return_value = query_sub
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
@ -347,11 +326,7 @@ def test_update_trigger_subscription_should_raise_error_when_name_conflicts(
|
||||
provider_id="langgenius/github/github",
|
||||
credential_type=CredentialType.API_KEY.value,
|
||||
)
|
||||
query_sub = MagicMock()
|
||||
query_sub.filter_by.return_value.first.return_value = subscription
|
||||
query_existing = MagicMock()
|
||||
query_existing.filter_by.return_value.first.return_value = object()
|
||||
mock_session.query.side_effect = [query_sub, query_existing]
|
||||
mock_session.scalar.side_effect = [subscription, object()] # found sub, name conflict
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
|
||||
# Act + Assert
|
||||
@ -378,11 +353,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
|
||||
credential_expires_at=0,
|
||||
expires_at=0,
|
||||
)
|
||||
query_sub = MagicMock()
|
||||
query_sub.filter_by.return_value.first.return_value = subscription
|
||||
query_existing = MagicMock()
|
||||
query_existing.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.side_effect = [query_sub, query_existing]
|
||||
mock_session.scalar.side_effect = [subscription, None] # found sub, no name conflict
|
||||
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
prop_enc = _encrypter_mock(decrypted={"project": "old-value"}, encrypted={"project": "new-value"})
|
||||
@ -417,7 +388,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
|
||||
|
||||
def test_get_subscription_by_id_should_return_none_when_missing(mocker: MockerFixture, mock_session: MagicMock) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1")
|
||||
@ -439,7 +410,7 @@ def test_get_subscription_by_id_should_decrypt_credentials_and_properties(
|
||||
credentials={"token": "enc"},
|
||||
properties={"project": "enc"},
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cred_enc = _encrypter_mock(decrypted={"token": "plain"})
|
||||
prop_enc = _encrypter_mock(decrypted={"project": "plain"})
|
||||
@ -466,7 +437,7 @@ def test_delete_trigger_provider_should_raise_error_when_subscription_missing(
|
||||
mock_session: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
@ -488,7 +459,7 @@ def test_delete_trigger_provider_should_delete_and_clear_cache_even_if_unsubscri
|
||||
credentials={"token": "enc"},
|
||||
to_entity=lambda: SimpleNamespace(id="sub-1"),
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cred_enc = _encrypter_mock(decrypted={"token": "plain"})
|
||||
mocker.patch(
|
||||
@ -524,7 +495,7 @@ def test_delete_trigger_provider_should_skip_unsubscribe_for_unauthorized(
|
||||
credentials={},
|
||||
to_entity=lambda: SimpleNamespace(id="sub-2"),
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mock_unsubscribe = mocker.patch("services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger")
|
||||
mocker.patch(
|
||||
@ -544,7 +515,7 @@ def test_refresh_oauth_token_should_raise_error_when_subscription_missing(
|
||||
mocker: MockerFixture, mock_session: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
@ -556,7 +527,7 @@ def test_refresh_oauth_token_should_raise_error_for_non_oauth_credentials(
|
||||
) -> None:
|
||||
# Arrange
|
||||
subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Only OAuth credentials can be refreshed"):
|
||||
@ -577,7 +548,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials(
|
||||
credentials={"access_token": "enc"},
|
||||
credential_expires_at=0,
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cache = MagicMock()
|
||||
cred_enc = _encrypter_mock(decrypted={"access_token": "old"}, encrypted={"access_token": "new"})
|
||||
@ -606,7 +577,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing(
|
||||
mocker: MockerFixture, mock_session: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
@ -616,7 +587,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing(
|
||||
def test_refresh_subscription_should_skip_when_not_due(mocker: MockerFixture, mock_session: MagicMock) -> None:
|
||||
# Arrange
|
||||
subscription = SimpleNamespace(expires_at=200)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100)
|
||||
@ -643,7 +614,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties(
|
||||
credentials={"c": "enc"},
|
||||
credential_type=CredentialType.API_KEY.value,
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cred_enc = _encrypter_mock(decrypted={"c": "plain"})
|
||||
prop_cache = MagicMock()
|
||||
@ -681,10 +652,7 @@ def test_get_oauth_client_should_return_tenant_client_when_available(
|
||||
) -> None:
|
||||
# Arrange
|
||||
tenant_client = SimpleNamespace(oauth_params={"client_id": "enc"})
|
||||
system_client = None
|
||||
query_tenant = MagicMock()
|
||||
query_tenant.filter_by.return_value.first.return_value = tenant_client
|
||||
mock_session.query.return_value = query_tenant
|
||||
mock_session.scalar.return_value = tenant_client
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
enc = _encrypter_mock(decrypted={"client_id": "plain"})
|
||||
mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock()))
|
||||
@ -703,11 +671,7 @@ def test_get_oauth_client_should_return_none_when_plugin_not_verified(
|
||||
provider_controller: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
query_tenant = MagicMock()
|
||||
query_tenant.filter_by.return_value.first.return_value = None
|
||||
query_system = MagicMock()
|
||||
query_system.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.side_effect = [query_tenant, query_system]
|
||||
mock_session.scalar.return_value = None # no tenant client; plugin not verified → early return
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False)
|
||||
|
||||
@ -725,11 +689,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified(
|
||||
provider_controller: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
query_tenant = MagicMock()
|
||||
query_tenant.filter_by.return_value.first.return_value = None
|
||||
query_system = MagicMock()
|
||||
query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc")
|
||||
mock_session.query.side_effect = [query_tenant, query_system]
|
||||
mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")]
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
|
||||
mocker.patch(
|
||||
@ -751,11 +711,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails(
|
||||
provider_controller: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
query_tenant = MagicMock()
|
||||
query_tenant.filter_by.return_value.first.return_value = None
|
||||
query_system = MagicMock()
|
||||
query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc")
|
||||
mock_session.query.side_effect = [query_tenant, query_system]
|
||||
mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")]
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
|
||||
mocker.patch(
|
||||
@ -794,7 +750,7 @@ def test_is_oauth_system_client_exists_should_reflect_database_record(
|
||||
provider_controller: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = object() if has_client else None
|
||||
mock_session.scalar.return_value = object() if has_client else None
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
|
||||
|
||||
@ -823,11 +779,11 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w
|
||||
provider_controller: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value.first.return_value = None
|
||||
mock_session.query.return_value = query
|
||||
mock_session.scalar.return_value = None
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
fake_model = SimpleNamespace(encrypted_oauth_params="", enabled=False, oauth_params={})
|
||||
# Also mock select() so SQLAlchemy doesn't validate the patched TriggerOAuthTenantClient.
|
||||
mocker.patch("services.trigger.trigger_provider_service.select", MagicMock(return_value=MagicMock()))
|
||||
mocker.patch("services.trigger.trigger_provider_service.TriggerOAuthTenantClient", return_value=fake_model)
|
||||
|
||||
# Act
|
||||
@ -853,7 +809,7 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c
|
||||
) -> None:
|
||||
# Arrange
|
||||
custom_client = SimpleNamespace(oauth_params={"client_id": "enc-old"}, enabled=False)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client
|
||||
mock_session.scalar.return_value = custom_client
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
cache = MagicMock()
|
||||
enc = _encrypter_mock(decrypted={"client_id": "old-id"}, encrypted={"client_id": "new-id"})
|
||||
@ -882,7 +838,7 @@ def test_get_custom_oauth_client_params_should_return_empty_when_record_missing(
|
||||
provider_id: TriggerProviderID,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id)
|
||||
@ -899,7 +855,7 @@ def test_get_custom_oauth_client_params_should_return_masked_decrypted_values(
|
||||
) -> None:
|
||||
# Arrange
|
||||
custom_client = SimpleNamespace(oauth_params={"client_id": "enc"})
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client
|
||||
mock_session.scalar.return_value = custom_client
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
enc = _encrypter_mock(decrypted={"client_id": "plain"}, masked={"client_id": "pl***id"})
|
||||
mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock()))
|
||||
@ -916,9 +872,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit(
|
||||
mock_session: MagicMock,
|
||||
provider_id: TriggerProviderID,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.delete.return_value = 1
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.delete_custom_oauth_client_params("tenant-1", provider_id)
|
||||
|
||||
@ -934,7 +887,7 @@ def test_is_oauth_custom_client_enabled_should_return_expected_boolean(
|
||||
provider_id: TriggerProviderID,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = object() if exists else None
|
||||
mock_session.scalar.return_value = object() if exists else None
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.is_oauth_custom_client_enabled("tenant-1", provider_id)
|
||||
@ -947,7 +900,7 @@ def test_get_subscription_by_endpoint_should_return_none_when_not_found(
|
||||
mocker: MockerFixture, mock_session: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1")
|
||||
@ -968,7 +921,7 @@ def test_get_subscription_by_endpoint_should_decrypt_credentials_and_properties(
|
||||
credentials={"token": "enc"},
|
||||
properties={"hook": "enc"},
|
||||
)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
|
||||
mock_session.scalar.return_value = subscription
|
||||
_mock_get_trigger_provider(mocker, provider_controller)
|
||||
mocker.patch(
|
||||
"services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription",
|
||||
|
||||
Reference in New Issue
Block a user