feat(trigger): introduce subscription builder and enhance trigger management

- Refactor trigger provider classes to improve naming consistency, including renaming classes for subscription management
- Implement new TriggerSubscriptionBuilderService for creating and verifying subscription builders
- Update API endpoints to support subscription builder creation and verification
- Enhance data models to include new attributes for subscription builders
- Remove the deprecated TriggerSubscriptionValidationService to streamline the codebase

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Harry
2025-09-02 12:06:27 +08:00
parent 694197a701
commit afd8989150
17 changed files with 544 additions and 476 deletions

View File

@ -1,6 +1,5 @@
import json
import logging
import re
from collections.abc import Mapping
from typing import Any, Optional
@ -16,7 +15,6 @@ from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.trigger.entities.api_entities import (
SubscriptionValidation,
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
)
@ -36,6 +34,9 @@ logger = logging.getLogger(__name__)
class TriggerProviderService:
"""Service for managing trigger providers and credentials"""
##########################
# Trigger provider
##########################
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
@classmethod
@ -73,10 +74,14 @@ class TriggerProviderService:
cls,
tenant_id: str,
user_id: str,
name: str,
provider_id: TriggerProviderID,
endpoint_id: str,
credential_type: CredentialType,
credentials: dict,
name: Optional[str] = None,
parameters: Mapping[str, Any],
properties: Mapping[str, Any],
credentials: Mapping[str, str],
credential_expires_at: int = -1,
expires_at: int = -1,
) -> dict:
"""
@ -93,7 +98,7 @@ class TriggerProviderService:
"""
try:
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
with Session(db.engine) as session:
with Session(db.engine, autoflush=False) as session:
# Use distributed lock to prevent race conditions
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
with redis_client.lock(lock_key, timeout=20):
@ -110,23 +115,14 @@ class TriggerProviderService:
f"reached for {provider_id}"
)
# Generate name if not provided
if not name:
name = cls._generate_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type,
)
else:
# Check if name already exists
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
.first()
)
if existing:
raise ValueError(f"Credential name '{name}' already exists for this provider")
# Check if name already exists
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
.first()
)
if existing:
raise ValueError(f"Credential name '{name}' already exists for this provider")
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
@ -138,10 +134,14 @@ class TriggerProviderService:
db_provider = TriggerSubscription(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
credential_type=credential_type.value,
credentials=encrypter.encrypt(credentials),
name=name,
endpoint_id=endpoint_id,
provider_id=provider_id,
parameters=parameters,
properties=properties,
credentials=encrypter.encrypt(dict(credentials)),
credential_type=credential_type.value,
credential_expires_at=credential_expires_at,
expires_at=expires_at,
)
@ -154,70 +154,6 @@ class TriggerProviderService:
logger.exception("Failed to add trigger provider")
raise ValueError(str(e))
@classmethod
def update_trigger_provider(
cls,
tenant_id: str,
subscription_id: str,
credentials: Optional[dict] = None,
name: Optional[str] = None,
) -> dict:
"""
Update an existing trigger provider's credentials or name.
:param tenant_id: Tenant ID
:param subscription_id: Subscription instance ID
:param credentials: New credentials (optional)
:param name: New name (optional)
:return: Success response
"""
with Session(db.engine) as session:
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
if not db_provider:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
try:
provider_controller = TriggerManager.get_trigger_provider(
tenant_id, TriggerProviderID(db_provider.provider_id)
)
if credentials:
encrypter, cache = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=db_provider,
)
# Handle hidden values
original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
db_provider.credentials = encrypter.encrypt(new_credentials)
cache.delete()
# Update name if provided
if name and name != db_provider.name:
# Check if name already exists
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=db_provider.provider_id, name=name)
.filter(TriggerSubscription.id != subscription_id)
.first()
)
if existing:
raise ValueError(f"Credential name '{name}' already exists")
db_provider.name = name
session.commit()
return {"result": "success"}
except Exception as e:
session.rollback()
raise ValueError(str(e))
@classmethod
def delete_trigger_provider(cls, tenant_id: str, subscription_id: str) -> dict:
"""
@ -505,59 +441,6 @@ class TriggerProviderService:
)
return custom_client is not None
@classmethod
def _generate_provider_name(
cls,
session: Session,
tenant_id: str,
provider_id: TriggerProviderID,
credential_type: CredentialType,
) -> str:
"""
Generate a unique name for a provider credential instance.
:param session: Database session
:param tenant_id: Tenant ID
:param provider: Provider identifier
:param credential_type: Credential type
:return: Generated name
"""
try:
db_providers = (
session.query(TriggerSubscription)
.filter_by(
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type.value,
)
.order_by(desc(TriggerSubscription.created_at))
.all()
)
# Get base name
base_name = credential_type.get_name()
# Find existing numbered names
pattern = rf"^{re.escape(base_name)}\s+(\d+)$"
numbers = []
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# Generate next number
if not numbers:
return f"{base_name} 1"
max_number = max(numbers)
return f"{base_name} {max_number + 1}"
except Exception as e:
logger.warning("Error generating provider name")
return f"{credential_type.get_name()} 1"
@classmethod
def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None:
"""
@ -566,15 +449,3 @@ class TriggerProviderService:
with Session(db.engine, autoflush=False) as session:
subscription = session.query(TriggerSubscription).filter_by(endpoint=endpoint_id).first()
return subscription
@classmethod
def get_subscription_validation(cls, endpoint_id: str) -> SubscriptionValidation | None:
"""
Get a trigger subscription by the endpoint ID.
"""
cache_key = f"trigger:subscription:validation:endpoint:{endpoint_id}"
subscription_cache = redis_client.get(cache_key)
if subscription_cache:
return SubscriptionValidation.model_validate(json.loads(subscription_cache))
return None

View File

@ -0,0 +1,240 @@
import json
import logging
import uuid
from collections.abc import Mapping
from flask import Request, Response
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import (
RequestLog,
SubscriptionBuilder,
)
from core.trigger.trigger_manager import TriggerManager
from extensions.ext_redis import redis_client
from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
class TriggerSubscriptionBuilderService:
"""Service for managing trigger providers and credentials"""
##########################
# Trigger provider
##########################
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
##########################
# Validation endpoint
##########################
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
__VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 30 * 60 * 1000
@classmethod
def encode_cache_key(cls, subscription_id: str) -> str:
return f"trigger:subscription:validation:{subscription_id}"
@classmethod
def verify_trigger_subscription_builder(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
subscription_builder_id: str,
) -> None:
"""Verify a trigger subscription builder"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
if not subscription_builder:
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
provider_controller.validate_credentials(user_id, subscription_builder.credentials)
@classmethod
def build_trigger_subscription_builder(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
subscription_builder_id: str,
) -> None:
"""Build a trigger subscription builder"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
if not subscription_builder:
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
if subscription_builder.name is None:
raise ValueError("Subscription builder name is required")
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value)
if credential_type == CredentialType.UNAUTHORIZED:
# manually create
TriggerProviderService.add_trigger_provider(
tenant_id=tenant_id,
user_id=user_id,
name=subscription_builder.name,
provider_id=provider_id,
endpoint_id=subscription_builder.endpoint_id,
parameters=subscription_builder.parameters,
properties=subscription_builder.properties,
credential_expires_at=subscription_builder.credential_expires_at or -1,
expires_at=subscription_builder.expires_at,
credentials=subscription_builder.credentials,
credential_type=credential_type,
)
else:
# automatically create
subscription = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=subscription_builder.endpoint_id,
parameters=subscription_builder.parameters,
credentials=subscription_builder.credentials,
)
TriggerProviderService.add_trigger_provider(
tenant_id=tenant_id,
user_id=user_id,
name=subscription_builder.name,
provider_id=provider_id,
endpoint_id=subscription_builder.endpoint_id,
parameters=subscription_builder.parameters,
properties=subscription.properties,
credentials=subscription_builder.credentials,
credential_type=credential_type,
credential_expires_at=subscription_builder.credential_expires_at or -1,
expires_at=subscription_builder.expires_at,
)
cls.delete_trigger_subscription_builder(subscription_builder_id)
@classmethod
def create_trigger_subscription_builder(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
credentials: Mapping[str, str],
credential_type: CredentialType,
credential_expires_at: int,
expires_at: int,
) -> SubscriptionBuilder:
"""
Add a new trigger subscription validation.
"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
subscription_schema = provider_controller.get_subscription_schema()
subscription_id = str(uuid.uuid4())
subscription_builder = SubscriptionBuilder(
id=subscription_id,
name="",
endpoint_id=subscription_id,
tenant_id=tenant_id,
user_id=user_id,
provider_id=str(provider_id),
parameters=subscription_schema.get_default_parameters(),
properties=subscription_schema.get_default_properties(),
credentials=credentials,
credential_type=credential_type,
credential_expires_at=credential_expires_at,
expires_at=expires_at,
)
cache_key = cls.encode_cache_key(subscription_id)
redis_client.setex(
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
)
return subscription_builder
@classmethod
def update_trigger_subscription_builder(
cls,
tenant_id: str,
subscription_builder: SubscriptionBuilder,
) -> SubscriptionBuilder:
"""
Update a trigger subscription validation.
"""
subscription_id = subscription_builder.id
cache_key = cls.encode_cache_key(subscription_id)
subscription_builder_cache = cls.get_subscription_builder(subscription_id)
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
raise ValueError(f"Subscription {subscription_id} not found")
redis_client.setex(
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
)
return subscription_builder
@classmethod
def delete_trigger_subscription_builder(cls, subscription_id: str) -> None:
"""
Delete a trigger subscription validation.
"""
cache_key = cls.encode_cache_key(subscription_id)
redis_client.delete(cache_key)
@classmethod
def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None:
"""
Get a trigger subscription by the endpoint ID.
"""
cache_key = cls.encode_cache_key(endpoint_id)
subscription_cache = redis_client.get(cache_key)
if subscription_cache:
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
return None
@classmethod
def append_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
"""
Append the validation request log to Redis.
"""
pass
@classmethod
def list_request_logs(cls, endpoint_id: str) -> list[RequestLog]:
"""
List the request logs for a validation endpoint.
"""
return []
@classmethod
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
"""
Process a temporary endpoint request.
:param endpoint_id: The endpoint identifier
:param request: The Flask request object
:return: The Flask response object
"""
# check if validation endpoint exists
subscription_builder = cls.get_subscription_builder(endpoint_id)
if not subscription_builder:
return None
# response to validation endpoint
controller = TriggerManager.get_trigger_provider(
subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id)
)
response = controller.dispatch(
user_id=subscription_builder.user_id,
request=request,
subscription=subscription_builder.to_subscription(),
)
# append the request log
cls.append_request_log(endpoint_id, request, response.response)
return response.response

View File

@ -1,48 +0,0 @@
import logging
from flask import Request, Response
from core.plugin.entities.plugin import TriggerProviderID
from core.trigger.trigger_manager import TriggerManager
from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
class TriggerSubscriptionValidationService:
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
__VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
@classmethod
def append_validation_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
"""
Append the validation request log to Redis.
"""
@classmethod
def process_validating_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
"""
Process a temporary endpoint request.
:param endpoint_id: The endpoint identifier
:param request: The Flask request object
:return: The Flask response object
"""
# check if validation endpoint exists
subscription_validation = TriggerProviderService.get_subscription_validation(endpoint_id)
if not subscription_validation:
return None
# response to validation endpoint
controller = TriggerManager.get_trigger_provider(
subscription_validation.tenant_id, TriggerProviderID(subscription_validation.provider_id)
)
response = controller.dispatch(
user_id=subscription_validation.user_id,
request=request,
subscription=subscription_validation.to_subscription(),
)
# append the request log
cls.append_validation_request_log(endpoint_id, request, response.response)
return response.response

View File

@ -1,14 +1,11 @@
import json
import logging
import time
import uuid
from typing import Any, Optional
from flask import Request, Response
from core.plugin.entities.plugin import TriggerProviderID
from core.trigger.entities.entities import TriggerEntity
from core.trigger.trigger_manager import TriggerManager
from extensions.ext_redis import redis_client
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
@ -18,34 +15,27 @@ class TriggerService:
__TEMPORARY_ENDPOINT_EXPIRE_MS__ = 5 * 60 * 1000
__ENDPOINT_REQUEST_CACHE_COUNT__ = 10
__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
# Lua script for atomic write with time & count based cleanup
__LUA_SCRIPT__ = """
-- KEYS[1] = zset key
-- ARGV[1] = max_count (maximum number of entries to keep)
-- ARGV[2] = min_ts_ms (minimum timestamp to keep = now_ms - ttl_ms)
-- ARGV[3] = now_ms (current timestamp in milliseconds)
-- ARGV[4] = member (log entry JSON)
local key = KEYS[1]
local maxCount = tonumber(ARGV[1])
local minTs = tonumber(ARGV[2])
local nowMs = tonumber(ARGV[3])
local member = ARGV[4]
@classmethod
def process_triggered_workflows(cls, subscription: TriggerSubscription, trigger: TriggerEntity, request: Request) -> None:
"""Process triggered workflows."""
-- 1) Add new entry with timestamp as score
redis.call('ZADD', key, nowMs, member)
-- 2) Remove entries older than minTs (time-based cleanup)
redis.call('ZREMRANGEBYSCORE', key, '-inf', minTs)
-- 3) Remove oldest entries if count exceeds maxCount (count-based cleanup)
local n = redis.call('ZCARD', key)
if n > maxCount then
redis.call('ZREMRANGEBYRANK', key, 0, n - maxCount - 1) -- 0 is oldest
end
return n
"""
@classmethod
def select_triggers(cls, controller, dispatch_response, provider_id, subscription) -> list[TriggerEntity]:
triggers = []
for trigger_name in dispatch_response.triggers:
trigger = controller.get_trigger(trigger_name)
if trigger is None:
logger.error(
"Trigger '%s' not found in provider '%s' for tenant '%s'",
trigger_name,
provider_id,
subscription.tenant_id,
)
raise ValueError(f"Trigger '{trigger_name}' not found")
triggers.append(trigger)
return triggers
@classmethod
def process_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
@ -53,140 +43,23 @@ class TriggerService:
subscription = TriggerProviderService.get_subscription_by_endpoint(endpoint_id)
if not subscription:
return None
provider_id = TriggerProviderID(subscription.provider_id)
controller = TriggerManager.get_trigger_provider(subscription.tenant_id, provider_id)
if not controller:
return None
dispatch_response = controller.dispatch(
user_id=subscription.user_id, request=request, subscription=subscription.to_entity()
)
# TODO invoke triggers
# dispatch_response.triggers
if dispatch_response.triggers:
triggers = cls.select_triggers(controller, dispatch_response, provider_id, subscription)
for trigger in triggers:
cls.process_triggered_workflows(
subscription=subscription,
trigger=trigger,
request=request,
)
return dispatch_response.response
@classmethod
def log_endpoint_request(cls, endpoint_id: str, request: Request) -> int:
"""
Log the endpoint request to Redis using ZSET for rolling log with time & count based retention.
Args:
endpoint_id: The endpoint identifier
request: The Flask request object
Returns:
The current number of logged requests for this endpoint
"""
try:
# Prepare timestamp
now_ms = int(time.time() * 1000)
min_ts = now_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__
# Extract request data
request_data = {
"id": str(uuid.uuid4()),
"timestamp": now_ms,
"method": request.method,
"path": request.path,
"headers": dict(request.headers),
"query_params": request.args.to_dict(flat=False) if request.args else {},
"body": None,
"remote_addr": request.remote_addr,
}
# Try to get request body if it exists
if request.is_json:
try:
request_data["body"] = request.get_json(force=True)
except Exception:
request_data["body"] = request.get_data(as_text=True)
elif request.data:
request_data["body"] = request.get_data(as_text=True)
# Serialize to JSON
member = json.dumps(request_data, separators=(",", ":"))
# Execute Lua script atomically
key = f"trigger:endpoint_requests:{endpoint_id}"
count = redis_client.eval(
cls.__LUA_SCRIPT__,
1, # number of keys
key, # KEYS[1]
str(cls.__ENDPOINT_REQUEST_CACHE_COUNT__), # ARGV[1] - max count
str(min_ts), # ARGV[2] - minimum timestamp
str(now_ms), # ARGV[3] - current timestamp
member, # ARGV[4] - log entry
)
logger.debug("Logged request for endpoint %s, current count: %s", endpoint_id, count)
return count
except Exception as e:
logger.exception("Failed to log endpoint request for %s", endpoint_id, exc_info=e)
# Don't fail the main request processing if logging fails
return 0
@classmethod
def get_recent_endpoint_requests(
cls, endpoint_id: str, limit: int = 100, start_time_ms: Optional[int] = None, end_time_ms: Optional[int] = None
) -> list[dict[str, Any]]:
"""
Retrieve recent logged requests for an endpoint.
Args:
endpoint_id: The endpoint identifier
limit: Maximum number of entries to return
start_time_ms: Start timestamp in milliseconds (optional)
end_time_ms: End timestamp in milliseconds (optional, defaults to now)
Returns:
List of request log entries, newest first
"""
try:
key = f"trigger:endpoint_requests:{endpoint_id}"
# Set time bounds
if end_time_ms is None:
end_time_ms = int(time.time() * 1000)
if start_time_ms is None:
start_time_ms = end_time_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__
# Get entries in reverse order (newest first)
entries = redis_client.zrevrangebyscore(key, max=end_time_ms, min=start_time_ms, start=0, num=limit)
# Parse JSON entries
requests = []
for entry in entries:
try:
requests.append(json.loads(entry))
except json.JSONDecodeError:
logger.warning("Failed to parse log entry: %s", entry)
return requests
except Exception as e:
logger.exception("Failed to retrieve endpoint requests for %s", endpoint_id, exc_info=e)
return []
@classmethod
def clear_endpoint_requests(cls, endpoint_id: str) -> bool:
"""
Clear all logged requests for an endpoint.
Args:
endpoint_id: The endpoint identifier
Returns:
True if successful, False otherwise
"""
try:
key = f"trigger:endpoint_requests:{endpoint_id}"
redis_client.delete(key)
logger.info("Cleared request logs for endpoint %s", endpoint_id)
return True
except Exception as e:
logger.exception("Failed to clear endpoint requests for %s", endpoint_id, exc_info=e)
return False