mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat(trigger): introduce subscription builder and enhance trigger management
- Refactor trigger provider classes to improve naming consistency, including renaming classes for subscription management - Implement new TriggerSubscriptionBuilderService for creating and verifying subscription builders - Update API endpoints to support subscription builder creation and verification - Enhance data models to include new attributes for subscription builders - Remove the deprecated TriggerSubscriptionValidationService to streamline the codebase Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -16,7 +15,6 @@ from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.trigger.entities.api_entities import (
|
||||
SubscriptionValidation,
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
)
|
||||
@ -36,6 +34,9 @@ logger = logging.getLogger(__name__)
|
||||
class TriggerProviderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
##########################
|
||||
# Trigger provider
|
||||
##########################
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||
|
||||
@classmethod
|
||||
@ -73,10 +74,14 @@ class TriggerProviderService:
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
provider_id: TriggerProviderID,
|
||||
endpoint_id: str,
|
||||
credential_type: CredentialType,
|
||||
credentials: dict,
|
||||
name: Optional[str] = None,
|
||||
parameters: Mapping[str, Any],
|
||||
properties: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
credential_expires_at: int = -1,
|
||||
expires_at: int = -1,
|
||||
) -> dict:
|
||||
"""
|
||||
@ -93,7 +98,7 @@ class TriggerProviderService:
|
||||
"""
|
||||
try:
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
# Use distributed lock to prevent race conditions
|
||||
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
@ -110,23 +115,14 @@ class TriggerProviderService:
|
||||
f"reached for {provider_id}"
|
||||
)
|
||||
|
||||
# Generate name if not provided
|
||||
if not name:
|
||||
name = cls._generate_provider_name(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
@ -138,10 +134,14 @@ class TriggerProviderService:
|
||||
db_provider = TriggerSubscription(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type.value,
|
||||
credentials=encrypter.encrypt(credentials),
|
||||
name=name,
|
||||
endpoint_id=endpoint_id,
|
||||
provider_id=provider_id,
|
||||
parameters=parameters,
|
||||
properties=properties,
|
||||
credentials=encrypter.encrypt(dict(credentials)),
|
||||
credential_type=credential_type.value,
|
||||
credential_expires_at=credential_expires_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
@ -154,70 +154,6 @@ class TriggerProviderService:
|
||||
logger.exception("Failed to add trigger provider")
|
||||
raise ValueError(str(e))
|
||||
|
||||
@classmethod
|
||||
def update_trigger_provider(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
credentials: Optional[dict] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Update an existing trigger provider's credentials or name.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param subscription_id: Subscription instance ID
|
||||
:param credentials: New credentials (optional)
|
||||
:param name: New name (optional)
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
if not db_provider:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
try:
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(db_provider.provider_id)
|
||||
)
|
||||
|
||||
if credentials:
|
||||
encrypter, cache = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=db_provider,
|
||||
)
|
||||
# Handle hidden values
|
||||
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||
new_credentials = {
|
||||
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
|
||||
db_provider.credentials = encrypter.encrypt(new_credentials)
|
||||
cache.delete()
|
||||
|
||||
# Update name if provided
|
||||
if name and name != db_provider.name:
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=db_provider.provider_id, name=name)
|
||||
.filter(TriggerSubscription.id != subscription_id)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists")
|
||||
|
||||
db_provider.name = name
|
||||
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
||||
@classmethod
|
||||
def delete_trigger_provider(cls, tenant_id: str, subscription_id: str) -> dict:
|
||||
"""
|
||||
@ -505,59 +441,6 @@ class TriggerProviderService:
|
||||
)
|
||||
return custom_client is not None
|
||||
|
||||
@classmethod
|
||||
def _generate_provider_name(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
credential_type: CredentialType,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique name for a provider credential instance.
|
||||
|
||||
:param session: Database session
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider: Provider identifier
|
||||
:param credential_type: Credential type
|
||||
:return: Generated name
|
||||
"""
|
||||
try:
|
||||
db_providers = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type.value,
|
||||
)
|
||||
.order_by(desc(TriggerSubscription.created_at))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get base name
|
||||
base_name = credential_type.get_name()
|
||||
|
||||
# Find existing numbered names
|
||||
pattern = rf"^{re.escape(base_name)}\s+(\d+)$"
|
||||
numbers = []
|
||||
|
||||
for db_provider in db_providers:
|
||||
if db_provider.name:
|
||||
match = re.match(pattern, db_provider.name.strip())
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
|
||||
# Generate next number
|
||||
if not numbers:
|
||||
return f"{base_name} 1"
|
||||
|
||||
max_number = max(numbers)
|
||||
return f"{base_name} {max_number + 1}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error generating provider name")
|
||||
return f"{credential_type.get_name()} 1"
|
||||
|
||||
@classmethod
|
||||
def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None:
|
||||
"""
|
||||
@ -566,15 +449,3 @@ class TriggerProviderService:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(endpoint=endpoint_id).first()
|
||||
return subscription
|
||||
|
||||
@classmethod
|
||||
def get_subscription_validation(cls, endpoint_id: str) -> SubscriptionValidation | None:
|
||||
"""
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
cache_key = f"trigger:subscription:validation:endpoint:{endpoint_id}"
|
||||
subscription_cache = redis_client.get(cache_key)
|
||||
if subscription_cache:
|
||||
return SubscriptionValidation.model_validate(json.loads(subscription_cache))
|
||||
|
||||
return None
|
||||
240
api/services/trigger/trigger_subscription_builder_service.py
Normal file
240
api/services/trigger/trigger_subscription_builder_service.py
Normal file
@ -0,0 +1,240 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import (
|
||||
RequestLog,
|
||||
SubscriptionBuilder,
|
||||
)
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
##########################
|
||||
# Trigger provider
|
||||
##########################
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||
|
||||
##########################
|
||||
# Validation endpoint
|
||||
##########################
|
||||
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
|
||||
__VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 30 * 60 * 1000
|
||||
|
||||
@classmethod
|
||||
def encode_cache_key(cls, subscription_id: str) -> str:
|
||||
return f"trigger:subscription:validation:{subscription_id}"
|
||||
|
||||
@classmethod
|
||||
def verify_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
) -> None:
|
||||
"""Verify a trigger subscription builder"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
provider_controller.validate_credentials(user_id, subscription_builder.credentials)
|
||||
|
||||
@classmethod
|
||||
def build_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
) -> None:
|
||||
"""Build a trigger subscription builder"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if subscription_builder.name is None:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_provider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription_builder.properties,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# automatically create
|
||||
subscription = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
credentials=subscription_builder.credentials,
|
||||
)
|
||||
|
||||
TriggerProviderService.add_trigger_provider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription.properties,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
)
|
||||
|
||||
cls.delete_trigger_subscription_builder(subscription_builder_id)
|
||||
|
||||
@classmethod
|
||||
def create_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
credential_expires_at: int,
|
||||
expires_at: int,
|
||||
) -> SubscriptionBuilder:
|
||||
"""
|
||||
Add a new trigger subscription validation.
|
||||
"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_schema = provider_controller.get_subscription_schema()
|
||||
subscription_id = str(uuid.uuid4())
|
||||
subscription_builder = SubscriptionBuilder(
|
||||
id=subscription_id,
|
||||
name="",
|
||||
endpoint_id=subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=subscription_schema.get_default_parameters(),
|
||||
properties=subscription_schema.get_default_properties(),
|
||||
credentials=credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=credential_expires_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
redis_client.setex(
|
||||
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
|
||||
)
|
||||
return subscription_builder
|
||||
|
||||
@classmethod
|
||||
def update_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_builder: SubscriptionBuilder,
|
||||
) -> SubscriptionBuilder:
|
||||
"""
|
||||
Update a trigger subscription validation.
|
||||
"""
|
||||
subscription_id = subscription_builder.id
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
|
||||
redis_client.setex(
|
||||
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
|
||||
)
|
||||
return subscription_builder
|
||||
|
||||
@classmethod
|
||||
def delete_trigger_subscription_builder(cls, subscription_id: str) -> None:
|
||||
"""
|
||||
Delete a trigger subscription validation.
|
||||
"""
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None:
|
||||
"""
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
cache_key = cls.encode_cache_key(endpoint_id)
|
||||
subscription_cache = redis_client.get(cache_key)
|
||||
if subscription_cache:
|
||||
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def append_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
|
||||
"""
|
||||
Append the validation request log to Redis.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def list_request_logs(cls, endpoint_id: str) -> list[RequestLog]:
|
||||
"""
|
||||
List the request logs for a validation endpoint.
|
||||
"""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
"""
|
||||
Process a temporary endpoint request.
|
||||
|
||||
:param endpoint_id: The endpoint identifier
|
||||
:param request: The Flask request object
|
||||
:return: The Flask response object
|
||||
"""
|
||||
# check if validation endpoint exists
|
||||
subscription_builder = cls.get_subscription_builder(endpoint_id)
|
||||
if not subscription_builder:
|
||||
return None
|
||||
|
||||
# response to validation endpoint
|
||||
controller = TriggerManager.get_trigger_provider(
|
||||
subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id)
|
||||
)
|
||||
response = controller.dispatch(
|
||||
user_id=subscription_builder.user_id,
|
||||
request=request,
|
||||
subscription=subscription_builder.to_subscription(),
|
||||
)
|
||||
# append the request log
|
||||
cls.append_request_log(endpoint_id, request, response.response)
|
||||
return response.response
|
||||
@ -1,48 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerSubscriptionValidationService:
|
||||
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
|
||||
__VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
|
||||
@classmethod
|
||||
def append_validation_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
|
||||
"""
|
||||
Append the validation request log to Redis.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def process_validating_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
"""
|
||||
Process a temporary endpoint request.
|
||||
|
||||
:param endpoint_id: The endpoint identifier
|
||||
:param request: The Flask request object
|
||||
:return: The Flask response object
|
||||
"""
|
||||
# check if validation endpoint exists
|
||||
subscription_validation = TriggerProviderService.get_subscription_validation(endpoint_id)
|
||||
if not subscription_validation:
|
||||
return None
|
||||
|
||||
# response to validation endpoint
|
||||
controller = TriggerManager.get_trigger_provider(
|
||||
subscription_validation.tenant_id, TriggerProviderID(subscription_validation.provider_id)
|
||||
)
|
||||
response = controller.dispatch(
|
||||
user_id=subscription_validation.user_id,
|
||||
request=request,
|
||||
subscription=subscription_validation.to_subscription(),
|
||||
)
|
||||
# append the request log
|
||||
cls.append_validation_request_log(endpoint_id, request, response.response)
|
||||
return response.response
|
||||
Reference in New Issue
Block a user