feat(trigger): refactor trigger provider to subscription model

- Rename classes and methods to reflect the transition from credentials to subscriptions
- Update API endpoints for managing trigger subscriptions
- Modify data models and entities to support subscription attributes
- Enhance service methods for listing, adding, updating, and deleting subscriptions
- Adjust encryption utilities to handle subscription data

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Harry
2025-08-29 17:07:06 +08:00
parent 5ddd5e49ee
commit 6acc77d86d
7 changed files with 110 additions and 140 deletions

View File

@ -15,15 +15,15 @@ from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderCredentialApiEntity
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import (
create_trigger_provider_encrypter_for_credential,
create_trigger_provider_encrypter_for_subscription,
create_trigger_provider_oauth_encrypter,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerProvider
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerSubscription
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
@ -40,29 +40,29 @@ class TriggerProviderService:
return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)]
@classmethod
def list_trigger_provider_credentials(
def list_trigger_provider_subscriptions(
cls, tenant_id: str, provider_id: TriggerProviderID
) -> list[TriggerProviderCredentialApiEntity]:
"""List all trigger providers for the current tenant"""
credentials: list[TriggerProviderCredentialApiEntity] = []
) -> list[TriggerProviderSubscriptionApiEntity]:
"""List all trigger subscriptions for the current tenant"""
subscriptions: list[TriggerProviderSubscriptionApiEntity] = []
with Session(db.engine, autoflush=False) as session:
credentials_db = (
session.query(TriggerProvider)
subscriptions_db = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
.order_by(desc(TriggerProvider.created_at))
.order_by(desc(TriggerSubscription.created_at))
.all()
)
credentials = [credential.to_api_entity() for credential in credentials_db]
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
for credential in credentials:
encrypter, _ = create_trigger_provider_encrypter_for_credential(
for subscription in subscriptions:
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
credential=credential,
subscription=subscription,
)
credential.credentials = encrypter.decrypt(credential.credentials)
return credentials
subscription.credentials = encrypter.decrypt(subscription.credentials)
return subscriptions
@classmethod
def add_trigger_provider(
@ -95,7 +95,9 @@ class TriggerProviderService:
with redis_client.lock(lock_key, timeout=20):
# Check provider count limit
provider_count = (
session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).count()
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=provider_id)
.count()
)
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
@ -115,7 +117,7 @@ class TriggerProviderService:
else:
# Check if name already exists
existing = (
session.query(TriggerProvider)
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
.first()
)
@ -129,12 +131,12 @@ class TriggerProviderService:
)
# Create provider record
db_provider = TriggerProvider(
db_provider = TriggerSubscription(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
credential_type=credential_type.value,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credentials=encrypter.encrypt(credentials),
name=name,
expires_at=expires_at,
)
@ -152,7 +154,7 @@ class TriggerProviderService:
def update_trigger_provider(
cls,
tenant_id: str,
credential_id: str,
subscription_id: str,
credentials: Optional[dict] = None,
name: Optional[str] = None,
) -> dict:
@ -160,15 +162,15 @@ class TriggerProviderService:
Update an existing trigger provider's credentials or name.
:param tenant_id: Tenant ID
:param credential_id: Credential instance ID
:param subscription_id: Subscription instance ID
:param credentials: New credentials (optional)
:param name: New name (optional)
:return: Success response
"""
with Session(db.engine) as session:
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
if not db_provider:
raise ValueError(f"Trigger provider credential {credential_id} not found")
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
try:
provider_controller = TriggerManager.get_trigger_provider(
@ -176,10 +178,10 @@ class TriggerProviderService:
)
if credentials:
encrypter, cache = create_trigger_provider_encrypter_for_credential(
encrypter, cache = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
credential=db_provider,
subscription=db_provider,
)
# Handle hidden values
original_credentials = encrypter.decrypt(db_provider.credentials)
@ -188,16 +190,16 @@ class TriggerProviderService:
for key, value in credentials.items()
}
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
db_provider.credentials = encrypter.encrypt(new_credentials)
cache.delete()
# Update name if provided
if name and name != db_provider.name:
# Check if name already exists
existing = (
session.query(TriggerProvider)
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=db_provider.provider_id, name=name)
.filter(TriggerProvider.id != credential_id)
.filter(TriggerSubscription.id != subscription_id)
.first()
)
if existing:
@ -213,27 +215,27 @@ class TriggerProviderService:
raise ValueError(str(e))
@classmethod
def delete_trigger_provider(cls, tenant_id: str, credential_id: str) -> dict:
def delete_trigger_provider(cls, tenant_id: str, subscription_id: str) -> dict:
"""
Delete a trigger provider credential.
Delete a trigger provider subscription.
:param tenant_id: Tenant ID
:param credential_id: Credential instance ID
:param subscription_id: Subscription instance ID
:return: Success response
"""
with Session(db.engine) as session:
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
if not db_provider:
raise ValueError(f"Trigger provider credential {credential_id} not found")
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
provider_controller = TriggerManager.get_trigger_provider(
tenant_id, TriggerProviderID(db_provider.provider_id)
)
# Clear cache
_, cache = create_trigger_provider_encrypter_for_credential(
_, cache = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
credential=db_provider,
subscription=db_provider,
)
session.delete(db_provider)
@ -246,20 +248,20 @@ class TriggerProviderService:
def refresh_oauth_token(
cls,
tenant_id: str,
credential_id: str,
subscription_id: str,
) -> dict:
"""
Refresh OAuth token for a trigger provider.
:param tenant_id: Tenant ID
:param credential_id: Credential instance ID
:param subscription_id: Subscription instance ID
:return: New token info
"""
with Session(db.engine) as session:
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
if not db_provider:
raise ValueError(f"Trigger provider credential {credential_id} not found")
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
if db_provider.credential_type != CredentialType.OAUTH2.value:
raise ValueError("Only OAuth credentials can be refreshed")
@ -267,10 +269,10 @@ class TriggerProviderService:
provider_id = TriggerProviderID(db_provider.provider_id)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
# Create encrypter
encrypter, cache = create_trigger_provider_encrypter_for_credential(
encrypter, cache = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
credential=db_provider,
subscription=db_provider,
)
# Decrypt current credentials
@ -295,7 +297,7 @@ class TriggerProviderService:
)
# Update credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(dict(refreshed_credentials.credentials)))
db_provider.credentials = encrypter.encrypt(dict(refreshed_credentials.credentials))
db_provider.expires_at = refreshed_credentials.expires_at
session.commit()
@ -518,13 +520,13 @@ class TriggerProviderService:
"""
try:
db_providers = (
session.query(TriggerProvider)
session.query(TriggerSubscription)
.filter_by(
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type.value,
)
.order_by(desc(TriggerProvider.created_at))
.order_by(desc(TriggerSubscription.created_at))
.all()
)