mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
feat(trigger): refactor trigger provider to subscription model
- Rename classes and methods to reflect the transition from credentials to subscriptions - Update API endpoints for managing trigger subscriptions - Modify data models and entities to support subscription attributes - Enhance service methods for listing, adding, updating, and deleting subscriptions - Adjust encryption utilities to handle subscription data Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@ -15,15 +15,15 @@ from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderCredentialApiEntity
|
||||
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import (
|
||||
create_trigger_provider_encrypter_for_credential,
|
||||
create_trigger_provider_encrypter_for_subscription,
|
||||
create_trigger_provider_oauth_encrypter,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerProvider
|
||||
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerSubscription
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -40,29 +40,29 @@ class TriggerProviderService:
|
||||
return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)]
|
||||
|
||||
@classmethod
|
||||
def list_trigger_provider_credentials(
|
||||
def list_trigger_provider_subscriptions(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID
|
||||
) -> list[TriggerProviderCredentialApiEntity]:
|
||||
"""List all trigger providers for the current tenant"""
|
||||
credentials: list[TriggerProviderCredentialApiEntity] = []
|
||||
) -> list[TriggerProviderSubscriptionApiEntity]:
|
||||
"""List all trigger subscriptions for the current tenant"""
|
||||
subscriptions: list[TriggerProviderSubscriptionApiEntity] = []
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
credentials_db = (
|
||||
session.query(TriggerProvider)
|
||||
subscriptions_db = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
.order_by(desc(TriggerProvider.created_at))
|
||||
.order_by(desc(TriggerSubscription.created_at))
|
||||
.all()
|
||||
)
|
||||
credentials = [credential.to_api_entity() for credential in credentials_db]
|
||||
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
|
||||
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
for credential in credentials:
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_credential(
|
||||
for subscription in subscriptions:
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
credential=credential,
|
||||
subscription=subscription,
|
||||
)
|
||||
credential.credentials = encrypter.decrypt(credential.credentials)
|
||||
return credentials
|
||||
subscription.credentials = encrypter.decrypt(subscription.credentials)
|
||||
return subscriptions
|
||||
|
||||
@classmethod
|
||||
def add_trigger_provider(
|
||||
@ -95,7 +95,9 @@ class TriggerProviderService:
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
# Check provider count limit
|
||||
provider_count = (
|
||||
session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).count()
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=provider_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
|
||||
@ -115,7 +117,7 @@ class TriggerProviderService:
|
||||
else:
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerProvider)
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
|
||||
.first()
|
||||
)
|
||||
@ -129,12 +131,12 @@ class TriggerProviderService:
|
||||
)
|
||||
|
||||
# Create provider record
|
||||
db_provider = TriggerProvider(
|
||||
db_provider = TriggerSubscription(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type.value,
|
||||
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||
credentials=encrypter.encrypt(credentials),
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
@ -152,7 +154,7 @@ class TriggerProviderService:
|
||||
def update_trigger_provider(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credential_id: str,
|
||||
subscription_id: str,
|
||||
credentials: Optional[dict] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> dict:
|
||||
@ -160,15 +162,15 @@ class TriggerProviderService:
|
||||
Update an existing trigger provider's credentials or name.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param credential_id: Credential instance ID
|
||||
:param subscription_id: Subscription instance ID
|
||||
:param credentials: New credentials (optional)
|
||||
:param name: New name (optional)
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
if not db_provider:
|
||||
raise ValueError(f"Trigger provider credential {credential_id} not found")
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
try:
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
@ -176,10 +178,10 @@ class TriggerProviderService:
|
||||
)
|
||||
|
||||
if credentials:
|
||||
encrypter, cache = create_trigger_provider_encrypter_for_credential(
|
||||
encrypter, cache = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
credential=db_provider,
|
||||
subscription=db_provider,
|
||||
)
|
||||
# Handle hidden values
|
||||
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||
@ -188,16 +190,16 @@ class TriggerProviderService:
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
|
||||
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
|
||||
db_provider.credentials = encrypter.encrypt(new_credentials)
|
||||
cache.delete()
|
||||
|
||||
# Update name if provided
|
||||
if name and name != db_provider.name:
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerProvider)
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=db_provider.provider_id, name=name)
|
||||
.filter(TriggerProvider.id != credential_id)
|
||||
.filter(TriggerSubscription.id != subscription_id)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
@ -213,27 +215,27 @@ class TriggerProviderService:
|
||||
raise ValueError(str(e))
|
||||
|
||||
@classmethod
|
||||
def delete_trigger_provider(cls, tenant_id: str, credential_id: str) -> dict:
|
||||
def delete_trigger_provider(cls, tenant_id: str, subscription_id: str) -> dict:
|
||||
"""
|
||||
Delete a trigger provider credential.
|
||||
Delete a trigger provider subscription.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param credential_id: Credential instance ID
|
||||
:param subscription_id: Subscription instance ID
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
if not db_provider:
|
||||
raise ValueError(f"Trigger provider credential {credential_id} not found")
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(db_provider.provider_id)
|
||||
)
|
||||
# Clear cache
|
||||
_, cache = create_trigger_provider_encrypter_for_credential(
|
||||
_, cache = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
credential=db_provider,
|
||||
subscription=db_provider,
|
||||
)
|
||||
|
||||
session.delete(db_provider)
|
||||
@ -246,20 +248,20 @@ class TriggerProviderService:
|
||||
def refresh_oauth_token(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credential_id: str,
|
||||
subscription_id: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Refresh OAuth token for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param credential_id: Credential instance ID
|
||||
:param subscription_id: Subscription instance ID
|
||||
:return: New token info
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
|
||||
if not db_provider:
|
||||
raise ValueError(f"Trigger provider credential {credential_id} not found")
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
if db_provider.credential_type != CredentialType.OAUTH2.value:
|
||||
raise ValueError("Only OAuth credentials can be refreshed")
|
||||
@ -267,10 +269,10 @@ class TriggerProviderService:
|
||||
provider_id = TriggerProviderID(db_provider.provider_id)
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
# Create encrypter
|
||||
encrypter, cache = create_trigger_provider_encrypter_for_credential(
|
||||
encrypter, cache = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
credential=db_provider,
|
||||
subscription=db_provider,
|
||||
)
|
||||
|
||||
# Decrypt current credentials
|
||||
@ -295,7 +297,7 @@ class TriggerProviderService:
|
||||
)
|
||||
|
||||
# Update credentials
|
||||
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(dict(refreshed_credentials.credentials)))
|
||||
db_provider.credentials = encrypter.encrypt(dict(refreshed_credentials.credentials))
|
||||
db_provider.expires_at = refreshed_credentials.expires_at
|
||||
session.commit()
|
||||
|
||||
@ -518,13 +520,13 @@ class TriggerProviderService:
|
||||
"""
|
||||
try:
|
||||
db_providers = (
|
||||
session.query(TriggerProvider)
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type.value,
|
||||
)
|
||||
.order_by(desc(TriggerProvider.created_at))
|
||||
.order_by(desc(TriggerSubscription.created_at))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user