refactor(services): migrate trigger_provider_service to SQLAlchemy 2.0 select() API (#34972)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wdeveloper16
2026-04-12 03:36:13 +02:00
committed by GitHub
parent 4ef67fef3a
commit 510120410b
2 changed files with 165 additions and 158 deletions

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Mapping
from typing import Any, TypedDict
from sqlalchemy import desc, func
from sqlalchemy import delete, desc, func, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -73,27 +73,28 @@ class TriggerProviderService:
workflows_in_use_map: dict[str, int] = {}
with Session(db.engine, expire_on_commit=False) as session:
# Get all subscriptions
subscriptions_db = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
subscriptions_db = session.scalars(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.provider_id == str(provider_id),
)
.order_by(desc(TriggerSubscription.created_at))
.all()
)
).all()
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
if not subscriptions:
return []
usage_counts = (
session.query(
usage_counts = session.execute(
select(
WorkflowPluginTrigger.subscription_id,
func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"),
)
.filter(
.where(
WorkflowPluginTrigger.tenant_id == tenant_id,
WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]),
)
.group_by(WorkflowPluginTrigger.subscription_id)
.all()
)
).all()
workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts}
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
@ -156,9 +157,13 @@ class TriggerProviderService:
with redis_client.lock(lock_key, timeout=20):
# Check provider count limit
provider_count = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
.count()
session.scalar(
select(func.count(TriggerSubscription.id)).where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.provider_id == str(provider_id),
)
)
or 0
)
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
@ -168,10 +173,14 @@ class TriggerProviderService:
)
# Check if name already exists
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
existing = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.provider_id == str(provider_id),
TriggerSubscription.name == name,
)
.limit(1)
)
if existing:
raise ValueError(f"Credential name '{name}' already exists for this provider")
@ -248,8 +257,13 @@ class TriggerProviderService:
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.id == subscription_id,
)
.limit(1)
)
if not subscription:
raise ValueError(f"Trigger subscription {subscription_id} not found")
@ -259,10 +273,14 @@ class TriggerProviderService:
# Check for name uniqueness if name is being updated
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
existing = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.provider_id == str(provider_id),
TriggerSubscription.name == name,
)
.limit(1)
)
if existing:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
@ -320,11 +338,18 @@ class TriggerProviderService:
with Session(db.engine, expire_on_commit=False) as session:
subscription: TriggerSubscription | None = None
if subscription_id:
subscription = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.id == subscription_id,
)
.limit(1)
)
else:
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first()
subscription = session.scalar(
select(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant_id).limit(1)
)
if subscription:
provider_controller = TriggerManager.get_trigger_provider(
tenant_id, TriggerProviderID(subscription.provider_id)
@ -353,8 +378,13 @@ class TriggerProviderService:
:param subscription_id: Subscription instance ID
:return: Success response
"""
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.id == subscription_id,
)
.limit(1)
)
if not subscription:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
@ -406,7 +436,14 @@ class TriggerProviderService:
:return: New token info
"""
with sessionmaker(bind=db.engine).begin() as session:
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.id == subscription_id,
)
.limit(1)
)
if not subscription:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
@ -479,8 +516,13 @@ class TriggerProviderService:
now_ts: int = int(now if now is not None else _time.time())
with sessionmaker(bind=db.engine).begin() as session:
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
subscription = session.scalar(
select(TriggerSubscription)
.where(
TriggerSubscription.tenant_id == tenant_id,
TriggerSubscription.id == subscription_id,
)
.limit(1)
)
if subscription is None:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
@ -556,15 +598,15 @@ class TriggerProviderService:
tenant_id=tenant_id, provider_id=provider_id
)
with Session(db.engine, expire_on_commit=False) as session:
tenant_client: TriggerOAuthTenantClient | None = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
enabled=True,
tenant_client = session.scalar(
select(TriggerOAuthTenantClient)
.where(
TriggerOAuthTenantClient.tenant_id == tenant_id,
TriggerOAuthTenantClient.provider == provider_id.provider_name,
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
TriggerOAuthTenantClient.enabled.is_(True),
)
.first()
.limit(1)
)
oauth_params: Mapping[str, Any] | None = None
@ -582,10 +624,13 @@ class TriggerProviderService:
return None
# Check for system-level OAuth client
system_client: TriggerOAuthSystemClient | None = (
session.query(TriggerOAuthSystemClient)
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
.first()
system_client = session.scalar(
select(TriggerOAuthSystemClient)
.where(
TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id,
TriggerOAuthSystemClient.provider == provider_id.provider_name,
)
.limit(1)
)
if system_client:
@ -606,10 +651,13 @@ class TriggerProviderService:
if not is_verified:
return False
with Session(db.engine, expire_on_commit=False) as session:
system_client: TriggerOAuthSystemClient | None = (
session.query(TriggerOAuthSystemClient)
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
.first()
system_client = session.scalar(
select(TriggerOAuthSystemClient)
.where(
TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id,
TriggerOAuthSystemClient.provider == provider_id.provider_name,
)
.limit(1)
)
return system_client is not None
@ -640,14 +688,14 @@ class TriggerProviderService:
with sessionmaker(bind=db.engine).begin() as session:
# Find existing custom client params
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
custom_client = session.scalar(
select(TriggerOAuthTenantClient)
.where(
TriggerOAuthTenantClient.tenant_id == tenant_id,
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
TriggerOAuthTenantClient.provider == provider_id.provider_name,
)
.first()
.limit(1)
)
# Create new record if doesn't exist
@ -694,14 +742,14 @@ class TriggerProviderService:
:return: Masked OAuth client parameters
"""
with Session(db.engine) as session:
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
custom_client = session.scalar(
select(TriggerOAuthTenantClient)
.where(
TriggerOAuthTenantClient.tenant_id == tenant_id,
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
TriggerOAuthTenantClient.provider == provider_id.provider_name,
)
.first()
.limit(1)
)
if custom_client is None:
@ -731,11 +779,15 @@ class TriggerProviderService:
:return: Success response
"""
with sessionmaker(bind=db.engine).begin() as session:
session.query(TriggerOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
).delete()
session.execute(
delete(TriggerOAuthTenantClient)
.where(
TriggerOAuthTenantClient.tenant_id == tenant_id,
TriggerOAuthTenantClient.provider == provider_id.provider_name,
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
)
.execution_options(synchronize_session=False)
)
return {"result": "success"}
@ -749,15 +801,15 @@ class TriggerProviderService:
:return: True if enabled, False otherwise
"""
with Session(db.engine, expire_on_commit=False) as session:
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
enabled=True,
custom_client = session.scalar(
select(TriggerOAuthTenantClient)
.where(
TriggerOAuthTenantClient.tenant_id == tenant_id,
TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id,
TriggerOAuthTenantClient.provider == provider_id.provider_name,
TriggerOAuthTenantClient.enabled.is_(True),
)
.first()
.limit(1)
)
return custom_client is not None
@ -767,7 +819,9 @@ class TriggerProviderService:
Get a trigger subscription by the endpoint ID.
"""
with Session(db.engine, expire_on_commit=False) as session:
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
subscription = session.scalar(
select(TriggerSubscription).where(TriggerSubscription.endpoint_id == endpoint_id).limit(1)
)
if not subscription:
return None
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(

View File

@ -124,9 +124,7 @@ def test_list_trigger_provider_subscriptions_should_return_empty_list_when_no_su
provider_id: TriggerProviderID,
) -> None:
# Arrange
query = MagicMock()
query.filter_by.return_value.order_by.return_value.all.return_value = []
mock_session.query.return_value = query
mock_session.scalars.return_value.all.return_value = []
# Act
result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id)
@ -152,11 +150,8 @@ def test_list_trigger_provider_subscriptions_should_mask_fields_and_attach_workf
db_sub = SimpleNamespace(to_api_entity=lambda: api_sub)
usage_row = SimpleNamespace(subscription_id="sub-1", app_count=2)
query_subs = MagicMock()
query_subs.filter_by.return_value.order_by.return_value.all.return_value = [db_sub]
query_usage = MagicMock()
query_usage.filter.return_value.group_by.return_value.all.return_value = [usage_row]
mock_session.query.side_effect = [query_subs, query_usage]
mock_session.scalars.return_value.all.return_value = [db_sub]
mock_session.execute.return_value.all.return_value = [usage_row]
_mock_get_trigger_provider(mocker, provider_controller)
cred_enc = _encrypter_mock(decrypted={"token": "plain"}, masked={"token": "****"})
@ -188,11 +183,7 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap
) -> None:
# Arrange
_patch_redis_lock(mocker)
query_count = MagicMock()
query_count.filter_by.return_value.count.return_value = 0
query_existing = MagicMock()
query_existing.filter_by.return_value.first.return_value = None
mock_session.query.side_effect = [query_count, query_existing]
mock_session.scalar.side_effect = [0, None] # count=0, no existing name
_mock_get_trigger_provider(mocker, provider_controller)
cred_enc = _encrypter_mock(encrypted={"api_key": "enc"})
@ -228,11 +219,7 @@ def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorize
) -> None:
# Arrange
_patch_redis_lock(mocker)
query_count = MagicMock()
query_count.filter_by.return_value.count.return_value = 0
query_existing = MagicMock()
query_existing.filter_by.return_value.first.return_value = None
mock_session.query.side_effect = [query_count, query_existing]
mock_session.scalar.side_effect = [0, None] # count=0, no existing name
_mock_get_trigger_provider(mocker, provider_controller)
prop_enc = _encrypter_mock(encrypted={"p": "enc"})
@ -267,9 +254,7 @@ def test_add_trigger_subscription_should_raise_error_when_provider_limit_reached
) -> None:
# Arrange
_patch_redis_lock(mocker)
query_count = MagicMock()
query_count.filter_by.return_value.count.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__
mock_session.query.return_value = query_count
mock_session.scalar.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__
_mock_get_trigger_provider(mocker, provider_controller)
mock_logger = mocker.patch("services.trigger.trigger_provider_service.logger")
@ -297,11 +282,7 @@ def test_add_trigger_subscription_should_raise_error_when_name_exists(
) -> None:
# Arrange
_patch_redis_lock(mocker)
query_count = MagicMock()
query_count.filter_by.return_value.count.return_value = 0
query_existing = MagicMock()
query_existing.filter_by.return_value.first.return_value = object()
mock_session.query.side_effect = [query_count, query_existing]
mock_session.scalar.side_effect = [0, object()] # count=0, existing name conflict
_mock_get_trigger_provider(mocker, provider_controller)
# Act + Assert
@ -325,9 +306,7 @@ def test_update_trigger_subscription_should_raise_error_when_subscription_not_fo
) -> None:
# Arrange
_patch_redis_lock(mocker)
query_sub = MagicMock()
query_sub.filter_by.return_value.first.return_value = None
mock_session.query.return_value = query_sub
mock_session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="not found"):
@ -347,11 +326,7 @@ def test_update_trigger_subscription_should_raise_error_when_name_conflicts(
provider_id="langgenius/github/github",
credential_type=CredentialType.API_KEY.value,
)
query_sub = MagicMock()
query_sub.filter_by.return_value.first.return_value = subscription
query_existing = MagicMock()
query_existing.filter_by.return_value.first.return_value = object()
mock_session.query.side_effect = [query_sub, query_existing]
mock_session.scalar.side_effect = [subscription, object()] # found sub, name conflict
_mock_get_trigger_provider(mocker, provider_controller)
# Act + Assert
@ -378,11 +353,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
credential_expires_at=0,
expires_at=0,
)
query_sub = MagicMock()
query_sub.filter_by.return_value.first.return_value = subscription
query_existing = MagicMock()
query_existing.filter_by.return_value.first.return_value = None
mock_session.query.side_effect = [query_sub, query_existing]
mock_session.scalar.side_effect = [subscription, None] # found sub, no name conflict
_mock_get_trigger_provider(mocker, provider_controller)
prop_enc = _encrypter_mock(decrypted={"project": "old-value"}, encrypted={"project": "new-value"})
@ -417,7 +388,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache(
def test_get_subscription_by_id_should_return_none_when_missing(mocker: MockerFixture, mock_session: MagicMock) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act
result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1")
@ -439,7 +410,7 @@ def test_get_subscription_by_id_should_decrypt_credentials_and_properties(
credentials={"token": "enc"},
properties={"project": "enc"},
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
cred_enc = _encrypter_mock(decrypted={"token": "plain"})
prop_enc = _encrypter_mock(decrypted={"project": "plain"})
@ -466,7 +437,7 @@ def test_delete_trigger_provider_should_raise_error_when_subscription_missing(
mock_session: MagicMock,
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="not found"):
@ -488,7 +459,7 @@ def test_delete_trigger_provider_should_delete_and_clear_cache_even_if_unsubscri
credentials={"token": "enc"},
to_entity=lambda: SimpleNamespace(id="sub-1"),
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
cred_enc = _encrypter_mock(decrypted={"token": "plain"})
mocker.patch(
@ -524,7 +495,7 @@ def test_delete_trigger_provider_should_skip_unsubscribe_for_unauthorized(
credentials={},
to_entity=lambda: SimpleNamespace(id="sub-2"),
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
mock_unsubscribe = mocker.patch("services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger")
mocker.patch(
@ -544,7 +515,7 @@ def test_refresh_oauth_token_should_raise_error_when_subscription_missing(
mocker: MockerFixture, mock_session: MagicMock
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="not found"):
@ -556,7 +527,7 @@ def test_refresh_oauth_token_should_raise_error_for_non_oauth_credentials(
) -> None:
# Arrange
subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
# Act + Assert
with pytest.raises(ValueError, match="Only OAuth credentials can be refreshed"):
@ -577,7 +548,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials(
credentials={"access_token": "enc"},
credential_expires_at=0,
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
cache = MagicMock()
cred_enc = _encrypter_mock(decrypted={"access_token": "old"}, encrypted={"access_token": "new"})
@ -606,7 +577,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing(
mocker: MockerFixture, mock_session: MagicMock
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="not found"):
@ -616,7 +587,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing(
def test_refresh_subscription_should_skip_when_not_due(mocker: MockerFixture, mock_session: MagicMock) -> None:
# Arrange
subscription = SimpleNamespace(expires_at=200)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
# Act
result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100)
@ -643,7 +614,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties(
credentials={"c": "enc"},
credential_type=CredentialType.API_KEY.value,
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
cred_enc = _encrypter_mock(decrypted={"c": "plain"})
prop_cache = MagicMock()
@ -681,10 +652,7 @@ def test_get_oauth_client_should_return_tenant_client_when_available(
) -> None:
# Arrange
tenant_client = SimpleNamespace(oauth_params={"client_id": "enc"})
system_client = None
query_tenant = MagicMock()
query_tenant.filter_by.return_value.first.return_value = tenant_client
mock_session.query.return_value = query_tenant
mock_session.scalar.return_value = tenant_client
_mock_get_trigger_provider(mocker, provider_controller)
enc = _encrypter_mock(decrypted={"client_id": "plain"})
mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock()))
@ -703,11 +671,7 @@ def test_get_oauth_client_should_return_none_when_plugin_not_verified(
provider_controller: MagicMock,
) -> None:
# Arrange
query_tenant = MagicMock()
query_tenant.filter_by.return_value.first.return_value = None
query_system = MagicMock()
query_system.filter_by.return_value.first.return_value = None
mock_session.query.side_effect = [query_tenant, query_system]
mock_session.scalar.return_value = None # no tenant client; plugin not verified → early return
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False)
@ -725,11 +689,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified(
provider_controller: MagicMock,
) -> None:
# Arrange
query_tenant = MagicMock()
query_tenant.filter_by.return_value.first.return_value = None
query_system = MagicMock()
query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc")
mock_session.query.side_effect = [query_tenant, query_system]
mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")]
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
mocker.patch(
@ -751,11 +711,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails(
provider_controller: MagicMock,
) -> None:
# Arrange
query_tenant = MagicMock()
query_tenant.filter_by.return_value.first.return_value = None
query_system = MagicMock()
query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc")
mock_session.query.side_effect = [query_tenant, query_system]
mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")]
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
mocker.patch(
@ -794,7 +750,7 @@ def test_is_oauth_system_client_exists_should_reflect_database_record(
provider_controller: MagicMock,
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = object() if has_client else None
mock_session.scalar.return_value = object() if has_client else None
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True)
@ -823,11 +779,11 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w
provider_controller: MagicMock,
) -> None:
# Arrange
query = MagicMock()
query.filter_by.return_value.first.return_value = None
mock_session.query.return_value = query
mock_session.scalar.return_value = None
_mock_get_trigger_provider(mocker, provider_controller)
fake_model = SimpleNamespace(encrypted_oauth_params="", enabled=False, oauth_params={})
# Also mock select() so SQLAlchemy doesn't validate the patched TriggerOAuthTenantClient.
mocker.patch("services.trigger.trigger_provider_service.select", MagicMock(return_value=MagicMock()))
mocker.patch("services.trigger.trigger_provider_service.TriggerOAuthTenantClient", return_value=fake_model)
# Act
@ -853,7 +809,7 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c
) -> None:
# Arrange
custom_client = SimpleNamespace(oauth_params={"client_id": "enc-old"}, enabled=False)
mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client
mock_session.scalar.return_value = custom_client
_mock_get_trigger_provider(mocker, provider_controller)
cache = MagicMock()
enc = _encrypter_mock(decrypted={"client_id": "old-id"}, encrypted={"client_id": "new-id"})
@ -882,7 +838,7 @@ def test_get_custom_oauth_client_params_should_return_empty_when_record_missing(
provider_id: TriggerProviderID,
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act
result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id)
@ -899,7 +855,7 @@ def test_get_custom_oauth_client_params_should_return_masked_decrypted_values(
) -> None:
# Arrange
custom_client = SimpleNamespace(oauth_params={"client_id": "enc"})
mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client
mock_session.scalar.return_value = custom_client
_mock_get_trigger_provider(mocker, provider_controller)
enc = _encrypter_mock(decrypted={"client_id": "plain"}, masked={"client_id": "pl***id"})
mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock()))
@ -916,9 +872,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit(
mock_session: MagicMock,
provider_id: TriggerProviderID,
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.delete.return_value = 1
# Act
result = TriggerProviderService.delete_custom_oauth_client_params("tenant-1", provider_id)
@ -934,7 +887,7 @@ def test_is_oauth_custom_client_enabled_should_return_expected_boolean(
provider_id: TriggerProviderID,
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = object() if exists else None
mock_session.scalar.return_value = object() if exists else None
# Act
result = TriggerProviderService.is_oauth_custom_client_enabled("tenant-1", provider_id)
@ -947,7 +900,7 @@ def test_get_subscription_by_endpoint_should_return_none_when_not_found(
mocker: MockerFixture, mock_session: MagicMock
) -> None:
# Arrange
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Act
result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1")
@ -968,7 +921,7 @@ def test_get_subscription_by_endpoint_should_decrypt_credentials_and_properties(
credentials={"token": "enc"},
properties={"hook": "enc"},
)
mock_session.query.return_value.filter_by.return_value.first.return_value = subscription
mock_session.scalar.return_value = subscription
_mock_get_trigger_provider(mocker, provider_controller)
mocker.patch(
"services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription",