feat: introduce trigger functionality (#27644)

Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: Stream <Stream_2@qq.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
Co-authored-by: Harry <xh001x@hotmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: yessenia <yessenia.contact@gmail.com>
Co-authored-by: hjlarry <hjlarry@163.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WTW0313 <twwu@dify.ai>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Yeuoly
2025-11-12 17:59:37 +08:00
committed by GitHub
parent ca7794305b
commit b76e17b25d
785 changed files with 41186 additions and 3725 deletions

View File

@ -26,6 +26,7 @@ from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode
from events.app_event import app_model_config_was_updated, app_was_created
from extensions.ext_redis import redis_client
from factories import variable_factory
@ -43,7 +44,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.4.0"
CURRENT_DSL_VERSION = "0.5.0"
class ImportMode(StrEnum):
@ -599,6 +600,16 @@ class AppDslService:
if not include_secret and data_type == NodeType.AGENT:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
if data_type == NodeType.TRIGGER_SCHEDULE.value:
# override the config with the default config
node_data["config"] = TriggerScheduleNode.get_default_config()["config"]
if data_type == NodeType.TRIGGER_WEBHOOK.value:
# clear the webhook_url
node_data["webhook_url"] = ""
node_data["webhook_debug_url"] = ""
if data_type == NodeType.TRIGGER_PLUGIN.value:
# clear the subscription_id
node_data["subscription_id"] = ""
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)

View File

@ -31,6 +31,7 @@ class AppGenerateService:
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
root_node_id: str | None = None,
):
"""
App Content Generate
@ -114,6 +115,7 @@ class AppGenerateService:
args=args,
invoke_from=invoke_from,
streaming=streaming,
root_node_id=root_node_id,
call_depth=0,
),
),

View File

@ -0,0 +1,323 @@
"""
Universal async workflow execution service.
This service provides a centralized entry point for triggering workflows asynchronously
with support for different subscription tiers, rate limiting, and execution tracking.
"""
import json
from datetime import UTC, datetime
from typing import Any, Union
from celery.result import AsyncResult
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.model import App, EndUser
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.errors.app import InvokeDailyRateLimitError, WorkflowNotFoundError
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow.rate_limiter import TenantDailyRateLimiter
from services.workflow_service import WorkflowService
from tasks.async_workflow_tasks import (
execute_workflow_professional,
execute_workflow_sandbox,
execute_workflow_team,
)
class AsyncWorkflowService:
"""
Universal entry point for async workflow execution - ALL METHODS ARE NON-BLOCKING
This service handles:
- Trigger data validation and processing
- Queue routing based on subscription tier
- Daily rate limiting with timezone support
- Execution tracking and logging
- Retry mechanisms for failed executions
Important: All trigger methods return immediately after queuing tasks.
Actual workflow execution happens asynchronously in background Celery workers.
Use trigger log IDs to monitor execution status and results.
"""
@classmethod
def trigger_workflow_async(
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
) -> AsyncTriggerResponse:
"""
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
Creates a trigger log and dispatches to appropriate queue based on subscription tier.
The workflow execution happens asynchronously in the background via Celery workers.
This method returns immediately after queuing the task, not after execution completion.
Args:
session: Database session to use for operations
user: User (Account or EndUser) who initiated the workflow trigger
trigger_data: Validated Pydantic model containing trigger information
Returns:
AsyncTriggerResponse with workflow_trigger_log_id, task_id, status="queued", and queue
Note: The actual workflow execution status must be checked separately via workflow_trigger_log_id
Raises:
WorkflowNotFoundError: If app or workflow not found
InvokeDailyRateLimitError: If daily rate limit exceeded
Behavior:
- Non-blocking: Returns immediately after queuing
- Asynchronous: Actual execution happens in background Celery workers
- Status tracking: Use workflow_trigger_log_id to monitor progress
- Queue-based: Routes to different queues based on subscription tier
"""
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
dispatcher_manager = QueueDispatcherManager()
workflow_service = WorkflowService()
rate_limiter = TenantDailyRateLimiter(redis_client)
# 1. Validate app exists
app_model = session.scalar(select(App).where(App.id == trigger_data.app_id))
if not app_model:
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
# 2. Get workflow
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
# 3. Get dispatcher based on tenant subscription
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
# 4. Rate limiting check will be done without timezone first
# 5. Determine user role and ID
if isinstance(user, Account):
created_by_role = CreatorUserRole.ACCOUNT
created_by = user.id
else: # EndUser
created_by_role = CreatorUserRole.END_USER
created_by = user.id
# 6. Create trigger log entry first (for tracking)
trigger_log = WorkflowTriggerLog(
tenant_id=trigger_data.tenant_id,
app_id=trigger_data.app_id,
workflow_id=workflow.id,
root_node_id=trigger_data.root_node_id,
trigger_metadata=(
trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}"
),
trigger_type=trigger_data.trigger_type,
trigger_data=trigger_data.model_dump_json(),
inputs=json.dumps(dict(trigger_data.inputs)),
status=WorkflowTriggerStatus.PENDING,
queue_name=dispatcher.get_queue_name(),
retry_count=0,
created_by_role=created_by_role,
created_by=created_by,
)
trigger_log = trigger_log_repo.create(trigger_log)
session.commit()
# 7. Check and consume daily quota
if not dispatcher.consume_quota(trigger_data.tenant_id):
# Update trigger log status
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
trigger_log.error = f"Daily limit reached for {dispatcher.get_queue_name()}"
trigger_log_repo.update(trigger_log)
session.commit()
tenant_owner_tz = rate_limiter.get_tenant_owner_timezone(trigger_data.tenant_id)
remaining = rate_limiter.get_remaining_quota(trigger_data.tenant_id, dispatcher.get_daily_limit())
reset_time = rate_limiter.get_quota_reset_time(trigger_data.tenant_id, tenant_owner_tz)
raise InvokeDailyRateLimitError(
f"Daily workflow execution limit reached. "
f"Limit resets at {reset_time.strftime('%Y-%m-%d %H:%M:%S %Z')}. "
f"Remaining quota: {remaining}"
)
# 8. Create task data
queue_name = dispatcher.get_queue_name()
task_data = WorkflowTaskData(workflow_trigger_log_id=trigger_log.id)
# 9. Dispatch to appropriate queue
task_data_dict = task_data.model_dump(mode="json")
task: AsyncResult[Any] | None = None
if queue_name == QueuePriority.PROFESSIONAL:
task = execute_workflow_professional.delay(task_data_dict) # type: ignore
elif queue_name == QueuePriority.TEAM:
task = execute_workflow_team.delay(task_data_dict) # type: ignore
else: # SANDBOX
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
# 10. Update trigger log with task info
trigger_log.status = WorkflowTriggerStatus.QUEUED
trigger_log.celery_task_id = task.id
trigger_log.triggered_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
return AsyncTriggerResponse(
workflow_trigger_log_id=trigger_log.id,
task_id=task.id, # type: ignore
status="queued",
queue=queue_name,
)
@classmethod
def reinvoke_trigger(
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
) -> AsyncTriggerResponse:
"""
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
Updates the existing trigger log to retry status and creates a new async execution.
Returns immediately after queuing the retry, not after execution completion.
Args:
session: Database session to use for operations
user: User (Account or EndUser) who initiated the retry
workflow_trigger_log_id: ID of the trigger log to re-invoke
Returns:
AsyncTriggerResponse with new execution information (status="queued")
Note: This creates a new trigger log entry for the retry attempt
Raises:
ValueError: If trigger log not found
Behavior:
- Non-blocking: Returns immediately after queuing retry
- Creates new trigger log: Original log marked as retrying, new log for execution
- Preserves original trigger data: Uses same inputs and configuration
"""
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id)
if not trigger_log:
raise ValueError(f"Trigger log not found: {workflow_trigger_log_id}")
# Reconstruct trigger data from log
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
# Reset log for retry
trigger_log.status = WorkflowTriggerStatus.RETRYING
trigger_log.retry_count += 1
trigger_log.error = None
trigger_log.triggered_at = datetime.now(UTC)
trigger_log_repo.update(trigger_log)
session.commit()
# Re-trigger workflow (this will create a new trigger log)
return cls.trigger_workflow_async(session, user, trigger_data)
@classmethod
def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: str | None = None) -> dict[str, Any] | None:
"""
Get trigger log by ID
Args:
workflow_trigger_log_id: ID of the trigger log
tenant_id: Optional tenant ID for security check
Returns:
Trigger log as dictionary or None if not found
"""
with Session(db.engine) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
if not trigger_log:
return None
return trigger_log.to_dict()
@classmethod
def get_recent_logs(
cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
) -> list[dict[str, Any]]:
"""
Get recent trigger logs
Args:
tenant_id: Tenant ID
app_id: Application ID
hours: Number of hours to look back
limit: Maximum number of results
offset: Number of results to skip
Returns:
List of trigger logs as dictionaries
"""
with Session(db.engine) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_recent_logs(
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
)
return [log.to_dict() for log in logs]
@classmethod
def get_failed_logs_for_retry(
cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100
) -> list[dict[str, Any]]:
"""
Get failed logs eligible for retry
Args:
tenant_id: Tenant ID
max_retry_count: Maximum retry count
limit: Maximum number of results
Returns:
List of failed trigger logs as dictionaries
"""
with Session(db.engine) as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_failed_for_retry(
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit
)
return [log.to_dict() for log in logs]
@staticmethod
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
"""
Get workflow for the app
Args:
app_model: App model instance
workflow_id: Optional specific workflow ID
Returns:
Workflow instance
Raises:
WorkflowNotFoundError: If workflow not found
"""
if workflow_id:
# Get specific published workflow
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
if not workflow:
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
else:
# Get default published workflow
workflow = workflow_service.get_published_workflow(app_model)
if not workflow:
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")
return workflow

View File

@ -11,9 +11,9 @@ from core.helper import encrypter
from core.helper.name_generator import generate_incremental_name
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.model_runtime.entities.provider_entities import FormType
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -338,7 +338,7 @@ class DatasourceProviderService:
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
tenant_oauth_client_params.client_params = dict(encrypter.encrypt(new_params))
if enabled is not None:
tenant_oauth_client_params.enabled = enabled
@ -374,7 +374,7 @@ class DatasourceProviderService:
def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
) -> dict[str, Any] | None:
) -> Mapping[str, Any] | None:
"""
get tenant oauth client
"""
@ -390,7 +390,7 @@ class DatasourceProviderService:
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
if mask:
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
return encrypter.mask_plugin_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
else:
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return None
@ -434,7 +434,7 @@ class DatasourceProviderService:
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return dict(encrypter.decrypt(tenant_oauth_client_params.client_params))
provider_controller = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)

View File

@ -0,0 +1,141 @@
from collections.abc import Mapping
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.model import App, DefaultEndUserSessionID, EndUser
class EndUserService:
"""
Service for managing end users.
"""
@classmethod
def get_or_create_end_user(cls, app_model: App, user_id: str | None = None) -> EndUser:
"""
Get or create an end user for a given app.
"""
return cls.get_or_create_end_user_by_type(InvokeFrom.SERVICE_API, app_model.tenant_id, app_model.id, user_id)
@classmethod
def get_or_create_end_user_by_type(
cls, type: InvokeFrom, tenant_id: str, app_id: str, user_id: str | None = None
) -> EndUser:
"""
Get or create an end user for a given app and type.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
session.query(EndUser)
.where(
EndUser.tenant_id == tenant_id,
EndUser.app_id == app_id,
EndUser.session_id == user_id,
EndUser.type == type,
)
.first()
)
if end_user is None:
end_user = EndUser(
tenant_id=tenant_id,
app_id=app_id,
type=type,
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
session_id=user_id,
external_user_id=user_id,
)
session.add(end_user)
session.commit()
return end_user
@classmethod
def create_end_user_batch(
cls, type: InvokeFrom, tenant_id: str, app_ids: list[str], user_id: str
) -> Mapping[str, EndUser]:
"""Create end users in batch.
Creates end users in batch for the specified tenant and application IDs in O(1) time.
This batch creation is necessary because trigger subscriptions can span multiple applications,
and trigger events may be dispatched to multiple applications simultaneously.
For each app_id in app_ids, check if an `EndUser` with the given
`user_id` (as session_id/external_user_id) already exists for the
tenant/app and type `type`. If it exists, return it; otherwise,
create it. Operates with minimal DB I/O by querying and inserting in
batches.
Returns a mapping of `app_id -> EndUser`.
"""
# Normalize user_id to default if empty
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Deduplicate app_ids while preserving input order
seen: set[str] = set()
unique_app_ids: list[str] = []
for app_id in app_ids:
if app_id not in seen:
seen.add(app_id)
unique_app_ids.append(app_id)
# Result is a simple app_id -> EndUser mapping
result: dict[str, EndUser] = {}
if not unique_app_ids:
return result
with Session(db.engine, expire_on_commit=False) as session:
# Fetch existing end users for all target apps in a single query
existing_end_users: list[EndUser] = (
session.query(EndUser)
.where(
EndUser.tenant_id == tenant_id,
EndUser.app_id.in_(unique_app_ids),
EndUser.session_id == user_id,
EndUser.type == type,
)
.all()
)
found_app_ids: set[str] = set()
for eu in existing_end_users:
# If duplicates exist due to weak DB constraints, prefer the first
if eu.app_id not in result:
result[eu.app_id] = eu
found_app_ids.add(eu.app_id)
# Determine which apps still need an EndUser created
missing_app_ids = [app_id for app_id in unique_app_ids if app_id not in found_app_ids]
if missing_app_ids:
new_end_users: list[EndUser] = []
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
for app_id in missing_app_ids:
new_end_users.append(
EndUser(
tenant_id=tenant_id,
app_id=app_id,
type=type,
is_anonymous=is_anonymous,
session_id=user_id,
external_user_id=user_id,
)
)
session.add_all(new_end_users)
session.commit()
for eu in new_end_users:
result[eu.app_id] = eu
return result

View File

@ -16,3 +16,9 @@ class WorkflowNotFoundError(Exception):
class WorkflowIdFormatError(Exception):
pass
class InvokeDailyRateLimitError(Exception):
"""Raised when daily rate limit is exceeded for workflow invocations."""
pass

View File

@ -16,6 +16,7 @@ class OAuthProxyService(BasePluginClient):
tenant_id: str,
plugin_id: str,
provider: str,
extra_data: dict = {},
credential_id: str | None = None,
):
"""
@ -32,6 +33,7 @@ class OAuthProxyService(BasePluginClient):
"""
context_id = str(uuid.uuid4())
data = {
**extra_data,
"user_id": user_id,
"plugin_id": plugin_id,
"tenant_id": tenant_id,

View File

@ -4,11 +4,16 @@ from typing import Any, Literal
from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.dynamic_select import DynamicSelectClient
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_tool_provider_encrypter
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
from core.trigger.entities.entities import SubscriptionBuilder
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
class PluginParameterService:
@ -20,7 +25,8 @@ class PluginParameterService:
provider: str,
action: str,
parameter: str,
provider_type: Literal["tool"],
credential_id: str | None,
provider_type: Literal["tool", "trigger"],
) -> Sequence[PluginParameterOption]:
"""
Get dynamic select options for a plugin parameter.
@ -33,7 +39,7 @@ class PluginParameterService:
parameter: The parameter name.
"""
credentials: Mapping[str, Any] = {}
credential_type: str = CredentialType.UNAUTHORIZED.value
match provider_type:
case "tool":
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@ -49,24 +55,53 @@ class PluginParameterService:
else:
# fetch credentials from db
with Session(db.engine) as session:
db_record = (
session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
if credential_id:
db_record = (
session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.id == credential_id,
)
.first()
)
else:
db_record = (
session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
.first()
)
if db_record is None:
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
credentials = encrypter.decrypt(db_record.credentials)
case _:
raise ValueError(f"Invalid provider type: {provider_type}")
credential_type = db_record.credential_type
case "trigger":
subscription: TriggerProviderSubscriptionApiEntity | SubscriptionBuilder | None
if credential_id:
subscription = TriggerSubscriptionBuilderService.get_subscription_builder(credential_id)
if not subscription:
trigger_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
subscription = trigger_subscription.to_api_entity() if trigger_subscription else None
else:
trigger_subscription = TriggerProviderService.get_subscription_by_id(tenant_id)
subscription = trigger_subscription.to_api_entity() if trigger_subscription else None
if subscription is None:
raise ValueError(f"Subscription {credential_id} not found")
credentials = subscription.credentials
credential_type = subscription.credential_type or CredentialType.UNAUTHORIZED
return (
DynamicSelectClient()
.fetch_dynamic_select_options(tenant_id, user_id, plugin_id, provider, action, credentials, parameter)
.fetch_dynamic_select_options(
tenant_id, user_id, plugin_id, provider, action, credentials, credential_type, parameter
)
.options
)

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
from yarl import URL
from configs import dify_config
from core.helper import marketplace
@ -175,6 +176,13 @@ class PluginService:
manager = PluginInstaller()
return manager.fetch_plugin_installation_by_ids(tenant_id, ids)
@classmethod
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
url_prefix = (
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
)
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
@staticmethod
def get_asset(tenant_id: str, asset_file: str) -> tuple[bytes, str]:
"""
@ -185,6 +193,11 @@ class PluginService:
mime_type, _ = guess_type(asset_file)
return manager.fetch_asset(tenant_id, asset_file), mime_type or "application/octet-stream"
@staticmethod
def extract_asset(tenant_id: str, plugin_unique_identifier: str, file_name: str) -> bytes:
manager = PluginAssetManager()
return manager.extract_asset(tenant_id, plugin_unique_identifier, file_name)
@staticmethod
def check_plugin_unique_identifier(tenant_id: str, plugin_unique_identifier: str) -> bool:
"""
@ -502,3 +515,11 @@ class PluginService:
"""
manager = PluginInstaller()
return manager.check_tools_existence(tenant_id, provider_ids)
@staticmethod
def fetch_plugin_readme(tenant_id: str, plugin_unique_identifier: str, language: str) -> str:
"""
Fetch plugin readme
"""
manager = PluginInstaller()
return manager.fetch_plugin_readme(tenant_id, plugin_unique_identifier, language)

View File

@ -300,13 +300,13 @@ class ApiToolManageService:
)
original_credentials = encrypter.decrypt(provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
masked_credentials = encrypter.mask_plugin_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
credentials = encrypter.encrypt(credentials)
credentials = dict(encrypter.encrypt(credentials))
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)
@ -417,7 +417,7 @@ class ApiToolManageService:
)
decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_plugin_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]

View File

@ -12,6 +12,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import (
@ -20,7 +21,6 @@ from core.tools.entities.api_entities import (
ToolProviderCredentialApiEntity,
ToolProviderCredentialInfoApiEntity,
)
from core.tools.entities.tool_entities import CredentialType
from core.tools.errors import ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
@ -39,7 +39,6 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
__DEFAULT_EXPIRES_AT__ = 2147483647
@staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
@ -278,9 +277,7 @@ class BuiltinToolManageService:
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
expires_at=expires_at
if expires_at is not None
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
expires_at=expires_at if expires_at is not None else -1,
)
session.add(db_provider)
@ -353,10 +350,10 @@ class BuiltinToolManageService:
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
tenant_id, provider, provider.provider, provider_controller
)
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
decrypt_credential = encrypter.mask_plugin_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
credentials=dict(decrypt_credential),
)
credentials.append(credential_entity)
return credentials
@ -727,4 +724,4 @@ class BuiltinToolManageService:
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
return encrypter.mask_plugin_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))

View File

@ -1,6 +1,7 @@
import hashlib
import json
import logging
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any
@ -420,7 +421,7 @@ class MCPToolManageService:
return json.dumps({"content": icon, "background": icon_background})
return icon
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> dict[str, str]:
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> Mapping[str, str]:
"""Encrypt specified fields in a dictionary.
Args:

View File

@ -9,7 +9,7 @@ from yarl import URL
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@ -19,7 +19,6 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolParameter,
ToolProviderType,
)
@ -28,18 +27,12 @@ from core.tools.utils.encryption import create_provider_encrypter, create_tool_p
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
class ToolTransformService:
@classmethod
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
url_prefix = (
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
)
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
@classmethod
def get_tool_provider_icon_url(
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
@ -79,11 +72,9 @@ class ToolTransformService:
elif isinstance(provider, ToolProviderApiEntity):
if provider.plugin_id:
if isinstance(provider.icon, str):
provider.icon = ToolTransformService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.icon
)
provider.icon = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon)
if isinstance(provider.icon_dark, str) and provider.icon_dark:
provider.icon_dark = ToolTransformService.get_plugin_icon_url(
provider.icon_dark = PluginService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.icon_dark
)
else:
@ -97,7 +88,7 @@ class ToolTransformService:
elif isinstance(provider, PluginDatasourceProviderEntity):
if provider.plugin_id:
if isinstance(provider.declaration.identity.icon, str):
provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
provider.declaration.identity.icon = PluginService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.declaration.identity.icon
)
@ -172,7 +163,7 @@ class ToolTransformService:
)
# decrypt the credentials and mask the credentials
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
@ -345,7 +336,7 @@ class ToolTransformService:
# decrypt the credentials and mask the credentials
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials

View File

@ -0,0 +1,312 @@
import json
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.workflow.nodes import NodeType
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
from models.account import Account, TenantAccountJoin
from models.trigger import WorkflowSchedulePlan
from models.workflow import Workflow
from services.errors.account import AccountNotFoundError
logger = logging.getLogger(__name__)
class ScheduleService:
@staticmethod
def create_schedule(
session: Session,
tenant_id: str,
app_id: str,
config: ScheduleConfig,
) -> WorkflowSchedulePlan:
"""
Create a new schedule with validated configuration.
Args:
session: Database session
tenant_id: Tenant ID
app_id: Application ID
config: Validated schedule configuration
Returns:
Created WorkflowSchedulePlan instance
"""
next_run_at = calculate_next_run_at(
config.cron_expression,
config.timezone,
)
schedule = WorkflowSchedulePlan(
tenant_id=tenant_id,
app_id=app_id,
node_id=config.node_id,
cron_expression=config.cron_expression,
timezone=config.timezone,
next_run_at=next_run_at,
)
session.add(schedule)
session.flush()
return schedule
@staticmethod
def update_schedule(
session: Session,
schedule_id: str,
updates: SchedulePlanUpdate,
) -> WorkflowSchedulePlan:
"""
Update an existing schedule with validated configuration.
Args:
session: Database session
schedule_id: Schedule ID to update
updates: Validated update configuration
Raises:
ScheduleNotFoundError: If schedule not found
Returns:
Updated WorkflowSchedulePlan instance
"""
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
# If time-related fields are updated, synchronously update the next_run_at.
time_fields_updated = False
if updates.node_id is not None:
schedule.node_id = updates.node_id
if updates.cron_expression is not None:
schedule.cron_expression = updates.cron_expression
time_fields_updated = True
if updates.timezone is not None:
schedule.timezone = updates.timezone
time_fields_updated = True
if time_fields_updated:
schedule.next_run_at = calculate_next_run_at(
schedule.cron_expression,
schedule.timezone,
)
session.flush()
return schedule
@staticmethod
def delete_schedule(
session: Session,
schedule_id: str,
) -> None:
"""
Delete a schedule plan.
Args:
session: Database session
schedule_id: Schedule ID to delete
"""
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
session.delete(schedule)
session.flush()
@staticmethod
def get_tenant_owner(session: Session, tenant_id: str) -> Account:
"""
Returns an account to execute scheduled workflows on behalf of the tenant.
Prioritizes owner over admin to ensure proper authorization hierarchy.
"""
result = session.execute(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == "owner")
.limit(1)
).scalar_one_or_none()
if not result:
# Owner may not exist in some tenant configurations, fallback to admin
result = session.execute(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == "admin")
.limit(1)
).scalar_one_or_none()
if result:
account = session.get(Account, result.account_id)
if not account:
raise AccountNotFoundError(f"Account not found: {result.account_id}")
return account
else:
raise AccountNotFoundError(f"Account not found for tenant: {tenant_id}")
@staticmethod
def update_next_run_at(
session: Session,
schedule_id: str,
) -> datetime:
"""
Advances the schedule to its next execution time after a successful trigger.
Uses current time as base to prevent missing executions during delays.
"""
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
# Base on current time to handle execution delays gracefully
next_run_at = calculate_next_run_at(
schedule.cron_expression,
schedule.timezone,
)
schedule.next_run_at = next_run_at
session.flush()
return next_run_at
@staticmethod
def to_schedule_config(node_config: Mapping[str, Any]) -> ScheduleConfig:
"""
Converts user-friendly visual schedule settings to cron expression.
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
"""
node_data = node_config.get("data", {})
mode = node_data.get("mode", "visual")
timezone = node_data.get("timezone", "UTC")
node_id = node_config.get("id", "start")
cron_expression = None
if mode == "cron":
cron_expression = node_data.get("cron_expression")
if not cron_expression:
raise ScheduleConfigError("Cron expression is required for cron mode")
elif mode == "visual":
frequency = str(node_data.get("frequency"))
if not frequency:
raise ScheduleConfigError("Frequency is required for visual mode")
visual_config = VisualConfig(**node_data.get("visual_config", {}))
cron_expression = ScheduleService.visual_to_cron(frequency=frequency, visual_config=visual_config)
if not cron_expression:
raise ScheduleConfigError("Cron expression is required for visual mode")
else:
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")
return ScheduleConfig(node_id=node_id, cron_expression=cron_expression, timezone=timezone)
@staticmethod
def extract_schedule_config(workflow: Workflow) -> ScheduleConfig | None:
"""
Extracts schedule configuration from workflow graph.
Searches for the first schedule trigger node in the workflow and converts
its configuration (either visual or cron mode) into a unified ScheduleConfig.
Args:
workflow: The workflow containing the graph definition
Returns:
ScheduleConfig if a valid schedule node is found, None if no schedule node exists
Raises:
ScheduleConfigError: If graph parsing fails or schedule configuration is invalid
Note:
Currently only returns the first schedule node found.
Multiple schedule nodes in the same workflow are not supported.
"""
try:
graph_data = workflow.graph_dict
except (json.JSONDecodeError, TypeError, AttributeError) as e:
raise ScheduleConfigError(f"Failed to parse workflow graph: {e}")
if not graph_data:
raise ScheduleConfigError("Workflow graph is empty")
nodes = graph_data.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value:
continue
mode = node_data.get("mode", "visual")
timezone = node_data.get("timezone", "UTC")
node_id = node.get("id", "start")
cron_expression = None
if mode == "cron":
cron_expression = node_data.get("cron_expression")
if not cron_expression:
raise ScheduleConfigError("Cron expression is required for cron mode")
elif mode == "visual":
frequency = node_data.get("frequency")
visual_config_dict = node_data.get("visual_config", {})
visual_config = VisualConfig(**visual_config_dict)
cron_expression = ScheduleService.visual_to_cron(frequency, visual_config)
else:
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")
return ScheduleConfig(node_id=node_id, cron_expression=cron_expression, timezone=timezone)
return None
@staticmethod
def visual_to_cron(frequency: str, visual_config: VisualConfig) -> str:
"""
Converts user-friendly visual schedule settings to cron expression.
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
"""
if frequency == "hourly":
if visual_config.on_minute is None:
raise ScheduleConfigError("on_minute is required for hourly schedules")
return f"{visual_config.on_minute} * * * *"
elif frequency == "daily":
if not visual_config.time:
raise ScheduleConfigError("time is required for daily schedules")
hour, minute = convert_12h_to_24h(visual_config.time)
return f"{minute} {hour} * * *"
elif frequency == "weekly":
if not visual_config.time:
raise ScheduleConfigError("time is required for weekly schedules")
if not visual_config.weekdays:
raise ScheduleConfigError("Weekdays are required for weekly schedules")
hour, minute = convert_12h_to_24h(visual_config.time)
weekday_map = {"sun": "0", "mon": "1", "tue": "2", "wed": "3", "thu": "4", "fri": "5", "sat": "6"}
cron_weekdays = [weekday_map[day] for day in visual_config.weekdays]
return f"{minute} {hour} * * {','.join(sorted(cron_weekdays))}"
elif frequency == "monthly":
if not visual_config.time:
raise ScheduleConfigError("time is required for monthly schedules")
if not visual_config.monthly_days:
raise ScheduleConfigError("Monthly days are required for monthly schedules")
hour, minute = convert_12h_to_24h(visual_config.time)
numeric_days: list[int] = []
has_last = False
for day in visual_config.monthly_days:
if day == "last":
has_last = True
else:
numeric_days.append(day)
result_days = [str(d) for d in sorted(set(numeric_days))]
if has_last:
result_days.append("L")
return f"{minute} {hour} {','.join(result_days)} * *"
else:
raise ScheduleConfigError(f"Unsupported frequency: {frequency}")

View File

@ -0,0 +1,687 @@
import json
import logging
import time as _time
import uuid
from collections.abc import Mapping
from typing import Any
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.trigger.entities.api_entities import (
TriggerProviderApiEntity,
TriggerProviderSubscriptionApiEntity,
)
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import (
create_trigger_provider_encrypter_for_properties,
create_trigger_provider_encrypter_for_subscription,
delete_cache_for_subscription,
)
from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import TriggerProviderID
from models.trigger import (
TriggerOAuthSystemClient,
TriggerOAuthTenantClient,
TriggerSubscription,
WorkflowPluginTrigger,
)
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
class TriggerProviderService:
"""Service for managing trigger providers and credentials"""
##########################
# Trigger provider
##########################
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
@classmethod
def get_trigger_provider(cls, tenant_id: str, provider: TriggerProviderID) -> TriggerProviderApiEntity:
"""Get info for a trigger provider"""
return TriggerManager.get_trigger_provider(tenant_id, provider).to_api_entity()
@classmethod
def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]:
"""List all trigger providers for the current tenant"""
return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)]
@classmethod
def list_trigger_provider_subscriptions(
cls, tenant_id: str, provider_id: TriggerProviderID
) -> list[TriggerProviderSubscriptionApiEntity]:
"""List all trigger subscriptions for the current tenant"""
subscriptions: list[TriggerProviderSubscriptionApiEntity] = []
workflows_in_use_map: dict[str, int] = {}
with Session(db.engine, expire_on_commit=False) as session:
# Get all subscriptions
subscriptions_db = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
.order_by(desc(TriggerSubscription.created_at))
.all()
)
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
if not subscriptions:
return []
usage_counts = (
session.query(
WorkflowPluginTrigger.subscription_id,
func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"),
)
.filter(
WorkflowPluginTrigger.tenant_id == tenant_id,
WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]),
)
.group_by(WorkflowPluginTrigger.subscription_id)
.all()
)
workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts}
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
for subscription in subscriptions:
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.credentials = dict(
encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
)
subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
count = workflows_in_use_map.get(subscription.id)
subscription.workflows_in_use = count if count is not None else 0
return subscriptions
@classmethod
def add_trigger_subscription(
cls,
tenant_id: str,
user_id: str,
name: str,
provider_id: TriggerProviderID,
endpoint_id: str,
credential_type: CredentialType,
parameters: Mapping[str, Any],
properties: Mapping[str, Any],
credentials: Mapping[str, str],
subscription_id: str | None = None,
credential_expires_at: int = -1,
expires_at: int = -1,
) -> Mapping[str, Any]:
"""
Add a new trigger provider with credentials.
Supports multiple credential instances per provider.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier (e.g., "plugin_id/provider_name")
:param credential_type: Type of credential (oauth or api_key)
:param credentials: Credential data to encrypt and store
:param name: Optional name for this credential instance
:param expires_at: OAuth token expiration timestamp
:return: Success response
"""
try:
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
with Session(db.engine, expire_on_commit=False) as session:
# Use distributed lock to prevent race conditions
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
with redis_client.lock(lock_key, timeout=20):
# Check provider count limit
provider_count = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
.count()
)
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
raise ValueError(
f"Maximum number of providers ({cls.__MAX_TRIGGER_PROVIDER_COUNT__}) "
f"reached for {provider_id}"
)
# Check if name already exists
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
)
if existing:
raise ValueError(f"Credential name '{name}' already exists for this provider")
credential_encrypter: ProviderConfigEncrypter | None = None
if credential_type != CredentialType.UNAUTHORIZED:
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
properties_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_properties_schema(),
cache=NoOpProviderCredentialCache(),
)
# Create provider record
subscription = TriggerSubscription(
id=subscription_id or str(uuid.uuid4()),
tenant_id=tenant_id,
user_id=user_id,
name=name,
endpoint_id=endpoint_id,
provider_id=str(provider_id),
parameters=parameters,
properties=properties_encrypter.encrypt(dict(properties)),
credentials=credential_encrypter.encrypt(dict(credentials)) if credential_encrypter else {},
credential_type=credential_type.value,
credential_expires_at=credential_expires_at,
expires_at=expires_at,
)
session.add(subscription)
session.commit()
return {
"result": "success",
"id": str(subscription.id),
}
except Exception as e:
logger.exception("Failed to add trigger provider")
raise ValueError(str(e))
@classmethod
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
"""
Get a trigger subscription by the ID.
"""
with Session(db.engine, expire_on_commit=False) as session:
subscription: TriggerSubscription | None = None
if subscription_id:
subscription = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
else:
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first()
if subscription:
provider_controller = TriggerManager.get_trigger_provider(
tenant_id, TriggerProviderID(subscription.provider_id)
)
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.credentials = dict(encrypter.decrypt(subscription.credentials))
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
tenant_id=subscription.tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
return subscription
@classmethod
def delete_trigger_provider(cls, session: Session, tenant_id: str, subscription_id: str):
"""
Delete a trigger provider subscription within an existing session.
:param session: Database session
:param tenant_id: Tenant ID
:param subscription_id: Subscription instance ID
:return: Success response
"""
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
if is_auto_created:
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=subscription.user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=encrypter.decrypt(subscription.credentials),
credential_type=credential_type,
)
except Exception as e:
logger.exception("Error unsubscribing trigger", exc_info=e)
# Clear cache
session.delete(subscription)
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
@classmethod
def refresh_oauth_token(
cls,
tenant_id: str,
subscription_id: str,
) -> Mapping[str, Any]:
"""
Refresh OAuth token for a trigger provider.
:param tenant_id: Tenant ID
:param subscription_id: Subscription instance ID
:return: New token info
"""
with Session(db.engine) as session:
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
if not subscription:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
if subscription.credential_type != CredentialType.OAUTH2.value:
raise ValueError("Only OAuth credentials can be refreshed")
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
# Create encrypter
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
# Decrypt current credentials
current_credentials = encrypter.decrypt(subscription.credentials)
# Get OAuth client configuration
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{subscription.provider_id}/trigger/callback"
)
system_credentials = cls.get_oauth_client(tenant_id, provider_id)
# Refresh token
oauth_handler = OAuthHandler()
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=subscription.user_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=current_credentials,
)
# Update credentials
subscription.credentials = dict(encrypter.encrypt(dict(refreshed_credentials.credentials)))
subscription.credential_expires_at = refreshed_credentials.expires_at
session.commit()
# Clear cache
cache.delete()
return {
"result": "success",
"expires_at": refreshed_credentials.expires_at,
}
@classmethod
def refresh_subscription(
cls,
tenant_id: str,
subscription_id: str,
now: int | None = None,
) -> Mapping[str, Any]:
"""
Refresh trigger subscription if expired.
Args:
tenant_id: Tenant ID
subscription_id: Subscription instance ID
now: Current timestamp, defaults to `int(time.time())`
Returns:
Mapping with keys: `result` ("success"|"skipped") and `expires_at` (new or existing value)
"""
now_ts: int = int(now if now is not None else _time.time())
with Session(db.engine) as session:
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if subscription is None:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
if subscription.expires_at == -1 or int(subscription.expires_at) > now_ts:
logger.debug(
"Subscription not due for refresh: tenant=%s id=%s expires_at=%s now=%s",
tenant_id,
subscription_id,
subscription.expires_at,
now_ts,
)
return {"result": "skipped", "expires_at": int(subscription.expires_at)}
provider_id = TriggerProviderID(subscription.provider_id)
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
# Decrypt credentials and properties for runtime
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=controller,
subscription=subscription,
)
properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties(
tenant_id=tenant_id,
controller=controller,
subscription=subscription,
)
decrypted_credentials = credential_encrypter.decrypt(subscription.credentials)
decrypted_properties = properties_encrypter.decrypt(subscription.properties)
sub_entity: TriggerSubscriptionEntity = TriggerSubscriptionEntity(
expires_at=int(subscription.expires_at),
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=subscription.parameters,
properties=decrypted_properties,
)
refreshed: TriggerSubscriptionEntity = controller.refresh_trigger(
subscription=sub_entity,
credentials=decrypted_credentials,
credential_type=CredentialType.of(subscription.credential_type),
)
# Persist refreshed properties and expires_at
subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties)))
subscription.expires_at = int(refreshed.expires_at)
session.commit()
properties_cache.delete()
logger.info(
"Subscription refreshed (service): tenant=%s id=%s new_expires_at=%s",
tenant_id,
subscription_id,
subscription.expires_at,
)
return {"result": "success", "expires_at": int(refreshed.expires_at)}
@classmethod
def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any] | None:
"""
Get OAuth client configuration for a provider.
First tries tenant-level OAuth, then falls back to system OAuth.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier
:return: OAuth client configuration or None
"""
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
with Session(db.engine, expire_on_commit=False) as session:
tenant_client: TriggerOAuthTenantClient | None = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
enabled=True,
)
.first()
)
oauth_params: Mapping[str, Any] | None = None
if tenant_client:
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
oauth_params = encrypter.decrypt(dict(tenant_client.oauth_params))
return oauth_params
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
if not is_verified:
return None
# Check for system-level OAuth client
system_client: TriggerOAuthSystemClient | None = (
session.query(TriggerOAuthSystemClient)
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
.first()
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")
return oauth_params
@classmethod
def is_oauth_system_client_exists(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool:
"""
Check if system OAuth client exists for a trigger provider.
"""
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
if not is_verified:
return False
with Session(db.engine, expire_on_commit=False) as session:
system_client: TriggerOAuthSystemClient | None = (
session.query(TriggerOAuthSystemClient)
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
.first()
)
return system_client is not None
@classmethod
def save_custom_oauth_client_params(
cls,
tenant_id: str,
provider_id: TriggerProviderID,
client_params: Mapping[str, Any] | None = None,
enabled: bool | None = None,
) -> Mapping[str, Any]:
"""
Save or update custom OAuth client parameters for a trigger provider.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier
:param client_params: OAuth client parameters (client_id, client_secret, etc.)
:param enabled: Enable/disable the custom OAuth client
:return: Success response
"""
if client_params is None and enabled is None:
return {"result": "success"}
# Get provider controller to access schema
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
with Session(db.engine) as session:
# Find existing custom client params
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
)
.first()
)
# Create new record if doesn't exist
if custom_client is None:
custom_client = TriggerOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
)
session.add(custom_client)
# Update client params if provided
if client_params is None:
custom_client.encrypted_oauth_params = json.dumps({})
else:
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
# Handle hidden values
original_params = encrypter.decrypt(dict(custom_client.oauth_params))
new_params: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
custom_client.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
cache.delete()
# Update enabled status if provided
if enabled is not None:
custom_client.enabled = enabled
session.commit()
return {"result": "success"}
@classmethod
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]:
"""
Get custom OAuth client parameters for a trigger provider.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier
:return: Masked OAuth client parameters
"""
with Session(db.engine) as session:
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
)
.first()
)
if custom_client is None:
return {}
# Get provider controller to access schema
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
# Create encrypter to decrypt and mask values
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_plugin_credentials(encrypter.decrypt(dict(custom_client.oauth_params)))
@classmethod
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]:
"""
Delete custom OAuth client parameters for a trigger provider.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier
:return: Success response
"""
with Session(db.engine) as session:
session.query(TriggerOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
).delete()
session.commit()
return {"result": "success"}
@classmethod
def is_oauth_custom_client_enabled(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool:
"""
Check if custom OAuth client is enabled for a trigger provider.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier
:return: True if enabled, False otherwise
"""
with Session(db.engine, expire_on_commit=False) as session:
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
enabled=True,
)
.first()
)
return custom_client is not None
@classmethod
def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None:
"""
Get a trigger subscription by the endpoint ID.
"""
with Session(db.engine, expire_on_commit=False) as session:
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
if not subscription:
return None
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=subscription.tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.credentials = dict(credential_encrypter.decrypt(subscription.credentials))
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
tenant_id=subscription.tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
return subscription

View File

@ -0,0 +1,65 @@
from collections.abc import Mapping
from typing import Any
from flask import Request
from pydantic import TypeAdapter
from core.plugin.utils.http_parser import deserialize_request, serialize_request
from extensions.ext_storage import storage
class TriggerHttpRequestCachingService:
"""
Service for caching trigger requests.
"""
_TRIGGER_STORAGE_PATH = "triggers"
@classmethod
def get_request(cls, request_id: str) -> Request:
"""
Get the request object from the storage.
Args:
request_id: The ID of the request.
Returns:
The request object.
"""
return deserialize_request(storage.load_once(f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.raw"))
@classmethod
def get_payload(cls, request_id: str) -> Mapping[str, Any]:
"""
Get the payload from the storage.
Args:
request_id: The ID of the request.
Returns:
The payload.
"""
return TypeAdapter(Mapping[str, Any]).validate_json(
storage.load_once(f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.payload")
)
@classmethod
def persist_request(cls, request_id: str, request: Request) -> None:
"""
Persist the request in the storage.
Args:
request_id: The ID of the request.
request: The request object.
"""
storage.save(f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.raw", serialize_request(request))
@classmethod
def persist_payload(cls, request_id: str, payload: Mapping[str, Any]) -> None:
"""
Persist the payload in the storage.
"""
storage.save(
f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.payload",
TypeAdapter(Mapping[str, Any]).dump_json(payload), # type: ignore
)

View File

@ -0,0 +1,307 @@
import logging
import secrets
import time
from collections.abc import Mapping
from typing import Any
from flask import Request, Response
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginNotFoundError
from core.trigger.debug.events import PluginTriggerDebugEvent
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription
from core.workflow.enums import NodeType
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import App
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerSubscription, WorkflowPluginTrigger
from models.workflow import Workflow
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
from services.workflow.entities import PluginTriggerDispatchData
from tasks.trigger_processing_tasks import dispatch_triggered_workflows_async
logger = logging.getLogger(__name__)
class TriggerService:
__TEMPORARY_ENDPOINT_EXPIRE_MS__ = 5 * 60 * 1000
__ENDPOINT_REQUEST_CACHE_COUNT__ = 10
__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
__PLUGIN_TRIGGER_NODE_CACHE_KEY__ = "plugin_trigger_nodes"
MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW = 5 # Maximum allowed plugin trigger nodes per workflow
@classmethod
def invoke_trigger_event(
cls, tenant_id: str, user_id: str, node_config: Mapping[str, Any], event: PluginTriggerDebugEvent
) -> TriggerInvokeEventResponse:
"""Invoke a trigger event."""
subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id(
tenant_id=tenant_id,
subscription_id=event.subscription_id,
)
if not subscription:
raise ValueError("Subscription not found")
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(node_config.get("data", {}))
request = TriggerHttpRequestCachingService.get_request(event.request_id)
payload = TriggerHttpRequestCachingService.get_payload(event.request_id)
# invoke triger
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id, TriggerProviderID(subscription.provider_id)
)
return TriggerManager.invoke_trigger_event(
tenant_id=tenant_id,
user_id=user_id,
provider_id=TriggerProviderID(event.provider_id),
event_name=event.name,
parameters=node_data.resolve_parameters(
parameter_schemas=provider_controller.get_event_parameters(event_name=event.name)
),
credentials=subscription.credentials,
credential_type=CredentialType.of(subscription.credential_type),
subscription=subscription.to_entity(),
request=request,
payload=payload,
)
@classmethod
def process_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
"""
Extract and process data from incoming endpoint request.
Args:
endpoint_id: Endpoint ID
request: Request
"""
timestamp = int(time.time())
subscription: TriggerSubscription | None = None
try:
subscription = TriggerProviderService.get_subscription_by_endpoint(endpoint_id)
except PluginNotFoundError:
return Response(status=404, response="Trigger provider not found")
except Exception:
return Response(status=500, response="Failed to get subscription by endpoint")
if not subscription:
return None
provider_id = TriggerProviderID(subscription.provider_id)
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription.tenant_id, provider_id=provider_id
)
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=subscription.tenant_id,
controller=controller,
subscription=subscription,
)
dispatch_response: TriggerDispatchResponse = controller.dispatch(
request=request,
subscription=subscription.to_entity(),
credentials=encrypter.decrypt(subscription.credentials),
credential_type=CredentialType.of(subscription.credential_type),
)
if dispatch_response.events:
request_id = f"trigger_request_{timestamp}_{secrets.token_hex(6)}"
# save the request and payload to storage as persistent data
TriggerHttpRequestCachingService.persist_request(request_id, request)
TriggerHttpRequestCachingService.persist_payload(request_id, dispatch_response.payload)
# Validate event names
for event_name in dispatch_response.events:
if controller.get_event(event_name) is None:
logger.error(
"Event name %s not found in provider %s for endpoint %s",
event_name,
subscription.provider_id,
endpoint_id,
)
raise ValueError(f"Event name {event_name} not found in provider {subscription.provider_id}")
plugin_trigger_dispatch_data = PluginTriggerDispatchData(
user_id=dispatch_response.user_id,
tenant_id=subscription.tenant_id,
endpoint_id=endpoint_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
timestamp=timestamp,
events=list(dispatch_response.events),
request_id=request_id,
)
dispatch_data = plugin_trigger_dispatch_data.model_dump(mode="json")
dispatch_triggered_workflows_async.delay(dispatch_data)
logger.info(
"Queued async dispatching for %d triggers on endpoint %s with request_id %s",
len(dispatch_response.events),
endpoint_id,
request_id,
)
return dispatch_response.response
@classmethod
def sync_plugin_trigger_relationships(cls, app: App, workflow: Workflow):
"""
Sync plugin trigger relationships in DB.
1. Check if the workflow has any plugin trigger nodes
2. Fetch the nodes from DB, see if there were any plugin trigger records already
3. Diff the nodes and the plugin trigger records, create/update/delete the records as needed
Approach:
Frequent DB operations may cause performance issues, using Redis to cache it instead.
If any record exists, cache it.
Limits:
- Maximum 5 plugin trigger nodes per workflow
"""
class Cache(BaseModel):
"""
Cache model for plugin trigger nodes
"""
record_id: str
node_id: str
provider_id: str
event_name: str
subscription_id: str
# Walk nodes to find plugin triggers
nodes_in_graph: list[Mapping[str, Any]] = []
for node_id, node_config in workflow.walk_nodes(NodeType.TRIGGER_PLUGIN):
# Extract plugin trigger configuration from node
plugin_id = node_config.get("plugin_id", "")
provider_id = node_config.get("provider_id", "")
event_name = node_config.get("event_name", "")
subscription_id = node_config.get("subscription_id", "")
if not subscription_id:
continue
nodes_in_graph.append(
{
"node_id": node_id,
"plugin_id": plugin_id,
"provider_id": provider_id,
"event_name": event_name,
"subscription_id": subscription_id,
}
)
# Check plugin trigger node limit
if len(nodes_in_graph) > cls.MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW:
raise ValueError(
f"Workflow exceeds maximum plugin trigger node limit. "
f"Found {len(nodes_in_graph)} plugin trigger nodes, "
f"maximum allowed is {cls.MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW}"
)
not_found_in_cache: list[Mapping[str, Any]] = []
for node_info in nodes_in_graph:
node_id = node_info["node_id"]
# firstly check if the node exists in cache
if not redis_client.get(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_id}"):
not_found_in_cache.append(node_info)
continue
with Session(db.engine) as session:
try:
# lock the concurrent plugin trigger creation
redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.app_id == app.id,
WorkflowPluginTrigger.tenant_id == app.tenant_id,
)
).all()
nodes_id_in_db = {node.node_id: node for node in all_records}
nodes_id_in_graph = {node["node_id"] for node in nodes_in_graph}
# get the nodes not found both in cache and DB
nodes_not_found = [
node_info for node_info in not_found_in_cache if node_info["node_id"] not in nodes_id_in_db
]
# create new plugin trigger records
for node_info in nodes_not_found:
plugin_trigger = WorkflowPluginTrigger(
app_id=app.id,
tenant_id=app.tenant_id,
node_id=node_info["node_id"],
provider_id=node_info["provider_id"],
event_name=node_info["event_name"],
subscription_id=node_info["subscription_id"],
)
session.add(plugin_trigger)
session.flush() # Get the ID for caching
cache = Cache(
record_id=plugin_trigger.id,
node_id=node_info["node_id"],
provider_id=node_info["provider_id"],
event_name=node_info["event_name"],
subscription_id=node_info["subscription_id"],
)
redis_client.set(
f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_info['node_id']}",
cache.model_dump_json(),
ex=60 * 60,
)
session.commit()
# Update existing records if subscription_id changed
for node_info in nodes_in_graph:
node_id = node_info["node_id"]
if node_id in nodes_id_in_db:
existing_record = nodes_id_in_db[node_id]
if (
existing_record.subscription_id != node_info["subscription_id"]
or existing_record.provider_id != node_info["provider_id"]
or existing_record.event_name != node_info["event_name"]
):
existing_record.subscription_id = node_info["subscription_id"]
existing_record.provider_id = node_info["provider_id"]
existing_record.event_name = node_info["event_name"]
session.add(existing_record)
# Update cache
cache = Cache(
record_id=existing_record.id,
node_id=node_id,
provider_id=node_info["provider_id"],
event_name=node_info["event_name"],
subscription_id=node_info["subscription_id"],
)
redis_client.set(
f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_id}",
cache.model_dump_json(),
ex=60 * 60,
)
session.commit()
# delete the nodes not found in the graph
for node_id in nodes_id_in_db:
if node_id not in nodes_id_in_graph:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_id}")
session.commit()
except Exception:
import logging
logger = logging.getLogger(__name__)
logger.exception("Failed to sync plugin trigger relationships for app %s", app.id)
raise
finally:
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock")

View File

@ -0,0 +1,492 @@
import json
import logging
import uuid
from collections.abc import Mapping
from contextlib import contextmanager
from datetime import datetime
from typing import Any
from flask import Request, Response
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerDispatchResponse
from core.tools.errors import ToolProviderCredentialValidationError
from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity
from core.trigger.entities.entities import (
RequestLog,
Subscription,
SubscriptionBuilder,
SubscriptionBuilderUpdater,
SubscriptionConstructor,
)
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import masked_credentials
from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url
from extensions.ext_redis import redis_client
from models.provider_ids import TriggerProviderID
from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
class TriggerSubscriptionBuilderService:
"""Service for managing trigger providers and credentials"""
##########################
# Trigger provider
##########################
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
##########################
# Builder endpoint
##########################
__BUILDER_CACHE_EXPIRE_SECONDS__ = 30 * 60
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__ = 30 * 60
##########################
# Distributed lock
##########################
__LOCK_EXPIRE_SECONDS__ = 30
@classmethod
def encode_cache_key(cls, subscription_id: str) -> str:
return f"trigger:subscription:builder:{subscription_id}"
@classmethod
def encode_lock_key(cls, subscription_id: str) -> str:
return f"trigger:subscription:builder:lock:{subscription_id}"
@classmethod
@contextmanager
def acquire_builder_lock(cls, subscription_id: str):
"""
Acquire a distributed lock for a subscription builder.
:param subscription_id: The subscription builder ID
"""
lock_key = cls.encode_lock_key(subscription_id)
with redis_client.lock(lock_key, timeout=cls.__LOCK_EXPIRE_SECONDS__):
yield
@classmethod
def verify_trigger_subscription_builder(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
subscription_builder_id: str,
) -> Mapping[str, Any]:
"""Verify a trigger subscription builder"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
if not subscription_builder:
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
if subscription_builder.credential_type == CredentialType.OAUTH2:
return {"verified": bool(subscription_builder.credentials)}
if subscription_builder.credential_type == CredentialType.API_KEY:
credentials_to_validate = subscription_builder.credentials
try:
provider_controller.validate_credentials(user_id, credentials_to_validate)
except ToolProviderCredentialValidationError as e:
raise ValueError(f"Invalid credentials: {e}")
return {"verified": True}
return {"verified": True}
@classmethod
def build_trigger_subscription_builder(
cls, tenant_id: str, user_id: str, provider_id: TriggerProviderID, subscription_builder_id: str
) -> None:
"""Build a trigger subscription builder"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
# Acquire lock to prevent concurrent build operations
with cls.acquire_builder_lock(subscription_builder_id):
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
if not subscription_builder:
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
if not subscription_builder.name:
raise ValueError("Subscription builder name is required")
credential_type = CredentialType.of(
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
)
if credential_type == CredentialType.UNAUTHORIZED:
# manually create
TriggerProviderService.add_trigger_subscription(
subscription_id=subscription_builder.id,
tenant_id=tenant_id,
user_id=user_id,
name=subscription_builder.name,
provider_id=provider_id,
endpoint_id=subscription_builder.endpoint_id,
parameters=subscription_builder.parameters,
properties=subscription_builder.properties,
credential_expires_at=subscription_builder.credential_expires_at or -1,
expires_at=subscription_builder.expires_at,
credentials=subscription_builder.credentials,
credential_type=credential_type,
)
else:
# automatically create
subscription: Subscription = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription_builder.endpoint_id),
parameters=subscription_builder.parameters,
credentials=subscription_builder.credentials,
credential_type=credential_type,
)
TriggerProviderService.add_trigger_subscription(
subscription_id=subscription_builder.id,
tenant_id=tenant_id,
user_id=user_id,
name=subscription_builder.name,
provider_id=provider_id,
endpoint_id=subscription_builder.endpoint_id,
parameters=subscription_builder.parameters,
properties=subscription.properties,
credentials=subscription_builder.credentials,
credential_type=credential_type,
credential_expires_at=subscription_builder.credential_expires_at or -1,
expires_at=subscription_builder.expires_at,
)
# Delete the builder after successful subscription creation
cache_key = cls.encode_cache_key(subscription_builder_id)
redis_client.delete(cache_key)
@classmethod
def create_trigger_subscription_builder(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
credential_type: CredentialType,
) -> SubscriptionBuilderApiEntity:
"""
Add a new trigger subscription validation.
"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
subscription_constructor: SubscriptionConstructor | None = provider_controller.get_subscription_constructor()
subscription_id = str(uuid.uuid4())
subscription_builder = SubscriptionBuilder(
id=subscription_id,
name=None,
endpoint_id=subscription_id,
tenant_id=tenant_id,
user_id=user_id,
provider_id=str(provider_id),
parameters=subscription_constructor.get_default_parameters() if subscription_constructor else {},
properties=provider_controller.get_subscription_default_properties(),
credentials={},
credential_type=credential_type,
credential_expires_at=-1,
expires_at=-1,
)
cache_key = cls.encode_cache_key(subscription_id)
redis_client.setex(cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder.model_dump_json())
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder)
@classmethod
def update_trigger_subscription_builder(
cls,
tenant_id: str,
provider_id: TriggerProviderID,
subscription_builder_id: str,
subscription_builder_updater: SubscriptionBuilderUpdater,
) -> SubscriptionBuilderApiEntity:
"""
Update a trigger subscription validation.
"""
subscription_id = subscription_builder_id
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
# Acquire lock to prevent concurrent updates
with cls.acquire_builder_lock(subscription_id):
cache_key = cls.encode_cache_key(subscription_id)
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
raise ValueError(f"Subscription {subscription_id} expired or not found")
subscription_builder_updater.update(subscription_builder_cache)
redis_client.setex(
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
)
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache)
@classmethod
def update_and_verify_builder(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
subscription_builder_id: str,
subscription_builder_updater: SubscriptionBuilderUpdater,
) -> Mapping[str, Any]:
"""
Atomically update and verify a subscription builder.
This ensures the verification is done on the exact data that was just updated.
"""
subscription_id = subscription_builder_id
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
# Acquire lock for the entire update + verify operation
with cls.acquire_builder_lock(subscription_id):
cache_key = cls.encode_cache_key(subscription_id)
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
raise ValueError(f"Subscription {subscription_id} expired or not found")
# Update
subscription_builder_updater.update(subscription_builder_cache)
redis_client.setex(
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
)
# Verify (using the just-updated data)
if subscription_builder_cache.credential_type == CredentialType.OAUTH2:
return {"verified": bool(subscription_builder_cache.credentials)}
if subscription_builder_cache.credential_type == CredentialType.API_KEY:
credentials_to_validate = subscription_builder_cache.credentials
try:
provider_controller.validate_credentials(user_id, credentials_to_validate)
except ToolProviderCredentialValidationError as e:
raise ValueError(f"Invalid credentials: {e}")
return {"verified": True}
return {"verified": True}
@classmethod
def update_and_build_builder(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
subscription_builder_id: str,
subscription_builder_updater: SubscriptionBuilderUpdater,
) -> None:
"""
Atomically update and build a subscription builder.
This ensures the build uses the exact data that was just updated.
"""
subscription_id = subscription_builder_id
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
# Acquire lock for the entire update + build operation
with cls.acquire_builder_lock(subscription_id):
cache_key = cls.encode_cache_key(subscription_id)
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
raise ValueError(f"Subscription {subscription_id} expired or not found")
# Update
subscription_builder_updater.update(subscription_builder_cache)
redis_client.setex(
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
)
# Re-fetch to ensure we have the latest data
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
if not subscription_builder:
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
if not subscription_builder.name:
raise ValueError("Subscription builder name is required")
# Build
credential_type = CredentialType.of(
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
)
if credential_type == CredentialType.UNAUTHORIZED:
# manually create
TriggerProviderService.add_trigger_subscription(
subscription_id=subscription_builder.id,
tenant_id=tenant_id,
user_id=user_id,
name=subscription_builder.name,
provider_id=provider_id,
endpoint_id=subscription_builder.endpoint_id,
parameters=subscription_builder.parameters,
properties=subscription_builder.properties,
credential_expires_at=subscription_builder.credential_expires_at or -1,
expires_at=subscription_builder.expires_at,
credentials=subscription_builder.credentials,
credential_type=credential_type,
)
else:
# automatically create
subscription: Subscription = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription_builder.endpoint_id),
parameters=subscription_builder.parameters,
credentials=subscription_builder.credentials,
credential_type=credential_type,
)
TriggerProviderService.add_trigger_subscription(
subscription_id=subscription_builder.id,
tenant_id=tenant_id,
user_id=user_id,
name=subscription_builder.name,
provider_id=provider_id,
endpoint_id=subscription_builder.endpoint_id,
parameters=subscription_builder.parameters,
properties=subscription.properties,
credentials=subscription_builder.credentials,
credential_type=credential_type,
credential_expires_at=subscription_builder.credential_expires_at or -1,
expires_at=subscription_builder.expires_at,
)
# Delete the builder after successful subscription creation
cache_key = cls.encode_cache_key(subscription_builder_id)
redis_client.delete(cache_key)
@classmethod
def builder_to_api_entity(
cls, controller: PluginTriggerProviderController, entity: SubscriptionBuilder
) -> SubscriptionBuilderApiEntity:
credential_type = CredentialType.of(entity.credential_type or CredentialType.UNAUTHORIZED.value)
return SubscriptionBuilderApiEntity(
id=entity.id,
name=entity.name or "",
provider=entity.provider_id,
endpoint=generate_plugin_trigger_endpoint_url(entity.endpoint_id),
parameters=entity.parameters,
properties=entity.properties,
credential_type=credential_type,
credentials=masked_credentials(
schemas=controller.get_credentials_schema(credential_type),
credentials=entity.credentials,
)
if controller.get_subscription_constructor()
else {},
)
@classmethod
def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None:
"""
Get a trigger subscription by the endpoint ID.
"""
cache_key = cls.encode_cache_key(endpoint_id)
subscription_cache = redis_client.get(cache_key)
if subscription_cache:
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
return None
@classmethod
def append_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
"""Append validation request log to Redis."""
log = RequestLog(
id=str(uuid.uuid4()),
endpoint=endpoint_id,
request={
"method": request.method,
"url": request.url,
"headers": dict(request.headers),
"data": request.get_data(as_text=True),
},
response={
"status_code": response.status_code,
"headers": dict(response.headers),
"data": response.get_data(as_text=True),
},
created_at=datetime.now(),
)
key = f"trigger:subscription:builder:logs:{endpoint_id}"
logs = json.loads(redis_client.get(key) or "[]")
logs.append(log.model_dump(mode="json"))
# Keep last N logs
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, json.dumps(logs, default=str))
@classmethod
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
"""List request logs for validation endpoint."""
key = f"trigger:subscription:builder:logs:{endpoint_id}"
logs_json = redis_client.get(key)
if not logs_json:
return []
return [RequestLog.model_validate(log) for log in json.loads(logs_json)]
@classmethod
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
"""
Process a temporary endpoint request.
:param endpoint_id: The endpoint identifier
:param request: The Flask request object
:return: The Flask response object
"""
# check if validation endpoint exists
subscription_builder: SubscriptionBuilder | None = cls.get_subscription_builder(endpoint_id)
if not subscription_builder:
return None
# response to validation endpoint
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id)
)
try:
dispatch_response: TriggerDispatchResponse = controller.dispatch(
request=request,
subscription=subscription_builder.to_subscription(),
credentials={},
credential_type=CredentialType.UNAUTHORIZED,
)
response: Response = dispatch_response.response
# append the request log
cls.append_log(
endpoint_id=endpoint_id,
request=request,
response=response,
)
return response
except Exception:
logger.exception("Error during validation endpoint dispatch for endpoint_id=%s", endpoint_id)
error_response = Response(status=500, response="An internal error has occurred.")
cls.append_log(endpoint_id=endpoint_id, request=request, response=error_response)
return error_response
@classmethod
def get_subscription_builder_by_id(cls, subscription_builder_id: str) -> SubscriptionBuilderApiEntity:
"""Get a trigger subscription builder API entity."""
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
if not subscription_builder:
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
return cls.builder_to_api_entity(
controller=TriggerManager.get_trigger_provider(
subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id)
),
entity=subscription_builder,
)

View File

@ -0,0 +1,70 @@
from sqlalchemy import and_, select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.enums import AppTriggerStatus
from models.trigger import AppTrigger, WorkflowPluginTrigger
class TriggerSubscriptionOperatorService:
@classmethod
def get_subscriber_triggers(
cls, tenant_id: str, subscription_id: str, event_name: str
) -> list[WorkflowPluginTrigger]:
"""
Get WorkflowPluginTriggers for a subscription and trigger.
Args:
tenant_id: Tenant ID
subscription_id: Subscription ID
event_name: Event name
"""
with Session(db.engine, expire_on_commit=False) as session:
subscribers = session.scalars(
select(WorkflowPluginTrigger)
.join(
AppTrigger,
and_(
AppTrigger.tenant_id == WorkflowPluginTrigger.tenant_id,
AppTrigger.app_id == WorkflowPluginTrigger.app_id,
AppTrigger.node_id == WorkflowPluginTrigger.node_id,
),
)
.where(
WorkflowPluginTrigger.tenant_id == tenant_id,
WorkflowPluginTrigger.subscription_id == subscription_id,
WorkflowPluginTrigger.event_name == event_name,
AppTrigger.status == AppTriggerStatus.ENABLED,
)
).all()
return list(subscribers)
@classmethod
def delete_plugin_trigger_by_subscription(
cls,
session: Session,
tenant_id: str,
subscription_id: str,
) -> None:
"""Delete a plugin trigger by tenant_id and subscription_id within an existing session
Args:
session: Database session
tenant_id: The tenant ID
subscription_id: The subscription ID
Raises:
NotFound: If plugin trigger not found
"""
# Find plugin trigger using indexed columns
plugin_trigger = session.scalar(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.tenant_id == tenant_id,
WorkflowPluginTrigger.subscription_id == subscription_id,
)
)
if not plugin_trigger:
return
session.delete(plugin_trigger)

View File

@ -0,0 +1,871 @@
import json
import logging
import mimetypes
import secrets
from collections.abc import Mapping
from typing import Any
from flask import request
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import RequestEntityTooLarge
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.variables.types import SegmentType
from core.workflow.enums import NodeType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import file_factory
from models.enums import AppTriggerStatus, AppTriggerType
from models.model import App
from models.trigger import AppTrigger, WorkflowWebhookTrigger
from models.workflow import Workflow
from services.async_workflow_service import AsyncWorkflowService
from services.end_user_service import EndUserService
from services.workflow.entities import WebhookTriggerData
logger = logging.getLogger(__name__)
class WebhookService:
"""Service for handling webhook operations."""
__WEBHOOK_NODE_CACHE_KEY__ = "webhook_nodes"
MAX_WEBHOOK_NODES_PER_WORKFLOW = 5 # Maximum allowed webhook nodes per workflow
@staticmethod
def _sanitize_key(key: str) -> str:
"""Normalize external keys (headers/params) to workflow-safe variables."""
if not isinstance(key, str):
return key
return key.replace("-", "_")
@classmethod
def get_webhook_trigger_and_workflow(
cls, webhook_id: str, is_debug: bool = False
) -> tuple[WorkflowWebhookTrigger, Workflow, Mapping[str, Any]]:
"""Get webhook trigger, workflow, and node configuration.
Args:
webhook_id: The webhook ID to look up
is_debug: If True, use the draft workflow graph and skip the trigger enabled status check
Returns:
A tuple containing:
- WorkflowWebhookTrigger: The webhook trigger object
- Workflow: The associated workflow object
- Mapping[str, Any]: The node configuration data
Raises:
ValueError: If webhook not found, app trigger not found, trigger disabled, or workflow not found
"""
with Session(db.engine) as session:
# Get webhook trigger
webhook_trigger = (
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first()
)
if not webhook_trigger:
raise ValueError(f"Webhook not found: {webhook_id}")
if is_debug:
workflow = (
session.query(Workflow)
.filter(
Workflow.app_id == webhook_trigger.app_id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.order_by(Workflow.created_at.desc())
.first()
)
else:
# Check if the corresponding AppTrigger exists
app_trigger = (
session.query(AppTrigger)
.filter(
AppTrigger.app_id == webhook_trigger.app_id,
AppTrigger.node_id == webhook_trigger.node_id,
AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK,
)
.first()
)
if not app_trigger:
raise ValueError(f"App trigger not found for webhook {webhook_id}")
# Only check enabled status if not in debug mode
if app_trigger.status != AppTriggerStatus.ENABLED:
raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}")
# Get workflow
workflow = (
session.query(Workflow)
.filter(
Workflow.app_id == webhook_trigger.app_id,
Workflow.version != Workflow.VERSION_DRAFT,
)
.order_by(Workflow.created_at.desc())
.first()
)
if not workflow:
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")
node_config = workflow.get_node_config_by_id(webhook_trigger.node_id)
return webhook_trigger, workflow, node_config
@classmethod
def extract_and_validate_webhook_data(
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: Mapping[str, Any]
) -> dict[str, Any]:
"""Extract and validate webhook data in a single unified process.
Args:
webhook_trigger: The webhook trigger object containing metadata
node_config: The node configuration containing validation rules
Returns:
dict[str, Any]: Processed and validated webhook data with correct types
Raises:
ValueError: If validation fails (HTTP method mismatch, missing required fields, type errors)
"""
# Extract raw data first
raw_data = cls.extract_webhook_data(webhook_trigger)
# Validate HTTP metadata (method, content-type)
node_data = node_config.get("data", {})
validation_result = cls._validate_http_metadata(raw_data, node_data)
if not validation_result["valid"]:
raise ValueError(validation_result["error"])
# Process and validate data according to configuration
processed_data = cls._process_and_validate_data(raw_data, node_data)
return processed_data
@classmethod
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]:
"""Extract raw data from incoming webhook request without type conversion.
Args:
webhook_trigger: The webhook trigger object for file processing context
Returns:
dict[str, Any]: Raw webhook data containing:
- method: HTTP method
- headers: Request headers
- query_params: Query parameters as strings
- body: Request body (varies by content type)
- files: Uploaded files (if any)
"""
cls._validate_content_length()
data = {
"method": request.method,
"headers": dict(request.headers),
"query_params": dict(request.args),
"body": {},
"files": {},
}
# Extract and normalize content type
content_type = cls._extract_content_type(dict(request.headers))
# Route to appropriate extractor based on content type
extractors = {
"application/json": cls._extract_json_body,
"application/x-www-form-urlencoded": cls._extract_form_body,
"multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger),
"application/octet-stream": lambda: cls._extract_octet_stream_body(webhook_trigger),
"text/plain": cls._extract_text_body,
}
extractor = extractors.get(content_type)
if not extractor:
# Default to text/plain for unknown content types
logger.warning("Unknown Content-Type: %s, treating as text/plain", content_type)
extractor = cls._extract_text_body
# Extract body and files
body_data, files_data = extractor()
data["body"] = body_data
data["files"] = files_data
return data
@classmethod
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
"""Process and validate webhook data according to node configuration.
Args:
raw_data: Raw webhook data from extraction
node_data: Node configuration containing validation and type rules
Returns:
dict[str, Any]: Processed data with validated types
Raises:
ValueError: If validation fails or required fields are missing
"""
result = raw_data.copy()
# Validate and process headers
cls._validate_required_headers(raw_data["headers"], node_data.get("headers", []))
# Process query parameters with type conversion and validation
result["query_params"] = cls._process_parameters(
raw_data["query_params"], node_data.get("params", []), is_form_data=True
)
# Process body parameters based on content type
configured_content_type = node_data.get("content_type", "application/json").lower()
result["body"] = cls._process_body_parameters(
raw_data["body"], node_data.get("body", []), configured_content_type
)
return result
@classmethod
def _validate_content_length(cls) -> None:
"""Validate request content length against maximum allowed size."""
content_length = request.content_length
if content_length and content_length > dify_config.WEBHOOK_REQUEST_BODY_MAX_SIZE:
raise RequestEntityTooLarge(
f"Webhook request too large: {content_length} bytes exceeds maximum allowed size "
f"of {dify_config.WEBHOOK_REQUEST_BODY_MAX_SIZE} bytes"
)
@classmethod
def _extract_json_body(cls) -> tuple[dict[str, Any], dict[str, Any]]:
"""Extract JSON body from request.
Returns:
tuple: (body_data, files_data) where:
- body_data: Parsed JSON content or empty dict if parsing fails
- files_data: Empty dict (JSON requests don't contain files)
"""
try:
body = request.get_json() or {}
except Exception:
logger.warning("Failed to parse JSON body")
body = {}
return body, {}
@classmethod
def _extract_form_body(cls) -> tuple[dict[str, Any], dict[str, Any]]:
"""Extract form-urlencoded body from request.
Returns:
tuple: (body_data, files_data) where:
- body_data: Form data as key-value pairs
- files_data: Empty dict (form-urlencoded requests don't contain files)
"""
return dict(request.form), {}
@classmethod
def _extract_multipart_body(cls, webhook_trigger: WorkflowWebhookTrigger) -> tuple[dict[str, Any], dict[str, Any]]:
"""Extract multipart/form-data body and files from request.
Args:
webhook_trigger: Webhook trigger for file processing context
Returns:
tuple: (body_data, files_data) where:
- body_data: Form data as key-value pairs
- files_data: Processed file objects indexed by field name
"""
body = dict(request.form)
files = cls._process_file_uploads(request.files, webhook_trigger) if request.files else {}
return body, files
@classmethod
def _extract_octet_stream_body(
cls, webhook_trigger: WorkflowWebhookTrigger
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Extract binary data as file from request.
Args:
webhook_trigger: Webhook trigger for file processing context
Returns:
tuple: (body_data, files_data) where:
- body_data: Dict with 'raw' key containing file object or None
- files_data: Empty dict
"""
try:
file_content = request.get_data()
if file_content:
file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger)
return {"raw": file_obj.to_dict()}, {}
else:
return {"raw": None}, {}
except Exception:
logger.exception("Failed to process octet-stream data")
return {"raw": None}, {}
@classmethod
def _extract_text_body(cls) -> tuple[dict[str, Any], dict[str, Any]]:
"""Extract text/plain body from request.
Returns:
tuple: (body_data, files_data) where:
- body_data: Dict with 'raw' key containing text content
- files_data: Empty dict (text requests don't contain files)
"""
try:
body = {"raw": request.get_data(as_text=True)}
except Exception:
logger.warning("Failed to extract text body")
body = {"raw": ""}
return body, {}
@classmethod
def _process_file_uploads(
cls, files: Mapping[str, FileStorage], webhook_trigger: WorkflowWebhookTrigger
) -> dict[str, Any]:
"""Process file uploads using ToolFileManager.
Args:
files: Flask request files object containing uploaded files
webhook_trigger: Webhook trigger for tenant and user context
Returns:
dict[str, Any]: Processed file objects indexed by field name
"""
processed_files = {}
for name, file in files.items():
if file and file.filename:
try:
file_content = file.read()
mimetype = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger)
processed_files[name] = file_obj.to_dict()
except Exception:
logger.exception("Failed to process file upload '%s'", name)
# Continue processing other files
return processed_files
@classmethod
def _create_file_from_binary(
cls, file_content: bytes, mimetype: str, webhook_trigger: WorkflowWebhookTrigger
) -> Any:
"""Create a file object from binary content using ToolFileManager.
Args:
file_content: The binary content of the file
mimetype: The MIME type of the file
webhook_trigger: Webhook trigger for tenant and user context
Returns:
Any: A file object built from the binary content
"""
tool_file_manager = ToolFileManager()
# Create file using ToolFileManager
tool_file = tool_file_manager.create_file_by_raw(
user_id=webhook_trigger.created_by,
tenant_id=webhook_trigger.tenant_id,
conversation_id=None,
file_binary=file_content,
mimetype=mimetype,
)
# Build File object
mapping = {
"tool_file_id": tool_file.id,
"transfer_method": FileTransferMethod.TOOL_FILE.value,
}
return file_factory.build_from_mapping(
mapping=mapping,
tenant_id=webhook_trigger.tenant_id,
)
@classmethod
def _process_parameters(
cls, raw_params: dict[str, str], param_configs: list, is_form_data: bool = False
) -> dict[str, Any]:
"""Process parameters with unified validation and type conversion.
Args:
raw_params: Raw parameter values as strings
param_configs: List of parameter configuration dictionaries
is_form_data: Whether the parameters are from form data (requiring string conversion)
Returns:
dict[str, Any]: Processed parameters with validated types
Raises:
ValueError: If required parameters are missing or validation fails
"""
processed = {}
configured_params = {config.get("name", ""): config for config in param_configs}
# Process configured parameters
for param_config in param_configs:
name = param_config.get("name", "")
param_type = param_config.get("type", SegmentType.STRING)
required = param_config.get("required", False)
# Check required parameters
if required and name not in raw_params:
raise ValueError(f"Required parameter missing: {name}")
if name in raw_params:
raw_value = raw_params[name]
processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data)
# Include unconfigured parameters as strings
for name, value in raw_params.items():
if name not in configured_params:
processed[name] = value
return processed
@classmethod
def _process_body_parameters(
cls, raw_body: dict[str, Any], body_configs: list, content_type: str
) -> dict[str, Any]:
"""Process body parameters based on content type and configuration.
Args:
raw_body: Raw body data from request
body_configs: List of body parameter configuration dictionaries
content_type: The request content type
Returns:
dict[str, Any]: Processed body parameters with validated types
Raises:
ValueError: If required body parameters are missing or validation fails
"""
if content_type in ["text/plain", "application/octet-stream"]:
# For text/plain and octet-stream, validate required content exists
if body_configs and any(config.get("required", False) for config in body_configs):
raw_content = raw_body.get("raw")
if not raw_content:
raise ValueError(f"Required body content missing for {content_type} request")
return raw_body
# For structured data (JSON, form-data, etc.)
processed = {}
configured_params = {config.get("name", ""): config for config in body_configs}
for body_config in body_configs:
name = body_config.get("name", "")
param_type = body_config.get("type", SegmentType.STRING)
required = body_config.get("required", False)
# Handle file parameters for multipart data
if param_type == SegmentType.FILE and content_type == "multipart/form-data":
# File validation is handled separately in extract phase
continue
# Check required parameters
if required and name not in raw_body:
raise ValueError(f"Required body parameter missing: {name}")
if name in raw_body:
raw_value = raw_body[name]
is_form_data = content_type in ["application/x-www-form-urlencoded", "multipart/form-data"]
processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data)
# Include unconfigured parameters
for name, value in raw_body.items():
if name not in configured_params:
processed[name] = value
return processed
@classmethod
def _validate_and_convert_value(cls, param_name: str, value: Any, param_type: str, is_form_data: bool) -> Any:
"""Unified validation and type conversion for parameter values.
Args:
param_name: Name of the parameter for error reporting
value: The value to validate and convert
param_type: The expected parameter type (SegmentType)
is_form_data: Whether the value is from form data (requiring string conversion)
Returns:
Any: The validated and converted value
Raises:
ValueError: If validation or conversion fails
"""
try:
if is_form_data:
# Form data comes as strings and needs conversion
return cls._convert_form_value(param_name, value, param_type)
else:
# JSON data should already be in correct types, just validate
return cls._validate_json_value(param_name, value, param_type)
except Exception as e:
raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}")
@classmethod
def _convert_form_value(cls, param_name: str, value: str, param_type: str) -> Any:
"""Convert form data string values to specified types.
Args:
param_name: Name of the parameter for error reporting
value: The string value to convert
param_type: The target type to convert to (SegmentType)
Returns:
Any: The converted value in the appropriate type
Raises:
ValueError: If the value cannot be converted to the specified type
"""
if param_type == SegmentType.STRING:
return value
elif param_type == SegmentType.NUMBER:
if not cls._can_convert_to_number(value):
raise ValueError(f"Cannot convert '{value}' to number")
numeric_value = float(value)
return int(numeric_value) if numeric_value.is_integer() else numeric_value
elif param_type == SegmentType.BOOLEAN:
lower_value = value.lower()
bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False}
if lower_value not in bool_map:
raise ValueError(f"Cannot convert '{value}' to boolean")
return bool_map[lower_value]
else:
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
@classmethod
def _validate_json_value(cls, param_name: str, value: Any, param_type: str) -> Any:
"""Validate JSON values against expected types.
Args:
param_name: Name of the parameter for error reporting
value: The value to validate
param_type: The expected parameter type (SegmentType)
Returns:
Any: The validated value (unchanged if valid)
Raises:
ValueError: If the value type doesn't match the expected type
"""
type_validators = {
SegmentType.STRING: (lambda v: isinstance(v, str), "string"),
SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"),
SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"),
SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"),
SegmentType.ARRAY_STRING: (
lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v),
"array of strings",
),
SegmentType.ARRAY_NUMBER: (
lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v),
"array of numbers",
),
SegmentType.ARRAY_BOOLEAN: (
lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v),
"array of booleans",
),
SegmentType.ARRAY_OBJECT: (
lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v),
"array of objects",
),
}
validator_info = type_validators.get(SegmentType(param_type))
if not validator_info:
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
return value
validator, expected_type = validator_info
if not validator(value):
actual_type = type(value).__name__
raise ValueError(f"Expected {expected_type}, got {actual_type}")
return value
@classmethod
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: list) -> None:
"""Validate required headers are present.
Args:
headers: Request headers dictionary
header_configs: List of header configuration dictionaries
Raises:
ValueError: If required headers are missing
"""
headers_lower = {k.lower(): v for k, v in headers.items()}
headers_sanitized = {cls._sanitize_key(k).lower(): v for k, v in headers.items()}
for header_config in header_configs:
if header_config.get("required", False):
header_name = header_config.get("name", "")
sanitized_name = cls._sanitize_key(header_name).lower()
if header_name.lower() not in headers_lower and sanitized_name not in headers_sanitized:
raise ValueError(f"Required header missing: {header_name}")
@classmethod
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
"""Validate HTTP method and content-type.
Args:
webhook_data: Extracted webhook data containing method and headers
node_data: Node configuration containing expected method and content-type
Returns:
dict[str, Any]: Validation result with 'valid' key and optional 'error' key
"""
# Validate HTTP method
configured_method = node_data.get("method", "get").upper()
request_method = webhook_data["method"].upper()
if configured_method != request_method:
return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}")
# Validate Content-type
configured_content_type = node_data.get("content_type", "application/json").lower()
request_content_type = cls._extract_content_type(webhook_data["headers"])
if configured_content_type != request_content_type:
return cls._validation_error(
f"Content-type mismatch. Expected {configured_content_type}, got {request_content_type}"
)
return {"valid": True}
@classmethod
def _extract_content_type(cls, headers: dict[str, Any]) -> str:
"""Extract and normalize content-type from headers.
Args:
headers: Request headers dictionary
Returns:
str: Normalized content-type (main type without parameters)
"""
content_type = headers.get("Content-Type", "").lower()
if not content_type:
content_type = headers.get("content-type", "application/json").lower()
# Extract the main content type (ignore parameters like boundary)
return content_type.split(";")[0].strip()
@classmethod
def _validation_error(cls, error_message: str) -> dict[str, Any]:
"""Create a standard validation error response.
Args:
error_message: The error message to include
Returns:
dict[str, Any]: Validation error response with 'valid' and 'error' keys
"""
return {"valid": False, "error": error_message}
@classmethod
def _can_convert_to_number(cls, value: str) -> bool:
"""Check if a string can be converted to a number."""
try:
float(value)
return True
except ValueError:
return False
@classmethod
def build_workflow_inputs(cls, webhook_data: dict[str, Any]) -> dict[str, Any]:
"""Construct workflow inputs payload from webhook data.
Args:
webhook_data: Processed webhook data containing headers, query params, and body
Returns:
dict[str, Any]: Workflow inputs formatted for execution
"""
return {
"webhook_data": webhook_data,
"webhook_headers": webhook_data.get("headers", {}),
"webhook_query_params": webhook_data.get("query_params", {}),
"webhook_body": webhook_data.get("body", {}),
}
@classmethod
def trigger_workflow_execution(
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
) -> None:
"""Trigger workflow execution via AsyncWorkflowService.
Args:
webhook_trigger: The webhook trigger object
webhook_data: Processed webhook data for workflow inputs
workflow: The workflow to execute
Raises:
ValueError: If tenant owner is not found
Exception: If workflow execution fails
"""
try:
with Session(db.engine) as session:
# Prepare inputs for the webhook node
# The webhook node expects webhook_data in the inputs
workflow_inputs = cls.build_workflow_inputs(webhook_data)
# Create trigger data
trigger_data = WebhookTriggerData(
app_id=webhook_trigger.app_id,
workflow_id=workflow.id,
root_node_id=webhook_trigger.node_id, # Start from the webhook node
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
)
end_user = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.TRIGGER,
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
user_id=None,
)
# Trigger workflow execution asynchronously
AsyncWorkflowService.trigger_workflow_async(
session,
end_user,
trigger_data,
)
except Exception:
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
raise
@classmethod
def generate_webhook_response(cls, node_config: Mapping[str, Any]) -> tuple[dict[str, Any], int]:
"""Generate HTTP response based on node configuration.
Args:
node_config: Node configuration containing response settings
Returns:
tuple[dict[str, Any], int]: Response data and HTTP status code
"""
node_data = node_config.get("data", {})
# Get configured status code and response body
status_code = node_data.get("status_code", 200)
response_body = node_data.get("response_body", "")
# Parse response body as JSON if it's valid JSON, otherwise return as text
try:
if response_body:
try:
response_data = (
json.loads(response_body)
if response_body.strip().startswith(("{", "["))
else {"message": response_body}
)
except json.JSONDecodeError:
response_data = {"message": response_body}
else:
response_data = {"status": "success", "message": "Webhook processed successfully"}
except:
response_data = {"message": response_body or "Webhook processed successfully"}
return response_data, status_code
@classmethod
def sync_webhook_relationships(cls, app: App, workflow: Workflow):
"""
Sync webhook relationships in DB.
1. Check if the workflow has any webhook trigger nodes
2. Fetch the nodes from DB, see if there were any webhook records already
3. Diff the nodes and the webhook records, create/update/delete the webhook records as needed
Approach:
Frequent DB operations may cause performance issues, using Redis to cache it instead.
If any record exists, cache it.
Limits:
- Maximum 5 webhook nodes per workflow
"""
class Cache(BaseModel):
"""
Cache model for webhook nodes
"""
record_id: str
node_id: str
webhook_id: str
nodes_id_in_graph = [node_id for node_id, _ in workflow.walk_nodes(NodeType.TRIGGER_WEBHOOK)]
# Check webhook node limit
if len(nodes_id_in_graph) > cls.MAX_WEBHOOK_NODES_PER_WORKFLOW:
raise ValueError(
f"Workflow exceeds maximum webhook node limit. "
f"Found {len(nodes_id_in_graph)} webhook nodes, maximum allowed is {cls.MAX_WEBHOOK_NODES_PER_WORKFLOW}"
)
not_found_in_cache: list[str] = []
for node_id in nodes_id_in_graph:
# firstly check if the node exists in cache
if not redis_client.get(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{node_id}"):
not_found_in_cache.append(node_id)
continue
with Session(db.engine) as session:
try:
# lock the concurrent webhook trigger creation
redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowWebhookTrigger).where(
WorkflowWebhookTrigger.app_id == app.id,
WorkflowWebhookTrigger.tenant_id == app.tenant_id,
)
).all()
nodes_id_in_db = {node.node_id: node for node in all_records}
# get the nodes not found both in cache and DB
nodes_not_found = [node_id for node_id in not_found_in_cache if node_id not in nodes_id_in_db]
# create new webhook records
for node_id in nodes_not_found:
webhook_record = WorkflowWebhookTrigger(
app_id=app.id,
tenant_id=app.tenant_id,
node_id=node_id,
webhook_id=cls.generate_webhook_id(),
created_by=app.created_by,
)
session.add(webhook_record)
session.flush()
cache = Cache(record_id=webhook_record.id, node_id=node_id, webhook_id=webhook_record.webhook_id)
redis_client.set(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{node_id}", cache.model_dump_json(), ex=60 * 60)
session.commit()
# delete the nodes not found in the graph
for node_id in nodes_id_in_db:
if node_id not in nodes_id_in_graph:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{node_id}")
session.commit()
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise
finally:
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
@classmethod
def generate_webhook_id(cls) -> str:
"""
Generate unique 24-character webhook ID
Deduplication is not needed, DB already has unique constraint on webhook_id.
"""
# Generate 24-character random string
return secrets.token_urlsafe(18)[:24] # token_urlsafe gives base64url, take first 24 chars

View File

@ -0,0 +1,165 @@
"""
Pydantic models for async workflow trigger system.
"""
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from models.enums import AppTriggerType, WorkflowRunTriggeredFrom
class AsyncTriggerStatus(StrEnum):
"""Async trigger execution status"""
COMPLETED = "completed"
FAILED = "failed"
TIMEOUT = "timeout"
class TriggerMetadata(BaseModel):
"""Trigger metadata"""
type: AppTriggerType = Field(default=AppTriggerType.UNKNOWN)
class TriggerData(BaseModel):
"""Base trigger data model for async workflow execution"""
app_id: str
tenant_id: str
workflow_id: str | None = None
root_node_id: str
inputs: Mapping[str, Any]
files: Sequence[Mapping[str, Any]] = Field(default_factory=list)
trigger_type: AppTriggerType
trigger_from: WorkflowRunTriggeredFrom
trigger_metadata: TriggerMetadata | None = None
model_config = ConfigDict(use_enum_values=True)
class WebhookTriggerData(TriggerData):
"""Webhook-specific trigger data"""
trigger_type: AppTriggerType = AppTriggerType.TRIGGER_WEBHOOK
trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.WEBHOOK
class ScheduleTriggerData(TriggerData):
"""Schedule-specific trigger data"""
trigger_type: AppTriggerType = AppTriggerType.TRIGGER_SCHEDULE
trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.SCHEDULE
class PluginTriggerMetadata(TriggerMetadata):
"""Plugin trigger metadata"""
type: AppTriggerType = AppTriggerType.TRIGGER_PLUGIN
endpoint_id: str
plugin_unique_identifier: str
provider_id: str
event_name: str
icon_filename: str
icon_dark_filename: str
class PluginTriggerData(TriggerData):
"""Plugin webhook trigger data"""
trigger_type: AppTriggerType = AppTriggerType.TRIGGER_PLUGIN
trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.PLUGIN
plugin_id: str
endpoint_id: str
class PluginTriggerDispatchData(BaseModel):
"""Plugin trigger dispatch data for Celery tasks"""
user_id: str
tenant_id: str
endpoint_id: str
provider_id: str
subscription_id: str
timestamp: int
events: list[str]
request_id: str
class WorkflowTaskData(BaseModel):
"""Lightweight data structure for Celery workflow tasks"""
workflow_trigger_log_id: str # Primary tracking ID - all other data can be fetched from DB
model_config = ConfigDict(arbitrary_types_allowed=True)
class AsyncTriggerExecutionResult(BaseModel):
"""Result from async trigger-based workflow execution"""
execution_id: str
status: AsyncTriggerStatus
result: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
model_config = ConfigDict(use_enum_values=True)
class AsyncTriggerResponse(BaseModel):
"""Response from triggering an async workflow"""
workflow_trigger_log_id: str
task_id: str
status: str
queue: str
model_config = ConfigDict(use_enum_values=True)
class TriggerLogResponse(BaseModel):
"""Response model for trigger log data"""
id: str
tenant_id: str
app_id: str
workflow_id: str
trigger_type: WorkflowRunTriggeredFrom
status: str
queue_name: str
retry_count: int
celery_task_id: str | None = None
workflow_run_id: str | None = None
error: str | None = None
outputs: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
created_at: str | None = None
triggered_at: str | None = None
finished_at: str | None = None
model_config = ConfigDict(use_enum_values=True)
class WorkflowScheduleCFSPlanEntity(BaseModel):
"""
CFS plan entity.
Ensure each workflow run inside Dify is associated with a CFS(Completely Fair Scheduler) plan.
"""
class Strategy(StrEnum):
"""
CFS plan strategy.
"""
TimeSlice = "time-slice" # time-slice based plan
Nop = "nop" # no plan, just run the workflow
schedule_strategy: Strategy
granularity: int = Field(default=-1) # -1 means infinite

View File

@ -0,0 +1,151 @@
"""
Queue dispatcher system for async workflow execution.
Implements an ABC-based pattern for handling different subscription tiers
with appropriate queue routing and rate limiting.
"""
from abc import ABC, abstractmethod
from enum import StrEnum
from configs import dify_config
from extensions.ext_redis import redis_client
from services.billing_service import BillingService
from services.workflow.rate_limiter import TenantDailyRateLimiter
class QueuePriority(StrEnum):
"""Queue priorities for different subscription tiers"""
PROFESSIONAL = "workflow_professional" # Highest priority
TEAM = "workflow_team"
SANDBOX = "workflow_sandbox" # Free tier
class BaseQueueDispatcher(ABC):
"""Abstract base class for queue dispatchers"""
def __init__(self):
self.rate_limiter = TenantDailyRateLimiter(redis_client)
@abstractmethod
def get_queue_name(self) -> str:
"""Get the queue name for this dispatcher"""
pass
@abstractmethod
def get_daily_limit(self) -> int:
"""Get daily execution limit"""
pass
@abstractmethod
def get_priority(self) -> int:
"""Get task priority level"""
pass
def check_daily_quota(self, tenant_id: str) -> bool:
"""
Check if tenant has remaining daily quota
Args:
tenant_id: The tenant identifier
Returns:
True if quota available, False otherwise
"""
# Check without consuming
remaining = self.rate_limiter.get_remaining_quota(tenant_id=tenant_id, max_daily_limit=self.get_daily_limit())
return remaining > 0
def consume_quota(self, tenant_id: str) -> bool:
"""
Consume one execution from daily quota
Args:
tenant_id: The tenant identifier
Returns:
True if quota consumed successfully, False if limit reached
"""
return self.rate_limiter.check_and_consume(tenant_id=tenant_id, max_daily_limit=self.get_daily_limit())
class ProfessionalQueueDispatcher(BaseQueueDispatcher):
"""Dispatcher for professional tier"""
def get_queue_name(self) -> str:
return QueuePriority.PROFESSIONAL
def get_daily_limit(self) -> int:
return int(1e9)
def get_priority(self) -> int:
return 100
class TeamQueueDispatcher(BaseQueueDispatcher):
"""Dispatcher for team tier"""
def get_queue_name(self) -> str:
return QueuePriority.TEAM
def get_daily_limit(self) -> int:
return int(1e9)
def get_priority(self) -> int:
return 50
class SandboxQueueDispatcher(BaseQueueDispatcher):
"""Dispatcher for free/sandbox tier"""
def get_queue_name(self) -> str:
return QueuePriority.SANDBOX
def get_daily_limit(self) -> int:
return dify_config.APP_DAILY_RATE_LIMIT
def get_priority(self) -> int:
return 10
class QueueDispatcherManager:
"""Factory for creating appropriate dispatcher based on tenant subscription"""
# Mapping of billing plans to dispatchers
PLAN_DISPATCHER_MAP = {
"professional": ProfessionalQueueDispatcher,
"team": TeamQueueDispatcher,
"sandbox": SandboxQueueDispatcher,
# Add new tiers here as they're created
# For any unknown plan, default to sandbox
}
@classmethod
def get_dispatcher(cls, tenant_id: str) -> BaseQueueDispatcher:
"""
Get dispatcher based on tenant's subscription plan
Args:
tenant_id: The tenant identifier
Returns:
Appropriate queue dispatcher instance
"""
if dify_config.BILLING_ENABLED:
try:
billing_info = BillingService.get_info(tenant_id)
plan = billing_info.get("subscription", {}).get("plan", "sandbox")
except Exception:
# If billing service fails, default to sandbox
plan = "sandbox"
else:
# If billing is disabled, use team tier as default
plan = "team"
dispatcher_class = cls.PLAN_DISPATCHER_MAP.get(
plan,
SandboxQueueDispatcher, # Default to sandbox for unknown plans
)
return dispatcher_class() # type: ignore

View File

@ -0,0 +1,183 @@
"""
Day-based rate limiter for workflow executions.
Implements UTC-based daily quotas that reset at midnight UTC for consistent rate limiting.
"""
from datetime import UTC, datetime, time, timedelta
from typing import Union
import pytz
from redis import Redis
from sqlalchemy import select
from extensions.ext_database import db
from extensions.ext_redis import RedisClientWrapper
from models.account import Account, TenantAccountJoin, TenantAccountRole
class TenantDailyRateLimiter:
"""
Day-based rate limiter that resets at midnight UTC
This class provides Redis-based rate limiting with the following features:
- Daily quotas that reset at midnight UTC for consistency
- Atomic check-and-consume operations
- Automatic cleanup of stale counters
- Timezone-aware error messages for better UX
"""
def __init__(self, redis_client: Union[Redis, RedisClientWrapper]):
self.redis = redis_client
def get_tenant_owner_timezone(self, tenant_id: str) -> str:
"""
Get timezone of tenant owner
Args:
tenant_id: The tenant identifier
Returns:
Timezone string (e.g., 'America/New_York', 'UTC')
"""
# Query to get tenant owner's timezone using scalar and select
owner = db.session.scalar(
select(Account)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == TenantAccountRole.OWNER)
)
if not owner:
return "UTC"
return owner.timezone or "UTC"
def _get_day_key(self, tenant_id: str) -> str:
"""
Get Redis key for current UTC day
Args:
tenant_id: The tenant identifier
Returns:
Redis key for the current UTC day
"""
utc_now = datetime.now(UTC)
date_str = utc_now.strftime("%Y-%m-%d")
return f"workflow:daily_limit:{tenant_id}:{date_str}"
def _get_ttl_seconds(self) -> int:
"""
Calculate seconds until UTC midnight
Returns:
Number of seconds until UTC midnight
"""
utc_now = datetime.now(UTC)
# Get next midnight in UTC
next_midnight = datetime.combine(utc_now.date() + timedelta(days=1), time.min)
next_midnight = next_midnight.replace(tzinfo=UTC)
return int((next_midnight - utc_now).total_seconds())
def check_and_consume(self, tenant_id: str, max_daily_limit: int) -> bool:
"""
Check if quota available and consume one execution
Args:
tenant_id: The tenant identifier
max_daily_limit: Maximum daily limit
Returns:
True if quota consumed successfully, False if limit reached
"""
key = self._get_day_key(tenant_id)
ttl = self._get_ttl_seconds()
# Check current usage
current = self.redis.get(key)
if current is None:
# First execution of the day - set to 1
self.redis.setex(key, ttl, 1)
return True
current_count = int(current)
if current_count < max_daily_limit:
# Within limit, increment
new_count = self.redis.incr(key)
# Update TTL
self.redis.expire(key, ttl)
# Double-check in case of race condition
if new_count <= max_daily_limit:
return True
else:
# Race condition occurred, decrement back
self.redis.decr(key)
return False
else:
# Limit exceeded
return False
def get_remaining_quota(self, tenant_id: str, max_daily_limit: int) -> int:
"""
Get remaining quota for the day
Args:
tenant_id: The tenant identifier
max_daily_limit: Maximum daily limit
Returns:
Number of remaining executions for the day
"""
key = self._get_day_key(tenant_id)
used = int(self.redis.get(key) or 0)
return max(0, max_daily_limit - used)
def get_current_usage(self, tenant_id: str) -> int:
"""
Get current usage for the day
Args:
tenant_id: The tenant identifier
Returns:
Number of executions used today
"""
key = self._get_day_key(tenant_id)
return int(self.redis.get(key) or 0)
def reset_quota(self, tenant_id: str) -> bool:
"""
Reset quota for testing purposes
Args:
tenant_id: The tenant identifier
Returns:
True if key was deleted, False if key didn't exist
"""
key = self._get_day_key(tenant_id)
return bool(self.redis.delete(key))
def get_quota_reset_time(self, tenant_id: str, timezone_str: str) -> datetime:
"""
Get the time when quota will reset (next UTC midnight in tenant's timezone)
Args:
tenant_id: The tenant identifier
timezone_str: Tenant's timezone for display purposes
Returns:
Datetime when quota resets (next UTC midnight in tenant's timezone)
"""
tz = pytz.timezone(timezone_str)
utc_now = datetime.now(UTC)
# Get next midnight in UTC, then convert to tenant's timezone
next_utc_midnight = datetime.combine(utc_now.date() + timedelta(days=1), time.min)
next_utc_midnight = pytz.UTC.localize(next_utc_midnight)
return next_utc_midnight.astimezone(tz)

View File

@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
from enum import StrEnum
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
class SchedulerCommand(StrEnum):
"""
Scheduler command.
"""
RESOURCE_LIMIT_REACHED = "resource_limit_reached"
NONE = "none"
class CFSPlanScheduler(ABC):
"""
CFS plan scheduler.
"""
def __init__(self, plan: WorkflowScheduleCFSPlanEntity):
"""
Initialize the CFS plan scheduler.
Args:
plan: The CFS plan.
"""
self.plan = plan
@abstractmethod
def can_schedule(self) -> SchedulerCommand:
"""
Whether a workflow run can be scheduled.
"""

View File

@ -1,12 +1,37 @@
import json
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowExecutionStatus
from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
from models.enums import AppTriggerType, CreatorUserRole
from models.trigger import WorkflowTriggerLog
from services.plugin.plugin_service import PluginService
from services.workflow.entities import TriggerMetadata
# Since the workflow_app_log table has exceeded 100 million records, we use an additional details field to extend it
class LogView:
"""Lightweight wrapper for WorkflowAppLog with computed details.
- Exposes `details_` for marshalling to `details` in API response
- Proxies all other attributes to the underlying `WorkflowAppLog`
"""
def __init__(self, log: WorkflowAppLog, details: dict | None):
self.log = log
self.details_ = details
@property
def details(self) -> dict | None:
return self.details_
def __getattr__(self, name):
return getattr(self.log, name)
class WorkflowAppService:
@ -21,6 +46,7 @@ class WorkflowAppService:
created_at_after: datetime | None = None,
page: int = 1,
limit: int = 20,
detail: bool = False,
created_by_end_user_session_id: str | None = None,
created_by_account: str | None = None,
):
@ -34,6 +60,7 @@ class WorkflowAppService:
:param created_at_after: filter logs created after this timestamp
:param page: page number
:param limit: items per page
:param detail: whether to return detailed logs
:param created_by_end_user_session_id: filter by end user session id
:param created_by_account: filter by account email
:return: Pagination object
@ -43,8 +70,20 @@ class WorkflowAppService:
WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id
)
if detail:
# Simple left join by workflow_run_id to fetch trigger_metadata
stmt = stmt.outerjoin(
WorkflowTriggerLog,
and_(
WorkflowTriggerLog.tenant_id == app_model.tenant_id,
WorkflowTriggerLog.app_id == app_model.id,
WorkflowTriggerLog.workflow_run_id == WorkflowAppLog.workflow_run_id,
),
).add_columns(WorkflowTriggerLog.trigger_metadata)
if keyword or status:
stmt = stmt.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
# Join to workflow run for filtering when needed.
if keyword:
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
@ -108,9 +147,17 @@ class WorkflowAppService:
# Apply pagination limits
offset_stmt = stmt.offset((page - 1) * limit).limit(limit)
# Execute query and get items
items = list(session.scalars(offset_stmt).all())
# wrapper moved to module scope as `LogView`
# Execute query and get items
if detail:
rows = session.execute(offset_stmt).all()
items = [
LogView(log, {"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, meta_val)})
for log, meta_val in rows
]
else:
items = [LogView(log, None) for log in session.scalars(offset_stmt).all()]
return {
"page": page,
"limit": limit,
@ -119,6 +166,31 @@ class WorkflowAppService:
"data": items,
}
def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]:
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
if not metadata:
return {}
trigger_metadata = TriggerMetadata.model_validate(metadata)
if trigger_metadata.type == AppTriggerType.TRIGGER_PLUGIN:
icon = metadata.get("icon_filename")
icon_dark = metadata.get("icon_dark_filename")
metadata["icon"] = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=icon) if icon else None
metadata["icon_dark"] = (
PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=icon_dark) if icon_dark else None
)
return metadata
@staticmethod
def _safe_json_loads(val):
if not val:
return None
if isinstance(val, str):
try:
return json.loads(val)
except Exception:
return None
return val
@staticmethod
def _safe_parse_uuid(value: str):
# fast check

View File

@ -1026,7 +1026,7 @@ class DraftVariableSaver:
return
if self._node_type == NodeType.VARIABLE_ASSIGNER:
draft_vars = self._build_from_variable_assigner_mapping(process_data=process_data)
elif self._node_type == NodeType.START:
elif self._node_type == NodeType.START or self._node_type.is_trigger_node:
draft_vars = self._build_variables_from_start_mapping(outputs)
else:
draft_vars = self._build_variables_from_mapping(outputs)

View File

@ -10,20 +10,22 @@ from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool, WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@ -32,6 +34,7 @@ from extensions.ext_storage import storage
from factories.file_factory import build_from_mapping, build_from_mappings
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import UserFrom
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
@ -211,6 +214,9 @@ class WorkflowService:
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
# validate graph structure
self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=graph)
# create draft workflow if not found
if not workflow:
workflow = Workflow(
@ -267,6 +273,9 @@ class WorkflowService:
if FeatureService.get_system_features().plugin_manager.enabled:
self._validate_workflow_credentials(draft_workflow)
# validate graph structure
self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=draft_workflow.graph_dict)
# create new workflow
workflow = Workflow.new(
tenant_id=app_model.tenant_id,
@ -623,7 +632,7 @@ class WorkflowService:
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
node_data = node_config.get("data", {})
if node_type == NodeType.START:
if node_type.is_start_node:
with Session(bind=db.engine) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
conversation_id = draft_var_srv.get_or_create_conversation(
@ -631,10 +640,11 @@ class WorkflowService:
app=app_model,
workflow=draft_workflow,
)
start_data = StartNodeData.model_validate(node_data)
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
)
if node_type is NodeType.START:
start_data = StartNodeData.model_validate(node_data)
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
)
# init variable pool
variable_pool = _setup_variable_pool(
query=query,
@ -895,6 +905,43 @@ class WorkflowService:
return new_app
def validate_graph_structure(self, user_id: str, app_model: App, graph: Mapping[str, Any]):
"""
Validate workflow graph structure by instantiating the Graph object.
This leverages the built-in graph validators (including trigger/UserInput exclusivity)
and raises any structural errors before persisting the workflow.
"""
node_configs = graph.get("nodes", [])
node_configs = cast(list[dict[str, object]], node_configs)
# is empty graph
if not node_configs:
return
workflow_id = app_model.workflow_id or "UNKNOWN"
Graph.init(
graph_config=graph,
# TODO(Mairuis): Add root node id
root_node_id=None,
node_factory=DifyNodeFactory(
graph_init_params=GraphInitParams(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
workflow_id=workflow_id,
graph_config=graph,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.VALIDATION,
call_depth=0,
),
graph_runtime_state=GraphRuntimeState(
variable_pool=VariablePool(),
start_at=time.perf_counter(),
),
),
)
def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(
@ -997,10 +1044,11 @@ def _setup_variable_pool(
conversation_variables: list[Variable],
):
# Only inject system variables for START node type.
if node_type == NodeType.START:
if node_type == NodeType.START or node_type.is_trigger_node:
system_variable = SystemVariable(
user_id=user_id,
app_id=workflow.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=workflow.id,
files=files or [],
workflow_execution_id=str(uuid.uuid4()),