mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
feat: introduce trigger functionality (#27644)
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: Stream <Stream_2@qq.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do> Co-authored-by: Harry <xh001x@hotmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: yessenia <yessenia.contact@gmail.com> Co-authored-by: hjlarry <hjlarry@163.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WTW0313 <twwu@dify.ai> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@ -26,6 +26,7 @@ from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode
|
||||
from events.app_event import app_model_config_was_updated, app_was_created
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
@ -43,7 +44,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.4.0"
|
||||
CURRENT_DSL_VERSION = "0.5.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
@ -599,6 +600,16 @@ class AppDslService:
|
||||
if not include_secret and data_type == NodeType.AGENT:
|
||||
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
|
||||
tool.pop("credential_id", None)
|
||||
if data_type == NodeType.TRIGGER_SCHEDULE.value:
|
||||
# override the config with the default config
|
||||
node_data["config"] = TriggerScheduleNode.get_default_config()["config"]
|
||||
if data_type == NodeType.TRIGGER_WEBHOOK.value:
|
||||
# clear the webhook_url
|
||||
node_data["webhook_url"] = ""
|
||||
node_data["webhook_debug_url"] = ""
|
||||
if data_type == NodeType.TRIGGER_PLUGIN.value:
|
||||
# clear the subscription_id
|
||||
node_data["subscription_id"] = ""
|
||||
|
||||
export_data["workflow"] = workflow_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||
|
||||
@ -31,6 +31,7 @@ class AppGenerateService:
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
root_node_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
App Content Generate
|
||||
@ -114,6 +115,7 @@ class AppGenerateService:
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
call_depth=0,
|
||||
),
|
||||
),
|
||||
|
||||
323
api/services/async_workflow_service.py
Normal file
323
api/services/async_workflow_service.py
Normal file
@ -0,0 +1,323 @@
|
||||
"""
|
||||
Universal async workflow execution service.
|
||||
|
||||
This service provides a centralized entry point for triggering workflows asynchronously
|
||||
with support for different subscription tiers, rate limiting, and execution tracking.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.model import App, EndUser
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import InvokeDailyRateLimitError, WorkflowNotFoundError
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||
from services.workflow.rate_limiter import TenantDailyRateLimiter
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.async_workflow_tasks import (
|
||||
execute_workflow_professional,
|
||||
execute_workflow_sandbox,
|
||||
execute_workflow_team,
|
||||
)
|
||||
|
||||
|
||||
class AsyncWorkflowService:
|
||||
"""
|
||||
Universal entry point for async workflow execution - ALL METHODS ARE NON-BLOCKING
|
||||
|
||||
This service handles:
|
||||
- Trigger data validation and processing
|
||||
- Queue routing based on subscription tier
|
||||
- Daily rate limiting with timezone support
|
||||
- Execution tracking and logging
|
||||
- Retry mechanisms for failed executions
|
||||
|
||||
Important: All trigger methods return immediately after queuing tasks.
|
||||
Actual workflow execution happens asynchronously in background Celery workers.
|
||||
Use trigger log IDs to monitor execution status and results.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def trigger_workflow_async(
|
||||
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
|
||||
) -> AsyncTriggerResponse:
|
||||
"""
|
||||
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
|
||||
|
||||
Creates a trigger log and dispatches to appropriate queue based on subscription tier.
|
||||
The workflow execution happens asynchronously in the background via Celery workers.
|
||||
This method returns immediately after queuing the task, not after execution completion.
|
||||
|
||||
Args:
|
||||
session: Database session to use for operations
|
||||
user: User (Account or EndUser) who initiated the workflow trigger
|
||||
trigger_data: Validated Pydantic model containing trigger information
|
||||
|
||||
Returns:
|
||||
AsyncTriggerResponse with workflow_trigger_log_id, task_id, status="queued", and queue
|
||||
Note: The actual workflow execution status must be checked separately via workflow_trigger_log_id
|
||||
|
||||
Raises:
|
||||
WorkflowNotFoundError: If app or workflow not found
|
||||
InvokeDailyRateLimitError: If daily rate limit exceeded
|
||||
|
||||
Behavior:
|
||||
- Non-blocking: Returns immediately after queuing
|
||||
- Asynchronous: Actual execution happens in background Celery workers
|
||||
- Status tracking: Use workflow_trigger_log_id to monitor progress
|
||||
- Queue-based: Routes to different queues based on subscription tier
|
||||
"""
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
dispatcher_manager = QueueDispatcherManager()
|
||||
workflow_service = WorkflowService()
|
||||
rate_limiter = TenantDailyRateLimiter(redis_client)
|
||||
|
||||
# 1. Validate app exists
|
||||
app_model = session.scalar(select(App).where(App.id == trigger_data.app_id))
|
||||
if not app_model:
|
||||
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
|
||||
|
||||
# 2. Get workflow
|
||||
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
|
||||
|
||||
# 3. Get dispatcher based on tenant subscription
|
||||
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
|
||||
|
||||
# 4. Rate limiting check will be done without timezone first
|
||||
|
||||
# 5. Determine user role and ID
|
||||
if isinstance(user, Account):
|
||||
created_by_role = CreatorUserRole.ACCOUNT
|
||||
created_by = user.id
|
||||
else: # EndUser
|
||||
created_by_role = CreatorUserRole.END_USER
|
||||
created_by = user.id
|
||||
|
||||
# 6. Create trigger log entry first (for tracking)
|
||||
trigger_log = WorkflowTriggerLog(
|
||||
tenant_id=trigger_data.tenant_id,
|
||||
app_id=trigger_data.app_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=trigger_data.root_node_id,
|
||||
trigger_metadata=(
|
||||
trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}"
|
||||
),
|
||||
trigger_type=trigger_data.trigger_type,
|
||||
trigger_data=trigger_data.model_dump_json(),
|
||||
inputs=json.dumps(dict(trigger_data.inputs)),
|
||||
status=WorkflowTriggerStatus.PENDING,
|
||||
queue_name=dispatcher.get_queue_name(),
|
||||
retry_count=0,
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
trigger_log = trigger_log_repo.create(trigger_log)
|
||||
session.commit()
|
||||
|
||||
# 7. Check and consume daily quota
|
||||
if not dispatcher.consume_quota(trigger_data.tenant_id):
|
||||
# Update trigger log status
|
||||
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
|
||||
trigger_log.error = f"Daily limit reached for {dispatcher.get_queue_name()}"
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
tenant_owner_tz = rate_limiter.get_tenant_owner_timezone(trigger_data.tenant_id)
|
||||
|
||||
remaining = rate_limiter.get_remaining_quota(trigger_data.tenant_id, dispatcher.get_daily_limit())
|
||||
|
||||
reset_time = rate_limiter.get_quota_reset_time(trigger_data.tenant_id, tenant_owner_tz)
|
||||
|
||||
raise InvokeDailyRateLimitError(
|
||||
f"Daily workflow execution limit reached. "
|
||||
f"Limit resets at {reset_time.strftime('%Y-%m-%d %H:%M:%S %Z')}. "
|
||||
f"Remaining quota: {remaining}"
|
||||
)
|
||||
|
||||
# 8. Create task data
|
||||
queue_name = dispatcher.get_queue_name()
|
||||
|
||||
task_data = WorkflowTaskData(workflow_trigger_log_id=trigger_log.id)
|
||||
|
||||
# 9. Dispatch to appropriate queue
|
||||
task_data_dict = task_data.model_dump(mode="json")
|
||||
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict) # type: ignore
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict) # type: ignore
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
|
||||
|
||||
# 10. Update trigger log with task info
|
||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||
trigger_log.celery_task_id = task.id
|
||||
trigger_log.triggered_at = datetime.now(UTC)
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
return AsyncTriggerResponse(
|
||||
workflow_trigger_log_id=trigger_log.id,
|
||||
task_id=task.id, # type: ignore
|
||||
status="queued",
|
||||
queue=queue_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def reinvoke_trigger(
|
||||
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
|
||||
) -> AsyncTriggerResponse:
|
||||
"""
|
||||
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
|
||||
|
||||
Updates the existing trigger log to retry status and creates a new async execution.
|
||||
Returns immediately after queuing the retry, not after execution completion.
|
||||
|
||||
Args:
|
||||
session: Database session to use for operations
|
||||
user: User (Account or EndUser) who initiated the retry
|
||||
workflow_trigger_log_id: ID of the trigger log to re-invoke
|
||||
|
||||
Returns:
|
||||
AsyncTriggerResponse with new execution information (status="queued")
|
||||
Note: This creates a new trigger log entry for the retry attempt
|
||||
|
||||
Raises:
|
||||
ValueError: If trigger log not found
|
||||
|
||||
Behavior:
|
||||
- Non-blocking: Returns immediately after queuing retry
|
||||
- Creates new trigger log: Original log marked as retrying, new log for execution
|
||||
- Preserves original trigger data: Uses same inputs and configuration
|
||||
"""
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
|
||||
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id)
|
||||
|
||||
if not trigger_log:
|
||||
raise ValueError(f"Trigger log not found: {workflow_trigger_log_id}")
|
||||
|
||||
# Reconstruct trigger data from log
|
||||
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
|
||||
|
||||
# Reset log for retry
|
||||
trigger_log.status = WorkflowTriggerStatus.RETRYING
|
||||
trigger_log.retry_count += 1
|
||||
trigger_log.error = None
|
||||
trigger_log.triggered_at = datetime.now(UTC)
|
||||
trigger_log_repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
# Re-trigger workflow (this will create a new trigger log)
|
||||
return cls.trigger_workflow_async(session, user, trigger_data)
|
||||
|
||||
@classmethod
|
||||
def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: str | None = None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get trigger log by ID
|
||||
|
||||
Args:
|
||||
workflow_trigger_log_id: ID of the trigger log
|
||||
tenant_id: Optional tenant ID for security check
|
||||
|
||||
Returns:
|
||||
Trigger log as dictionary or None if not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
|
||||
|
||||
if not trigger_log:
|
||||
return None
|
||||
|
||||
return trigger_log.to_dict()
|
||||
|
||||
@classmethod
|
||||
def get_recent_logs(
|
||||
cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get recent trigger logs
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
app_id: Application ID
|
||||
hours: Number of hours to look back
|
||||
limit: Maximum number of results
|
||||
offset: Number of results to skip
|
||||
|
||||
Returns:
|
||||
List of trigger logs as dictionaries
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
logs = trigger_log_repo.get_recent_logs(
|
||||
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
|
||||
)
|
||||
|
||||
return [log.to_dict() for log in logs]
|
||||
|
||||
@classmethod
|
||||
def get_failed_logs_for_retry(
|
||||
cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get failed logs eligible for retry
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
max_retry_count: Maximum retry count
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of failed trigger logs as dictionaries
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
logs = trigger_log_repo.get_failed_for_retry(
|
||||
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit
|
||||
)
|
||||
|
||||
return [log.to_dict() for log in logs]
|
||||
|
||||
@staticmethod
|
||||
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow:
|
||||
"""
|
||||
Get workflow for the app
|
||||
|
||||
Args:
|
||||
app_model: App model instance
|
||||
workflow_id: Optional specific workflow ID
|
||||
|
||||
Returns:
|
||||
Workflow instance
|
||||
|
||||
Raises:
|
||||
WorkflowNotFoundError: If workflow not found
|
||||
"""
|
||||
if workflow_id:
|
||||
# Get specific published workflow
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
|
||||
else:
|
||||
# Get default published workflow
|
||||
workflow = workflow_service.get_published_workflow(app_model)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")
|
||||
|
||||
return workflow
|
||||
@ -11,9 +11,9 @@ from core.helper import encrypter
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.model_runtime.entities.provider_entities import FormType
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -338,7 +338,7 @@ class DatasourceProviderService:
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
|
||||
tenant_oauth_client_params.client_params = dict(encrypter.encrypt(new_params))
|
||||
|
||||
if enabled is not None:
|
||||
tenant_oauth_client_params.enabled = enabled
|
||||
@ -374,7 +374,7 @@ class DatasourceProviderService:
|
||||
|
||||
def get_tenant_oauth_client(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
|
||||
) -> dict[str, Any] | None:
|
||||
) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
get tenant oauth client
|
||||
"""
|
||||
@ -390,7 +390,7 @@ class DatasourceProviderService:
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
if mask:
|
||||
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
|
||||
return encrypter.mask_plugin_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
|
||||
else:
|
||||
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
||||
return None
|
||||
@ -434,7 +434,7 @@ class DatasourceProviderService:
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
||||
return dict(encrypter.decrypt(tenant_oauth_client_params.client_params))
|
||||
|
||||
provider_controller = self.provider_manager.fetch_datasource_provider(
|
||||
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
|
||||
|
||||
141
api/services/end_user_service.py
Normal file
141
api/services/end_user_service.py
Normal file
@ -0,0 +1,141 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, DefaultEndUserSessionID, EndUser
|
||||
|
||||
|
||||
class EndUserService:
|
||||
"""
|
||||
Service for managing end users.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_or_create_end_user(cls, app_model: App, user_id: str | None = None) -> EndUser:
|
||||
"""
|
||||
Get or create an end user for a given app.
|
||||
"""
|
||||
|
||||
return cls.get_or_create_end_user_by_type(InvokeFrom.SERVICE_API, app_model.tenant_id, app_model.id, user_id)
|
||||
|
||||
@classmethod
|
||||
def get_or_create_end_user_by_type(
|
||||
cls, type: InvokeFrom, tenant_id: str, app_id: str, user_id: str | None = None
|
||||
) -> EndUser:
|
||||
"""
|
||||
Get or create an end user for a given app and type.
|
||||
"""
|
||||
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
end_user = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == tenant_id,
|
||||
EndUser.app_id == app_id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type=type,
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
|
||||
session_id=user_id,
|
||||
external_user_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
@classmethod
|
||||
def create_end_user_batch(
|
||||
cls, type: InvokeFrom, tenant_id: str, app_ids: list[str], user_id: str
|
||||
) -> Mapping[str, EndUser]:
|
||||
"""Create end users in batch.
|
||||
|
||||
Creates end users in batch for the specified tenant and application IDs in O(1) time.
|
||||
|
||||
This batch creation is necessary because trigger subscriptions can span multiple applications,
|
||||
and trigger events may be dispatched to multiple applications simultaneously.
|
||||
|
||||
For each app_id in app_ids, check if an `EndUser` with the given
|
||||
`user_id` (as session_id/external_user_id) already exists for the
|
||||
tenant/app and type `type`. If it exists, return it; otherwise,
|
||||
create it. Operates with minimal DB I/O by querying and inserting in
|
||||
batches.
|
||||
|
||||
Returns a mapping of `app_id -> EndUser`.
|
||||
"""
|
||||
|
||||
# Normalize user_id to default if empty
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
# Deduplicate app_ids while preserving input order
|
||||
seen: set[str] = set()
|
||||
unique_app_ids: list[str] = []
|
||||
for app_id in app_ids:
|
||||
if app_id not in seen:
|
||||
seen.add(app_id)
|
||||
unique_app_ids.append(app_id)
|
||||
|
||||
# Result is a simple app_id -> EndUser mapping
|
||||
result: dict[str, EndUser] = {}
|
||||
if not unique_app_ids:
|
||||
return result
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Fetch existing end users for all target apps in a single query
|
||||
existing_end_users: list[EndUser] = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == tenant_id,
|
||||
EndUser.app_id.in_(unique_app_ids),
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == type,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
found_app_ids: set[str] = set()
|
||||
for eu in existing_end_users:
|
||||
# If duplicates exist due to weak DB constraints, prefer the first
|
||||
if eu.app_id not in result:
|
||||
result[eu.app_id] = eu
|
||||
found_app_ids.add(eu.app_id)
|
||||
|
||||
# Determine which apps still need an EndUser created
|
||||
missing_app_ids = [app_id for app_id in unique_app_ids if app_id not in found_app_ids]
|
||||
|
||||
if missing_app_ids:
|
||||
new_end_users: list[EndUser] = []
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
for app_id in missing_app_ids:
|
||||
new_end_users.append(
|
||||
EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type=type,
|
||||
is_anonymous=is_anonymous,
|
||||
session_id=user_id,
|
||||
external_user_id=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
session.add_all(new_end_users)
|
||||
session.commit()
|
||||
|
||||
for eu in new_end_users:
|
||||
result[eu.app_id] = eu
|
||||
|
||||
return result
|
||||
@ -16,3 +16,9 @@ class WorkflowNotFoundError(Exception):
|
||||
|
||||
class WorkflowIdFormatError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvokeDailyRateLimitError(Exception):
|
||||
"""Raised when daily rate limit is exceeded for workflow invocations."""
|
||||
|
||||
pass
|
||||
|
||||
@ -16,6 +16,7 @@ class OAuthProxyService(BasePluginClient):
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
extra_data: dict = {},
|
||||
credential_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
@ -32,6 +33,7 @@ class OAuthProxyService(BasePluginClient):
|
||||
"""
|
||||
context_id = str(uuid.uuid4())
|
||||
data = {
|
||||
**extra_data,
|
||||
"user_id": user_id,
|
||||
"plugin_id": plugin_id,
|
||||
"tenant_id": tenant_id,
|
||||
|
||||
@ -4,11 +4,16 @@ from typing import Any, Literal
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||
from core.trigger.entities.entities import SubscriptionBuilder
|
||||
from extensions.ext_database import db
|
||||
from models.tools import BuiltinToolProvider
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
|
||||
|
||||
class PluginParameterService:
|
||||
@ -20,7 +25,8 @@ class PluginParameterService:
|
||||
provider: str,
|
||||
action: str,
|
||||
parameter: str,
|
||||
provider_type: Literal["tool"],
|
||||
credential_id: str | None,
|
||||
provider_type: Literal["tool", "trigger"],
|
||||
) -> Sequence[PluginParameterOption]:
|
||||
"""
|
||||
Get dynamic select options for a plugin parameter.
|
||||
@ -33,7 +39,7 @@ class PluginParameterService:
|
||||
parameter: The parameter name.
|
||||
"""
|
||||
credentials: Mapping[str, Any] = {}
|
||||
|
||||
credential_type: str = CredentialType.UNAUTHORIZED.value
|
||||
match provider_type:
|
||||
case "tool":
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
@ -49,24 +55,53 @@ class PluginParameterService:
|
||||
else:
|
||||
# fetch credentials from db
|
||||
with Session(db.engine) as session:
|
||||
db_record = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
if credential_id:
|
||||
db_record = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
db_record = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_record is None:
|
||||
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
|
||||
|
||||
credentials = encrypter.decrypt(db_record.credentials)
|
||||
case _:
|
||||
raise ValueError(f"Invalid provider type: {provider_type}")
|
||||
credential_type = db_record.credential_type
|
||||
case "trigger":
|
||||
subscription: TriggerProviderSubscriptionApiEntity | SubscriptionBuilder | None
|
||||
if credential_id:
|
||||
subscription = TriggerSubscriptionBuilderService.get_subscription_builder(credential_id)
|
||||
if not subscription:
|
||||
trigger_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
|
||||
subscription = trigger_subscription.to_api_entity() if trigger_subscription else None
|
||||
else:
|
||||
trigger_subscription = TriggerProviderService.get_subscription_by_id(tenant_id)
|
||||
subscription = trigger_subscription.to_api_entity() if trigger_subscription else None
|
||||
|
||||
if subscription is None:
|
||||
raise ValueError(f"Subscription {credential_id} not found")
|
||||
|
||||
credentials = subscription.credentials
|
||||
credential_type = subscription.credential_type or CredentialType.UNAUTHORIZED
|
||||
|
||||
return (
|
||||
DynamicSelectClient()
|
||||
.fetch_dynamic_select_options(tenant_id, user_id, plugin_id, provider, action, credentials, parameter)
|
||||
.fetch_dynamic_select_options(
|
||||
tenant_id, user_id, plugin_id, provider, action, credentials, credential_type, parameter
|
||||
)
|
||||
.options
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import marketplace
|
||||
@ -175,6 +176,13 @@ class PluginService:
|
||||
manager = PluginInstaller()
|
||||
return manager.fetch_plugin_installation_by_ids(tenant_id, ids)
|
||||
|
||||
@classmethod
|
||||
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
|
||||
url_prefix = (
|
||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
|
||||
)
|
||||
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
|
||||
|
||||
@staticmethod
|
||||
def get_asset(tenant_id: str, asset_file: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
@ -185,6 +193,11 @@ class PluginService:
|
||||
mime_type, _ = guess_type(asset_file)
|
||||
return manager.fetch_asset(tenant_id, asset_file), mime_type or "application/octet-stream"
|
||||
|
||||
@staticmethod
|
||||
def extract_asset(tenant_id: str, plugin_unique_identifier: str, file_name: str) -> bytes:
|
||||
manager = PluginAssetManager()
|
||||
return manager.extract_asset(tenant_id, plugin_unique_identifier, file_name)
|
||||
|
||||
@staticmethod
|
||||
def check_plugin_unique_identifier(tenant_id: str, plugin_unique_identifier: str) -> bool:
|
||||
"""
|
||||
@ -502,3 +515,11 @@ class PluginService:
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.check_tools_existence(tenant_id, provider_ids)
|
||||
|
||||
@staticmethod
|
||||
def fetch_plugin_readme(tenant_id: str, plugin_unique_identifier: str, language: str) -> str:
|
||||
"""
|
||||
Fetch plugin readme
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.fetch_plugin_readme(tenant_id, plugin_unique_identifier, language)
|
||||
|
||||
@ -300,13 +300,13 @@ class ApiToolManageService:
|
||||
)
|
||||
|
||||
original_credentials = encrypter.decrypt(provider.credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
|
||||
credentials = encrypter.encrypt(credentials)
|
||||
credentials = dict(encrypter.encrypt(credentials))
|
||||
provider.credentials_str = json.dumps(credentials)
|
||||
|
||||
db.session.add(provider)
|
||||
@ -417,7 +417,7 @@ class ApiToolManageService:
|
||||
)
|
||||
decrypted_credentials = encrypter.decrypt(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(decrypted_credentials)
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = decrypted_credentials[name]
|
||||
|
||||
@ -12,6 +12,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import (
|
||||
@ -20,7 +21,6 @@ from core.tools.entities.api_entities import (
|
||||
ToolProviderCredentialApiEntity,
|
||||
ToolProviderCredentialInfoApiEntity,
|
||||
)
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
@ -39,7 +39,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BuiltinToolManageService:
|
||||
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
||||
__DEFAULT_EXPIRES_AT__ = 2147483647
|
||||
|
||||
@staticmethod
|
||||
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
|
||||
@ -278,9 +277,7 @@ class BuiltinToolManageService:
|
||||
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||
credential_type=api_type.value,
|
||||
name=name,
|
||||
expires_at=expires_at
|
||||
if expires_at is not None
|
||||
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
|
||||
expires_at=expires_at if expires_at is not None else -1,
|
||||
)
|
||||
|
||||
session.add(db_provider)
|
||||
@ -353,10 +350,10 @@ class BuiltinToolManageService:
|
||||
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, provider, provider.provider, provider_controller
|
||||
)
|
||||
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
|
||||
decrypt_credential = encrypter.mask_plugin_credentials(encrypter.decrypt(provider.credentials))
|
||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||
provider=provider,
|
||||
credentials=decrypt_credential,
|
||||
credentials=dict(decrypt_credential),
|
||||
)
|
||||
credentials.append(credential_entity)
|
||||
return credentials
|
||||
@ -727,4 +724,4 @@ class BuiltinToolManageService:
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
|
||||
return encrypter.mask_plugin_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
@ -420,7 +421,7 @@ class MCPToolManageService:
|
||||
return json.dumps({"content": icon, "background": icon_background})
|
||||
return icon
|
||||
|
||||
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> dict[str, str]:
|
||||
def _encrypt_dict_fields(self, data: dict[str, Any], secret_fields: list[str], tenant_id: str) -> Mapping[str, str]:
|
||||
"""Encrypt specified fields in a dictionary.
|
||||
|
||||
Args:
|
||||
|
||||
@ -9,7 +9,7 @@ from yarl import URL
|
||||
from configs import dify_config
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@ -19,7 +19,6 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
CredentialType,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
@ -28,18 +27,12 @@ from core.tools.utils.encryption import create_provider_encrypter, create_tool_p
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
@classmethod
|
||||
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
|
||||
url_prefix = (
|
||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
|
||||
)
|
||||
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
|
||||
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(
|
||||
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
|
||||
@ -79,11 +72,9 @@ class ToolTransformService:
|
||||
elif isinstance(provider, ToolProviderApiEntity):
|
||||
if provider.plugin_id:
|
||||
if isinstance(provider.icon, str):
|
||||
provider.icon = ToolTransformService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.icon
|
||||
)
|
||||
provider.icon = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon)
|
||||
if isinstance(provider.icon_dark, str) and provider.icon_dark:
|
||||
provider.icon_dark = ToolTransformService.get_plugin_icon_url(
|
||||
provider.icon_dark = PluginService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.icon_dark
|
||||
)
|
||||
else:
|
||||
@ -97,7 +88,7 @@ class ToolTransformService:
|
||||
elif isinstance(provider, PluginDatasourceProviderEntity):
|
||||
if provider.plugin_id:
|
||||
if isinstance(provider.declaration.identity.icon, str):
|
||||
provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
|
||||
provider.declaration.identity.icon = PluginService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.declaration.identity.icon
|
||||
)
|
||||
|
||||
@ -172,7 +163,7 @@ class ToolTransformService:
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
result.original_credentials = decrypted_credentials
|
||||
@ -345,7 +336,7 @@ class ToolTransformService:
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
|
||||
masked_credentials = encrypter.mask_plugin_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
|
||||
|
||||
312
api/services/trigger/schedule_service.py
Normal file
312
api/services/trigger/schedule_service.py
Normal file
@ -0,0 +1,312 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
|
||||
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from models.workflow import Workflow
|
||||
from services.errors.account import AccountNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScheduleService:
|
||||
@staticmethod
|
||||
def create_schedule(
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
config: ScheduleConfig,
|
||||
) -> WorkflowSchedulePlan:
|
||||
"""
|
||||
Create a new schedule with validated configuration.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
tenant_id: Tenant ID
|
||||
app_id: Application ID
|
||||
config: Validated schedule configuration
|
||||
|
||||
Returns:
|
||||
Created WorkflowSchedulePlan instance
|
||||
"""
|
||||
next_run_at = calculate_next_run_at(
|
||||
config.cron_expression,
|
||||
config.timezone,
|
||||
)
|
||||
|
||||
schedule = WorkflowSchedulePlan(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
node_id=config.node_id,
|
||||
cron_expression=config.cron_expression,
|
||||
timezone=config.timezone,
|
||||
next_run_at=next_run_at,
|
||||
)
|
||||
|
||||
session.add(schedule)
|
||||
session.flush()
|
||||
|
||||
return schedule
|
||||
|
||||
@staticmethod
|
||||
def update_schedule(
|
||||
session: Session,
|
||||
schedule_id: str,
|
||||
updates: SchedulePlanUpdate,
|
||||
) -> WorkflowSchedulePlan:
|
||||
"""
|
||||
Update an existing schedule with validated configuration.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
schedule_id: Schedule ID to update
|
||||
updates: Validated update configuration
|
||||
|
||||
Raises:
|
||||
ScheduleNotFoundError: If schedule not found
|
||||
|
||||
Returns:
|
||||
Updated WorkflowSchedulePlan instance
|
||||
"""
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
|
||||
|
||||
# If time-related fields are updated, synchronously update the next_run_at.
|
||||
time_fields_updated = False
|
||||
|
||||
if updates.node_id is not None:
|
||||
schedule.node_id = updates.node_id
|
||||
|
||||
if updates.cron_expression is not None:
|
||||
schedule.cron_expression = updates.cron_expression
|
||||
time_fields_updated = True
|
||||
|
||||
if updates.timezone is not None:
|
||||
schedule.timezone = updates.timezone
|
||||
time_fields_updated = True
|
||||
|
||||
if time_fields_updated:
|
||||
schedule.next_run_at = calculate_next_run_at(
|
||||
schedule.cron_expression,
|
||||
schedule.timezone,
|
||||
)
|
||||
|
||||
session.flush()
|
||||
return schedule
|
||||
|
||||
@staticmethod
|
||||
def delete_schedule(
|
||||
session: Session,
|
||||
schedule_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a schedule plan.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
schedule_id: Schedule ID to delete
|
||||
"""
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
|
||||
|
||||
session.delete(schedule)
|
||||
session.flush()
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_owner(session: Session, tenant_id: str) -> Account:
|
||||
"""
|
||||
Returns an account to execute scheduled workflows on behalf of the tenant.
|
||||
Prioritizes owner over admin to ensure proper authorization hierarchy.
|
||||
"""
|
||||
result = session.execute(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not result:
|
||||
# Owner may not exist in some tenant configurations, fallback to admin
|
||||
result = session.execute(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == "admin")
|
||||
.limit(1)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if result:
|
||||
account = session.get(Account, result.account_id)
|
||||
if not account:
|
||||
raise AccountNotFoundError(f"Account not found: {result.account_id}")
|
||||
return account
|
||||
else:
|
||||
raise AccountNotFoundError(f"Account not found for tenant: {tenant_id}")
|
||||
|
||||
@staticmethod
|
||||
def update_next_run_at(
|
||||
session: Session,
|
||||
schedule_id: str,
|
||||
) -> datetime:
|
||||
"""
|
||||
Advances the schedule to its next execution time after a successful trigger.
|
||||
Uses current time as base to prevent missing executions during delays.
|
||||
"""
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
|
||||
|
||||
# Base on current time to handle execution delays gracefully
|
||||
next_run_at = calculate_next_run_at(
|
||||
schedule.cron_expression,
|
||||
schedule.timezone,
|
||||
)
|
||||
|
||||
schedule.next_run_at = next_run_at
|
||||
session.flush()
|
||||
return next_run_at
|
||||
|
||||
@staticmethod
|
||||
def to_schedule_config(node_config: Mapping[str, Any]) -> ScheduleConfig:
|
||||
"""
|
||||
Converts user-friendly visual schedule settings to cron expression.
|
||||
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
|
||||
"""
|
||||
node_data = node_config.get("data", {})
|
||||
mode = node_data.get("mode", "visual")
|
||||
timezone = node_data.get("timezone", "UTC")
|
||||
node_id = node_config.get("id", "start")
|
||||
|
||||
cron_expression = None
|
||||
if mode == "cron":
|
||||
cron_expression = node_data.get("cron_expression")
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for cron mode")
|
||||
elif mode == "visual":
|
||||
frequency = str(node_data.get("frequency"))
|
||||
if not frequency:
|
||||
raise ScheduleConfigError("Frequency is required for visual mode")
|
||||
visual_config = VisualConfig(**node_data.get("visual_config", {}))
|
||||
cron_expression = ScheduleService.visual_to_cron(frequency=frequency, visual_config=visual_config)
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for visual mode")
|
||||
else:
|
||||
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")
|
||||
return ScheduleConfig(node_id=node_id, cron_expression=cron_expression, timezone=timezone)
|
||||
|
||||
@staticmethod
|
||||
def extract_schedule_config(workflow: Workflow) -> ScheduleConfig | None:
|
||||
"""
|
||||
Extracts schedule configuration from workflow graph.
|
||||
|
||||
Searches for the first schedule trigger node in the workflow and converts
|
||||
its configuration (either visual or cron mode) into a unified ScheduleConfig.
|
||||
|
||||
Args:
|
||||
workflow: The workflow containing the graph definition
|
||||
|
||||
Returns:
|
||||
ScheduleConfig if a valid schedule node is found, None if no schedule node exists
|
||||
|
||||
Raises:
|
||||
ScheduleConfigError: If graph parsing fails or schedule configuration is invalid
|
||||
|
||||
Note:
|
||||
Currently only returns the first schedule node found.
|
||||
Multiple schedule nodes in the same workflow are not supported.
|
||||
"""
|
||||
try:
|
||||
graph_data = workflow.graph_dict
|
||||
except (json.JSONDecodeError, TypeError, AttributeError) as e:
|
||||
raise ScheduleConfigError(f"Failed to parse workflow graph: {e}")
|
||||
|
||||
if not graph_data:
|
||||
raise ScheduleConfigError("Workflow graph is empty")
|
||||
|
||||
nodes = graph_data.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
|
||||
if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value:
|
||||
continue
|
||||
|
||||
mode = node_data.get("mode", "visual")
|
||||
timezone = node_data.get("timezone", "UTC")
|
||||
node_id = node.get("id", "start")
|
||||
|
||||
cron_expression = None
|
||||
if mode == "cron":
|
||||
cron_expression = node_data.get("cron_expression")
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for cron mode")
|
||||
elif mode == "visual":
|
||||
frequency = node_data.get("frequency")
|
||||
visual_config_dict = node_data.get("visual_config", {})
|
||||
visual_config = VisualConfig(**visual_config_dict)
|
||||
cron_expression = ScheduleService.visual_to_cron(frequency, visual_config)
|
||||
else:
|
||||
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")
|
||||
|
||||
return ScheduleConfig(node_id=node_id, cron_expression=cron_expression, timezone=timezone)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def visual_to_cron(frequency: str, visual_config: VisualConfig) -> str:
|
||||
"""
|
||||
Converts user-friendly visual schedule settings to cron expression.
|
||||
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
|
||||
"""
|
||||
if frequency == "hourly":
|
||||
if visual_config.on_minute is None:
|
||||
raise ScheduleConfigError("on_minute is required for hourly schedules")
|
||||
return f"{visual_config.on_minute} * * * *"
|
||||
|
||||
elif frequency == "daily":
|
||||
if not visual_config.time:
|
||||
raise ScheduleConfigError("time is required for daily schedules")
|
||||
hour, minute = convert_12h_to_24h(visual_config.time)
|
||||
return f"{minute} {hour} * * *"
|
||||
|
||||
elif frequency == "weekly":
|
||||
if not visual_config.time:
|
||||
raise ScheduleConfigError("time is required for weekly schedules")
|
||||
if not visual_config.weekdays:
|
||||
raise ScheduleConfigError("Weekdays are required for weekly schedules")
|
||||
hour, minute = convert_12h_to_24h(visual_config.time)
|
||||
weekday_map = {"sun": "0", "mon": "1", "tue": "2", "wed": "3", "thu": "4", "fri": "5", "sat": "6"}
|
||||
cron_weekdays = [weekday_map[day] for day in visual_config.weekdays]
|
||||
return f"{minute} {hour} * * {','.join(sorted(cron_weekdays))}"
|
||||
|
||||
elif frequency == "monthly":
|
||||
if not visual_config.time:
|
||||
raise ScheduleConfigError("time is required for monthly schedules")
|
||||
if not visual_config.monthly_days:
|
||||
raise ScheduleConfigError("Monthly days are required for monthly schedules")
|
||||
hour, minute = convert_12h_to_24h(visual_config.time)
|
||||
|
||||
numeric_days: list[int] = []
|
||||
has_last = False
|
||||
for day in visual_config.monthly_days:
|
||||
if day == "last":
|
||||
has_last = True
|
||||
else:
|
||||
numeric_days.append(day)
|
||||
|
||||
result_days = [str(d) for d in sorted(set(numeric_days))]
|
||||
if has_last:
|
||||
result_days.append("L")
|
||||
|
||||
return f"{minute} {hour} {','.join(result_days)} * *"
|
||||
|
||||
else:
|
||||
raise ScheduleConfigError(f"Unsupported frequency: {frequency}")
|
||||
687
api/services/trigger/trigger_provider_service.py
Normal file
687
api/services/trigger/trigger_provider_service.py
Normal file
@ -0,0 +1,687 @@
|
||||
import json
|
||||
import logging
|
||||
import time as _time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.trigger.entities.api_entities import (
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
)
|
||||
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import (
|
||||
create_trigger_provider_encrypter_for_properties,
|
||||
create_trigger_provider_encrypter_for_subscription,
|
||||
delete_cache_for_subscription,
|
||||
)
|
||||
from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import (
|
||||
TriggerOAuthSystemClient,
|
||||
TriggerOAuthTenantClient,
|
||||
TriggerSubscription,
|
||||
WorkflowPluginTrigger,
|
||||
)
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
##########################
|
||||
# Trigger provider
|
||||
##########################
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||
|
||||
@classmethod
|
||||
def get_trigger_provider(cls, tenant_id: str, provider: TriggerProviderID) -> TriggerProviderApiEntity:
|
||||
"""Get info for a trigger provider"""
|
||||
return TriggerManager.get_trigger_provider(tenant_id, provider).to_api_entity()
|
||||
|
||||
@classmethod
|
||||
def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]:
|
||||
"""List all trigger providers for the current tenant"""
|
||||
return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)]
|
||||
|
||||
@classmethod
|
||||
def list_trigger_provider_subscriptions(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID
|
||||
) -> list[TriggerProviderSubscriptionApiEntity]:
|
||||
"""List all trigger subscriptions for the current tenant"""
|
||||
subscriptions: list[TriggerProviderSubscriptionApiEntity] = []
|
||||
workflows_in_use_map: dict[str, int] = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get all subscriptions
|
||||
subscriptions_db = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
.order_by(desc(TriggerSubscription.created_at))
|
||||
.all()
|
||||
)
|
||||
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
|
||||
if not subscriptions:
|
||||
return []
|
||||
usage_counts = (
|
||||
session.query(
|
||||
WorkflowPluginTrigger.subscription_id,
|
||||
func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"),
|
||||
)
|
||||
.filter(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]),
|
||||
)
|
||||
.group_by(WorkflowPluginTrigger.subscription_id)
|
||||
.all()
|
||||
)
|
||||
workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts}
|
||||
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
for subscription in subscriptions:
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.credentials = dict(
|
||||
encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
|
||||
)
|
||||
subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
|
||||
subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
|
||||
count = workflows_in_use_map.get(subscription.id)
|
||||
subscription.workflows_in_use = count if count is not None else 0
|
||||
|
||||
return subscriptions
|
||||
|
||||
@classmethod
|
||||
def add_trigger_subscription(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
provider_id: TriggerProviderID,
|
||||
endpoint_id: str,
|
||||
credential_type: CredentialType,
|
||||
parameters: Mapping[str, Any],
|
||||
properties: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
subscription_id: str | None = None,
|
||||
credential_expires_at: int = -1,
|
||||
expires_at: int = -1,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Add a new trigger provider with credentials.
|
||||
Supports multiple credential instances per provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier (e.g., "plugin_id/provider_name")
|
||||
:param credential_type: Type of credential (oauth or api_key)
|
||||
:param credentials: Credential data to encrypt and store
|
||||
:param name: Optional name for this credential instance
|
||||
:param expires_at: OAuth token expiration timestamp
|
||||
:return: Success response
|
||||
"""
|
||||
try:
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Use distributed lock to prevent race conditions
|
||||
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
# Check provider count limit
|
||||
provider_count = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
.count()
|
||||
)
|
||||
|
||||
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
|
||||
raise ValueError(
|
||||
f"Maximum number of providers ({cls.__MAX_TRIGGER_PROVIDER_COUNT__}) "
|
||||
f"reached for {provider_id}"
|
||||
)
|
||||
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
|
||||
credential_encrypter: ProviderConfigEncrypter | None = None
|
||||
if credential_type != CredentialType.UNAUTHORIZED:
|
||||
credential_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
properties_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_properties_schema(),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Create provider record
|
||||
subscription = TriggerSubscription(
|
||||
id=subscription_id or str(uuid.uuid4()),
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
endpoint_id=endpoint_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=parameters,
|
||||
properties=properties_encrypter.encrypt(dict(properties)),
|
||||
credentials=credential_encrypter.encrypt(dict(credentials)) if credential_encrypter else {},
|
||||
credential_type=credential_type.value,
|
||||
credential_expires_at=credential_expires_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
session.add(subscription)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"id": str(subscription.id),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add trigger provider")
|
||||
raise ValueError(str(e))
|
||||
|
||||
@classmethod
|
||||
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
|
||||
"""
|
||||
Get a trigger subscription by the ID.
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
subscription: TriggerSubscription | None = None
|
||||
if subscription_id:
|
||||
subscription = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
else:
|
||||
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first()
|
||||
if subscription:
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.credentials = dict(encrypter.decrypt(subscription.credentials))
|
||||
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id=subscription.tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
|
||||
return subscription
|
||||
|
||||
@classmethod
|
||||
def delete_trigger_provider(cls, session: Session, tenant_id: str, subscription_id: str):
|
||||
"""
|
||||
Delete a trigger provider subscription within an existing session.
|
||||
|
||||
:param session: Database session
|
||||
:param tenant_id: Tenant ID
|
||||
:param subscription_id: Subscription instance ID
|
||||
:return: Success response
|
||||
"""
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
|
||||
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
|
||||
if is_auto_created:
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
try:
|
||||
TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=subscription.user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=encrypter.decrypt(subscription.credentials),
|
||||
credential_type=credential_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error unsubscribing trigger", exc_info=e)
|
||||
|
||||
# Clear cache
|
||||
session.delete(subscription)
|
||||
delete_cache_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def refresh_oauth_token(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Refresh OAuth token for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param subscription_id: Subscription instance ID
|
||||
:return: New token info
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
if subscription.credential_type != CredentialType.OAUTH2.value:
|
||||
raise ValueError("Only OAuth credentials can be refreshed")
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
# Create encrypter
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Decrypt current credentials
|
||||
current_credentials = encrypter.decrypt(subscription.credentials)
|
||||
|
||||
# Get OAuth client configuration
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{subscription.provider_id}/trigger/callback"
|
||||
)
|
||||
system_credentials = cls.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
# Refresh token
|
||||
oauth_handler = OAuthHandler()
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=subscription.user_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=current_credentials,
|
||||
)
|
||||
|
||||
# Update credentials
|
||||
subscription.credentials = dict(encrypter.encrypt(dict(refreshed_credentials.credentials)))
|
||||
subscription.credential_expires_at = refreshed_credentials.expires_at
|
||||
session.commit()
|
||||
|
||||
# Clear cache
|
||||
cache.delete()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"expires_at": refreshed_credentials.expires_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def refresh_subscription(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
now: int | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Refresh trigger subscription if expired.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
subscription_id: Subscription instance ID
|
||||
now: Current timestamp, defaults to `int(time.time())`
|
||||
|
||||
Returns:
|
||||
Mapping with keys: `result` ("success"|"skipped") and `expires_at` (new or existing value)
|
||||
"""
|
||||
now_ts: int = int(now if now is not None else _time.time())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
if subscription is None:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
if subscription.expires_at == -1 or int(subscription.expires_at) > now_ts:
|
||||
logger.debug(
|
||||
"Subscription not due for refresh: tenant=%s id=%s expires_at=%s now=%s",
|
||||
tenant_id,
|
||||
subscription_id,
|
||||
subscription.expires_at,
|
||||
now_ts,
|
||||
)
|
||||
return {"result": "skipped", "expires_at": int(subscription.expires_at)}
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
|
||||
# Decrypt credentials and properties for runtime
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
properties_encrypter, properties_cache = create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id=tenant_id,
|
||||
controller=controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
|
||||
decrypted_credentials = credential_encrypter.decrypt(subscription.credentials)
|
||||
decrypted_properties = properties_encrypter.decrypt(subscription.properties)
|
||||
|
||||
sub_entity: TriggerSubscriptionEntity = TriggerSubscriptionEntity(
|
||||
expires_at=int(subscription.expires_at),
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=subscription.parameters,
|
||||
properties=decrypted_properties,
|
||||
)
|
||||
|
||||
refreshed: TriggerSubscriptionEntity = controller.refresh_trigger(
|
||||
subscription=sub_entity,
|
||||
credentials=decrypted_credentials,
|
||||
credential_type=CredentialType.of(subscription.credential_type),
|
||||
)
|
||||
|
||||
# Persist refreshed properties and expires_at
|
||||
subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties)))
|
||||
subscription.expires_at = int(refreshed.expires_at)
|
||||
session.commit()
|
||||
properties_cache.delete()
|
||||
|
||||
logger.info(
|
||||
"Subscription refreshed (service): tenant=%s id=%s new_expires_at=%s",
|
||||
tenant_id,
|
||||
subscription_id,
|
||||
subscription.expires_at,
|
||||
)
|
||||
|
||||
return {"result": "success", "expires_at": int(refreshed.expires_at)}
|
||||
|
||||
@classmethod
|
||||
def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get OAuth client configuration for a provider.
|
||||
First tries tenant-level OAuth, then falls back to system OAuth.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: OAuth client configuration or None
|
||||
"""
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
tenant_client: TriggerOAuthTenantClient | None = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
oauth_params: Mapping[str, Any] | None = None
|
||||
if tenant_client:
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
oauth_params = encrypter.decrypt(dict(tenant_client.oauth_params))
|
||||
return oauth_params
|
||||
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
|
||||
if not is_verified:
|
||||
return None
|
||||
|
||||
# Check for system-level OAuth client
|
||||
system_client: TriggerOAuthSystemClient | None = (
|
||||
session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
return oauth_params
|
||||
|
||||
@classmethod
|
||||
def is_oauth_system_client_exists(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool:
|
||||
"""
|
||||
Check if system OAuth client exists for a trigger provider.
|
||||
"""
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
|
||||
if not is_verified:
|
||||
return False
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
system_client: TriggerOAuthSystemClient | None = (
|
||||
session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
||||
.first()
|
||||
)
|
||||
return system_client is not None
|
||||
|
||||
@classmethod
|
||||
def save_custom_oauth_client_params(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
client_params: Mapping[str, Any] | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Save or update custom OAuth client parameters for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:param client_params: OAuth client parameters (client_id, client_secret, etc.)
|
||||
:param enabled: Enable/disable the custom OAuth client
|
||||
:return: Success response
|
||||
"""
|
||||
if client_params is None and enabled is None:
|
||||
return {"result": "success"}
|
||||
|
||||
# Get provider controller to access schema
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find existing custom client params
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Create new record if doesn't exist
|
||||
if custom_client is None:
|
||||
custom_client = TriggerOAuthTenantClient(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
session.add(custom_client)
|
||||
|
||||
# Update client params if provided
|
||||
if client_params is None:
|
||||
custom_client.encrypted_oauth_params = json.dumps({})
|
||||
else:
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Handle hidden values
|
||||
original_params = encrypter.decrypt(dict(custom_client.oauth_params))
|
||||
new_params: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
custom_client.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
|
||||
cache.delete()
|
||||
|
||||
# Update enabled status if provided
|
||||
if enabled is not None:
|
||||
custom_client.enabled = enabled
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]:
|
||||
"""
|
||||
Get custom OAuth client parameters for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: Masked OAuth client parameters
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if custom_client is None:
|
||||
return {}
|
||||
|
||||
# Get provider controller to access schema
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
|
||||
# Create encrypter to decrypt and mask values
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.mask_plugin_credentials(encrypter.decrypt(dict(custom_client.oauth_params)))
|
||||
|
||||
@classmethod
|
||||
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> Mapping[str, Any]:
|
||||
"""
|
||||
Delete custom OAuth client parameters for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
session.query(TriggerOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def is_oauth_custom_client_enabled(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool:
|
||||
"""
|
||||
Check if custom OAuth client is enabled for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: True if enabled, False otherwise
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return custom_client is not None
|
||||
|
||||
@classmethod
|
||||
def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None:
|
||||
"""
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
|
||||
if not subscription:
|
||||
return None
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=subscription.tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||
|
||||
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id=subscription.tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
|
||||
return subscription
|
||||
65
api/services/trigger/trigger_request_service.py
Normal file
65
api/services/trigger/trigger_request_service.py
Normal file
@ -0,0 +1,65 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.plugin.utils.http_parser import deserialize_request, serialize_request
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class TriggerHttpRequestCachingService:
|
||||
"""
|
||||
Service for caching trigger requests.
|
||||
"""
|
||||
|
||||
_TRIGGER_STORAGE_PATH = "triggers"
|
||||
|
||||
@classmethod
|
||||
def get_request(cls, request_id: str) -> Request:
|
||||
"""
|
||||
Get the request object from the storage.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
|
||||
Returns:
|
||||
The request object.
|
||||
"""
|
||||
return deserialize_request(storage.load_once(f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.raw"))
|
||||
|
||||
@classmethod
|
||||
def get_payload(cls, request_id: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Get the payload from the storage.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
|
||||
Returns:
|
||||
The payload.
|
||||
"""
|
||||
return TypeAdapter(Mapping[str, Any]).validate_json(
|
||||
storage.load_once(f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.payload")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def persist_request(cls, request_id: str, request: Request) -> None:
|
||||
"""
|
||||
Persist the request in the storage.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
request: The request object.
|
||||
"""
|
||||
storage.save(f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.raw", serialize_request(request))
|
||||
|
||||
@classmethod
|
||||
def persist_payload(cls, request_id: str, payload: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
Persist the payload in the storage.
|
||||
"""
|
||||
storage.save(
|
||||
f"{cls._TRIGGER_STORAGE_PATH}/{request_id}.payload",
|
||||
TypeAdapter(Mapping[str, Any]).dump_json(payload), # type: ignore
|
||||
)
|
||||
307
api/services/trigger/trigger_service.py
Normal file
307
api/services/trigger/trigger_service.py
Normal file
@ -0,0 +1,307 @@
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse
|
||||
from core.plugin.impl.exc import PluginNotFoundError
|
||||
from core.trigger.debug.events import PluginTriggerDebugEvent
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import App
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription, WorkflowPluginTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
|
||||
from services.workflow.entities import PluginTriggerDispatchData
|
||||
from tasks.trigger_processing_tasks import dispatch_triggered_workflows_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerService:
|
||||
__TEMPORARY_ENDPOINT_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
__ENDPOINT_REQUEST_CACHE_COUNT__ = 10
|
||||
__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
__PLUGIN_TRIGGER_NODE_CACHE_KEY__ = "plugin_trigger_nodes"
|
||||
MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW = 5 # Maximum allowed plugin trigger nodes per workflow
|
||||
|
||||
@classmethod
|
||||
def invoke_trigger_event(
|
||||
cls, tenant_id: str, user_id: str, node_config: Mapping[str, Any], event: PluginTriggerDebugEvent
|
||||
) -> TriggerInvokeEventResponse:
|
||||
"""Invoke a trigger event."""
|
||||
subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=event.subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError("Subscription not found")
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(node_config.get("data", {}))
|
||||
request = TriggerHttpRequestCachingService.get_request(event.request_id)
|
||||
payload = TriggerHttpRequestCachingService.get_payload(event.request_id)
|
||||
# invoke triger
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
return TriggerManager.invoke_trigger_event(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=TriggerProviderID(event.provider_id),
|
||||
event_name=event.name,
|
||||
parameters=node_data.resolve_parameters(
|
||||
parameter_schemas=provider_controller.get_event_parameters(event_name=event.name)
|
||||
),
|
||||
credentials=subscription.credentials,
|
||||
credential_type=CredentialType.of(subscription.credential_type),
|
||||
subscription=subscription.to_entity(),
|
||||
request=request,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
"""
|
||||
Extract and process data from incoming endpoint request.
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
request: Request
|
||||
"""
|
||||
timestamp = int(time.time())
|
||||
subscription: TriggerSubscription | None = None
|
||||
try:
|
||||
subscription = TriggerProviderService.get_subscription_by_endpoint(endpoint_id)
|
||||
except PluginNotFoundError:
|
||||
return Response(status=404, response="Trigger provider not found")
|
||||
except Exception:
|
||||
return Response(status=500, response="Failed to get subscription by endpoint")
|
||||
|
||||
if not subscription:
|
||||
return None
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=subscription.tenant_id, provider_id=provider_id
|
||||
)
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=subscription.tenant_id,
|
||||
controller=controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
dispatch_response: TriggerDispatchResponse = controller.dispatch(
|
||||
request=request,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=encrypter.decrypt(subscription.credentials),
|
||||
credential_type=CredentialType.of(subscription.credential_type),
|
||||
)
|
||||
|
||||
if dispatch_response.events:
|
||||
request_id = f"trigger_request_{timestamp}_{secrets.token_hex(6)}"
|
||||
|
||||
# save the request and payload to storage as persistent data
|
||||
TriggerHttpRequestCachingService.persist_request(request_id, request)
|
||||
TriggerHttpRequestCachingService.persist_payload(request_id, dispatch_response.payload)
|
||||
|
||||
# Validate event names
|
||||
for event_name in dispatch_response.events:
|
||||
if controller.get_event(event_name) is None:
|
||||
logger.error(
|
||||
"Event name %s not found in provider %s for endpoint %s",
|
||||
event_name,
|
||||
subscription.provider_id,
|
||||
endpoint_id,
|
||||
)
|
||||
raise ValueError(f"Event name {event_name} not found in provider {subscription.provider_id}")
|
||||
|
||||
plugin_trigger_dispatch_data = PluginTriggerDispatchData(
|
||||
user_id=dispatch_response.user_id,
|
||||
tenant_id=subscription.tenant_id,
|
||||
endpoint_id=endpoint_id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
timestamp=timestamp,
|
||||
events=list(dispatch_response.events),
|
||||
request_id=request_id,
|
||||
)
|
||||
dispatch_data = plugin_trigger_dispatch_data.model_dump(mode="json")
|
||||
dispatch_triggered_workflows_async.delay(dispatch_data)
|
||||
|
||||
logger.info(
|
||||
"Queued async dispatching for %d triggers on endpoint %s with request_id %s",
|
||||
len(dispatch_response.events),
|
||||
endpoint_id,
|
||||
request_id,
|
||||
)
|
||||
return dispatch_response.response
|
||||
|
||||
@classmethod
|
||||
def sync_plugin_trigger_relationships(cls, app: App, workflow: Workflow):
|
||||
"""
|
||||
Sync plugin trigger relationships in DB.
|
||||
|
||||
1. Check if the workflow has any plugin trigger nodes
|
||||
2. Fetch the nodes from DB, see if there were any plugin trigger records already
|
||||
3. Diff the nodes and the plugin trigger records, create/update/delete the records as needed
|
||||
|
||||
Approach:
|
||||
Frequent DB operations may cause performance issues, using Redis to cache it instead.
|
||||
If any record exists, cache it.
|
||||
|
||||
Limits:
|
||||
- Maximum 5 plugin trigger nodes per workflow
|
||||
"""
|
||||
|
||||
class Cache(BaseModel):
|
||||
"""
|
||||
Cache model for plugin trigger nodes
|
||||
"""
|
||||
|
||||
record_id: str
|
||||
node_id: str
|
||||
provider_id: str
|
||||
event_name: str
|
||||
subscription_id: str
|
||||
|
||||
# Walk nodes to find plugin triggers
|
||||
nodes_in_graph: list[Mapping[str, Any]] = []
|
||||
for node_id, node_config in workflow.walk_nodes(NodeType.TRIGGER_PLUGIN):
|
||||
# Extract plugin trigger configuration from node
|
||||
plugin_id = node_config.get("plugin_id", "")
|
||||
provider_id = node_config.get("provider_id", "")
|
||||
event_name = node_config.get("event_name", "")
|
||||
subscription_id = node_config.get("subscription_id", "")
|
||||
|
||||
if not subscription_id:
|
||||
continue
|
||||
|
||||
nodes_in_graph.append(
|
||||
{
|
||||
"node_id": node_id,
|
||||
"plugin_id": plugin_id,
|
||||
"provider_id": provider_id,
|
||||
"event_name": event_name,
|
||||
"subscription_id": subscription_id,
|
||||
}
|
||||
)
|
||||
|
||||
# Check plugin trigger node limit
|
||||
if len(nodes_in_graph) > cls.MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW:
|
||||
raise ValueError(
|
||||
f"Workflow exceeds maximum plugin trigger node limit. "
|
||||
f"Found {len(nodes_in_graph)} plugin trigger nodes, "
|
||||
f"maximum allowed is {cls.MAX_PLUGIN_TRIGGER_NODES_PER_WORKFLOW}"
|
||||
)
|
||||
|
||||
not_found_in_cache: list[Mapping[str, Any]] = []
|
||||
for node_info in nodes_in_graph:
|
||||
node_id = node_info["node_id"]
|
||||
# firstly check if the node exists in cache
|
||||
if not redis_client.get(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_id}"):
|
||||
not_found_in_cache.append(node_info)
|
||||
continue
|
||||
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
# lock the concurrent plugin trigger creation
|
||||
redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
|
||||
# fetch the non-cached nodes from DB
|
||||
all_records = session.scalars(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app.id,
|
||||
WorkflowPluginTrigger.tenant_id == app.tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
nodes_id_in_db = {node.node_id: node for node in all_records}
|
||||
nodes_id_in_graph = {node["node_id"] for node in nodes_in_graph}
|
||||
|
||||
# get the nodes not found both in cache and DB
|
||||
nodes_not_found = [
|
||||
node_info for node_info in not_found_in_cache if node_info["node_id"] not in nodes_id_in_db
|
||||
]
|
||||
|
||||
# create new plugin trigger records
|
||||
for node_info in nodes_not_found:
|
||||
plugin_trigger = WorkflowPluginTrigger(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
node_id=node_info["node_id"],
|
||||
provider_id=node_info["provider_id"],
|
||||
event_name=node_info["event_name"],
|
||||
subscription_id=node_info["subscription_id"],
|
||||
)
|
||||
session.add(plugin_trigger)
|
||||
session.flush() # Get the ID for caching
|
||||
|
||||
cache = Cache(
|
||||
record_id=plugin_trigger.id,
|
||||
node_id=node_info["node_id"],
|
||||
provider_id=node_info["provider_id"],
|
||||
event_name=node_info["event_name"],
|
||||
subscription_id=node_info["subscription_id"],
|
||||
)
|
||||
redis_client.set(
|
||||
f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_info['node_id']}",
|
||||
cache.model_dump_json(),
|
||||
ex=60 * 60,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Update existing records if subscription_id changed
|
||||
for node_info in nodes_in_graph:
|
||||
node_id = node_info["node_id"]
|
||||
if node_id in nodes_id_in_db:
|
||||
existing_record = nodes_id_in_db[node_id]
|
||||
if (
|
||||
existing_record.subscription_id != node_info["subscription_id"]
|
||||
or existing_record.provider_id != node_info["provider_id"]
|
||||
or existing_record.event_name != node_info["event_name"]
|
||||
):
|
||||
existing_record.subscription_id = node_info["subscription_id"]
|
||||
existing_record.provider_id = node_info["provider_id"]
|
||||
existing_record.event_name = node_info["event_name"]
|
||||
session.add(existing_record)
|
||||
|
||||
# Update cache
|
||||
cache = Cache(
|
||||
record_id=existing_record.id,
|
||||
node_id=node_id,
|
||||
provider_id=node_info["provider_id"],
|
||||
event_name=node_info["event_name"],
|
||||
subscription_id=node_info["subscription_id"],
|
||||
)
|
||||
redis_client.set(
|
||||
f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_id}",
|
||||
cache.model_dump_json(),
|
||||
ex=60 * 60,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# delete the nodes not found in the graph
|
||||
for node_id in nodes_id_in_db:
|
||||
if node_id not in nodes_id_in_graph:
|
||||
session.delete(nodes_id_in_db[node_id])
|
||||
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{node_id}")
|
||||
session.commit()
|
||||
except Exception:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.exception("Failed to sync plugin trigger relationships for app %s", app.id)
|
||||
raise
|
||||
finally:
|
||||
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock")
|
||||
492
api/services/trigger/trigger_subscription_builder_service.py
Normal file
492
api/services/trigger/trigger_subscription_builder_service.py
Normal file
@ -0,0 +1,492 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerDispatchResponse
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity
|
||||
from core.trigger.entities.entities import (
|
||||
RequestLog,
|
||||
Subscription,
|
||||
SubscriptionBuilder,
|
||||
SubscriptionBuilderUpdater,
|
||||
SubscriptionConstructor,
|
||||
)
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import masked_credentials
|
||||
from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
##########################
|
||||
# Trigger provider
|
||||
##########################
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||
|
||||
##########################
|
||||
# Builder endpoint
|
||||
##########################
|
||||
__BUILDER_CACHE_EXPIRE_SECONDS__ = 30 * 60
|
||||
|
||||
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
|
||||
__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__ = 30 * 60
|
||||
|
||||
##########################
|
||||
# Distributed lock
|
||||
##########################
|
||||
__LOCK_EXPIRE_SECONDS__ = 30
|
||||
|
||||
@classmethod
|
||||
def encode_cache_key(cls, subscription_id: str) -> str:
|
||||
return f"trigger:subscription:builder:{subscription_id}"
|
||||
|
||||
@classmethod
|
||||
def encode_lock_key(cls, subscription_id: str) -> str:
|
||||
return f"trigger:subscription:builder:lock:{subscription_id}"
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def acquire_builder_lock(cls, subscription_id: str):
|
||||
"""
|
||||
Acquire a distributed lock for a subscription builder.
|
||||
|
||||
:param subscription_id: The subscription builder ID
|
||||
"""
|
||||
lock_key = cls.encode_lock_key(subscription_id)
|
||||
with redis_client.lock(lock_key, timeout=cls.__LOCK_EXPIRE_SECONDS__):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
def verify_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
) -> Mapping[str, Any]:
|
||||
"""Verify a trigger subscription builder"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if subscription_builder.credential_type == CredentialType.OAUTH2:
|
||||
return {"verified": bool(subscription_builder.credentials)}
|
||||
|
||||
if subscription_builder.credential_type == CredentialType.API_KEY:
|
||||
credentials_to_validate = subscription_builder.credentials
|
||||
try:
|
||||
provider_controller.validate_credentials(user_id, credentials_to_validate)
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
raise ValueError(f"Invalid credentials: {e}")
|
||||
return {"verified": True}
|
||||
|
||||
return {"verified": True}
|
||||
|
||||
@classmethod
|
||||
def build_trigger_subscription_builder(
|
||||
cls, tenant_id: str, user_id: str, provider_id: TriggerProviderID, subscription_builder_id: str
|
||||
) -> None:
|
||||
"""Build a trigger subscription builder"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Acquire lock to prevent concurrent build operations
|
||||
with cls.acquire_builder_lock(subscription_builder_id):
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if not subscription_builder.name:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
credential_type = CredentialType.of(
|
||||
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
|
||||
)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription_builder.properties,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# automatically create
|
||||
subscription: Subscription = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription_builder.endpoint_id),
|
||||
parameters=subscription_builder.parameters,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription.properties,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
)
|
||||
|
||||
# Delete the builder after successful subscription creation
|
||||
cache_key = cls.encode_cache_key(subscription_builder_id)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def create_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
credential_type: CredentialType,
|
||||
) -> SubscriptionBuilderApiEntity:
|
||||
"""
|
||||
Add a new trigger subscription validation.
|
||||
"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_constructor: SubscriptionConstructor | None = provider_controller.get_subscription_constructor()
|
||||
subscription_id = str(uuid.uuid4())
|
||||
subscription_builder = SubscriptionBuilder(
|
||||
id=subscription_id,
|
||||
name=None,
|
||||
endpoint_id=subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=subscription_constructor.get_default_parameters() if subscription_constructor else {},
|
||||
properties=provider_controller.get_subscription_default_properties(),
|
||||
credentials={},
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=-1,
|
||||
expires_at=-1,
|
||||
)
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
redis_client.setex(cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder.model_dump_json())
|
||||
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder)
|
||||
|
||||
@classmethod
|
||||
def update_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
subscription_builder_updater: SubscriptionBuilderUpdater,
|
||||
) -> SubscriptionBuilderApiEntity:
|
||||
"""
|
||||
Update a trigger subscription validation.
|
||||
"""
|
||||
subscription_id = subscription_builder_id
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Acquire lock to prevent concurrent updates
|
||||
with cls.acquire_builder_lock(subscription_id):
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
|
||||
redis_client.setex(
|
||||
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache)
|
||||
|
||||
@classmethod
|
||||
def update_and_verify_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
subscription_builder_updater: SubscriptionBuilderUpdater,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Atomically update and verify a subscription builder.
|
||||
This ensures the verification is done on the exact data that was just updated.
|
||||
"""
|
||||
subscription_id = subscription_builder_id
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Acquire lock for the entire update + verify operation
|
||||
with cls.acquire_builder_lock(subscription_id):
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
# Update
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
redis_client.setex(
|
||||
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
|
||||
# Verify (using the just-updated data)
|
||||
if subscription_builder_cache.credential_type == CredentialType.OAUTH2:
|
||||
return {"verified": bool(subscription_builder_cache.credentials)}
|
||||
|
||||
if subscription_builder_cache.credential_type == CredentialType.API_KEY:
|
||||
credentials_to_validate = subscription_builder_cache.credentials
|
||||
try:
|
||||
provider_controller.validate_credentials(user_id, credentials_to_validate)
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
raise ValueError(f"Invalid credentials: {e}")
|
||||
return {"verified": True}
|
||||
|
||||
return {"verified": True}
|
||||
|
||||
@classmethod
|
||||
def update_and_build_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
subscription_builder_updater: SubscriptionBuilderUpdater,
|
||||
) -> None:
|
||||
"""
|
||||
Atomically update and build a subscription builder.
|
||||
This ensures the build uses the exact data that was just updated.
|
||||
"""
|
||||
subscription_id = subscription_builder_id
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Acquire lock for the entire update + build operation
|
||||
with cls.acquire_builder_lock(subscription_id):
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
# Update
|
||||
subscription_builder_updater.update(subscription_builder_cache)
|
||||
redis_client.setex(
|
||||
cache_key, cls.__BUILDER_CACHE_EXPIRE_SECONDS__, subscription_builder_cache.model_dump_json()
|
||||
)
|
||||
|
||||
# Re-fetch to ensure we have the latest data
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if not subscription_builder.name:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
# Build
|
||||
credential_type = CredentialType.of(
|
||||
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
|
||||
)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription_builder.properties,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# automatically create
|
||||
subscription: Subscription = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription_builder.endpoint_id),
|
||||
parameters=subscription_builder.parameters,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
subscription_id=subscription_builder.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription.properties,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
)
|
||||
|
||||
# Delete the builder after successful subscription creation
|
||||
cache_key = cls.encode_cache_key(subscription_builder_id)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def builder_to_api_entity(
|
||||
cls, controller: PluginTriggerProviderController, entity: SubscriptionBuilder
|
||||
) -> SubscriptionBuilderApiEntity:
|
||||
credential_type = CredentialType.of(entity.credential_type or CredentialType.UNAUTHORIZED.value)
|
||||
return SubscriptionBuilderApiEntity(
|
||||
id=entity.id,
|
||||
name=entity.name or "",
|
||||
provider=entity.provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(entity.endpoint_id),
|
||||
parameters=entity.parameters,
|
||||
properties=entity.properties,
|
||||
credential_type=credential_type,
|
||||
credentials=masked_credentials(
|
||||
schemas=controller.get_credentials_schema(credential_type),
|
||||
credentials=entity.credentials,
|
||||
)
|
||||
if controller.get_subscription_constructor()
|
||||
else {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None:
|
||||
"""
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
cache_key = cls.encode_cache_key(endpoint_id)
|
||||
subscription_cache = redis_client.get(cache_key)
|
||||
if subscription_cache:
|
||||
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def append_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
|
||||
"""Append validation request log to Redis."""
|
||||
log = RequestLog(
|
||||
id=str(uuid.uuid4()),
|
||||
endpoint=endpoint_id,
|
||||
request={
|
||||
"method": request.method,
|
||||
"url": request.url,
|
||||
"headers": dict(request.headers),
|
||||
"data": request.get_data(as_text=True),
|
||||
},
|
||||
response={
|
||||
"status_code": response.status_code,
|
||||
"headers": dict(response.headers),
|
||||
"data": response.get_data(as_text=True),
|
||||
},
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
key = f"trigger:subscription:builder:logs:{endpoint_id}"
|
||||
logs = json.loads(redis_client.get(key) or "[]")
|
||||
logs.append(log.model_dump(mode="json"))
|
||||
|
||||
# Keep last N logs
|
||||
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
|
||||
redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_SECONDS__, json.dumps(logs, default=str))
|
||||
|
||||
@classmethod
|
||||
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
|
||||
"""List request logs for validation endpoint."""
|
||||
key = f"trigger:subscription:builder:logs:{endpoint_id}"
|
||||
logs_json = redis_client.get(key)
|
||||
if not logs_json:
|
||||
return []
|
||||
return [RequestLog.model_validate(log) for log in json.loads(logs_json)]
|
||||
|
||||
@classmethod
|
||||
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
"""
|
||||
Process a temporary endpoint request.
|
||||
|
||||
:param endpoint_id: The endpoint identifier
|
||||
:param request: The Flask request object
|
||||
:return: The Flask response object
|
||||
"""
|
||||
# check if validation endpoint exists
|
||||
subscription_builder: SubscriptionBuilder | None = cls.get_subscription_builder(endpoint_id)
|
||||
if not subscription_builder:
|
||||
return None
|
||||
|
||||
# response to validation endpoint
|
||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id)
|
||||
)
|
||||
try:
|
||||
dispatch_response: TriggerDispatchResponse = controller.dispatch(
|
||||
request=request,
|
||||
subscription=subscription_builder.to_subscription(),
|
||||
credentials={},
|
||||
credential_type=CredentialType.UNAUTHORIZED,
|
||||
)
|
||||
response: Response = dispatch_response.response
|
||||
# append the request log
|
||||
cls.append_log(
|
||||
endpoint_id=endpoint_id,
|
||||
request=request,
|
||||
response=response,
|
||||
)
|
||||
return response
|
||||
except Exception:
|
||||
logger.exception("Error during validation endpoint dispatch for endpoint_id=%s", endpoint_id)
|
||||
error_response = Response(status=500, response="An internal error has occurred.")
|
||||
cls.append_log(endpoint_id=endpoint_id, request=request, response=error_response)
|
||||
return error_response
|
||||
|
||||
@classmethod
|
||||
def get_subscription_builder_by_id(cls, subscription_builder_id: str) -> SubscriptionBuilderApiEntity:
|
||||
"""Get a trigger subscription builder API entity."""
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
return cls.builder_to_api_entity(
|
||||
controller=TriggerManager.get_trigger_provider(
|
||||
subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id)
|
||||
),
|
||||
entity=subscription_builder,
|
||||
)
|
||||
@ -0,0 +1,70 @@
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.trigger import AppTrigger, WorkflowPluginTrigger
|
||||
|
||||
|
||||
class TriggerSubscriptionOperatorService:
|
||||
@classmethod
|
||||
def get_subscriber_triggers(
|
||||
cls, tenant_id: str, subscription_id: str, event_name: str
|
||||
) -> list[WorkflowPluginTrigger]:
|
||||
"""
|
||||
Get WorkflowPluginTriggers for a subscription and trigger.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
subscription_id: Subscription ID
|
||||
event_name: Event name
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
subscribers = session.scalars(
|
||||
select(WorkflowPluginTrigger)
|
||||
.join(
|
||||
AppTrigger,
|
||||
and_(
|
||||
AppTrigger.tenant_id == WorkflowPluginTrigger.tenant_id,
|
||||
AppTrigger.app_id == WorkflowPluginTrigger.app_id,
|
||||
AppTrigger.node_id == WorkflowPluginTrigger.node_id,
|
||||
),
|
||||
)
|
||||
.where(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id == subscription_id,
|
||||
WorkflowPluginTrigger.event_name == event_name,
|
||||
AppTrigger.status == AppTriggerStatus.ENABLED,
|
||||
)
|
||||
).all()
|
||||
return list(subscribers)
|
||||
|
||||
@classmethod
|
||||
def delete_plugin_trigger_by_subscription(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
) -> None:
|
||||
"""Delete a plugin trigger by tenant_id and subscription_id within an existing session
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
tenant_id: The tenant ID
|
||||
subscription_id: The subscription ID
|
||||
|
||||
Raises:
|
||||
NotFound: If plugin trigger not found
|
||||
"""
|
||||
# Find plugin trigger using indexed columns
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id == subscription_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
return
|
||||
|
||||
session.delete(plugin_trigger)
|
||||
871
api/services/trigger/webhook_service.py
Normal file
871
api/services/trigger/webhook_service.py
Normal file
@ -0,0 +1,871 @@
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import secrets
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import FileTransferMethod
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory
|
||||
from models.enums import AppTriggerStatus, AppTriggerType
|
||||
from models.model import App
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebhookService:
|
||||
"""Service for handling webhook operations."""
|
||||
|
||||
__WEBHOOK_NODE_CACHE_KEY__ = "webhook_nodes"
|
||||
MAX_WEBHOOK_NODES_PER_WORKFLOW = 5 # Maximum allowed webhook nodes per workflow
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_key(key: str) -> str:
|
||||
"""Normalize external keys (headers/params) to workflow-safe variables."""
|
||||
if not isinstance(key, str):
|
||||
return key
|
||||
return key.replace("-", "_")
|
||||
|
||||
@classmethod
|
||||
def get_webhook_trigger_and_workflow(
|
||||
cls, webhook_id: str, is_debug: bool = False
|
||||
) -> tuple[WorkflowWebhookTrigger, Workflow, Mapping[str, Any]]:
|
||||
"""Get webhook trigger, workflow, and node configuration.
|
||||
|
||||
Args:
|
||||
webhook_id: The webhook ID to look up
|
||||
is_debug: If True, use the draft workflow graph and skip the trigger enabled status check
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- WorkflowWebhookTrigger: The webhook trigger object
|
||||
- Workflow: The associated workflow object
|
||||
- Mapping[str, Any]: The node configuration data
|
||||
|
||||
Raises:
|
||||
ValueError: If webhook not found, app trigger not found, trigger disabled, or workflow not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first()
|
||||
)
|
||||
if not webhook_trigger:
|
||||
raise ValueError(f"Webhook not found: {webhook_id}")
|
||||
|
||||
if is_debug:
|
||||
workflow = (
|
||||
session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
# Check if the corresponding AppTrigger exists
|
||||
app_trigger = (
|
||||
session.query(AppTrigger)
|
||||
.filter(
|
||||
AppTrigger.app_id == webhook_trigger.app_id,
|
||||
AppTrigger.node_id == webhook_trigger.node_id,
|
||||
AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app_trigger:
|
||||
raise ValueError(f"App trigger not found for webhook {webhook_id}")
|
||||
|
||||
# Only check enabled status if not in debug mode
|
||||
if app_trigger.status != AppTriggerStatus.ENABLED:
|
||||
raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}")
|
||||
|
||||
# Get workflow
|
||||
workflow = (
|
||||
session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")
|
||||
|
||||
node_config = workflow.get_node_config_by_id(webhook_trigger.node_id)
|
||||
|
||||
return webhook_trigger, workflow, node_config
|
||||
|
||||
@classmethod
|
||||
def extract_and_validate_webhook_data(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: Mapping[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Extract and validate webhook data in a single unified process.
|
||||
|
||||
Args:
|
||||
webhook_trigger: The webhook trigger object containing metadata
|
||||
node_config: The node configuration containing validation rules
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Processed and validated webhook data with correct types
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails (HTTP method mismatch, missing required fields, type errors)
|
||||
"""
|
||||
# Extract raw data first
|
||||
raw_data = cls.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Validate HTTP metadata (method, content-type)
|
||||
node_data = node_config.get("data", {})
|
||||
validation_result = cls._validate_http_metadata(raw_data, node_data)
|
||||
if not validation_result["valid"]:
|
||||
raise ValueError(validation_result["error"])
|
||||
|
||||
# Process and validate data according to configuration
|
||||
processed_data = cls._process_and_validate_data(raw_data, node_data)
|
||||
|
||||
return processed_data
|
||||
|
||||
@classmethod
|
||||
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]:
|
||||
"""Extract raw data from incoming webhook request without type conversion.
|
||||
|
||||
Args:
|
||||
webhook_trigger: The webhook trigger object for file processing context
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Raw webhook data containing:
|
||||
- method: HTTP method
|
||||
- headers: Request headers
|
||||
- query_params: Query parameters as strings
|
||||
- body: Request body (varies by content type)
|
||||
- files: Uploaded files (if any)
|
||||
"""
|
||||
cls._validate_content_length()
|
||||
|
||||
data = {
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"query_params": dict(request.args),
|
||||
"body": {},
|
||||
"files": {},
|
||||
}
|
||||
|
||||
# Extract and normalize content type
|
||||
content_type = cls._extract_content_type(dict(request.headers))
|
||||
|
||||
# Route to appropriate extractor based on content type
|
||||
extractors = {
|
||||
"application/json": cls._extract_json_body,
|
||||
"application/x-www-form-urlencoded": cls._extract_form_body,
|
||||
"multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger),
|
||||
"application/octet-stream": lambda: cls._extract_octet_stream_body(webhook_trigger),
|
||||
"text/plain": cls._extract_text_body,
|
||||
}
|
||||
|
||||
extractor = extractors.get(content_type)
|
||||
if not extractor:
|
||||
# Default to text/plain for unknown content types
|
||||
logger.warning("Unknown Content-Type: %s, treating as text/plain", content_type)
|
||||
extractor = cls._extract_text_body
|
||||
|
||||
# Extract body and files
|
||||
body_data, files_data = extractor()
|
||||
data["body"] = body_data
|
||||
data["files"] = files_data
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Process and validate webhook data according to node configuration.
|
||||
|
||||
Args:
|
||||
raw_data: Raw webhook data from extraction
|
||||
node_data: Node configuration containing validation and type rules
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Processed data with validated types
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails or required fields are missing
|
||||
"""
|
||||
result = raw_data.copy()
|
||||
|
||||
# Validate and process headers
|
||||
cls._validate_required_headers(raw_data["headers"], node_data.get("headers", []))
|
||||
|
||||
# Process query parameters with type conversion and validation
|
||||
result["query_params"] = cls._process_parameters(
|
||||
raw_data["query_params"], node_data.get("params", []), is_form_data=True
|
||||
)
|
||||
|
||||
# Process body parameters based on content type
|
||||
configured_content_type = node_data.get("content_type", "application/json").lower()
|
||||
result["body"] = cls._process_body_parameters(
|
||||
raw_data["body"], node_data.get("body", []), configured_content_type
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _validate_content_length(cls) -> None:
|
||||
"""Validate request content length against maximum allowed size."""
|
||||
content_length = request.content_length
|
||||
if content_length and content_length > dify_config.WEBHOOK_REQUEST_BODY_MAX_SIZE:
|
||||
raise RequestEntityTooLarge(
|
||||
f"Webhook request too large: {content_length} bytes exceeds maximum allowed size "
|
||||
f"of {dify_config.WEBHOOK_REQUEST_BODY_MAX_SIZE} bytes"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_json_body(cls) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Extract JSON body from request.
|
||||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Parsed JSON content or empty dict if parsing fails
|
||||
- files_data: Empty dict (JSON requests don't contain files)
|
||||
"""
|
||||
try:
|
||||
body = request.get_json() or {}
|
||||
except Exception:
|
||||
logger.warning("Failed to parse JSON body")
|
||||
body = {}
|
||||
return body, {}
|
||||
|
||||
@classmethod
|
||||
def _extract_form_body(cls) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Extract form-urlencoded body from request.
|
||||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Form data as key-value pairs
|
||||
- files_data: Empty dict (form-urlencoded requests don't contain files)
|
||||
"""
|
||||
return dict(request.form), {}
|
||||
|
||||
@classmethod
|
||||
def _extract_multipart_body(cls, webhook_trigger: WorkflowWebhookTrigger) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Extract multipart/form-data body and files from request.
|
||||
|
||||
Args:
|
||||
webhook_trigger: Webhook trigger for file processing context
|
||||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Form data as key-value pairs
|
||||
- files_data: Processed file objects indexed by field name
|
||||
"""
|
||||
body = dict(request.form)
|
||||
files = cls._process_file_uploads(request.files, webhook_trigger) if request.files else {}
|
||||
return body, files
|
||||
|
||||
@classmethod
|
||||
def _extract_octet_stream_body(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Extract binary data as file from request.
|
||||
|
||||
Args:
|
||||
webhook_trigger: Webhook trigger for file processing context
|
||||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Dict with 'raw' key containing file object or None
|
||||
- files_data: Empty dict
|
||||
"""
|
||||
try:
|
||||
file_content = request.get_data()
|
||||
if file_content:
|
||||
file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger)
|
||||
return {"raw": file_obj.to_dict()}, {}
|
||||
else:
|
||||
return {"raw": None}, {}
|
||||
except Exception:
|
||||
logger.exception("Failed to process octet-stream data")
|
||||
return {"raw": None}, {}
|
||||
|
||||
@classmethod
|
||||
def _extract_text_body(cls) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Extract text/plain body from request.
|
||||
|
||||
Returns:
|
||||
tuple: (body_data, files_data) where:
|
||||
- body_data: Dict with 'raw' key containing text content
|
||||
- files_data: Empty dict (text requests don't contain files)
|
||||
"""
|
||||
try:
|
||||
body = {"raw": request.get_data(as_text=True)}
|
||||
except Exception:
|
||||
logger.warning("Failed to extract text body")
|
||||
body = {"raw": ""}
|
||||
return body, {}
|
||||
|
||||
@classmethod
|
||||
def _process_file_uploads(
|
||||
cls, files: Mapping[str, FileStorage], webhook_trigger: WorkflowWebhookTrigger
|
||||
) -> dict[str, Any]:
|
||||
"""Process file uploads using ToolFileManager.
|
||||
|
||||
Args:
|
||||
files: Flask request files object containing uploaded files
|
||||
webhook_trigger: Webhook trigger for tenant and user context
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Processed file objects indexed by field name
|
||||
"""
|
||||
processed_files = {}
|
||||
|
||||
for name, file in files.items():
|
||||
if file and file.filename:
|
||||
try:
|
||||
file_content = file.read()
|
||||
mimetype = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
|
||||
file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger)
|
||||
processed_files[name] = file_obj.to_dict()
|
||||
except Exception:
|
||||
logger.exception("Failed to process file upload '%s'", name)
|
||||
# Continue processing other files
|
||||
|
||||
return processed_files
|
||||
|
||||
@classmethod
|
||||
def _create_file_from_binary(
|
||||
cls, file_content: bytes, mimetype: str, webhook_trigger: WorkflowWebhookTrigger
|
||||
) -> Any:
|
||||
"""Create a file object from binary content using ToolFileManager.
|
||||
|
||||
Args:
|
||||
file_content: The binary content of the file
|
||||
mimetype: The MIME type of the file
|
||||
webhook_trigger: Webhook trigger for tenant and user context
|
||||
|
||||
Returns:
|
||||
Any: A file object built from the binary content
|
||||
"""
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
# Create file using ToolFileManager
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=webhook_trigger.created_by,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_content,
|
||||
mimetype=mimetype,
|
||||
)
|
||||
|
||||
# Build File object
|
||||
mapping = {
|
||||
"tool_file_id": tool_file.id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE.value,
|
||||
}
|
||||
return file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _process_parameters(
|
||||
cls, raw_params: dict[str, str], param_configs: list, is_form_data: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""Process parameters with unified validation and type conversion.
|
||||
|
||||
Args:
|
||||
raw_params: Raw parameter values as strings
|
||||
param_configs: List of parameter configuration dictionaries
|
||||
is_form_data: Whether the parameters are from form data (requiring string conversion)
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Processed parameters with validated types
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing or validation fails
|
||||
"""
|
||||
processed = {}
|
||||
configured_params = {config.get("name", ""): config for config in param_configs}
|
||||
|
||||
# Process configured parameters
|
||||
for param_config in param_configs:
|
||||
name = param_config.get("name", "")
|
||||
param_type = param_config.get("type", SegmentType.STRING)
|
||||
required = param_config.get("required", False)
|
||||
|
||||
# Check required parameters
|
||||
if required and name not in raw_params:
|
||||
raise ValueError(f"Required parameter missing: {name}")
|
||||
|
||||
if name in raw_params:
|
||||
raw_value = raw_params[name]
|
||||
processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data)
|
||||
|
||||
# Include unconfigured parameters as strings
|
||||
for name, value in raw_params.items():
|
||||
if name not in configured_params:
|
||||
processed[name] = value
|
||||
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def _process_body_parameters(
|
||||
cls, raw_body: dict[str, Any], body_configs: list, content_type: str
|
||||
) -> dict[str, Any]:
|
||||
"""Process body parameters based on content type and configuration.
|
||||
|
||||
Args:
|
||||
raw_body: Raw body data from request
|
||||
body_configs: List of body parameter configuration dictionaries
|
||||
content_type: The request content type
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Processed body parameters with validated types
|
||||
|
||||
Raises:
|
||||
ValueError: If required body parameters are missing or validation fails
|
||||
"""
|
||||
if content_type in ["text/plain", "application/octet-stream"]:
|
||||
# For text/plain and octet-stream, validate required content exists
|
||||
if body_configs and any(config.get("required", False) for config in body_configs):
|
||||
raw_content = raw_body.get("raw")
|
||||
if not raw_content:
|
||||
raise ValueError(f"Required body content missing for {content_type} request")
|
||||
return raw_body
|
||||
|
||||
# For structured data (JSON, form-data, etc.)
|
||||
processed = {}
|
||||
configured_params = {config.get("name", ""): config for config in body_configs}
|
||||
|
||||
for body_config in body_configs:
|
||||
name = body_config.get("name", "")
|
||||
param_type = body_config.get("type", SegmentType.STRING)
|
||||
required = body_config.get("required", False)
|
||||
|
||||
# Handle file parameters for multipart data
|
||||
if param_type == SegmentType.FILE and content_type == "multipart/form-data":
|
||||
# File validation is handled separately in extract phase
|
||||
continue
|
||||
|
||||
# Check required parameters
|
||||
if required and name not in raw_body:
|
||||
raise ValueError(f"Required body parameter missing: {name}")
|
||||
|
||||
if name in raw_body:
|
||||
raw_value = raw_body[name]
|
||||
is_form_data = content_type in ["application/x-www-form-urlencoded", "multipart/form-data"]
|
||||
processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data)
|
||||
|
||||
# Include unconfigured parameters
|
||||
for name, value in raw_body.items():
|
||||
if name not in configured_params:
|
||||
processed[name] = value
|
||||
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def _validate_and_convert_value(cls, param_name: str, value: Any, param_type: str, is_form_data: bool) -> Any:
|
||||
"""Unified validation and type conversion for parameter values.
|
||||
|
||||
Args:
|
||||
param_name: Name of the parameter for error reporting
|
||||
value: The value to validate and convert
|
||||
param_type: The expected parameter type (SegmentType)
|
||||
is_form_data: Whether the value is from form data (requiring string conversion)
|
||||
|
||||
Returns:
|
||||
Any: The validated and converted value
|
||||
|
||||
Raises:
|
||||
ValueError: If validation or conversion fails
|
||||
"""
|
||||
try:
|
||||
if is_form_data:
|
||||
# Form data comes as strings and needs conversion
|
||||
return cls._convert_form_value(param_name, value, param_type)
|
||||
else:
|
||||
# JSON data should already be in correct types, just validate
|
||||
return cls._validate_json_value(param_name, value, param_type)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def _convert_form_value(cls, param_name: str, value: str, param_type: str) -> Any:
|
||||
"""Convert form data string values to specified types.
|
||||
|
||||
Args:
|
||||
param_name: Name of the parameter for error reporting
|
||||
value: The string value to convert
|
||||
param_type: The target type to convert to (SegmentType)
|
||||
|
||||
Returns:
|
||||
Any: The converted value in the appropriate type
|
||||
|
||||
Raises:
|
||||
ValueError: If the value cannot be converted to the specified type
|
||||
"""
|
||||
if param_type == SegmentType.STRING:
|
||||
return value
|
||||
elif param_type == SegmentType.NUMBER:
|
||||
if not cls._can_convert_to_number(value):
|
||||
raise ValueError(f"Cannot convert '{value}' to number")
|
||||
numeric_value = float(value)
|
||||
return int(numeric_value) if numeric_value.is_integer() else numeric_value
|
||||
elif param_type == SegmentType.BOOLEAN:
|
||||
lower_value = value.lower()
|
||||
bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False}
|
||||
if lower_value not in bool_map:
|
||||
raise ValueError(f"Cannot convert '{value}' to boolean")
|
||||
return bool_map[lower_value]
|
||||
else:
|
||||
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
|
||||
|
||||
@classmethod
|
||||
def _validate_json_value(cls, param_name: str, value: Any, param_type: str) -> Any:
|
||||
"""Validate JSON values against expected types.
|
||||
|
||||
Args:
|
||||
param_name: Name of the parameter for error reporting
|
||||
value: The value to validate
|
||||
param_type: The expected parameter type (SegmentType)
|
||||
|
||||
Returns:
|
||||
Any: The validated value (unchanged if valid)
|
||||
|
||||
Raises:
|
||||
ValueError: If the value type doesn't match the expected type
|
||||
"""
|
||||
type_validators = {
|
||||
SegmentType.STRING: (lambda v: isinstance(v, str), "string"),
|
||||
SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"),
|
||||
SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"),
|
||||
SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"),
|
||||
SegmentType.ARRAY_STRING: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v),
|
||||
"array of strings",
|
||||
),
|
||||
SegmentType.ARRAY_NUMBER: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v),
|
||||
"array of numbers",
|
||||
),
|
||||
SegmentType.ARRAY_BOOLEAN: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v),
|
||||
"array of booleans",
|
||||
),
|
||||
SegmentType.ARRAY_OBJECT: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v),
|
||||
"array of objects",
|
||||
),
|
||||
}
|
||||
|
||||
validator_info = type_validators.get(SegmentType(param_type))
|
||||
if not validator_info:
|
||||
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
|
||||
return value
|
||||
|
||||
validator, expected_type = validator_info
|
||||
if not validator(value):
|
||||
actual_type = type(value).__name__
|
||||
raise ValueError(f"Expected {expected_type}, got {actual_type}")
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: list) -> None:
|
||||
"""Validate required headers are present.
|
||||
|
||||
Args:
|
||||
headers: Request headers dictionary
|
||||
header_configs: List of header configuration dictionaries
|
||||
|
||||
Raises:
|
||||
ValueError: If required headers are missing
|
||||
"""
|
||||
headers_lower = {k.lower(): v for k, v in headers.items()}
|
||||
headers_sanitized = {cls._sanitize_key(k).lower(): v for k, v in headers.items()}
|
||||
for header_config in header_configs:
|
||||
if header_config.get("required", False):
|
||||
header_name = header_config.get("name", "")
|
||||
sanitized_name = cls._sanitize_key(header_name).lower()
|
||||
if header_name.lower() not in headers_lower and sanitized_name not in headers_sanitized:
|
||||
raise ValueError(f"Required header missing: {header_name}")
|
||||
|
||||
@classmethod
|
||||
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate HTTP method and content-type.
|
||||
|
||||
Args:
|
||||
webhook_data: Extracted webhook data containing method and headers
|
||||
node_data: Node configuration containing expected method and content-type
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Validation result with 'valid' key and optional 'error' key
|
||||
"""
|
||||
# Validate HTTP method
|
||||
configured_method = node_data.get("method", "get").upper()
|
||||
request_method = webhook_data["method"].upper()
|
||||
if configured_method != request_method:
|
||||
return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}")
|
||||
|
||||
# Validate Content-type
|
||||
configured_content_type = node_data.get("content_type", "application/json").lower()
|
||||
request_content_type = cls._extract_content_type(webhook_data["headers"])
|
||||
|
||||
if configured_content_type != request_content_type:
|
||||
return cls._validation_error(
|
||||
f"Content-type mismatch. Expected {configured_content_type}, got {request_content_type}"
|
||||
)
|
||||
|
||||
return {"valid": True}
|
||||
|
||||
@classmethod
|
||||
def _extract_content_type(cls, headers: dict[str, Any]) -> str:
|
||||
"""Extract and normalize content-type from headers.
|
||||
|
||||
Args:
|
||||
headers: Request headers dictionary
|
||||
|
||||
Returns:
|
||||
str: Normalized content-type (main type without parameters)
|
||||
"""
|
||||
content_type = headers.get("Content-Type", "").lower()
|
||||
if not content_type:
|
||||
content_type = headers.get("content-type", "application/json").lower()
|
||||
# Extract the main content type (ignore parameters like boundary)
|
||||
return content_type.split(";")[0].strip()
|
||||
|
||||
@classmethod
|
||||
def _validation_error(cls, error_message: str) -> dict[str, Any]:
|
||||
"""Create a standard validation error response.
|
||||
|
||||
Args:
|
||||
error_message: The error message to include
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Validation error response with 'valid' and 'error' keys
|
||||
"""
|
||||
return {"valid": False, "error": error_message}
|
||||
|
||||
@classmethod
|
||||
def _can_convert_to_number(cls, value: str) -> bool:
|
||||
"""Check if a string can be converted to a number."""
|
||||
try:
|
||||
float(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def build_workflow_inputs(cls, webhook_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Construct workflow inputs payload from webhook data.
|
||||
|
||||
Args:
|
||||
webhook_data: Processed webhook data containing headers, query params, and body
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Workflow inputs formatted for execution
|
||||
"""
|
||||
return {
|
||||
"webhook_data": webhook_data,
|
||||
"webhook_headers": webhook_data.get("headers", {}),
|
||||
"webhook_query_params": webhook_data.get("query_params", {}),
|
||||
"webhook_body": webhook_data.get("body", {}),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def trigger_workflow_execution(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
|
||||
) -> None:
|
||||
"""Trigger workflow execution via AsyncWorkflowService.
|
||||
|
||||
Args:
|
||||
webhook_trigger: The webhook trigger object
|
||||
webhook_data: Processed webhook data for workflow inputs
|
||||
workflow: The workflow to execute
|
||||
|
||||
Raises:
|
||||
ValueError: If tenant owner is not found
|
||||
Exception: If workflow execution fails
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Prepare inputs for the webhook node
|
||||
# The webhook node expects webhook_data in the inputs
|
||||
workflow_inputs = cls.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Create trigger data
|
||||
trigger_data = WebhookTriggerData(
|
||||
app_id=webhook_trigger.app_id,
|
||||
workflow_id=workflow.id,
|
||||
root_node_id=webhook_trigger.node_id, # Start from the webhook node
|
||||
inputs=workflow_inputs,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
)
|
||||
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
type=InvokeFrom.TRIGGER,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
app_id=webhook_trigger.app_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Trigger workflow execution asynchronously
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def generate_webhook_response(cls, node_config: Mapping[str, Any]) -> tuple[dict[str, Any], int]:
|
||||
"""Generate HTTP response based on node configuration.
|
||||
|
||||
Args:
|
||||
node_config: Node configuration containing response settings
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, Any], int]: Response data and HTTP status code
|
||||
"""
|
||||
node_data = node_config.get("data", {})
|
||||
|
||||
# Get configured status code and response body
|
||||
status_code = node_data.get("status_code", 200)
|
||||
response_body = node_data.get("response_body", "")
|
||||
|
||||
# Parse response body as JSON if it's valid JSON, otherwise return as text
|
||||
try:
|
||||
if response_body:
|
||||
try:
|
||||
response_data = (
|
||||
json.loads(response_body)
|
||||
if response_body.strip().startswith(("{", "["))
|
||||
else {"message": response_body}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
response_data = {"message": response_body}
|
||||
else:
|
||||
response_data = {"status": "success", "message": "Webhook processed successfully"}
|
||||
except:
|
||||
response_data = {"message": response_body or "Webhook processed successfully"}
|
||||
|
||||
return response_data, status_code
|
||||
|
||||
@classmethod
|
||||
def sync_webhook_relationships(cls, app: App, workflow: Workflow):
|
||||
"""
|
||||
Sync webhook relationships in DB.
|
||||
|
||||
1. Check if the workflow has any webhook trigger nodes
|
||||
2. Fetch the nodes from DB, see if there were any webhook records already
|
||||
3. Diff the nodes and the webhook records, create/update/delete the webhook records as needed
|
||||
|
||||
Approach:
|
||||
Frequent DB operations may cause performance issues, using Redis to cache it instead.
|
||||
If any record exists, cache it.
|
||||
|
||||
Limits:
|
||||
- Maximum 5 webhook nodes per workflow
|
||||
"""
|
||||
|
||||
class Cache(BaseModel):
|
||||
"""
|
||||
Cache model for webhook nodes
|
||||
"""
|
||||
|
||||
record_id: str
|
||||
node_id: str
|
||||
webhook_id: str
|
||||
|
||||
nodes_id_in_graph = [node_id for node_id, _ in workflow.walk_nodes(NodeType.TRIGGER_WEBHOOK)]
|
||||
|
||||
# Check webhook node limit
|
||||
if len(nodes_id_in_graph) > cls.MAX_WEBHOOK_NODES_PER_WORKFLOW:
|
||||
raise ValueError(
|
||||
f"Workflow exceeds maximum webhook node limit. "
|
||||
f"Found {len(nodes_id_in_graph)} webhook nodes, maximum allowed is {cls.MAX_WEBHOOK_NODES_PER_WORKFLOW}"
|
||||
)
|
||||
|
||||
not_found_in_cache: list[str] = []
|
||||
for node_id in nodes_id_in_graph:
|
||||
# firstly check if the node exists in cache
|
||||
if not redis_client.get(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{node_id}"):
|
||||
not_found_in_cache.append(node_id)
|
||||
continue
|
||||
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
# lock the concurrent webhook trigger creation
|
||||
redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
|
||||
# fetch the non-cached nodes from DB
|
||||
all_records = session.scalars(
|
||||
select(WorkflowWebhookTrigger).where(
|
||||
WorkflowWebhookTrigger.app_id == app.id,
|
||||
WorkflowWebhookTrigger.tenant_id == app.tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
nodes_id_in_db = {node.node_id: node for node in all_records}
|
||||
|
||||
# get the nodes not found both in cache and DB
|
||||
nodes_not_found = [node_id for node_id in not_found_in_cache if node_id not in nodes_id_in_db]
|
||||
|
||||
# create new webhook records
|
||||
for node_id in nodes_not_found:
|
||||
webhook_record = WorkflowWebhookTrigger(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
node_id=node_id,
|
||||
webhook_id=cls.generate_webhook_id(),
|
||||
created_by=app.created_by,
|
||||
)
|
||||
session.add(webhook_record)
|
||||
session.flush()
|
||||
cache = Cache(record_id=webhook_record.id, node_id=node_id, webhook_id=webhook_record.webhook_id)
|
||||
redis_client.set(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{node_id}", cache.model_dump_json(), ex=60 * 60)
|
||||
session.commit()
|
||||
|
||||
# delete the nodes not found in the graph
|
||||
for node_id in nodes_id_in_db:
|
||||
if node_id not in nodes_id_in_graph:
|
||||
session.delete(nodes_id_in_db[node_id])
|
||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{node_id}")
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to sync webhook relationships for app %s", app.id)
|
||||
raise
|
||||
finally:
|
||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
|
||||
|
||||
@classmethod
|
||||
def generate_webhook_id(cls) -> str:
|
||||
"""
|
||||
Generate unique 24-character webhook ID
|
||||
|
||||
Deduplication is not needed, DB already has unique constraint on webhook_id.
|
||||
"""
|
||||
# Generate 24-character random string
|
||||
return secrets.token_urlsafe(18)[:24] # token_urlsafe gives base64url, take first 24 chars
|
||||
165
api/services/workflow/entities.py
Normal file
165
api/services/workflow/entities.py
Normal file
@ -0,0 +1,165 @@
|
||||
"""
|
||||
Pydantic models for async workflow trigger system.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from models.enums import AppTriggerType, WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
class AsyncTriggerStatus(StrEnum):
|
||||
"""Async trigger execution status"""
|
||||
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
class TriggerMetadata(BaseModel):
|
||||
"""Trigger metadata"""
|
||||
|
||||
type: AppTriggerType = Field(default=AppTriggerType.UNKNOWN)
|
||||
|
||||
|
||||
class TriggerData(BaseModel):
|
||||
"""Base trigger data model for async workflow execution"""
|
||||
|
||||
app_id: str
|
||||
tenant_id: str
|
||||
workflow_id: str | None = None
|
||||
root_node_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
files: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
trigger_type: AppTriggerType
|
||||
trigger_from: WorkflowRunTriggeredFrom
|
||||
trigger_metadata: TriggerMetadata | None = None
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class WebhookTriggerData(TriggerData):
|
||||
"""Webhook-specific trigger data"""
|
||||
|
||||
trigger_type: AppTriggerType = AppTriggerType.TRIGGER_WEBHOOK
|
||||
trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.WEBHOOK
|
||||
|
||||
|
||||
class ScheduleTriggerData(TriggerData):
|
||||
"""Schedule-specific trigger data"""
|
||||
|
||||
trigger_type: AppTriggerType = AppTriggerType.TRIGGER_SCHEDULE
|
||||
trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.SCHEDULE
|
||||
|
||||
|
||||
class PluginTriggerMetadata(TriggerMetadata):
|
||||
"""Plugin trigger metadata"""
|
||||
|
||||
type: AppTriggerType = AppTriggerType.TRIGGER_PLUGIN
|
||||
|
||||
endpoint_id: str
|
||||
plugin_unique_identifier: str
|
||||
provider_id: str
|
||||
event_name: str
|
||||
icon_filename: str
|
||||
icon_dark_filename: str
|
||||
|
||||
|
||||
class PluginTriggerData(TriggerData):
|
||||
"""Plugin webhook trigger data"""
|
||||
|
||||
trigger_type: AppTriggerType = AppTriggerType.TRIGGER_PLUGIN
|
||||
trigger_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.PLUGIN
|
||||
plugin_id: str
|
||||
endpoint_id: str
|
||||
|
||||
|
||||
class PluginTriggerDispatchData(BaseModel):
|
||||
"""Plugin trigger dispatch data for Celery tasks"""
|
||||
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
endpoint_id: str
|
||||
provider_id: str
|
||||
subscription_id: str
|
||||
timestamp: int
|
||||
events: list[str]
|
||||
request_id: str
|
||||
|
||||
|
||||
class WorkflowTaskData(BaseModel):
|
||||
"""Lightweight data structure for Celery workflow tasks"""
|
||||
|
||||
workflow_trigger_log_id: str # Primary tracking ID - all other data can be fetched from DB
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class AsyncTriggerExecutionResult(BaseModel):
|
||||
"""Result from async trigger-based workflow execution"""
|
||||
|
||||
execution_id: str
|
||||
status: AsyncTriggerStatus
|
||||
result: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class AsyncTriggerResponse(BaseModel):
|
||||
"""Response from triggering an async workflow"""
|
||||
|
||||
workflow_trigger_log_id: str
|
||||
task_id: str
|
||||
status: str
|
||||
queue: str
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class TriggerLogResponse(BaseModel):
|
||||
"""Response model for trigger log data"""
|
||||
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
trigger_type: WorkflowRunTriggeredFrom
|
||||
status: str
|
||||
queue_name: str
|
||||
retry_count: int
|
||||
celery_task_id: str | None = None
|
||||
workflow_run_id: str | None = None
|
||||
error: str | None = None
|
||||
outputs: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
created_at: str | None = None
|
||||
triggered_at: str | None = None
|
||||
finished_at: str | None = None
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class WorkflowScheduleCFSPlanEntity(BaseModel):
|
||||
"""
|
||||
CFS plan entity.
|
||||
Ensure each workflow run inside Dify is associated with a CFS(Completely Fair Scheduler) plan.
|
||||
|
||||
"""
|
||||
|
||||
class Strategy(StrEnum):
|
||||
"""
|
||||
CFS plan strategy.
|
||||
"""
|
||||
|
||||
TimeSlice = "time-slice" # time-slice based plan
|
||||
Nop = "nop" # no plan, just run the workflow
|
||||
|
||||
schedule_strategy: Strategy
|
||||
granularity: int = Field(default=-1) # -1 means infinite
|
||||
151
api/services/workflow/queue_dispatcher.py
Normal file
151
api/services/workflow/queue_dispatcher.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""
|
||||
Queue dispatcher system for async workflow execution.
|
||||
|
||||
Implements an ABC-based pattern for handling different subscription tiers
|
||||
with appropriate queue routing and rate limiting.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.billing_service import BillingService
|
||||
from services.workflow.rate_limiter import TenantDailyRateLimiter
|
||||
|
||||
|
||||
class QueuePriority(StrEnum):
|
||||
"""Queue priorities for different subscription tiers"""
|
||||
|
||||
PROFESSIONAL = "workflow_professional" # Highest priority
|
||||
TEAM = "workflow_team"
|
||||
SANDBOX = "workflow_sandbox" # Free tier
|
||||
|
||||
|
||||
class BaseQueueDispatcher(ABC):
|
||||
"""Abstract base class for queue dispatchers"""
|
||||
|
||||
def __init__(self):
|
||||
self.rate_limiter = TenantDailyRateLimiter(redis_client)
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_name(self) -> str:
|
||||
"""Get the queue name for this dispatcher"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_daily_limit(self) -> int:
|
||||
"""Get daily execution limit"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_priority(self) -> int:
|
||||
"""Get task priority level"""
|
||||
pass
|
||||
|
||||
def check_daily_quota(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
Check if tenant has remaining daily quota
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
True if quota available, False otherwise
|
||||
"""
|
||||
# Check without consuming
|
||||
remaining = self.rate_limiter.get_remaining_quota(tenant_id=tenant_id, max_daily_limit=self.get_daily_limit())
|
||||
return remaining > 0
|
||||
|
||||
def consume_quota(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
Consume one execution from daily quota
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
True if quota consumed successfully, False if limit reached
|
||||
"""
|
||||
return self.rate_limiter.check_and_consume(tenant_id=tenant_id, max_daily_limit=self.get_daily_limit())
|
||||
|
||||
|
||||
class ProfessionalQueueDispatcher(BaseQueueDispatcher):
|
||||
"""Dispatcher for professional tier"""
|
||||
|
||||
def get_queue_name(self) -> str:
|
||||
return QueuePriority.PROFESSIONAL
|
||||
|
||||
def get_daily_limit(self) -> int:
|
||||
return int(1e9)
|
||||
|
||||
def get_priority(self) -> int:
|
||||
return 100
|
||||
|
||||
|
||||
class TeamQueueDispatcher(BaseQueueDispatcher):
|
||||
"""Dispatcher for team tier"""
|
||||
|
||||
def get_queue_name(self) -> str:
|
||||
return QueuePriority.TEAM
|
||||
|
||||
def get_daily_limit(self) -> int:
|
||||
return int(1e9)
|
||||
|
||||
def get_priority(self) -> int:
|
||||
return 50
|
||||
|
||||
|
||||
class SandboxQueueDispatcher(BaseQueueDispatcher):
|
||||
"""Dispatcher for free/sandbox tier"""
|
||||
|
||||
def get_queue_name(self) -> str:
|
||||
return QueuePriority.SANDBOX
|
||||
|
||||
def get_daily_limit(self) -> int:
|
||||
return dify_config.APP_DAILY_RATE_LIMIT
|
||||
|
||||
def get_priority(self) -> int:
|
||||
return 10
|
||||
|
||||
|
||||
class QueueDispatcherManager:
|
||||
"""Factory for creating appropriate dispatcher based on tenant subscription"""
|
||||
|
||||
# Mapping of billing plans to dispatchers
|
||||
PLAN_DISPATCHER_MAP = {
|
||||
"professional": ProfessionalQueueDispatcher,
|
||||
"team": TeamQueueDispatcher,
|
||||
"sandbox": SandboxQueueDispatcher,
|
||||
# Add new tiers here as they're created
|
||||
# For any unknown plan, default to sandbox
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_dispatcher(cls, tenant_id: str) -> BaseQueueDispatcher:
|
||||
"""
|
||||
Get dispatcher based on tenant's subscription plan
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Appropriate queue dispatcher instance
|
||||
"""
|
||||
if dify_config.BILLING_ENABLED:
|
||||
try:
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
plan = billing_info.get("subscription", {}).get("plan", "sandbox")
|
||||
except Exception:
|
||||
# If billing service fails, default to sandbox
|
||||
plan = "sandbox"
|
||||
else:
|
||||
# If billing is disabled, use team tier as default
|
||||
plan = "team"
|
||||
|
||||
dispatcher_class = cls.PLAN_DISPATCHER_MAP.get(
|
||||
plan,
|
||||
SandboxQueueDispatcher, # Default to sandbox for unknown plans
|
||||
)
|
||||
|
||||
return dispatcher_class() # type: ignore
|
||||
183
api/services/workflow/rate_limiter.py
Normal file
183
api/services/workflow/rate_limiter.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""
|
||||
Day-based rate limiter for workflow executions.
|
||||
|
||||
Implements UTC-based daily quotas that reset at midnight UTC for consistent rate limiting.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, time, timedelta
|
||||
from typing import Union
|
||||
|
||||
import pytz
|
||||
from redis import Redis
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
class TenantDailyRateLimiter:
|
||||
"""
|
||||
Day-based rate limiter that resets at midnight UTC
|
||||
|
||||
This class provides Redis-based rate limiting with the following features:
|
||||
- Daily quotas that reset at midnight UTC for consistency
|
||||
- Atomic check-and-consume operations
|
||||
- Automatic cleanup of stale counters
|
||||
- Timezone-aware error messages for better UX
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Union[Redis, RedisClientWrapper]):
|
||||
self.redis = redis_client
|
||||
|
||||
def get_tenant_owner_timezone(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get timezone of tenant owner
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Timezone string (e.g., 'America/New_York', 'UTC')
|
||||
"""
|
||||
# Query to get tenant owner's timezone using scalar and select
|
||||
owner = db.session.scalar(
|
||||
select(Account)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == TenantAccountRole.OWNER)
|
||||
)
|
||||
|
||||
if not owner:
|
||||
return "UTC"
|
||||
|
||||
return owner.timezone or "UTC"
|
||||
|
||||
def _get_day_key(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get Redis key for current UTC day
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Redis key for the current UTC day
|
||||
"""
|
||||
utc_now = datetime.now(UTC)
|
||||
date_str = utc_now.strftime("%Y-%m-%d")
|
||||
return f"workflow:daily_limit:{tenant_id}:{date_str}"
|
||||
|
||||
def _get_ttl_seconds(self) -> int:
|
||||
"""
|
||||
Calculate seconds until UTC midnight
|
||||
|
||||
Returns:
|
||||
Number of seconds until UTC midnight
|
||||
"""
|
||||
utc_now = datetime.now(UTC)
|
||||
|
||||
# Get next midnight in UTC
|
||||
next_midnight = datetime.combine(utc_now.date() + timedelta(days=1), time.min)
|
||||
next_midnight = next_midnight.replace(tzinfo=UTC)
|
||||
|
||||
return int((next_midnight - utc_now).total_seconds())
|
||||
|
||||
def check_and_consume(self, tenant_id: str, max_daily_limit: int) -> bool:
|
||||
"""
|
||||
Check if quota available and consume one execution
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
max_daily_limit: Maximum daily limit
|
||||
|
||||
Returns:
|
||||
True if quota consumed successfully, False if limit reached
|
||||
"""
|
||||
key = self._get_day_key(tenant_id)
|
||||
ttl = self._get_ttl_seconds()
|
||||
|
||||
# Check current usage
|
||||
current = self.redis.get(key)
|
||||
|
||||
if current is None:
|
||||
# First execution of the day - set to 1
|
||||
self.redis.setex(key, ttl, 1)
|
||||
return True
|
||||
|
||||
current_count = int(current)
|
||||
if current_count < max_daily_limit:
|
||||
# Within limit, increment
|
||||
new_count = self.redis.incr(key)
|
||||
# Update TTL
|
||||
self.redis.expire(key, ttl)
|
||||
|
||||
# Double-check in case of race condition
|
||||
if new_count <= max_daily_limit:
|
||||
return True
|
||||
else:
|
||||
# Race condition occurred, decrement back
|
||||
self.redis.decr(key)
|
||||
return False
|
||||
else:
|
||||
# Limit exceeded
|
||||
return False
|
||||
|
||||
def get_remaining_quota(self, tenant_id: str, max_daily_limit: int) -> int:
|
||||
"""
|
||||
Get remaining quota for the day
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
max_daily_limit: Maximum daily limit
|
||||
|
||||
Returns:
|
||||
Number of remaining executions for the day
|
||||
"""
|
||||
key = self._get_day_key(tenant_id)
|
||||
used = int(self.redis.get(key) or 0)
|
||||
return max(0, max_daily_limit - used)
|
||||
|
||||
def get_current_usage(self, tenant_id: str) -> int:
|
||||
"""
|
||||
Get current usage for the day
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Number of executions used today
|
||||
"""
|
||||
key = self._get_day_key(tenant_id)
|
||||
return int(self.redis.get(key) or 0)
|
||||
|
||||
def reset_quota(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
Reset quota for testing purposes
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False if key didn't exist
|
||||
"""
|
||||
key = self._get_day_key(tenant_id)
|
||||
return bool(self.redis.delete(key))
|
||||
|
||||
def get_quota_reset_time(self, tenant_id: str, timezone_str: str) -> datetime:
|
||||
"""
|
||||
Get the time when quota will reset (next UTC midnight in tenant's timezone)
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
timezone_str: Tenant's timezone for display purposes
|
||||
|
||||
Returns:
|
||||
Datetime when quota resets (next UTC midnight in tenant's timezone)
|
||||
"""
|
||||
tz = pytz.timezone(timezone_str)
|
||||
utc_now = datetime.now(UTC)
|
||||
|
||||
# Get next midnight in UTC, then convert to tenant's timezone
|
||||
next_utc_midnight = datetime.combine(utc_now.date() + timedelta(days=1), time.min)
|
||||
next_utc_midnight = pytz.UTC.localize(next_utc_midnight)
|
||||
|
||||
return next_utc_midnight.astimezone(tz)
|
||||
34
api/services/workflow/scheduler.py
Normal file
34
api/services/workflow/scheduler.py
Normal file
@ -0,0 +1,34 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum
|
||||
|
||||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
|
||||
|
||||
class SchedulerCommand(StrEnum):
|
||||
"""
|
||||
Scheduler command.
|
||||
"""
|
||||
|
||||
RESOURCE_LIMIT_REACHED = "resource_limit_reached"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class CFSPlanScheduler(ABC):
|
||||
"""
|
||||
CFS plan scheduler.
|
||||
"""
|
||||
|
||||
def __init__(self, plan: WorkflowScheduleCFSPlanEntity):
|
||||
"""
|
||||
Initialize the CFS plan scheduler.
|
||||
|
||||
Args:
|
||||
plan: The CFS plan.
|
||||
"""
|
||||
self.plan = plan
|
||||
|
||||
@abstractmethod
|
||||
def can_schedule(self) -> SchedulerCommand:
|
||||
"""
|
||||
Whether a workflow run can be scheduled.
|
||||
"""
|
||||
@ -1,12 +1,37 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import AppTriggerType, CreatorUserRole
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.workflow.entities import TriggerMetadata
|
||||
|
||||
|
||||
# Since the workflow_app_log table has exceeded 100 million records, we use an additional details field to extend it
|
||||
class LogView:
|
||||
"""Lightweight wrapper for WorkflowAppLog with computed details.
|
||||
|
||||
- Exposes `details_` for marshalling to `details` in API response
|
||||
- Proxies all other attributes to the underlying `WorkflowAppLog`
|
||||
"""
|
||||
|
||||
def __init__(self, log: WorkflowAppLog, details: dict | None):
|
||||
self.log = log
|
||||
self.details_ = details
|
||||
|
||||
@property
|
||||
def details(self) -> dict | None:
|
||||
return self.details_
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.log, name)
|
||||
|
||||
|
||||
class WorkflowAppService:
|
||||
@ -21,6 +46,7 @@ class WorkflowAppService:
|
||||
created_at_after: datetime | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
detail: bool = False,
|
||||
created_by_end_user_session_id: str | None = None,
|
||||
created_by_account: str | None = None,
|
||||
):
|
||||
@ -34,6 +60,7 @@ class WorkflowAppService:
|
||||
:param created_at_after: filter logs created after this timestamp
|
||||
:param page: page number
|
||||
:param limit: items per page
|
||||
:param detail: whether to return detailed logs
|
||||
:param created_by_end_user_session_id: filter by end user session id
|
||||
:param created_by_account: filter by account email
|
||||
:return: Pagination object
|
||||
@ -43,8 +70,20 @@ class WorkflowAppService:
|
||||
WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id
|
||||
)
|
||||
|
||||
if detail:
|
||||
# Simple left join by workflow_run_id to fetch trigger_metadata
|
||||
stmt = stmt.outerjoin(
|
||||
WorkflowTriggerLog,
|
||||
and_(
|
||||
WorkflowTriggerLog.tenant_id == app_model.tenant_id,
|
||||
WorkflowTriggerLog.app_id == app_model.id,
|
||||
WorkflowTriggerLog.workflow_run_id == WorkflowAppLog.workflow_run_id,
|
||||
),
|
||||
).add_columns(WorkflowTriggerLog.trigger_metadata)
|
||||
|
||||
if keyword or status:
|
||||
stmt = stmt.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
||||
# Join to workflow run for filtering when needed.
|
||||
|
||||
if keyword:
|
||||
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
|
||||
@ -108,9 +147,17 @@ class WorkflowAppService:
|
||||
# Apply pagination limits
|
||||
offset_stmt = stmt.offset((page - 1) * limit).limit(limit)
|
||||
|
||||
# Execute query and get items
|
||||
items = list(session.scalars(offset_stmt).all())
|
||||
# wrapper moved to module scope as `LogView`
|
||||
|
||||
# Execute query and get items
|
||||
if detail:
|
||||
rows = session.execute(offset_stmt).all()
|
||||
items = [
|
||||
LogView(log, {"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, meta_val)})
|
||||
for log, meta_val in rows
|
||||
]
|
||||
else:
|
||||
items = [LogView(log, None) for log in session.scalars(offset_stmt).all()]
|
||||
return {
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
@ -119,6 +166,31 @@ class WorkflowAppService:
|
||||
"data": items,
|
||||
}
|
||||
|
||||
def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]:
|
||||
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
|
||||
if not metadata:
|
||||
return {}
|
||||
trigger_metadata = TriggerMetadata.model_validate(metadata)
|
||||
if trigger_metadata.type == AppTriggerType.TRIGGER_PLUGIN:
|
||||
icon = metadata.get("icon_filename")
|
||||
icon_dark = metadata.get("icon_dark_filename")
|
||||
metadata["icon"] = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=icon) if icon else None
|
||||
metadata["icon_dark"] = (
|
||||
PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=icon_dark) if icon_dark else None
|
||||
)
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _safe_json_loads(val):
|
||||
if not val:
|
||||
return None
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except Exception:
|
||||
return None
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def _safe_parse_uuid(value: str):
|
||||
# fast check
|
||||
|
||||
@ -1026,7 +1026,7 @@ class DraftVariableSaver:
|
||||
return
|
||||
if self._node_type == NodeType.VARIABLE_ASSIGNER:
|
||||
draft_vars = self._build_from_variable_assigner_mapping(process_data=process_data)
|
||||
elif self._node_type == NodeType.START:
|
||||
elif self._node_type == NodeType.START or self._node_type.is_trigger_node:
|
||||
draft_vars = self._build_variables_from_start_mapping(outputs)
|
||||
else:
|
||||
draft_vars = self._build_variables_from_mapping(outputs)
|
||||
|
||||
@ -10,20 +10,22 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool, WorkflowNodeExecution
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
@ -32,6 +34,7 @@ from extensions.ext_storage import storage
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
|
||||
@ -211,6 +214,9 @@ class WorkflowService:
|
||||
# validate features structure
|
||||
self.validate_features_structure(app_model=app_model, features=features)
|
||||
|
||||
# validate graph structure
|
||||
self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=graph)
|
||||
|
||||
# create draft workflow if not found
|
||||
if not workflow:
|
||||
workflow = Workflow(
|
||||
@ -267,6 +273,9 @@ class WorkflowService:
|
||||
if FeatureService.get_system_features().plugin_manager.enabled:
|
||||
self._validate_workflow_credentials(draft_workflow)
|
||||
|
||||
# validate graph structure
|
||||
self.validate_graph_structure(user_id=account.id, app_model=app_model, graph=draft_workflow.graph_dict)
|
||||
|
||||
# create new workflow
|
||||
workflow = Workflow.new(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@ -623,7 +632,7 @@ class WorkflowService:
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id)
|
||||
node_type = Workflow.get_node_type_from_node_config(node_config)
|
||||
node_data = node_config.get("data", {})
|
||||
if node_type == NodeType.START:
|
||||
if node_type.is_start_node:
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
conversation_id = draft_var_srv.get_or_create_conversation(
|
||||
@ -631,10 +640,11 @@ class WorkflowService:
|
||||
app=app_model,
|
||||
workflow=draft_workflow,
|
||||
)
|
||||
start_data = StartNodeData.model_validate(node_data)
|
||||
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
|
||||
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
|
||||
)
|
||||
if node_type is NodeType.START:
|
||||
start_data = StartNodeData.model_validate(node_data)
|
||||
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
|
||||
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
|
||||
)
|
||||
# init variable pool
|
||||
variable_pool = _setup_variable_pool(
|
||||
query=query,
|
||||
@ -895,6 +905,43 @@ class WorkflowService:
|
||||
|
||||
return new_app
|
||||
|
||||
def validate_graph_structure(self, user_id: str, app_model: App, graph: Mapping[str, Any]):
|
||||
"""
|
||||
Validate workflow graph structure by instantiating the Graph object.
|
||||
|
||||
This leverages the built-in graph validators (including trigger/UserInput exclusivity)
|
||||
and raises any structural errors before persisting the workflow.
|
||||
"""
|
||||
node_configs = graph.get("nodes", [])
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
|
||||
# is empty graph
|
||||
if not node_configs:
|
||||
return
|
||||
|
||||
workflow_id = app_model.workflow_id or "UNKNOWN"
|
||||
Graph.init(
|
||||
graph_config=graph,
|
||||
# TODO(Mairuis): Add root node id
|
||||
root_node_id=None,
|
||||
node_factory=DifyNodeFactory(
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.VALIDATION,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=VariablePool(),
|
||||
start_at=time.perf_counter(),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
@ -997,10 +1044,11 @@ def _setup_variable_pool(
|
||||
conversation_variables: list[Variable],
|
||||
):
|
||||
# Only inject system variables for START node type.
|
||||
if node_type == NodeType.START:
|
||||
if node_type == NodeType.START or node_type.is_trigger_node:
|
||||
system_variable = SystemVariable(
|
||||
user_id=user_id,
|
||||
app_id=workflow.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=workflow.id,
|
||||
files=files or [],
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
|
||||
Reference in New Issue
Block a user