feat: introduce trigger functionality (#27644)

Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: Stream <Stream_2@qq.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
Co-authored-by: Harry <xh001x@hotmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: yessenia <yessenia.contact@gmail.com>
Co-authored-by: hjlarry <hjlarry@163.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WTW0313 <twwu@dify.ai>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Yeuoly
2025-11-12 17:59:37 +08:00
committed by CodingOnStar
parent 73ef99c25a
commit 244bd08db2
785 changed files with 41186 additions and 3724 deletions

View File

@ -37,6 +37,7 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import (
NodeType,
@ -303,6 +304,11 @@ class WorkflowResponseConverter:
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
)
elif event.node_type == NodeType.TRIGGER_PLUGIN:
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
return response

View File

@ -27,6 +27,7 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -38,10 +39,16 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@staticmethod
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
@overload
def generate(
self,
@ -53,7 +60,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
) -> Generator[Mapping | str, None, None]: ...
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Generator[Mapping[str, Any] | str, None, None]: ...
@overload
def generate(
@ -66,6 +76,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any]: ...
@overload
@ -79,7 +92,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
self,
@ -91,7 +107,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files
@ -126,17 +145,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
# for trigger debug run, not prepare user inputs
if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
inputs=inputs,
files=list(system_files),
user_id=user.id,
stream=streaming,
@ -155,7 +177,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
if triggered_from is not None:
# Use explicitly provided triggered_from (for async triggers)
workflow_triggered_from = triggered_from
elif invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
@ -182,8 +207,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
)
def resume(self, *, workflow_run_id: str) -> None:
"""
@TBD
"""
pass
def _generate(
self,
*,
@ -196,6 +229,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -231,8 +266,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
"queue_manager": queue_manager,
"context": context,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": graph_engine_layers,
},
)
@ -426,6 +463,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> None:
"""
Generate worker in a new thread.
@ -469,6 +508,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
system_user_id=system_user_id,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
)
try:

View File

@ -1,5 +1,6 @@
import logging
import time
from collections.abc import Sequence
from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -8,6 +9,7 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -16,6 +18,7 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
from models.workflow import Workflow
@ -35,17 +38,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
root_node_id: str | None = None,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
graph_engine_layers=graph_engine_layers,
)
self.application_generate_entity = application_generate_entity
self._workflow = workflow
self._sys_user_id = system_user_id
self._root_node_id = root_node_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
@ -60,6 +67,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
@ -92,6 +100,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
root_node_id=self._root_node_id,
)
# RUN WORKFLOW

View File

@ -84,6 +84,7 @@ class WorkflowBasedAppRunner:
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
root_node_id: str | None = None,
) -> Graph:
"""
Init graph
@ -117,7 +118,7 @@ class WorkflowBasedAppRunner:
)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
if not graph:
raise ValueError("graph not found in workflow")

View File

@ -32,6 +32,10 @@ class InvokeFrom(StrEnum):
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
WEB_APP = "web-app"
# TRIGGER indicates that this invocation is from a trigger.
# this is used for plugin trigger and webhook trigger.
TRIGGER = "trigger"
# EXPLORE indicates that this invocation is from
# the workflow (or chatflow) explore page.
EXPLORE = "explore"
@ -40,6 +44,9 @@ class InvokeFrom(StrEnum):
DEBUGGER = "debugger"
PUBLISHED = "published"
# VALIDATION indicates that this invocation is from validation.
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str):
"""
@ -65,6 +72,8 @@ class InvokeFrom(StrEnum):
return "dev"
elif self == InvokeFrom.EXPLORE:
return "explore_app"
elif self == InvokeFrom.TRIGGER:
return "trigger"
elif self == InvokeFrom.SERVICE_API:
return "api"

View File

@ -2,7 +2,7 @@ from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@ -55,7 +55,7 @@ class WorkflowResumptionContext(BaseModel):
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(
self,
session_factory: Engine | sessionmaker,
session_factory: Engine | sessionmaker[Session],
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
state_owner_user_id: str,
):
@ -103,10 +103,8 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
elif isinstance(self._generate_entity, AdvancedChatAppGenerateEntity):
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
else:
raise AssertionError(f"unknown entity type: type={type(self._generate_entity)}")
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
state = WorkflowResumptionContext(
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),

View File

@ -0,0 +1,21 @@
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunPausedEvent
class SuspendLayer(GraphEngineLayer):
""" """
def on_graph_start(self):
pass
def on_event(self, event: GraphEngineEvent):
"""
Handle the paused event, stash runtime state into storage and wait for resume.
"""
if isinstance(event, GraphRunPausedEvent):
pass
def on_graph_end(self, error: Exception | None):
""" """
pass

View File

@ -0,0 +1,88 @@
import logging
import uuid
from typing import ClassVar
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
from core.workflow.graph_engine.entities.commands import CommandType, GraphEngineCommand
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
logger = logging.getLogger(__name__)
class TimeSliceLayer(GraphEngineLayer):
"""
CFS plan scheduler to control the timeslice of the workflow.
"""
scheduler: ClassVar[BackgroundScheduler] = BackgroundScheduler()
def __init__(self, cfs_plan_scheduler: CFSPlanScheduler) -> None:
"""
CFS plan scheduler allows to control the timeslice of the workflow.
"""
if not TimeSliceLayer.scheduler.running:
TimeSliceLayer.scheduler.start()
super().__init__()
self.cfs_plan_scheduler = cfs_plan_scheduler
self.stopped = False
self.schedule_id = ""
def _checker_job(self, schedule_id: str):
"""
Check if the workflow need to be suspended.
"""
try:
if self.stopped:
self.scheduler.remove_job(schedule_id)
return
if self.cfs_plan_scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED:
# remove the job
self.scheduler.remove_job(schedule_id)
if not self.command_channel:
logger.exception("No command channel to stop the workflow")
return
# send command to pause the workflow
self.command_channel.send_command(
GraphEngineCommand(
command_type=CommandType.PAUSE,
payload={
"reason": SchedulerCommand.RESOURCE_LIMIT_REACHED,
},
)
)
except Exception:
logger.exception("scheduler error during check if the workflow need to be suspended")
def on_graph_start(self):
"""
Start timer to check if the workflow need to be suspended.
"""
if self.cfs_plan_scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice:
self.schedule_id = uuid.uuid4().hex
self.scheduler.add_job(
lambda: self._checker_job(self.schedule_id),
"interval",
seconds=self.cfs_plan_scheduler.plan.granularity,
id=self.schedule_id,
)
def on_event(self, event: GraphEngineEvent):
pass
def on_graph_end(self, error: Exception | None) -> None:
self.stopped = True
# remove the scheduler
if self.schedule_id:
self.scheduler.remove_job(self.schedule_id)

View File

@ -0,0 +1,88 @@
import logging
from datetime import UTC, datetime
from typing import Any, ClassVar
from pydantic import TypeAdapter
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
from models.enums import WorkflowTriggerStatus
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity
logger = logging.getLogger(__name__)
class TriggerPostLayer(GraphEngineLayer):
"""
Trigger post layer.
"""
_STATUS_MAP: ClassVar[dict[type[GraphEngineEvent], WorkflowTriggerStatus]] = {
GraphRunSucceededEvent: WorkflowTriggerStatus.SUCCEEDED,
GraphRunFailedEvent: WorkflowTriggerStatus.FAILED,
GraphRunPausedEvent: WorkflowTriggerStatus.PAUSED,
}
def __init__(
self,
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
start_time: datetime,
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
self.session_maker = session_maker
def on_graph_start(self):
pass
def on_event(self, event: GraphEngineEvent):
"""
Update trigger log with success or failure.
"""
if isinstance(event, tuple(self._STATUS_MAP.keys())):
with self.session_maker() as session:
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = repo.get_by_id(self.trigger_log_id)
if not trigger_log:
logger.exception("Trigger log not found: %s", self.trigger_log_id)
return
# Calculate elapsed time
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
# Extract relevant data from result
if not self.graph_runtime_state:
logger.exception("Graph runtime state is not set")
return
outputs = self.graph_runtime_state.outputs
# BASICLY, workflow_execution_id is the same as workflow_run_id
workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id
assert workflow_run_id, "Workflow run id is not set"
total_tokens = self.graph_runtime_state.total_tokens
# Update trigger log with success
trigger_log.status = self._STATUS_MAP[type(event)]
trigger_log.workflow_run_id = workflow_run_id
trigger_log.outputs = TypeAdapter(dict[str, Any]).dump_json(outputs).decode()
if trigger_log.elapsed_time is None:
trigger_log.elapsed_time = elapsed_time
else:
trigger_log.elapsed_time += elapsed_time
trigger_log.total_tokens = total_tokens
trigger_log.finished_at = datetime.now(UTC)
repo.update(trigger_log)
session.commit()
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -14,6 +14,7 @@ class CommonParameterType(StrEnum):
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
CHECKBOX = "checkbox"
ANY = auto()
# Dynamic select parameter

View File

@ -107,7 +107,7 @@ class CustomModelConfiguration(BaseModel):
model: str
model_type: ModelType
credentials: dict | None = None
credentials: dict | None
current_credential_id: str | None = None
current_credential_name: str | None = None
available_model_credentials: list[CredentialConfiguration] = []
@ -207,6 +207,7 @@ class ProviderConfig(BasicProviderConfig):
required: bool = False
default: Union[int, str, float, bool] | None = None
options: list[Option] | None = None
multiple: bool | None = False
label: I18nObject | None = None
help: I18nObject | None = None
url: str | None = None

View File

@ -3,7 +3,7 @@ import re
from collections.abc import Sequence
from typing import Any
from core.tools.entities.tool_entities import CredentialType
from core.plugin.entities.plugin_daemon import CredentialType
logger = logging.getLogger(__name__)

View File

@ -0,0 +1,129 @@
import contextlib
from collections.abc import Mapping
from copy import deepcopy
from typing import Any, Protocol
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> dict[str, Any] | None:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = dict(self._deep_copy(data))
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
"""
mask credentials
return a deep copy of credentials with masked values
"""
data = dict(self._deep_copy(data))
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def mask_plugin_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
return self.mask_credentials(data)
def decrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = dict(self._deep_copy(data))
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
with contextlib.suppress(Exception):
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
self.provider_config_cache.set(dict(data))
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache

View File

@ -4,7 +4,6 @@ from typing import Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
@ -16,6 +15,7 @@ from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from extensions.ext_database import db
from models import Account
from models.model import App, AppMode, EndUser
from services.end_user_service import EndUserService
class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
@ -64,7 +64,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
app = cls._get_app(app_id, tenant_id)
if not user_id:
user = create_or_update_end_user_for_user_id(app)
user = EndUserService.get_or_create_end_user(app)
else:
user = cls._get_user(user_id)

View File

@ -39,7 +39,7 @@ class PluginParameterType(StrEnum):
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
ANY = CommonParameterType.ANY
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT
CHECKBOX = CommonParameterType.CHECKBOX
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
@ -94,6 +94,7 @@ def as_normal_type(typ: StrEnum):
if typ.value in {
PluginParameterType.SECRET_INPUT,
PluginParameterType.SELECT,
PluginParameterType.CHECKBOX,
}:
return "string"
return typ.value
@ -102,7 +103,13 @@ def as_normal_type(typ: StrEnum):
def cast_parameter_value(typ: StrEnum, value: Any, /):
try:
match typ.value:
case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT:
case (
PluginParameterType.STRING
| PluginParameterType.SECRET_INPUT
| PluginParameterType.SELECT
| PluginParameterType.CHECKBOX
| PluginParameterType.DYNAMIC_SELECT
):
if value is None:
return ""
else:

View File

@ -13,6 +13,7 @@ from core.plugin.entities.base import BasePluginEntity
from core.plugin.entities.endpoint import EndpointProviderDeclaration
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntity
from core.trigger.entities.entities import TriggerProviderEntity
class PluginInstallationSource(StrEnum):
@ -63,6 +64,7 @@ class PluginCategory(StrEnum):
Extension = auto()
AgentStrategy = "agent-strategy"
Datasource = "datasource"
Trigger = "trigger"
class PluginDeclaration(BaseModel):
@ -71,6 +73,7 @@ class PluginDeclaration(BaseModel):
models: list[str] | None = Field(default_factory=list[str])
endpoints: list[str] | None = Field(default_factory=list[str])
datasources: list[str] | None = Field(default_factory=list[str])
triggers: list[str] | None = Field(default_factory=list[str])
class Meta(BaseModel):
minimum_dify_version: str | None = Field(default=None)
@ -106,6 +109,7 @@ class PluginDeclaration(BaseModel):
endpoint: EndpointProviderDeclaration | None = None
agent_strategy: AgentStrategyProviderEntity | None = None
datasource: DatasourceProviderEntity | None = None
trigger: TriggerProviderEntity | None = None
meta: Meta
@field_validator("version")
@ -129,6 +133,8 @@ class PluginDeclaration(BaseModel):
values["category"] = PluginCategory.Datasource
elif values.get("agent_strategy"):
values["category"] = PluginCategory.AgentStrategy
elif values.get("trigger"):
values["category"] = PluginCategory.Trigger
else:
values["category"] = PluginCategory.Extension
return values

View File

@ -1,3 +1,4 @@
import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
@ -14,6 +15,7 @@ from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
from core.trigger.entities.entities import TriggerProviderEntity
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
@ -205,3 +207,53 @@ class PluginListResponse(BaseModel):
class PluginDynamicSelectOptionsResponse(BaseModel):
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
class PluginTriggerProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str
plugin_id: str
declaration: TriggerProviderEntity
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
UNAUTHORIZED = "unauthorized"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
elif self == CredentialType.UNAUTHORIZED:
return "UNAUTHORIZED"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name in {"api-key", "api_key"}:
return cls.API_KEY
elif type_name in {"oauth2", "oauth"}:
return cls.OAUTH2
elif type_name == "unauthorized":
return cls.UNAUTHORIZED
else:
raise ValueError(f"Invalid credential type: {credential_type}")
class PluginReadmeResponse(BaseModel):
content: str = Field(description="The readme of the plugin.")
language: str = Field(description="The language of the readme.")

View File

@ -1,5 +1,9 @@
import binascii
import json
from collections.abc import Mapping
from typing import Any, Literal
from flask import Response
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.entities.provider_entities import BasicProviderConfig
@ -13,6 +17,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.utils.http_parser import deserialize_response
from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
@ -237,3 +242,43 @@ class RequestFetchAppInfo(BaseModel):
"""
app_id: str
class TriggerInvokeEventResponse(BaseModel):
variables: Mapping[str, Any] = Field(default_factory=dict)
cancelled: bool = Field(default=False)
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
@field_validator("variables", mode="before")
@classmethod
def convert_variables(cls, v):
if isinstance(v, str):
return json.loads(v)
else:
return v
class TriggerSubscriptionResponse(BaseModel):
subscription: dict[str, Any]
class TriggerValidateProviderCredentialsResponse(BaseModel):
result: bool
class TriggerDispatchResponse(BaseModel):
user_id: str
events: list[str]
response: Response
payload: Mapping[str, Any] = Field(default_factory=dict)
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)
@field_validator("response", mode="before")
@classmethod
def convert_response(cls, v: str):
try:
return deserialize_response(binascii.unhexlify(v.encode()))
except Exception as e:
raise ValueError("Failed to deserialize response from hex string") from e

View File

@ -10,3 +10,13 @@ class PluginAssetManager(BasePluginClient):
if response.status_code != 200:
raise ValueError(f"can not found asset {id}")
return response.content
def extract_asset(self, tenant_id: str, plugin_unique_identifier: str, filename: str) -> bytes:
response = self._request(
method="GET",
path=f"plugin/{tenant_id}/extract-asset/",
params={"plugin_unique_identifier": plugin_unique_identifier, "file_path": filename},
)
if response.status_code != 200:
raise ValueError(f"can not found asset {plugin_unique_identifier}, {str(response.status_code)}")
return response.content

View File

@ -29,6 +29,12 @@ from core.plugin.impl.exc import (
PluginPermissionDeniedError,
PluginUniqueIdentifierError,
)
from core.trigger.errors import (
EventIgnoreError,
TriggerInvokeError,
TriggerPluginInvokeError,
TriggerProviderCredentialValidationError,
)
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
_plugin_daemon_timeout_config = cast(
@ -43,7 +49,7 @@ elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
else:
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str))
logger = logging.getLogger(__name__)
@ -53,10 +59,10 @@ class BasePluginClient:
self,
method: str,
path: str,
headers: dict | None = None,
data: bytes | dict | str | None = None,
params: dict | None = None,
files: dict | None = None,
headers: dict[str, str] | None = None,
data: bytes | dict[str, Any] | str | None = None,
params: dict[str, Any] | None = None,
files: dict[str, Any] | None = None,
) -> httpx.Response:
"""
Make a request to the plugin daemon inner API.
@ -87,17 +93,17 @@ class BasePluginClient:
def _prepare_request(
self,
path: str,
headers: dict | None,
data: bytes | dict | str | None,
params: dict | None,
files: dict | None,
) -> tuple[str, dict, bytes | dict | str | None, dict | None, dict | None]:
headers: dict[str, str] | None,
data: bytes | dict[str, Any] | str | None,
params: dict[str, Any] | None,
files: dict[str, Any] | None,
) -> tuple[str, dict[str, str], bytes | dict[str, Any] | str | None, dict[str, Any] | None, dict[str, Any] | None]:
url = plugin_daemon_inner_api_baseurl / path
prepared_headers = dict(headers or {})
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
prepared_data: bytes | dict | str | None = (
prepared_data: bytes | dict[str, Any] | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
if isinstance(data, dict):
@ -112,10 +118,10 @@ class BasePluginClient:
self,
method: str,
path: str,
params: dict | None = None,
headers: dict | None = None,
data: bytes | dict | None = None,
files: dict | None = None,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
data: bytes | dict[str, Any] | None = None,
files: dict[str, Any] | None = None,
) -> Generator[str, None, None]:
"""
Make a stream request to the plugin daemon inner API
@ -138,7 +144,7 @@ class BasePluginClient:
try:
with httpx.stream(**stream_kwargs) as response:
for raw_line in response.iter_lines():
if raw_line is None:
if not raw_line:
continue
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
line = line.strip()
@ -155,10 +161,10 @@ class BasePluginClient:
method: str,
path: str,
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
files: dict | None = None,
headers: dict[str, str] | None = None,
data: bytes | dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
files: dict[str, Any] | None = None,
) -> Generator[T, None, None]:
"""
Make a stream request to the plugin daemon inner API and yield the response as a model.
@ -171,10 +177,10 @@ class BasePluginClient:
method: str,
path: str,
type_: type[T],
headers: dict | None = None,
headers: dict[str, str] | None = None,
data: bytes | None = None,
params: dict | None = None,
files: dict | None = None,
params: dict[str, Any] | None = None,
files: dict[str, Any] | None = None,
) -> T:
"""
Make a request to the plugin daemon inner API and return the response as a model.
@ -187,11 +193,11 @@ class BasePluginClient:
method: str,
path: str,
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
files: dict | None = None,
transformer: Callable[[dict], dict] | None = None,
headers: dict[str, str] | None = None,
data: bytes | dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
files: dict[str, Any] | None = None,
transformer: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
) -> T:
"""
Make a request to the plugin daemon inner API and return the response as a model.
@ -239,10 +245,10 @@ class BasePluginClient:
method: str,
path: str,
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
files: dict | None = None,
headers: dict[str, str] | None = None,
data: bytes | dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
files: dict[str, Any] | None = None,
) -> Generator[T, None, None]:
"""
Make a stream request to the plugin daemon inner API and yield the response as a model.
@ -302,6 +308,14 @@ class BasePluginClient:
raise CredentialsValidateFailedError(error_object.get("message"))
case EndpointSetupFailedError.__name__:
raise EndpointSetupFailedError(error_object.get("message"))
case TriggerProviderCredentialValidationError.__name__:
raise TriggerProviderCredentialValidationError(error_object.get("message"))
case TriggerPluginInvokeError.__name__:
raise TriggerPluginInvokeError(description=error_object.get("description"))
case TriggerInvokeError.__name__:
raise TriggerInvokeError(error_object.get("message"))
case EventIgnoreError.__name__:
raise EventIgnoreError(description=error_object.get("description"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:

View File

@ -15,6 +15,7 @@ class DynamicSelectClient(BasePluginClient):
provider: str,
action: str,
credentials: Mapping[str, Any],
credential_type: str,
parameter: str,
) -> PluginDynamicSelectOptionsResponse:
"""
@ -29,6 +30,7 @@ class DynamicSelectClient(BasePluginClient):
"data": {
"provider": GenericProviderID(provider).provider_name,
"credentials": credentials,
"credential_type": credential_type,
"provider_action": action,
"parameter": parameter,
},

View File

@ -58,6 +58,20 @@ class PluginInvokeError(PluginDaemonClientSideError, ValueError):
except Exception:
return self.description
def to_user_friendly_error(self, plugin_name: str = "currently running plugin") -> str:
"""
Convert the error to a user-friendly error message.
:param plugin_name: The name of the plugin that caused the error.
:return: A user-friendly error message.
"""
return (
f"An error occurred in the {plugin_name}, "
f"please contact the author of {plugin_name} for help, "
f"error type: {self.get_error_type()}, "
f"error details: {self.get_error_message()}"
)
class PluginUniqueIdentifierError(PluginDaemonClientSideError):
description: str = "Unique Identifier Error"

View File

@ -1,5 +1,7 @@
from collections.abc import Sequence
from requests import HTTPError
from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import (
MissingPluginDependency,
@ -13,12 +15,35 @@ from core.plugin.entities.plugin_daemon import (
PluginInstallTask,
PluginInstallTaskStartResponse,
PluginListResponse,
PluginReadmeResponse,
)
from core.plugin.impl.base import BasePluginClient
from models.provider_ids import GenericProviderID
class PluginInstaller(BasePluginClient):
def fetch_plugin_readme(self, tenant_id: str, plugin_unique_identifier: str, language: str) -> str:
"""
Fetch plugin readme
"""
try:
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/fetch/readme",
PluginReadmeResponse,
params={
"tenant_id": tenant_id,
"plugin_unique_identifier": plugin_unique_identifier,
"language": language,
},
)
return response.content
except HTTPError as e:
message = e.args[0]
if "404" in message:
return ""
raise e
def fetch_plugin_by_identifier(
self,
tenant_id: str,

View File

@ -3,14 +3,12 @@ from typing import Any
from pydantic import BaseModel
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
PluginToolProviderEntity,
)
# from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.chunk_merger import merge_blob_chunks
from core.schemas.resolver import resolve_dify_schema_refs
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from models.provider_ids import GenericProviderID, ToolProviderID

View File

@ -0,0 +1,305 @@
import binascii
from collections.abc import Generator, Mapping
from typing import Any
from flask import Request
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
from core.plugin.entities.request import (
TriggerDispatchResponse,
TriggerInvokeEventResponse,
TriggerSubscriptionResponse,
TriggerValidateProviderCredentialsResponse,
)
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.http_parser import serialize_request
from core.trigger.entities.entities import Subscription
from models.provider_ids import TriggerProviderID
class PluginTriggerClient(BasePluginClient):
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
"""
Fetch trigger providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_id = provider.get("plugin_id") + "/" + provider.get("provider")
for event in declaration.get("events", []):
event["identity"]["provider"] = provider_id
return json_response
response: list[PluginTriggerProviderEntity] = self._request_with_plugin_daemon_response(
method="GET",
path=f"plugin/{tenant_id}/management/triggers",
type_=list[PluginTriggerProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each trigger to plugin_id/provider_name
for event in provider.declaration.events:
event.identity.provider = provider.declaration.identity.name
return response
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
"""
Fetch trigger provider for the given tenant and plugin.
"""
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
data = json_response.get("data")
if data:
for event in data.get("declaration", {}).get("events", []):
event["identity"]["provider"] = str(provider_id)
return json_response
response: PluginTriggerProviderEntity = self._request_with_plugin_daemon_response(
method="GET",
path=f"plugin/{tenant_id}/management/trigger",
type_=PluginTriggerProviderEntity,
params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
transformer=transformer,
)
response.declaration.identity.name = str(provider_id)
# override the provider name for each trigger to plugin_id/provider_name
for event in response.declaration.events:
event.identity.provider = str(provider_id)
return response
def invoke_trigger_event(
self,
tenant_id: str,
user_id: str,
provider: str,
event_name: str,
credentials: Mapping[str, str],
credential_type: CredentialType,
request: Request,
parameters: Mapping[str, Any],
subscription: Subscription,
payload: Mapping[str, Any],
) -> TriggerInvokeEventResponse:
"""
Invoke a trigger with the given parameters.
"""
provider_id = TriggerProviderID(provider)
response: Generator[TriggerInvokeEventResponse, None, None] = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/trigger/invoke_event",
type_=TriggerInvokeEventResponse,
data={
"user_id": user_id,
"data": {
"provider": provider_id.provider_name,
"event": event_name,
"credentials": credentials,
"credential_type": credential_type,
"subscription": subscription.model_dump(),
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
"parameters": parameters,
"payload": payload,
},
},
headers={
"X-Plugin-ID": provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for invoke trigger")
def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str]
) -> bool:
"""
Validate the credentials of the trigger provider.
"""
provider_id = TriggerProviderID(provider)
response: Generator[TriggerValidateProviderCredentialsResponse, None, None] = (
self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/trigger/validate_credentials",
type_=TriggerValidateProviderCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider_id.provider_name,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": provider_id.plugin_id,
"Content-Type": "application/json",
},
)
)
for resp in response:
return resp.result
raise ValueError("No response received from plugin daemon for validate provider credentials")
def dispatch_event(
self,
tenant_id: str,
provider: str,
subscription: Mapping[str, Any],
request: Request,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> TriggerDispatchResponse:
"""
Dispatch an event to triggers.
"""
provider_id = TriggerProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/trigger/dispatch_event",
type_=TriggerDispatchResponse,
data={
"data": {
"provider": provider_id.provider_name,
"subscription": subscription,
"credentials": credentials,
"credential_type": credential_type,
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
},
},
headers={
"X-Plugin-ID": provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for dispatch event")
def subscribe(
self,
tenant_id: str,
user_id: str,
provider: str,
credentials: Mapping[str, str],
credential_type: CredentialType,
endpoint: str,
parameters: Mapping[str, Any],
) -> TriggerSubscriptionResponse:
"""
Subscribe to a trigger.
"""
provider_id = TriggerProviderID(provider)
response: Generator[TriggerSubscriptionResponse, None, None] = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/trigger/subscribe",
type_=TriggerSubscriptionResponse,
data={
"user_id": user_id,
"data": {
"provider": provider_id.provider_name,
"credentials": credentials,
"credential_type": credential_type,
"endpoint": endpoint,
"parameters": parameters,
},
},
headers={
"X-Plugin-ID": provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for subscribe")
def unsubscribe(
self,
tenant_id: str,
user_id: str,
provider: str,
subscription: Subscription,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> TriggerSubscriptionResponse:
"""
Unsubscribe from a trigger.
"""
provider_id = TriggerProviderID(provider)
response: Generator[TriggerSubscriptionResponse, None, None] = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/trigger/unsubscribe",
type_=TriggerSubscriptionResponse,
data={
"user_id": user_id,
"data": {
"provider": provider_id.provider_name,
"subscription": subscription.model_dump(),
"credentials": credentials,
"credential_type": credential_type,
},
},
headers={
"X-Plugin-ID": provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for unsubscribe")
def refresh(
self,
tenant_id: str,
user_id: str,
provider: str,
subscription: Subscription,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> TriggerSubscriptionResponse:
"""
Refresh a trigger subscription.
"""
provider_id = TriggerProviderID(provider)
response: Generator[TriggerSubscriptionResponse, None, None] = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/trigger/refresh",
type_=TriggerSubscriptionResponse,
data={
"user_id": user_id,
"data": {
"provider": provider_id.provider_name,
"subscription": subscription.model_dump(),
"credentials": credentials,
"credential_type": credential_type,
},
},
headers={
"X-Plugin-ID": provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for refresh")

View File

@ -0,0 +1,163 @@
from io import BytesIO
from flask import Request, Response
from werkzeug.datastructures import Headers
def serialize_request(request: Request) -> bytes:
method = request.method
path = request.full_path.rstrip("?")
raw = f"{method} {path} HTTP/1.1\r\n".encode()
for name, value in request.headers.items():
raw += f"{name}: {value}\r\n".encode()
raw += b"\r\n"
body = request.get_data(as_text=False)
if body:
raw += body
return raw
def deserialize_request(raw_data: bytes) -> Request:
header_end = raw_data.find(b"\r\n\r\n")
if header_end == -1:
header_end = raw_data.find(b"\n\n")
if header_end == -1:
header_data = raw_data
body = b""
else:
header_data = raw_data[:header_end]
body = raw_data[header_end + 2 :]
else:
header_data = raw_data[:header_end]
body = raw_data[header_end + 4 :]
lines = header_data.split(b"\r\n")
if len(lines) == 1 and b"\n" in lines[0]:
lines = header_data.split(b"\n")
if not lines or not lines[0]:
raise ValueError("Empty HTTP request")
request_line = lines[0].decode("utf-8", errors="ignore")
parts = request_line.split(" ", 2)
if len(parts) < 2:
raise ValueError(f"Invalid request line: {request_line}")
method = parts[0]
full_path = parts[1]
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
if "?" in full_path:
path, query_string = full_path.split("?", 1)
else:
path = full_path
query_string = ""
headers = Headers()
for line in lines[1:]:
if not line:
continue
line_str = line.decode("utf-8", errors="ignore")
if ":" not in line_str:
continue
name, value = line_str.split(":", 1)
headers.add(name, value.strip())
host = headers.get("Host", "localhost")
if ":" in host:
server_name, server_port = host.rsplit(":", 1)
else:
server_name = host
server_port = "80"
environ = {
"REQUEST_METHOD": method,
"PATH_INFO": path,
"QUERY_STRING": query_string,
"SERVER_NAME": server_name,
"SERVER_PORT": server_port,
"SERVER_PROTOCOL": protocol,
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
}
if "Content-Type" in headers:
content_type = headers.get("Content-Type")
if content_type is not None:
environ["CONTENT_TYPE"] = content_type
if "Content-Length" in headers:
content_length = headers.get("Content-Length")
if content_length is not None:
environ["CONTENT_LENGTH"] = content_length
elif body:
environ["CONTENT_LENGTH"] = str(len(body))
for name, value in headers.items():
if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"):
continue
env_name = f"HTTP_{name.upper().replace('-', '_')}"
environ[env_name] = value
return Request(environ)
def serialize_response(response: Response) -> bytes:
raw = f"HTTP/1.1 {response.status}\r\n".encode()
for name, value in response.headers.items():
raw += f"{name}: {value}\r\n".encode()
raw += b"\r\n"
body = response.get_data(as_text=False)
if body:
raw += body
return raw
def deserialize_response(raw_data: bytes) -> Response:
header_end = raw_data.find(b"\r\n\r\n")
if header_end == -1:
header_end = raw_data.find(b"\n\n")
if header_end == -1:
header_data = raw_data
body = b""
else:
header_data = raw_data[:header_end]
body = raw_data[header_end + 2 :]
else:
header_data = raw_data[:header_end]
body = raw_data[header_end + 4 :]
lines = header_data.split(b"\r\n")
if len(lines) == 1 and b"\n" in lines[0]:
lines = header_data.split(b"\n")
if not lines or not lines[0]:
raise ValueError("Empty HTTP response")
status_line = lines[0].decode("utf-8", errors="ignore")
parts = status_line.split(" ", 2)
if len(parts) < 2:
raise ValueError(f"Invalid status line: {status_line}")
status_code = int(parts[1])
response = Response(response=body, status=status_code)
for line in lines[1:]:
if not line:
continue
line_str = line.decode("utf-8", errors="ignore")
if ":" not in line_str:
continue
name, value = line_str.split(":", 1)
response.headers[name] = value.strip()
return response

View File

@ -3,7 +3,8 @@ from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.entities.tool_entities import ToolInvokeFrom
class ToolRuntime(BaseModel):

View File

@ -4,11 +4,11 @@ from typing import Any
from core.entities.provider_entities import ProviderConfig
from core.helper.module_import_helper import load_single_subclass_from_source
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import (
CredentialType,
OAuthSchema,
ToolEntity,
ToolProviderEntity,

View File

@ -6,9 +6,10 @@ from pydantic import BaseModel, Field, field_validator
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
from core.tools.entities.tool_entities import ToolProviderType
class ToolApiEntity(BaseModel):

View File

@ -268,6 +268,7 @@ class ToolParameter(PluginParameter):
SECRET_INPUT = PluginParameterType.SECRET_INPUT
FILE = PluginParameterType.FILE
FILES = PluginParameterType.FILES
CHECKBOX = PluginParameterType.CHECKBOX
APP_SELECTOR = PluginParameterType.APP_SELECTOR
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
ANY = PluginParameterType.ANY
@ -489,36 +490,3 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()
class CredentialType(StrEnum):
API_KEY = "api-key"
OAUTH2 = auto()
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name in {"api-key", "api_key"}:
return cls.API_KEY
elif type_name in {"oauth2", "oauth"}:
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")

View File

@ -8,7 +8,6 @@ from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import Session
from yarl import URL
@ -39,6 +38,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool import Tool
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -49,7 +49,6 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
@ -289,10 +288,8 @@ class ToolManager:
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
builtin_provider.encrypted_credentials = json.dumps(
encrypter.encrypt(refreshed_credentials.credentials)
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
@ -322,7 +319,7 @@ class ToolManager:
return api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=encrypter.decrypt(credentials),
credentials=dict(encrypter.decrypt(credentials)),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
@ -833,7 +830,7 @@ class ToolManager:
controller=controller,
)
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
try:
icon = json.loads(provider_obj.icon)

View File

@ -1,137 +1,24 @@
import contextlib
from copy import deepcopy
from typing import Any, Protocol
# Import generic components from provider_encryption module
from core.helper.provider_encryption import (
ProviderConfigCache,
ProviderConfigEncrypter,
create_provider_encrypter,
)
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
# Re-export for backward compatibility
__all__ = [
"ProviderConfigCache",
"ProviderConfigEncrypter",
"create_provider_encrypter",
"create_tool_provider_encrypter",
]
# Tool-specific imports
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> dict | None:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]):
"""Cache provider configuration"""
...
def delete(self):
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
with contextlib.suppress(Exception):
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(
tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(
tenant_id: str, controller: ToolProviderController
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,
provider_type=controller.provider_type.value,

View File

@ -0,0 +1 @@
# Core trigger module initialization

View File

@ -0,0 +1,124 @@
import hashlib
import logging
from typing import TypeVar
from redis import RedisError
from core.trigger.debug.events import BaseDebugEvent
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
TRIGGER_DEBUG_EVENT_TTL = 300
TTriggerDebugEvent = TypeVar("TTriggerDebugEvent", bound="BaseDebugEvent")
class TriggerDebugEventBus:
"""
Unified Redis-based trigger debug service with polling support.
Uses {tenant_id} hash tags for Redis Cluster compatibility.
Supports multiple event types through a generic dispatch/poll interface.
"""
# LUA_SELECT: Atomic poll or register for event
# KEYS[1] = trigger_debug_inbox:{tenant_id}:{address_id}
# KEYS[2] = trigger_debug_waiting_pool:{tenant_id}:...
# ARGV[1] = address_id
LUA_SELECT = (
"local v=redis.call('GET',KEYS[1]);"
"if v then redis.call('DEL',KEYS[1]);return v end;"
"redis.call('SADD',KEYS[2],ARGV[1]);"
f"redis.call('EXPIRE',KEYS[2],{TRIGGER_DEBUG_EVENT_TTL});"
"return false"
)
# LUA_DISPATCH: Dispatch event to all waiting addresses
# KEYS[1] = trigger_debug_waiting_pool:{tenant_id}:...
# ARGV[1] = tenant_id
# ARGV[2] = event_json
LUA_DISPATCH = (
"local a=redis.call('SMEMBERS',KEYS[1]);"
"if #a==0 then return 0 end;"
"redis.call('DEL',KEYS[1]);"
"for i=1,#a do "
f"redis.call('SET','trigger_debug_inbox:'..ARGV[1]..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});"
"end;"
"return #a"
)
@classmethod
def dispatch(
cls,
tenant_id: str,
event: BaseDebugEvent,
pool_key: str,
) -> int:
"""
Dispatch event to all waiting addresses in the pool.
Args:
tenant_id: Tenant ID for hash tag
event: Event object to dispatch
pool_key: Pool key (generate using build_{?}_pool_key(...))
Returns:
Number of addresses the event was dispatched to
"""
event_data = event.model_dump_json()
try:
result = redis_client.eval(
cls.LUA_DISPATCH,
1,
pool_key,
tenant_id,
event_data,
)
return int(result)
except RedisError:
logger.exception("Failed to dispatch event to pool: %s", pool_key)
return 0
@classmethod
def poll(
cls,
event_type: type[TTriggerDebugEvent],
pool_key: str,
tenant_id: str,
user_id: str,
app_id: str,
node_id: str,
) -> TTriggerDebugEvent | None:
"""
Poll for an event or register to the waiting pool.
If an event is available in the inbox, return it immediately.
Otherwise, register the address to the waiting pool for future dispatch.
Args:
event_class: Event class for deserialization and type safety
pool_key: Pool key (generate using build_{?}_pool_key(...))
tenant_id: Tenant ID
user_id: User ID for address calculation
app_id: App ID for address calculation
node_id: Node ID for address calculation
Returns:
Event object if available, None otherwise
"""
address_id: str = hashlib.sha256(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest()
address: str = f"trigger_debug_inbox:{tenant_id}:{address_id}"
try:
event_data = redis_client.eval(
cls.LUA_SELECT,
2,
address,
pool_key,
address_id,
)
return event_type.model_validate_json(json_data=event_data) if event_data else None
except RedisError:
logger.exception("Failed to poll event from pool: %s", pool_key)
return None

View File

@ -0,0 +1,243 @@
"""Trigger debug service supporting plugin and webhook debugging in draft workflows."""
import hashlib
import logging
import time
from abc import ABC, abstractmethod
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from pydantic import BaseModel
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import (
PluginTriggerDebugEvent,
ScheduleDebugEvent,
WebhookDebugEvent,
build_plugin_pool_key,
build_webhook_pool_key,
)
from core.workflow.enums import NodeType
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig
from extensions.ext_redis import redis_client
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
from libs.schedule_utils import calculate_next_run_at
from models.model import App
from models.provider_ids import TriggerProviderID
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class TriggerDebugEvent(BaseModel):
workflow_args: Mapping[str, Any]
node_id: str
class TriggerDebugEventPoller(ABC):
app_id: str
user_id: str
tenant_id: str
node_config: Mapping[str, Any]
node_id: str
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str):
self.tenant_id = tenant_id
self.user_id = user_id
self.app_id = app_id
self.node_config = node_config
self.node_id = node_id
@abstractmethod
def poll(self) -> TriggerDebugEvent | None:
raise NotImplementedError
class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
def poll(self) -> TriggerDebugEvent | None:
from services.trigger.trigger_service import TriggerService
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {}))
provider_id = TriggerProviderID(plugin_trigger_data.provider_id)
pool_key: str = build_plugin_pool_key(
name=plugin_trigger_data.event_name,
provider_id=str(provider_id),
tenant_id=self.tenant_id,
subscription_id=plugin_trigger_data.subscription_id,
)
plugin_trigger_event: PluginTriggerDebugEvent | None = TriggerDebugEventBus.poll(
event_type=PluginTriggerDebugEvent,
pool_key=pool_key,
tenant_id=self.tenant_id,
user_id=self.user_id,
app_id=self.app_id,
node_id=self.node_id,
)
if not plugin_trigger_event:
return None
trigger_event_response: TriggerInvokeEventResponse = TriggerService.invoke_trigger_event(
event=plugin_trigger_event,
user_id=plugin_trigger_event.user_id,
tenant_id=self.tenant_id,
node_config=self.node_config,
)
if trigger_event_response.cancelled:
return None
return TriggerDebugEvent(
workflow_args={
"inputs": trigger_event_response.variables,
"files": [],
},
node_id=self.node_id,
)
class WebhookTriggerDebugEventPoller(TriggerDebugEventPoller):
def poll(self) -> TriggerDebugEvent | None:
pool_key = build_webhook_pool_key(
tenant_id=self.tenant_id,
app_id=self.app_id,
node_id=self.node_id,
)
webhook_event: WebhookDebugEvent | None = TriggerDebugEventBus.poll(
event_type=WebhookDebugEvent,
pool_key=pool_key,
tenant_id=self.tenant_id,
user_id=self.user_id,
app_id=self.app_id,
node_id=self.node_id,
)
if not webhook_event:
return None
from services.trigger.webhook_service import WebhookService
payload = webhook_event.payload or {}
workflow_inputs = payload.get("inputs")
if workflow_inputs is None:
webhook_data = payload.get("webhook_data", {})
workflow_inputs = WebhookService.build_workflow_inputs(webhook_data)
workflow_args: Mapping[str, Any] = {
"inputs": workflow_inputs or {},
"files": [],
}
return TriggerDebugEvent(workflow_args=workflow_args, node_id=self.node_id)
class ScheduleTriggerDebugEventPoller(TriggerDebugEventPoller):
"""
Poller for schedule trigger debug events.
This poller will simulate the schedule trigger event by creating a schedule debug runtime cache
and calculating the next run at.
"""
RUNTIME_CACHE_TTL = 60 * 5
class ScheduleDebugRuntime(BaseModel):
cache_key: str
timezone: str
cron_expression: str
next_run_at: datetime
def schedule_debug_runtime_key(self, cron_hash: str) -> str:
return f"schedule_debug_runtime:{self.tenant_id}:{self.user_id}:{self.app_id}:{self.node_id}:{cron_hash}"
def get_or_create_schedule_debug_runtime(self):
from services.trigger.schedule_service import ScheduleService
schedule_config: ScheduleConfig = ScheduleService.to_schedule_config(self.node_config)
cron_hash = hashlib.sha256(schedule_config.cron_expression.encode()).hexdigest()
cache_key = self.schedule_debug_runtime_key(cron_hash)
runtime_cache = redis_client.get(cache_key)
if runtime_cache is None:
schedule_debug_runtime = self.ScheduleDebugRuntime(
cron_expression=schedule_config.cron_expression,
timezone=schedule_config.timezone,
cache_key=cache_key,
next_run_at=ensure_naive_utc(
calculate_next_run_at(schedule_config.cron_expression, schedule_config.timezone)
),
)
redis_client.setex(
name=self.schedule_debug_runtime_key(cron_hash),
time=self.RUNTIME_CACHE_TTL,
value=schedule_debug_runtime.model_dump_json(),
)
return schedule_debug_runtime
else:
redis_client.expire(cache_key, self.RUNTIME_CACHE_TTL)
runtime = self.ScheduleDebugRuntime.model_validate_json(runtime_cache)
runtime.next_run_at = ensure_naive_utc(runtime.next_run_at)
return runtime
def create_schedule_event(self, schedule_debug_runtime: ScheduleDebugRuntime) -> ScheduleDebugEvent:
redis_client.delete(schedule_debug_runtime.cache_key)
return ScheduleDebugEvent(
timestamp=int(time.time()),
node_id=self.node_id,
inputs={},
)
def poll(self) -> TriggerDebugEvent | None:
schedule_debug_runtime = self.get_or_create_schedule_debug_runtime()
if schedule_debug_runtime.next_run_at > naive_utc_now():
return None
schedule_event: ScheduleDebugEvent = self.create_schedule_event(schedule_debug_runtime)
workflow_args: Mapping[str, Any] = {
"inputs": schedule_event.inputs or {},
"files": [],
}
return TriggerDebugEvent(workflow_args=workflow_args, node_id=self.node_id)
def create_event_poller(
draft_workflow: Workflow, tenant_id: str, user_id: str, app_id: str, node_id: str
) -> TriggerDebugEventPoller:
node_config = draft_workflow.get_node_config_by_id(node_id=node_id)
if not node_config:
raise ValueError("Node data not found for node %s", node_id)
node_type = draft_workflow.get_node_type_from_node_config(node_config)
match node_type:
case NodeType.TRIGGER_PLUGIN:
return PluginTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
case NodeType.TRIGGER_WEBHOOK:
return WebhookTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
case NodeType.TRIGGER_SCHEDULE:
return ScheduleTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
case _:
raise ValueError("unable to create event poller for node type %s", node_type)
def select_trigger_debug_events(
draft_workflow: Workflow, app_model: App, user_id: str, node_ids: list[str]
) -> TriggerDebugEvent | None:
event: TriggerDebugEvent | None = None
for node_id in node_ids:
node_config = draft_workflow.get_node_config_by_id(node_id=node_id)
if not node_config:
raise ValueError("Node data not found for node %s", node_id)
poller: TriggerDebugEventPoller = create_event_poller(
draft_workflow=draft_workflow,
tenant_id=app_model.tenant_id,
user_id=user_id,
app_id=app_model.id,
node_id=node_id,
)
event = poller.poll()
if event is not None:
return event
return None

View File

@ -0,0 +1,67 @@
from collections.abc import Mapping
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class TriggerDebugPoolKey(StrEnum):
"""Trigger debug pool key."""
SCHEDULE = "schedule_trigger_debug_waiting_pool"
WEBHOOK = "webhook_trigger_debug_waiting_pool"
PLUGIN = "plugin_trigger_debug_waiting_pool"
class BaseDebugEvent(BaseModel):
"""Base class for all debug events."""
timestamp: int
class ScheduleDebugEvent(BaseDebugEvent):
"""Debug event for schedule triggers."""
node_id: str
inputs: Mapping[str, Any]
class WebhookDebugEvent(BaseDebugEvent):
"""Debug event for webhook triggers."""
request_id: str
node_id: str
payload: dict[str, Any] = Field(default_factory=dict)
def build_webhook_pool_key(tenant_id: str, app_id: str, node_id: str) -> str:
"""Generate pool key for webhook events.
Args:
tenant_id: Tenant ID
app_id: App ID
node_id: Node ID
"""
return f"{TriggerDebugPoolKey.WEBHOOK}:{tenant_id}:{app_id}:{node_id}"
class PluginTriggerDebugEvent(BaseDebugEvent):
"""Debug event for plugin triggers."""
name: str
user_id: str = Field(description="This is end user id, only for trigger the event. no related with account user id")
request_id: str
subscription_id: str
provider_id: str
def build_plugin_pool_key(tenant_id: str, provider_id: str, subscription_id: str, name: str) -> str:
"""Generate pool key for plugin trigger events.
Args:
name: Event name
tenant_id: Tenant ID
provider_id: Provider ID
subscription_id: Subscription ID
"""
return f"{TriggerDebugPoolKey.PLUGIN}:{tenant_id}:{str(provider_id)}:{subscription_id}:{name}"

View File

@ -0,0 +1,76 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.entities.common_entities import I18nObject
from core.trigger.entities.entities import (
EventIdentity,
EventParameter,
SubscriptionConstructor,
TriggerCreationMethod,
)
class TriggerProviderSubscriptionApiEntity(BaseModel):
id: str = Field(description="The unique id of the subscription")
name: str = Field(description="The name of the subscription")
provider: str = Field(description="The provider id of the subscription")
credential_type: CredentialType = Field(description="The type of the credential")
credentials: dict[str, Any] = Field(description="The credentials of the subscription")
endpoint: str = Field(description="The endpoint of the subscription")
parameters: dict[str, Any] = Field(description="The parameters of the subscription")
properties: dict[str, Any] = Field(description="The properties of the subscription")
workflows_in_use: int = Field(description="The number of workflows using this subscription")
class EventApiEntity(BaseModel):
name: str = Field(description="The name of the trigger")
identity: EventIdentity = Field(description="The identity of the trigger")
description: I18nObject = Field(description="The description of the trigger")
parameters: list[EventParameter] = Field(description="The parameters of the trigger")
output_schema: Mapping[str, Any] | None = Field(description="The output schema of the trigger")
class TriggerProviderApiEntity(BaseModel):
author: str = Field(..., description="The author of the trigger provider")
name: str = Field(..., description="The name of the trigger provider")
label: I18nObject = Field(..., description="The label of the trigger provider")
description: I18nObject = Field(..., description="The description of the trigger provider")
icon: str | None = Field(default=None, description="The icon of the trigger provider")
icon_dark: str | None = Field(default=None, description="The dark icon of the trigger provider")
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
plugin_id: str | None = Field(default="", description="The plugin id of the tool")
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool")
supported_creation_methods: list[TriggerCreationMethod] = Field(
default_factory=list,
description="Supported creation methods for the trigger provider. like 'OAUTH', 'APIKEY', 'MANUAL'.",
)
subscription_constructor: SubscriptionConstructor | None = Field(
default=None, description="The subscription constructor of the trigger provider"
)
subscription_schema: list[ProviderConfig] = Field(
default_factory=list,
description="The subscription schema of the trigger provider",
)
events: list[EventApiEntity] = Field(description="The events of the trigger provider")
class SubscriptionBuilderApiEntity(BaseModel):
id: str = Field(description="The id of the subscription builder")
name: str = Field(description="The name of the subscription builder")
provider: str = Field(description="The provider id of the subscription builder")
endpoint: str = Field(description="The endpoint id of the subscription builder")
parameters: Mapping[str, Any] = Field(description="The parameters of the subscription builder")
properties: Mapping[str, Any] = Field(description="The properties of the subscription builder")
credentials: Mapping[str, str] = Field(description="The credentials of the subscription builder")
credential_type: CredentialType = Field(description="The credential type of the subscription builder")
__all__ = ["EventApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"]

View File

@ -0,0 +1,288 @@
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.parameters import (
PluginParameterAutoGenerate,
PluginParameterOption,
PluginParameterTemplate,
PluginParameterType,
)
from core.tools.entities.common_entities import I18nObject
class EventParameterType(StrEnum):
"""The type of the parameter"""
STRING = PluginParameterType.STRING
NUMBER = PluginParameterType.NUMBER
BOOLEAN = PluginParameterType.BOOLEAN
SELECT = PluginParameterType.SELECT
FILE = PluginParameterType.FILE
FILES = PluginParameterType.FILES
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
APP_SELECTOR = PluginParameterType.APP_SELECTOR
OBJECT = PluginParameterType.OBJECT
ARRAY = PluginParameterType.ARRAY
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT
CHECKBOX = PluginParameterType.CHECKBOX
class EventParameter(BaseModel):
"""
The parameter of the event
"""
name: str = Field(..., description="The name of the parameter")
label: I18nObject = Field(..., description="The label presented to the user")
type: EventParameterType = Field(..., description="The type of the parameter")
auto_generate: PluginParameterAutoGenerate | None = Field(
default=None, description="The auto generate of the parameter"
)
template: PluginParameterTemplate | None = Field(default=None, description="The template of the parameter")
scope: str | None = None
required: bool | None = False
multiple: bool | None = Field(
default=False,
description="Whether the parameter is multiple select, only valid for select or dynamic-select type",
)
default: Union[int, float, str, list[Any], None] = None
min: Union[float, int, None] = None
max: Union[float, int, None] = None
precision: int | None = None
options: list[PluginParameterOption] | None = None
description: I18nObject | None = None
class TriggerProviderIdentity(BaseModel):
"""
The identity of the trigger provider
"""
author: str = Field(..., description="The author of the trigger provider")
name: str = Field(..., description="The name of the trigger provider")
label: I18nObject = Field(..., description="The label of the trigger provider")
description: I18nObject = Field(..., description="The description of the trigger provider")
icon: str | None = Field(default=None, description="The icon of the trigger provider")
icon_dark: str | None = Field(default=None, description="The dark icon of the trigger provider")
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
class EventIdentity(BaseModel):
"""
The identity of the event
"""
author: str = Field(..., description="The author of the event")
name: str = Field(..., description="The name of the event")
label: I18nObject = Field(..., description="The label of the event")
provider: str | None = Field(default=None, description="The provider of the event")
class EventEntity(BaseModel):
"""
The configuration of an event
"""
identity: EventIdentity = Field(..., description="The identity of the event")
parameters: list[EventParameter] = Field(
default_factory=list[EventParameter], description="The parameters of the event"
)
description: I18nObject = Field(..., description="The description of the event")
output_schema: Mapping[str, Any] | None = Field(
default=None, description="The output schema that this event produces"
)
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[EventParameter]:
return v or []
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class SubscriptionConstructor(BaseModel):
"""
The subscription constructor of the trigger provider
"""
parameters: list[EventParameter] = Field(
default_factory=list, description="The parameters schema of the subscription constructor"
)
credentials_schema: list[ProviderConfig] = Field(
default_factory=list,
description="The credentials schema of the subscription constructor",
)
oauth_schema: OAuthSchema | None = Field(
default=None,
description="The OAuth schema of the subscription constructor if OAuth is supported",
)
def get_default_parameters(self) -> Mapping[str, Any]:
"""Get the default parameters from the parameters schema"""
if not self.parameters:
return {}
return {param.name: param.default for param in self.parameters if param.default}
class TriggerProviderEntity(BaseModel):
"""
The configuration of a trigger provider
"""
identity: TriggerProviderIdentity = Field(..., description="The identity of the trigger provider")
subscription_schema: list[ProviderConfig] = Field(
default_factory=list,
description="The configuration schema stored in the subscription entity",
)
subscription_constructor: SubscriptionConstructor | None = Field(
default=None,
description="The subscription constructor of the trigger provider",
)
events: list[EventEntity] = Field(default_factory=list, description="The events of the trigger provider")
class Subscription(BaseModel):
"""
Result of a successful trigger subscription operation.
Contains all information needed to manage the subscription lifecycle.
"""
expires_at: int = Field(
..., description="The timestamp when the subscription will expire, this for refresh the subscription"
)
endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events")
parameters: Mapping[str, Any] = Field(
default_factory=dict, description="The parameters of the subscription constructor"
)
properties: Mapping[str, Any] = Field(
..., description="Subscription data containing all properties and provider-specific information"
)
class UnsubscribeResult(BaseModel):
"""
Result of a trigger unsubscription operation.
Provides detailed information about the unsubscription attempt,
including success status and error details if failed.
"""
success: bool = Field(..., description="Whether the unsubscription was successful")
message: str | None = Field(
None,
description="Human-readable message about the operation result. "
"Success message for successful operations, "
"detailed error information for failures.",
)
class RequestLog(BaseModel):
id: str = Field(..., description="The id of the request log")
endpoint: str = Field(..., description="The endpoint of the request log")
request: dict[str, Any] = Field(..., description="The request of the request log")
response: dict[str, Any] = Field(..., description="The response of the request log")
created_at: datetime = Field(..., description="The created at of the request log")
class SubscriptionBuilder(BaseModel):
id: str = Field(..., description="The id of the subscription builder")
name: str | None = Field(default=None, description="The name of the subscription builder")
tenant_id: str = Field(..., description="The tenant id of the subscription builder")
user_id: str = Field(..., description="The user id of the subscription builder")
provider_id: str = Field(..., description="The provider id of the subscription builder")
endpoint_id: str = Field(..., description="The endpoint id of the subscription builder")
parameters: Mapping[str, Any] = Field(..., description="The parameters of the subscription builder")
properties: Mapping[str, Any] = Field(..., description="The properties of the subscription builder")
credentials: Mapping[str, Any] = Field(..., description="The credentials of the subscription builder")
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
credential_expires_at: int | None = Field(
default=None, description="The credential expires at of the subscription builder"
)
expires_at: int = Field(..., description="The expires at of the subscription builder")
def to_subscription(self) -> Subscription:
return Subscription(
expires_at=self.expires_at,
endpoint=self.endpoint_id,
properties=self.properties,
)
class SubscriptionBuilderUpdater(BaseModel):
name: str | None = Field(default=None, description="The name of the subscription builder")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters of the subscription builder")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties of the subscription builder")
credentials: Mapping[str, Any] | None = Field(
default=None, description="The credentials of the subscription builder"
)
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
credential_expires_at: int | None = Field(
default=None, description="The credential expires at of the subscription builder"
)
expires_at: int | None = Field(default=None, description="The expires at of the subscription builder")
def update(self, subscription_builder: SubscriptionBuilder) -> None:
if self.name is not None:
subscription_builder.name = self.name
if self.parameters is not None:
subscription_builder.parameters = self.parameters
if self.properties is not None:
subscription_builder.properties = self.properties
if self.credentials is not None:
subscription_builder.credentials = self.credentials
if self.credential_type is not None:
subscription_builder.credential_type = self.credential_type
if self.credential_expires_at is not None:
subscription_builder.credential_expires_at = self.credential_expires_at
if self.expires_at is not None:
subscription_builder.expires_at = self.expires_at
class TriggerEventData(BaseModel):
"""Event data dispatched to trigger sessions."""
subscription_id: str
events: list[str]
request_id: str
timestamp: float
model_config = ConfigDict(arbitrary_types_allowed=True)
class TriggerCreationMethod(StrEnum):
OAUTH = "OAUTH"
APIKEY = "APIKEY"
MANUAL = "MANUAL"
# Export all entities
__all__: list[str] = [
"EventEntity",
"EventIdentity",
"EventParameter",
"EventParameterType",
"OAuthSchema",
"RequestLog",
"Subscription",
"SubscriptionBuilder",
"TriggerCreationMethod",
"TriggerEventData",
"TriggerProviderEntity",
"TriggerProviderIdentity",
"UnsubscribeResult",
]

View File

@ -0,0 +1,19 @@
from core.plugin.impl.exc import PluginInvokeError
class TriggerProviderCredentialValidationError(ValueError):
pass
class TriggerPluginInvokeError(PluginInvokeError):
pass
class TriggerInvokeError(PluginInvokeError):
pass
class EventIgnoreError(TriggerInvokeError):
"""
Trigger event ignore error
"""

View File

@ -0,0 +1,421 @@
"""
Trigger Provider Controller for managing trigger providers
"""
import logging
from collections.abc import Mapping
from typing import Any
from flask import Request
from core.entities.provider_entities import BasicProviderConfig
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import (
TriggerDispatchResponse,
TriggerInvokeEventResponse,
TriggerSubscriptionResponse,
)
from core.plugin.impl.trigger import PluginTriggerClient
from core.trigger.entities.api_entities import EventApiEntity, TriggerProviderApiEntity
from core.trigger.entities.entities import (
EventEntity,
EventParameter,
ProviderConfig,
Subscription,
SubscriptionConstructor,
TriggerCreationMethod,
TriggerProviderEntity,
TriggerProviderIdentity,
UnsubscribeResult,
)
from core.trigger.errors import TriggerProviderCredentialValidationError
from models.provider_ids import TriggerProviderID
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
class PluginTriggerProviderController:
"""
Controller for plugin trigger providers
"""
def __init__(
self,
entity: TriggerProviderEntity,
plugin_id: str,
plugin_unique_identifier: str,
provider_id: TriggerProviderID,
tenant_id: str,
):
"""
Initialize plugin trigger provider controller
:param entity: Trigger provider entity
:param plugin_id: Plugin ID
:param plugin_unique_identifier: Plugin unique identifier
:param provider_id: Provider ID
:param tenant_id: Tenant ID
"""
self.entity = entity
self.tenant_id = tenant_id
self.plugin_id = plugin_id
self.provider_id = provider_id
self.plugin_unique_identifier = plugin_unique_identifier
def get_provider_id(self) -> TriggerProviderID:
"""
Get provider ID
"""
return self.provider_id
def to_api_entity(self) -> TriggerProviderApiEntity:
"""
Convert to API entity
"""
icon = (
PluginService.get_plugin_icon_url(self.tenant_id, self.entity.identity.icon)
if self.entity.identity.icon
else None
)
icon_dark = (
PluginService.get_plugin_icon_url(self.tenant_id, self.entity.identity.icon_dark)
if self.entity.identity.icon_dark
else None
)
subscription_constructor = self.entity.subscription_constructor
supported_creation_methods = [TriggerCreationMethod.MANUAL]
if subscription_constructor and subscription_constructor.oauth_schema:
supported_creation_methods.append(TriggerCreationMethod.OAUTH)
if subscription_constructor and subscription_constructor.credentials_schema:
supported_creation_methods.append(TriggerCreationMethod.APIKEY)
return TriggerProviderApiEntity(
author=self.entity.identity.author,
name=self.entity.identity.name,
label=self.entity.identity.label,
description=self.entity.identity.description,
icon=icon,
icon_dark=icon_dark,
tags=self.entity.identity.tags,
plugin_id=self.plugin_id,
plugin_unique_identifier=self.plugin_unique_identifier,
subscription_constructor=subscription_constructor,
subscription_schema=self.entity.subscription_schema,
supported_creation_methods=supported_creation_methods,
events=[
EventApiEntity(
name=event.identity.name,
identity=event.identity,
description=event.description,
parameters=event.parameters,
output_schema=event.output_schema,
)
for event in self.entity.events
],
)
@property
def identity(self) -> TriggerProviderIdentity:
"""Get provider identity"""
return self.entity.identity
def get_events(self) -> list[EventEntity]:
"""
Get all events for this provider
:return: List of event entities
"""
return self.entity.events
def get_event(self, event_name: str) -> EventEntity | None:
"""
Get a specific event by name
:param event_name: Event name
:return: Event entity or None
"""
for event in self.entity.events:
if event.identity.name == event_name:
return event
return None
def get_subscription_default_properties(self) -> Mapping[str, Any]:
"""
Get default properties for this provider
:return: Default properties
"""
return {prop.name: prop.default for prop in self.entity.subscription_schema if prop.default}
def get_subscription_constructor(self) -> SubscriptionConstructor | None:
"""
Get subscription constructor for this provider
:return: Subscription constructor
"""
return self.entity.subscription_constructor
def validate_credentials(self, user_id: str, credentials: Mapping[str, str]) -> None:
"""
Validate credentials against schema
:param credentials: Credentials to validate
:return: Validation response
"""
# First validate against schema
subscription_constructor: SubscriptionConstructor | None = self.entity.subscription_constructor
if not subscription_constructor:
raise ValueError("Subscription constructor not found")
for config in subscription_constructor.credentials_schema or []:
if config.required and config.name not in credentials:
raise TriggerProviderCredentialValidationError(f"Missing required credential field: {config.name}")
# Then validate with the plugin daemon
manager = PluginTriggerClient()
provider_id = self.get_provider_id()
response = manager.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id=user_id,
provider=str(provider_id),
credentials=credentials,
)
if not response:
raise TriggerProviderCredentialValidationError(
"Invalid credentials",
)
def get_supported_credential_types(self) -> list[CredentialType]:
"""
Get supported credential types for this provider.
:return: List of supported credential types
"""
types: list[CredentialType] = []
subscription_constructor = self.entity.subscription_constructor
if subscription_constructor and subscription_constructor.oauth_schema:
types.append(CredentialType.OAUTH2)
if subscription_constructor and subscription_constructor.credentials_schema:
types.append(CredentialType.API_KEY)
return types
def get_credentials_schema(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
"""
Get credentials schema by credential type
:param credential_type: The type of credential (oauth or api_key)
:return: List of provider config schemas
"""
subscription_constructor = self.entity.subscription_constructor
if not subscription_constructor:
return []
credential_type = CredentialType.of(credential_type)
if credential_type == CredentialType.OAUTH2:
return (
subscription_constructor.oauth_schema.credentials_schema.copy()
if subscription_constructor and subscription_constructor.oauth_schema
else []
)
if credential_type == CredentialType.API_KEY:
return (
subscription_constructor.credentials_schema.copy() or []
if subscription_constructor and subscription_constructor.credentials_schema
else []
)
if credential_type == CredentialType.UNAUTHORIZED:
return []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]:
"""
Get credential schema config by credential type
"""
return [x.to_basic_provider_config() for x in self.get_credentials_schema(credential_type)]
def get_oauth_client_schema(self) -> list[ProviderConfig]:
"""
Get OAuth client schema for this provider
:return: List of OAuth client config schemas
"""
subscription_constructor = self.entity.subscription_constructor
return (
subscription_constructor.oauth_schema.client_schema.copy()
if subscription_constructor and subscription_constructor.oauth_schema
else []
)
def get_properties_schema(self) -> list[BasicProviderConfig]:
"""
Get properties schema for this provider
:return: List of properties config schemas
"""
return (
[x.to_basic_provider_config() for x in self.entity.subscription_schema.copy()]
if self.entity.subscription_schema
else []
)
def get_event_parameters(self, event_name: str) -> Mapping[str, EventParameter]:
"""
Get event parameters for this provider
"""
event = self.get_event(event_name)
if not event:
return {}
return {parameter.name: parameter for parameter in event.parameters}
def dispatch(
self,
request: Request,
subscription: Subscription,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> TriggerDispatchResponse:
"""
Dispatch a trigger through plugin runtime
:param user_id: User ID
:param request: Flask request object
:param subscription: Subscription
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Dispatch response with triggers and raw HTTP response
"""
manager = PluginTriggerClient()
provider_id: TriggerProviderID = self.get_provider_id()
response: TriggerDispatchResponse = manager.dispatch_event(
tenant_id=self.tenant_id,
provider=str(provider_id),
subscription=subscription.model_dump(),
request=request,
credentials=credentials,
credential_type=credential_type,
)
return response
def invoke_trigger_event(
self,
user_id: str,
event_name: str,
parameters: Mapping[str, Any],
credentials: Mapping[str, str],
credential_type: CredentialType,
subscription: Subscription,
request: Request,
payload: Mapping[str, Any],
) -> TriggerInvokeEventResponse:
"""
Execute a trigger through plugin runtime
:param user_id: User ID
:param event_name: Event name
:param parameters: Trigger parameters
:param credentials: Provider credentials
:param credential_type: Credential type
:param request: Request
:param payload: Payload
:return: Trigger execution result
"""
manager = PluginTriggerClient()
provider_id: TriggerProviderID = self.get_provider_id()
return manager.invoke_trigger_event(
tenant_id=self.tenant_id,
user_id=user_id,
provider=str(provider_id),
event_name=event_name,
credentials=credentials,
credential_type=credential_type,
request=request,
parameters=parameters,
subscription=subscription,
payload=payload,
)
def subscribe_trigger(
self,
user_id: str,
endpoint: str,
parameters: Mapping[str, Any],
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> Subscription:
"""
Subscribe to a trigger through plugin runtime
:param user_id: User ID
:param endpoint: Subscription endpoint
:param subscription_params: Subscription parameters
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Subscription result
"""
manager = PluginTriggerClient()
provider_id: TriggerProviderID = self.get_provider_id()
response: TriggerSubscriptionResponse = manager.subscribe(
tenant_id=self.tenant_id,
user_id=user_id,
provider=str(provider_id),
endpoint=endpoint,
parameters=parameters,
credentials=credentials,
credential_type=credential_type,
)
return Subscription.model_validate(response.subscription)
def unsubscribe_trigger(
self, user_id: str, subscription: Subscription, credentials: Mapping[str, str], credential_type: CredentialType
) -> UnsubscribeResult:
"""
Unsubscribe from a trigger through plugin runtime
:param user_id: User ID
:param subscription: Subscription metadata
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Unsubscribe result
"""
manager = PluginTriggerClient()
provider_id: TriggerProviderID = self.get_provider_id()
response: TriggerSubscriptionResponse = manager.unsubscribe(
tenant_id=self.tenant_id,
user_id=user_id,
provider=str(provider_id),
subscription=subscription,
credentials=credentials,
credential_type=credential_type,
)
return UnsubscribeResult.model_validate(response.subscription)
def refresh_trigger(
self, subscription: Subscription, credentials: Mapping[str, str], credential_type: CredentialType
) -> Subscription:
"""
Refresh a trigger subscription through plugin runtime
:param subscription: Subscription metadata
:param credentials: Provider credentials
:return: Refreshed subscription result
"""
manager = PluginTriggerClient()
provider_id: TriggerProviderID = self.get_provider_id()
response: TriggerSubscriptionResponse = manager.refresh(
tenant_id=self.tenant_id,
user_id="system", # System refresh
provider=str(provider_id),
subscription=subscription,
credentials=credentials,
credential_type=credential_type,
)
return Subscription.model_validate(response.subscription)
__all__ = ["PluginTriggerProviderController"]

View File

@ -0,0 +1,285 @@
"""
Trigger Manager for loading and managing trigger providers and triggers
"""
import logging
from collections.abc import Mapping
from threading import Lock
from typing import Any
from flask import Request
import contexts
from configs import dify_config
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError
from core.plugin.impl.trigger import PluginTriggerClient
from core.trigger.entities.entities import (
EventEntity,
Subscription,
UnsubscribeResult,
)
from core.trigger.errors import EventIgnoreError
from core.trigger.provider import PluginTriggerProviderController
from models.provider_ids import TriggerProviderID
logger = logging.getLogger(__name__)
class TriggerManager:
"""
Manager for trigger providers and triggers
"""
@classmethod
def get_trigger_plugin_icon(cls, tenant_id: str, provider_id: str) -> str:
"""
Get the icon of a trigger plugin
"""
manager = PluginTriggerClient()
provider: PluginTriggerProviderEntity = manager.fetch_trigger_provider(
tenant_id=tenant_id, provider_id=TriggerProviderID(provider_id)
)
filename = provider.declaration.identity.icon
base_url = f"{dify_config.CONSOLE_API_URL}/console/api/workspaces/current/plugin/icon"
return f"{base_url}?tenant_id={tenant_id}&filename={filename}"
@classmethod
def list_plugin_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
"""
List all plugin trigger providers for a tenant
:param tenant_id: Tenant ID
:return: List of trigger provider controllers
"""
manager = PluginTriggerClient()
provider_entities = manager.fetch_trigger_providers(tenant_id)
controllers: list[PluginTriggerProviderController] = []
for provider in provider_entities:
try:
controller = PluginTriggerProviderController(
entity=provider.declaration,
plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
provider_id=TriggerProviderID(provider.provider),
tenant_id=tenant_id,
)
controllers.append(controller)
except Exception:
logger.exception("Failed to load trigger provider %s", provider.plugin_id)
continue
return controllers
@classmethod
def get_trigger_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderController:
"""
Get a specific plugin trigger provider
:param tenant_id: Tenant ID
:param provider_id: Provider ID
:return: Trigger provider controller or None
"""
# check if context is set
try:
contexts.plugin_trigger_providers.get()
except LookupError:
contexts.plugin_trigger_providers.set({})
contexts.plugin_trigger_providers_lock.set(Lock())
plugin_trigger_providers = contexts.plugin_trigger_providers.get()
provider_id_str = str(provider_id)
if provider_id_str in plugin_trigger_providers:
return plugin_trigger_providers[provider_id_str]
with contexts.plugin_trigger_providers_lock.get():
# double check
plugin_trigger_providers = contexts.plugin_trigger_providers.get()
if provider_id_str in plugin_trigger_providers:
return plugin_trigger_providers[provider_id_str]
try:
manager = PluginTriggerClient()
provider = manager.fetch_trigger_provider(tenant_id, provider_id)
if not provider:
raise ValueError(f"Trigger provider {provider_id} not found")
controller = PluginTriggerProviderController(
entity=provider.declaration,
plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
provider_id=provider_id,
tenant_id=tenant_id,
)
plugin_trigger_providers[provider_id_str] = controller
return controller
except PluginNotFoundError as e:
raise ValueError(f"Trigger provider {provider_id} not found") from e
except PluginDaemonError as e:
raise e
except Exception as e:
logger.exception("Failed to load trigger provider")
raise e
@classmethod
def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
"""
List all trigger providers (plugin)
:param tenant_id: Tenant ID
:return: List of all trigger provider controllers
"""
return cls.list_plugin_trigger_providers(tenant_id)
@classmethod
def list_triggers_by_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[EventEntity]:
"""
List all triggers for a specific provider
:param tenant_id: Tenant ID
:param provider_id: Provider ID
:return: List of trigger entities
"""
provider = cls.get_trigger_provider(tenant_id, provider_id)
return provider.get_events()
@classmethod
def invoke_trigger_event(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
event_name: str,
parameters: Mapping[str, Any],
credentials: Mapping[str, str],
credential_type: CredentialType,
subscription: Subscription,
request: Request,
payload: Mapping[str, Any],
) -> TriggerInvokeEventResponse:
"""
Execute a trigger
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider_id: Provider ID
:param event_name: Event name
:param parameters: Trigger parameters
:param credentials: Provider credentials
:param credential_type: Credential type
:param subscription: Subscription
:param request: Request
:param payload: Payload
:return: Trigger execution result
"""
provider: PluginTriggerProviderController = cls.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
try:
return provider.invoke_trigger_event(
user_id=user_id,
event_name=event_name,
parameters=parameters,
credentials=credentials,
credential_type=credential_type,
subscription=subscription,
request=request,
payload=payload,
)
except EventIgnoreError:
return TriggerInvokeEventResponse(variables={}, cancelled=True)
except Exception as e:
raise e
@classmethod
def subscribe_trigger(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
endpoint: str,
parameters: Mapping[str, Any],
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> Subscription:
"""
Subscribe to a trigger (e.g., register webhook)
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider_id: Provider ID
:param endpoint: Subscription endpoint
:param parameters: Subscription parameters
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Subscription result
"""
provider: PluginTriggerProviderController = cls.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
return provider.subscribe_trigger(
user_id=user_id,
endpoint=endpoint,
parameters=parameters,
credentials=credentials,
credential_type=credential_type,
)
@classmethod
def unsubscribe_trigger(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
subscription: Subscription,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> UnsubscribeResult:
"""
Unsubscribe from a trigger
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider_id: Provider ID
:param subscription: Subscription metadata from subscribe operation
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Unsubscription result
"""
provider: PluginTriggerProviderController = cls.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
return provider.unsubscribe_trigger(
user_id=user_id,
subscription=subscription,
credentials=credentials,
credential_type=credential_type,
)
@classmethod
def refresh_trigger(
cls,
tenant_id: str,
provider_id: TriggerProviderID,
subscription: Subscription,
credentials: Mapping[str, str],
credential_type: CredentialType,
) -> Subscription:
"""
Refresh a trigger subscription
:param tenant_id: Tenant ID
:param provider_id: Provider ID
:param subscription: Subscription metadata from subscribe operation
:param credentials: Provider credentials
:param credential_type: Credential type
:return: Refreshed subscription result
"""
# TODO you should update the subscription using the return value of the refresh_trigger
return cls.get_trigger_provider(tenant_id=tenant_id, provider_id=provider_id).refresh_trigger(
subscription=subscription, credentials=credentials, credential_type=credential_type
)

View File

@ -0,0 +1,145 @@
from collections.abc import Mapping
from typing import Union
from core.entities.provider_entities import BasicProviderConfig, ProviderConfig
from core.helper.provider_cache import ProviderCredentialsCache
from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
from core.trigger.provider import PluginTriggerProviderController
from models.trigger import TriggerSubscription
class TriggerProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for trigger provider credentials"""
def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_id = kwargs["provider_id"]
credential_id = kwargs["credential_id"]
return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache):
"""Cache for trigger provider OAuth client"""
def __init__(self, tenant_id: str, provider_id: str):
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_id = kwargs["provider_id"]
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
class TriggerProviderPropertiesCache(ProviderCredentialsCache):
"""Cache for trigger provider properties"""
def __init__(self, tenant_id: str, provider_id: str, subscription_id: str):
super().__init__(tenant_id=tenant_id, provider_id=provider_id, subscription_id=subscription_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_id = kwargs["provider_id"]
subscription_id = kwargs["subscription_id"]
return f"trigger_properties:tenant_id:{tenant_id}:provider_id:{provider_id}:subscription_id:{subscription_id}"
def create_trigger_provider_encrypter_for_subscription(
tenant_id: str,
controller: PluginTriggerProviderController,
subscription: Union[TriggerSubscription, TriggerProviderSubscriptionApiEntity],
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
cache = TriggerProviderCredentialsCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
credential_id=subscription.id,
)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=controller.get_credential_schema_config(subscription.credential_type),
cache=cache,
)
return encrypter, cache
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
cache = TriggerProviderCredentialsCache(
tenant_id=tenant_id,
provider_id=provider_id,
credential_id=subscription_id,
)
cache.delete()
def create_trigger_provider_encrypter_for_properties(
tenant_id: str,
controller: PluginTriggerProviderController,
subscription: Union[TriggerSubscription, TriggerProviderSubscriptionApiEntity],
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
cache = TriggerProviderPropertiesCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
subscription_id=subscription.id,
)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=controller.get_properties_schema(),
cache=cache,
)
return encrypter, cache
def create_trigger_provider_encrypter(
tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
cache = TriggerProviderCredentialsCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
credential_id=credential_id,
)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=controller.get_credential_schema_config(credential_type),
cache=cache,
)
return encrypter, cache
def create_trigger_provider_oauth_encrypter(
tenant_id: str, controller: PluginTriggerProviderController
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
cache = TriggerProviderOAuthClientParamsCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_oauth_client_schema()],
cache=cache,
)
return encrypter, cache
def masked_credentials(
schemas: list[ProviderConfig],
credentials: Mapping[str, str],
) -> Mapping[str, str]:
masked_credentials = {}
configs = {x.name: x.to_basic_provider_config() for x in schemas}
for key, value in credentials.items():
config = configs.get(key)
if not config:
masked_credentials[key] = value
continue
if config.type == BasicProviderConfig.Type.SECRET_INPUT:
if len(value) <= 4:
masked_credentials[key] = "*" * len(value)
else:
masked_credentials[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
else:
masked_credentials[key] = value
return masked_credentials

View File

@ -0,0 +1,24 @@
from yarl import URL
from configs import dify_config
"""
Basic URL for thirdparty trigger services
"""
base_url = URL(dify_config.TRIGGER_URL)
def generate_plugin_trigger_endpoint_url(endpoint_id: str) -> str:
"""
Generate url for plugin trigger endpoint url
"""
return str(base_url / "triggers" / "plugin" / endpoint_id)
def generate_webhook_trigger_endpoint(webhook_id: str, debug: bool = False) -> str:
"""
Generate url for webhook trigger endpoint url
"""
return str(base_url / "triggers" / ("webhook-debug" if debug else "webhook") / webhook_id)

View File

@ -0,0 +1,12 @@
from collections.abc import Sequence
from itertools import starmap
def build_trigger_refresh_lock_key(tenant_id: str, subscription_id: str) -> str:
"""Build the Redis lock key for trigger subscription refresh in-flight protection."""
return f"trigger_provider_refresh_lock:{tenant_id}_{subscription_id}"
def build_trigger_refresh_lock_keys(pairs: Sequence[tuple[str, str]]) -> list[str]:
"""Build Redis lock keys for a sequence of (tenant_id, subscription_id) pairs."""
return list(starmap(build_trigger_refresh_lock_key, pairs))

View File

@ -22,6 +22,7 @@ class SystemVariableKey(StrEnum):
APP_ID = "app_id"
WORKFLOW_ID = "workflow_id"
WORKFLOW_EXECUTION_ID = "workflow_run_id"
TIMESTAMP = "timestamp"
# RAG Pipeline
DOCUMENT_ID = "document_id"
ORIGINAL_DOCUMENT_ID = "original_document_id"
@ -58,8 +59,31 @@ class NodeType(StrEnum):
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
TRIGGER_WEBHOOK = "trigger-webhook"
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
HUMAN_INPUT = "human-input"
@property
def is_trigger_node(self) -> bool:
"""Check if this node type is a trigger node."""
return self in [
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
@property
def is_start_node(self) -> bool:
"""Check if this node type can serve as a workflow entry point."""
return self in [
NodeType.START,
NodeType.DATASOURCE,
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
class NodeExecutionType(StrEnum):
"""Node execution type classification."""
@ -208,6 +232,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
TRIGGER_INFO = "trigger_info"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"

View File

@ -117,7 +117,7 @@ class Graph:
node_type = node_data.get("type")
if not isinstance(node_type, str):
continue
if node_type in [NodeType.START, NodeType.DATASOURCE]:
if NodeType(node_type).is_start_node:
start_node_id = nid
break

View File

@ -114,9 +114,45 @@ class GraphValidator:
raise GraphValidationError(issues)
@dataclass(frozen=True, slots=True)
class _TriggerStartExclusivityValidator:
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
start_node_id: str | None = None
trigger_node_ids: list[str] = []
for node in graph.nodes.values():
node_type = getattr(node, "node_type", None)
if not isinstance(node_type, NodeType):
continue
if node_type == NodeType.START:
start_node_id = node.id
elif node_type.is_trigger_node:
trigger_node_ids.append(node.id)
if start_node_id and trigger_node_ids:
trigger_list = ", ".join(trigger_node_ids)
return [
GraphValidationIssue(
code=self.conflict_code,
message=(
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
),
node_id=start_node_id,
)
]
return []
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
_TriggerStartExclusivityValidator(),
)

View File

@ -126,6 +126,12 @@ class Node:
start_event.provider_id = f"{plugin_id}/{provider_name}"
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
if isinstance(self, TriggerEventNode):
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from typing import cast
from core.workflow.nodes.agent.agent_node import AgentNode

View File

@ -22,6 +22,9 @@ from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.trigger_plugin import TriggerEventNode
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
@ -147,4 +150,16 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
LATEST_VERSION: KnowledgeIndexNode,
"1": KnowledgeIndexNode,
},
NodeType.TRIGGER_WEBHOOK: {
LATEST_VERSION: TriggerWebhookNode,
"1": TriggerWebhookNode,
},
NodeType.TRIGGER_PLUGIN: {
LATEST_VERSION: TriggerEventNode,
"1": TriggerEventNode,
},
NodeType.TRIGGER_SCHEDULE: {
LATEST_VERSION: TriggerScheduleNode,
"1": TriggerScheduleNode,
},
}

View File

@ -164,10 +164,7 @@ class ToolNode(Node):
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error="An error occurred in the plugin, "
f"please contact the author of {node_data.provider_name} for help, "
f"error type: {e.get_error_type()}, "
f"error details: {e.get_error_message()}",
error=e.to_user_friendly_error(plugin_name=node_data.provider_name),
error_type=type(e).__name__,
)
)

View File

@ -0,0 +1,3 @@
from .trigger_event_node import TriggerEventNode
__all__ = ["TriggerEventNode"]

View File

@ -0,0 +1,77 @@
from collections.abc import Mapping
from typing import Any, Literal, Union
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.trigger.entities.entities import EventParameter
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.trigger_plugin.exc import TriggerEventParameterError
class TriggerEventNodeData(BaseNodeData):
"""Plugin trigger node data"""
class TriggerEventInput(BaseModel):
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
type = value
value = validation_info.data.get("value")
if value is None:
return type
if type == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
if type == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
if type == "constant" and not isinstance(value, str | int | float | bool | dict | list):
raise ValueError("value must be a string, int, float, bool or dict")
return type
title: str
desc: str | None = None
plugin_id: str = Field(..., description="Plugin ID")
provider_id: str = Field(..., description="Provider ID")
event_name: str = Field(..., description="Event name")
subscription_id: str = Field(..., description="Subscription ID")
plugin_unique_identifier: str = Field(..., description="Plugin unique identifier")
event_parameters: Mapping[str, TriggerEventInput] = Field(default_factory=dict, description="Trigger parameters")
def resolve_parameters(
self,
*,
parameter_schemas: Mapping[str, EventParameter],
) -> Mapping[str, Any]:
"""
Generate parameters based on the given plugin trigger parameters.
Args:
parameter_schemas (Mapping[str, EventParameter]): The mapping of parameter schemas.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
result: dict[str, Any] = {}
for parameter_name in self.event_parameters:
parameter: EventParameter | None = parameter_schemas.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
event_input = self.event_parameters[parameter_name]
# trigger node only supports constant input
if event_input.type != "constant":
raise TriggerEventParameterError(f"Unknown plugin trigger input type '{event_input.type}'")
result[parameter_name] = event_input.value
return result

View File

@ -0,0 +1,10 @@
class TriggerEventNodeError(ValueError):
"""Base exception for plugin trigger node errors."""
pass
class TriggerEventParameterError(TriggerEventNodeError):
"""Exception raised for errors in plugin trigger parameters."""
pass

View File

@ -0,0 +1,89 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import TriggerEventNodeData
class TriggerEventNode(Node):
node_type = NodeType.TRIGGER_PLUGIN
execution_type = NodeExecutionType.ROOT
_node_data: TriggerEventNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = TriggerEventNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "plugin",
"config": {
"title": "",
"plugin_id": "",
"provider_id": "",
"event_name": "",
"subscription_id": "",
"plugin_unique_identifier": "",
"event_parameters": {},
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the plugin trigger node.
This node invokes the trigger to convert request data into events
and makes them available to downstream nodes.
"""
# Get trigger data passed when workflow was triggered
metadata = {
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
"provider_id": self._node_data.provider_id,
"event_name": self._node_data.event_name,
"plugin_unique_identifier": self._node_data.plugin_unique_identifier,
},
}
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
outputs=outputs,
metadata=metadata,
)

View File

@ -0,0 +1,3 @@
from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode
__all__ = ["TriggerScheduleNode"]

View File

@ -0,0 +1,49 @@
from typing import Literal, Union
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
class TriggerScheduleNodeData(BaseNodeData):
"""
Trigger Schedule Node Data
"""
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
visual_config: dict | None = Field(default=None, description="Visual configuration details")
timezone: str = Field(default="UTC", description="Timezone for schedule execution")
class ScheduleConfig(BaseModel):
node_id: str
cron_expression: str
timezone: str = "UTC"
class SchedulePlanUpdate(BaseModel):
node_id: str | None = None
cron_expression: str | None = None
timezone: str | None = None
class VisualConfig(BaseModel):
"""Visual configuration for schedule trigger"""
# For hourly frequency
on_minute: int | None = Field(default=0, ge=0, le=59, description="Minute of the hour (0-59)")
# For daily, weekly, monthly frequencies
time: str | None = Field(default="12:00 AM", description="Time in 12-hour format (e.g., '2:30 PM')")
# For weekly frequency
weekdays: list[Literal["sun", "mon", "tue", "wed", "thu", "fri", "sat"]] | None = Field(
default=None, description="List of weekdays to run on"
)
# For monthly frequency
monthly_days: list[Union[int, Literal["last"]]] | None = Field(
default=None, description="Days of month to run on (1-31 or 'last')"
)

View File

@ -0,0 +1,31 @@
from core.workflow.nodes.base.exc import BaseNodeError
class ScheduleNodeError(BaseNodeError):
"""Base schedule node error."""
pass
class ScheduleNotFoundError(ScheduleNodeError):
"""Schedule not found error."""
pass
class ScheduleConfigError(ScheduleNodeError):
"""Schedule configuration error."""
pass
class ScheduleExecutionError(ScheduleNodeError):
"""Schedule execution error."""
pass
class TenantOwnerNotFoundError(ScheduleExecutionError):
"""Tenant owner not found error for schedule execution."""
pass

View File

@ -0,0 +1,69 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData
class TriggerScheduleNode(Node):
node_type = NodeType.TRIGGER_SCHEDULE
execution_type = NodeExecutionType.ROOT
_node_data: TriggerScheduleNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = TriggerScheduleNodeData(**data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "trigger-schedule",
"config": {
"mode": "visual",
"frequency": "daily",
"visual_config": {"time": "12:00 AM", "on_minute": 0, "weekdays": ["sun"], "monthly_days": [1]},
"timezone": "UTC",
},
}
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
outputs=outputs,
)

View File

@ -0,0 +1,3 @@
from .node import TriggerWebhookNode
__all__ = ["TriggerWebhookNode"]

View File

@ -0,0 +1,79 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel, Field, field_validator
from core.workflow.nodes.base import BaseNodeData
class Method(StrEnum):
GET = "get"
POST = "post"
HEAD = "head"
PATCH = "patch"
PUT = "put"
DELETE = "delete"
class ContentType(StrEnum):
JSON = "application/json"
FORM_DATA = "multipart/form-data"
FORM_URLENCODED = "application/x-www-form-urlencoded"
TEXT = "text/plain"
BINARY = "application/octet-stream"
class WebhookParameter(BaseModel):
"""Parameter definition for headers, query params, or body."""
name: str
required: bool = False
class WebhookBodyParameter(BaseModel):
"""Body parameter with type information."""
name: str
type: Literal[
"string",
"number",
"boolean",
"object",
"array[string]",
"array[number]",
"array[boolean]",
"array[object]",
"file",
] = "string"
required: bool = False
class WebhookData(BaseNodeData):
"""
Webhook Node Data.
"""
class SyncMode(StrEnum):
SYNC = "async" # only support
method: Method = Method.GET
content_type: ContentType = Field(default=ContentType.JSON)
headers: Sequence[WebhookParameter] = Field(default_factory=list)
params: Sequence[WebhookParameter] = Field(default_factory=list) # query parameters
body: Sequence[WebhookBodyParameter] = Field(default_factory=list)
@field_validator("method", mode="before")
@classmethod
def normalize_method(cls, v) -> str:
"""Normalize HTTP method to lowercase to support both uppercase and lowercase input."""
if isinstance(v, str):
return v.lower()
return v
status_code: int = 200 # Expected status code for response
response_body: str = "" # Template for response body
# Webhook specific fields (not from client data, set internally)
webhook_id: str | None = None # Set when webhook trigger is created
timeout: int = 30 # Timeout in seconds to wait for webhook response

View File

@ -0,0 +1,25 @@
from core.workflow.nodes.base.exc import BaseNodeError
class WebhookNodeError(BaseNodeError):
"""Base webhook node error."""
pass
class WebhookTimeoutError(WebhookNodeError):
"""Webhook timeout error."""
pass
class WebhookNotFoundError(WebhookNodeError):
"""Webhook not found error."""
pass
class WebhookConfigError(WebhookNodeError):
"""Webhook configuration error."""
pass

View File

@ -0,0 +1,148 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import ContentType, WebhookData
class TriggerWebhookNode(Node):
node_type = NodeType.TRIGGER_WEBHOOK
execution_type = NodeExecutionType.ROOT
_node_data: WebhookData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = WebhookData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "webhook",
"config": {
"method": "get",
"content_type": "application/json",
"headers": [],
"params": [],
"body": [],
"async_mode": True,
"status_code": 200,
"response_body": "",
"timeout": 30,
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the webhook node.
Like the start node, this simply takes the webhook data from the variable pool
and makes it available to downstream nodes. The actual webhook handling
happens in the trigger controller.
"""
# Get webhook data from variable pool (injected by Celery task)
webhook_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
# Extract webhook-specific outputs based on node configuration
outputs = self._extract_configured_outputs(webhook_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
outputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=webhook_inputs,
outputs=outputs,
)
def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]:
"""Extract outputs based on node configuration from webhook inputs."""
outputs = {}
# Get the raw webhook data (should be injected by Celery task)
webhook_data = webhook_inputs.get("webhook_data", {})
def _to_sanitized(name: str) -> str:
return name.replace("-", "_")
def _get_normalized(mapping: dict[str, Any], key: str) -> Any:
if not isinstance(mapping, dict):
return None
if key in mapping:
return mapping[key]
alternate = key.replace("-", "_") if "-" in key else key.replace("_", "-")
if alternate in mapping:
return mapping[alternate]
return None
# Extract configured headers (case-insensitive)
webhook_headers = webhook_data.get("headers", {})
webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()}
for header in self._node_data.headers:
header_name = header.name
value = _get_normalized(webhook_headers, header_name)
if value is None:
value = _get_normalized(webhook_headers_lower, header_name.lower())
sanitized_name = _to_sanitized(header_name)
outputs[sanitized_name] = value
# Extract configured query parameters
for param in self._node_data.params:
param_name = param.name
outputs[param_name] = webhook_data.get("query_params", {}).get(param_name)
# Extract configured body parameters
for body_param in self._node_data.body:
param_name = body_param.name
param_type = body_param.type
if self._node_data.content_type == ContentType.TEXT:
# For text/plain, the entire body is a single string parameter
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self._node_data.content_type == ContentType.BINARY:
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
continue
if param_type == "file":
# Get File object (already processed by webhook controller)
file_obj = webhook_data.get("files", {}).get(param_name)
outputs[param_name] = file_obj
else:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
# Include raw webhook data for debugging/advanced use
outputs["_webhook_raw"] = webhook_data
return outputs

View File

@ -29,6 +29,8 @@ class SystemVariable(BaseModel):
app_id: str | None = None
workflow_id: str | None = None
timestamp: int | None = None
files: Sequence[File] = Field(default_factory=list)
# NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`.
@ -108,6 +110,8 @@ class SystemVariable(BaseModel):
d[SystemVariableKey.DATASOURCE_INFO] = self.datasource_info
if self.invoke_from is not None:
d[SystemVariableKey.INVOKE_FROM] = self.invoke_from
if self.timestamp is not None:
d[SystemVariableKey.TIMESTAMP] = self.timestamp
return d
def as_view(self) -> "SystemVariableReadOnlyView":