feat: Human Input Node (#32060)

The frontend and backend implementation for the human input node.

Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
QuantumGhost
2026-02-09 14:57:23 +08:00
committed by GitHub
parent 56e3a55023
commit a1fc280102
474 changed files with 32667 additions and 2050 deletions

View File

@ -44,7 +44,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.5.0"
CURRENT_DSL_VERSION = "0.6.0"
class ImportMode(StrEnum):

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from collections.abc import Callable, Generator, Mapping
from typing import TYPE_CHECKING, Any, Union
from configs import dify_config
@ -9,22 +11,63 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting import RateLimit
from core.app.features.rate_limiting.rate_limit import rate_limit_context
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.db import session_factory
from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from models.workflow import Workflow, WorkflowRun
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_service import WorkflowService
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
logger = logging.getLogger(__name__)
SSE_TASK_START_FALLBACK_MS = 200
if TYPE_CHECKING:
from controllers.console.app.workflow import LoopNodeRunPayload
class AppGenerateService:
@staticmethod
def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]:
started = False
lock = threading.Lock()
def _try_start() -> bool:
nonlocal started
with lock:
if started:
return True
try:
start_task()
except Exception:
logger.exception("Failed to enqueue streaming task")
return False
started = True
return True
# XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber.
# The Celery task may publish the first event before the API side actually subscribes,
# causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe,
# but also use a short fallback timer so the task still runs if the client never consumes.
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)
timer.daemon = True
timer.start()
def _on_subscribe() -> None:
if _try_start():
timer.cancel()
return _on_subscribe
@classmethod
@trace_span(AppGenerateHandler)
def generate(
@ -88,15 +131,29 @@ class AppGenerateService:
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
generator = AdvancedChatAppGenerator()
return rate_limit.generate(
AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
generator.convert_to_event_stream(
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id=request_id,
@ -104,6 +161,40 @@ class AppGenerateService:
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
root_node_id=root_node_id,
workflow_run_id=str(uuid.uuid4()),
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
MessageBasedAppGenerator.retrieve_events(
AppMode.WORKFLOW,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
@ -112,9 +203,10 @@ class AppGenerateService:
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
streaming=False,
root_node_id=root_node_id,
call_depth=0,
pause_state_config=pause_config,
),
),
request_id,
@ -248,3 +340,19 @@ class AppGenerateService:
raise ValueError("Workflow not published")
return workflow
@classmethod
def get_response_generator(
cls,
app_model: App,
workflow_run: WorkflowRun,
):
if workflow_run.status.is_ended():
# TODO(QuantumGhost): handled the ended scenario.
pass
generator = AdvancedChatAppGenerator()
return generator.convert_to_event_stream(
generator.retrieve_events(AppMode(app_model.mode), workflow_run.id),
)

View File

@ -136,7 +136,7 @@ class AudioService:
message = db.session.query(Message).where(Message.id == message_id).first()
if message is None:
return None
if message.answer == "" and message.status == MessageStatus.NORMAL:
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:
return None
else:

View File

@ -138,6 +138,8 @@ class FeatureModel(BaseModel):
is_allow_transfer_workspace: bool = True
trigger_event: Quota = Quota(usage=0, limit=3000, reset_date=0)
api_rate_limit: Quota = Quota(usage=0, limit=5000, reset_date=0)
# Controls whether email delivery is allowed for HumanInput nodes.
human_input_email_delivery_enabled: bool = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
knowledge_pipeline: KnowledgePipeline = KnowledgePipeline()
@ -191,6 +193,11 @@ class FeatureService:
features.knowledge_pipeline.publish_enabled = True
cls._fulfill_params_from_workspace_info(features, tenant_id)
features.human_input_email_delivery_enabled = cls._resolve_human_input_email_delivery_enabled(
features=features,
tenant_id=tenant_id,
)
return features
@classmethod
@ -203,6 +210,17 @@ class FeatureService:
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", CloudPlan.SANDBOX)
return knowledge_rate_limit
@classmethod
def _resolve_human_input_email_delivery_enabled(cls, *, features: FeatureModel, tenant_id: str | None) -> bool:
if dify_config.ENTERPRISE_ENABLED or not dify_config.BILLING_ENABLED:
return True
if not tenant_id:
return False
return features.billing.enabled and features.billing.subscription.plan in (
CloudPlan.PROFESSIONAL,
CloudPlan.TEAM,
)
@classmethod
def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
system_features = SystemFeatureModel()

View File

@ -0,0 +1,249 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Protocol
from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.workflow.nodes.human_input.entities import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
ExternalRecipient,
MemberRecipient,
)
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from extensions.ext_mail import mail
from libs.email_template_renderer import render_email_template
from models import Account, TenantAccountJoin
from services.feature_service import FeatureService
class DeliveryTestStatus(StrEnum):
OK = "ok"
FAILED = "failed"
@dataclass(frozen=True)
class DeliveryTestEmailRecipient:
email: str
form_token: str
@dataclass(frozen=True)
class DeliveryTestContext:
tenant_id: str
app_id: str
node_id: str
node_title: str | None
rendered_content: str
template_vars: dict[str, str] = field(default_factory=dict)
recipients: list[DeliveryTestEmailRecipient] = field(default_factory=list)
variable_pool: VariablePool | None = None
@dataclass(frozen=True)
class DeliveryTestResult:
status: DeliveryTestStatus
delivered_to: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
class DeliveryTestError(Exception):
pass
class DeliveryTestUnsupportedError(DeliveryTestError):
pass
def _build_form_link(token: str | None) -> str | None:
if not token:
return None
base_url = dify_config.APP_WEB_URL
if not base_url:
return None
return f"{base_url.rstrip('/')}/form/{token}"
class DeliveryTestHandler(Protocol):
def supports(self, method: DeliveryChannelConfig) -> bool: ...
def send_test(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult: ...
class DeliveryTestRegistry:
def __init__(self, handlers: list[DeliveryTestHandler] | None = None) -> None:
self._handlers = list(handlers or [])
def register(self, handler: DeliveryTestHandler) -> None:
self._handlers.append(handler)
def dispatch(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult:
for handler in self._handlers:
if handler.supports(method):
return handler.send_test(context=context, method=method)
raise DeliveryTestUnsupportedError("Delivery method does not support test send.")
@classmethod
def default(cls) -> DeliveryTestRegistry:
return cls([EmailDeliveryTestHandler()])
class HumanInputDeliveryTestService:
def __init__(self, registry: DeliveryTestRegistry | None = None) -> None:
self._registry = registry or DeliveryTestRegistry.default()
def send_test(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult:
return self._registry.dispatch(context=context, method=method)
class EmailDeliveryTestHandler:
def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None:
if session_factory is None:
session_factory = sessionmaker(bind=db.engine)
elif isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def supports(self, method: DeliveryChannelConfig) -> bool:
return isinstance(method, EmailDeliveryMethod)
def send_test(
self,
*,
context: DeliveryTestContext,
method: DeliveryChannelConfig,
) -> DeliveryTestResult:
if not isinstance(method, EmailDeliveryMethod):
raise DeliveryTestUnsupportedError("Delivery method does not support test send.")
features = FeatureService.get_features(context.tenant_id)
if not features.human_input_email_delivery_enabled:
raise DeliveryTestError("Email delivery is not available for current plan.")
if not mail.is_inited():
raise DeliveryTestError("Mail client is not initialized.")
recipients = self._resolve_recipients(
tenant_id=context.tenant_id,
method=method,
)
if not recipients:
raise DeliveryTestError("No recipients configured for delivery method.")
delivered: list[str] = []
for recipient_email in recipients:
substitutions = self._build_substitutions(
context=context,
recipient_email=recipient_email,
)
subject = render_email_template(method.config.subject, substitutions)
templated_body = EmailDeliveryConfig.render_body_template(
body=method.config.body,
url=substitutions.get("form_link"),
variable_pool=context.variable_pool,
)
body = render_email_template(templated_body, substitutions)
mail.send(
to=recipient_email,
subject=subject,
html=body,
)
delivered.append(recipient_email)
return DeliveryTestResult(status=DeliveryTestStatus.OK, delivered_to=delivered)
def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]:
recipients = method.config.recipients
emails: list[str] = []
member_user_ids: list[str] = []
for recipient in recipients.items:
if isinstance(recipient, MemberRecipient):
member_user_ids.append(recipient.user_id)
elif isinstance(recipient, ExternalRecipient):
if recipient.email:
emails.append(recipient.email)
if recipients.whole_workspace:
member_user_ids = []
member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None)
emails.extend(member_emails.values())
elif member_user_ids:
member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids)
for user_id in member_user_ids:
email = member_emails.get(user_id)
if email:
emails.append(email)
return list(dict.fromkeys([email for email in emails if email]))
def _query_workspace_member_emails(
self,
*,
tenant_id: str,
user_ids: list[str] | None,
) -> dict[str, str]:
if user_ids is None:
unique_ids = None
else:
unique_ids = {user_id for user_id in user_ids if user_id}
if not unique_ids:
return {}
stmt = (
select(Account.id, Account.email)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == tenant_id)
)
if unique_ids is not None:
stmt = stmt.where(Account.id.in_(unique_ids))
with self._session_factory() as session:
rows = session.execute(stmt).all()
return dict(rows)
@staticmethod
def _build_substitutions(
*,
context: DeliveryTestContext,
recipient_email: str,
) -> dict[str, str]:
raw_values: dict[str, str | None] = {
"form_id": "",
"node_title": context.node_title,
"workflow_run_id": "",
"form_token": "",
"form_link": "",
"form_content": context.rendered_content,
"recipient_email": recipient_email,
}
substitutions = {key: value or "" for key, value in raw_values.items()}
if context.template_vars:
substitutions.update({key: value for key, value in context.template_vars.items() if value is not None})
token = next(
(recipient.form_token for recipient in context.recipients if recipient.email == recipient_email),
None,
)
if token:
substitutions["form_token"] = token
substitutions["form_link"] = _build_form_link(token) or ""
return substitutions

View File

@ -0,0 +1,250 @@
import logging
from collections.abc import Mapping
from datetime import datetime, timedelta
from typing import Any
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormSubmissionRepository,
)
from core.workflow.nodes.human_input.entities import (
FormDefinition,
HumanInputSubmissionValidationError,
validate_human_input_submission,
)
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
from libs.exception import BaseHTTPException
from models.human_input import RecipientType
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE, resume_app_execution
class Form:
def __init__(self, record: HumanInputFormRecord):
self._record = record
def get_definition(self) -> FormDefinition:
return self._record.definition
@property
def submitted(self) -> bool:
return self._record.submitted
@property
def id(self) -> str:
return self._record.form_id
@property
def workflow_run_id(self) -> str | None:
"""Workflow run id for runtime forms; None for delivery tests."""
return self._record.workflow_run_id
@property
def tenant_id(self) -> str:
return self._record.tenant_id
@property
def app_id(self) -> str:
return self._record.app_id
@property
def recipient_id(self) -> str | None:
return self._record.recipient_id
@property
def recipient_type(self) -> RecipientType | None:
return self._record.recipient_type
@property
def status(self) -> HumanInputFormStatus:
return self._record.status
@property
def form_kind(self) -> HumanInputFormKind:
return self._record.form_kind
@property
def created_at(self) -> "datetime":
return self._record.created_at
@property
def expiration_time(self) -> "datetime":
return self._record.expiration_time
class HumanInputError(Exception):
pass
class FormSubmittedError(HumanInputError, BaseHTTPException):
error_code = "human_input_form_submitted"
description = "This form has already been submitted by another user, form_id={form_id}"
code = 412
def __init__(self, form_id: str):
template = self.description or "This form has already been submitted by another user, form_id={form_id}"
description = template.format(form_id=form_id)
super().__init__(description=description)
class FormNotFoundError(HumanInputError, BaseHTTPException):
error_code = "human_input_form_not_found"
code = 404
class InvalidFormDataError(HumanInputError, BaseHTTPException):
error_code = "invalid_form_data"
code = 400
def __init__(self, description: str):
super().__init__(description=description)
class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
pass
class FormExpiredError(HumanInputError, BaseHTTPException):
error_code = "human_input_form_expired"
code = 412
def __init__(self, form_id: str):
super().__init__(description=f"This form has expired, form_id={form_id}")
logger = logging.getLogger(__name__)
class HumanInputService:
def __init__(
self,
session_factory: sessionmaker[Session] | Engine,
form_repository: HumanInputFormSubmissionRepository | None = None,
):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory)
def get_form_by_token(self, form_token: str) -> Form | None:
record = self._form_repository.get_by_token(form_token)
if record is None:
return None
return Form(record)
def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form | None:
form = self.get_form_by_token(form_token)
if form is None or form.recipient_type != recipient_type:
return None
self._ensure_not_submitted(form)
return form
def get_form_definition_by_token_for_console(self, form_token: str) -> Form | None:
form = self.get_form_by_token(form_token)
if form is None or form.recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
return None
self._ensure_not_submitted(form)
return form
def submit_form_by_token(
self,
recipient_type: RecipientType,
form_token: str,
selected_action_id: str,
form_data: Mapping[str, Any],
submission_end_user_id: str | None = None,
submission_user_id: str | None = None,
):
form = self.get_form_by_token(form_token)
if form is None or form.recipient_type != recipient_type:
raise WebAppDeliveryNotEnabledError()
self.ensure_form_active(form)
self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data)
result = self._form_repository.mark_submitted(
form_id=form.id,
recipient_id=form.recipient_id,
selected_action_id=selected_action_id,
form_data=form_data,
submission_user_id=submission_user_id,
submission_end_user_id=submission_end_user_id,
)
if result.form_kind != HumanInputFormKind.RUNTIME:
return
if result.workflow_run_id is None:
return
self.enqueue_resume(result.workflow_run_id)
def ensure_form_active(self, form: Form) -> None:
if form.submitted:
raise FormSubmittedError(form.id)
if form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
raise FormExpiredError(form.id)
now = naive_utc_now()
if ensure_naive_utc(form.expiration_time) <= now:
raise FormExpiredError(form.id)
if self._is_globally_expired(form, now=now):
raise FormExpiredError(form.id)
def _ensure_not_submitted(self, form: Form) -> None:
if form.submitted:
raise FormSubmittedError(form.id)
def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
definition = form.get_definition()
try:
validate_human_input_submission(
inputs=definition.inputs,
user_actions=definition.user_actions,
selected_action_id=selected_action_id,
form_data=form_data,
)
except HumanInputSubmissionValidationError as exc:
raise InvalidFormDataError(str(exc)) from exc
def enqueue_resume(self, workflow_run_id: str) -> None:
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(workflow_run_id)
if workflow_run is None:
raise AssertionError(f"WorkflowRun not found, id={workflow_run_id}")
with self._session_factory(expire_on_commit=False) as session:
app_query = select(App).where(App.id == workflow_run.app_id)
app = session.execute(app_query).scalar_one_or_none()
if app is None:
logger.error(
"App not found for WorkflowRun, workflow_run_id=%s, app_id=%s", workflow_run_id, workflow_run.app_id
)
return
if app.mode in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
payload = {"workflow_run_id": workflow_run_id}
try:
resume_app_execution.apply_async(
kwargs={"payload": payload},
queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE,
)
except Exception: # pragma: no cover
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)
return
logger.warning("App mode %s does not support resume for workflow run %s", app.mode, workflow_run_id)
def _is_globally_expired(self, form: Form, *, now: datetime | None = None) -> bool:
global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS
if global_timeout_seconds <= 0:
return False
if form.workflow_run_id is None:
return False
current = now or naive_utc_now()
created_at = ensure_naive_utc(form.created_at)
global_deadline = created_at + timedelta(seconds=global_timeout_seconds)
return global_deadline <= current

View File

@ -1,6 +1,9 @@
import json
from collections.abc import Sequence
from typing import Union
from sqlalchemy.orm import sessionmaker
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
@ -14,6 +17,10 @@ from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
from repositories.sqlalchemy_execution_extra_content_repository import (
SQLAlchemyExecutionExtraContentRepository,
)
from services.conversation_service import ConversationService
from services.errors.message import (
FirstMessageNotExistsError,
@ -24,6 +31,23 @@ from services.errors.message import (
from services.workflow_service import WorkflowService
def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
return SQLAlchemyExecutionExtraContentRepository(session_maker=session_maker)
def attach_message_extra_contents(messages: Sequence[Message]) -> None:
if not messages:
return
repository = _create_execution_extra_content_repository()
extra_contents_lists = repository.get_by_message_ids([message.id for message in messages])
for index, message in enumerate(messages):
contents = extra_contents_lists[index] if index < len(extra_contents_lists) else []
message.set_extra_contents([content.model_dump(mode="json", exclude_none=True) for content in contents])
class MessageService:
@classmethod
def pagination_by_first_id(
@ -85,6 +109,8 @@ class MessageService:
if order == "asc":
history_messages = list(reversed(history_messages))
attach_message_extra_contents(history_messages)
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
@classmethod

View File

@ -10,6 +10,7 @@ from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
@ -63,6 +64,8 @@ class WorkflowToolManageService:
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}")
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
@ -152,6 +155,8 @@ class WorkflowToolManageService:
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
workflow_tool_provider.name = name
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)

View File

@ -98,6 +98,12 @@ class WorkflowTaskData(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class WorkflowResumeTaskData(BaseModel):
"""Payload for workflow resumption tasks."""
workflow_run_id: str
class AsyncTriggerExecutionResult(BaseModel):
"""Result from async trigger-based workflow execution"""

View File

@ -0,0 +1,460 @@
from __future__ import annotations
import json
import logging
import queue
import threading
import time
from collections.abc import Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import Any
from sqlalchemy import desc, select
from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.message_generator import MessageGenerator
from core.app.entities.task_entities import (
MessageReplaceStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
StreamEvent,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.workflow.entities import WorkflowStartReason
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from core.workflow.runtime import GraphRuntimeState
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models.model import AppMode, Message
from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.factory import DifyAPIRepositoryFactory
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class MessageContext:
conversation_id: str
message_id: str
created_at: int
answer: str | None = None
@dataclass
class BufferState:
queue: queue.Queue[Mapping[str, Any]]
stop_event: threading.Event
done_event: threading.Event
task_id_ready: threading.Event
task_id_hint: str | None = None
def build_workflow_event_stream(
*,
app_mode: AppMode,
workflow_run: WorkflowRun,
tenant_id: str,
app_id: str,
session_maker: sessionmaker[Session],
idle_timeout: float = 300,
ping_interval: float = 10.0,
) -> Generator[Mapping[str, Any] | str, None, None]:
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
message_context = (
_get_message_context(session_maker, workflow_run.id) if app_mode == AppMode.ADVANCED_CHAT else None
)
pause_entity: WorkflowPauseEntity | None = None
if workflow_run.status == WorkflowExecutionStatus.PAUSED:
try:
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run.id)
except Exception:
logger.exception("Failed to load workflow pause for run %s", workflow_run.id)
pause_entity = None
resumption_context = _load_resumption_context(pause_entity)
node_snapshots = node_execution_repo.get_execution_snapshots_by_workflow_run(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_run.workflow_id,
# NOTE(QuantumGhost): for events resumption, we only care about
# the execution records from `WORKFLOW_RUN`.
#
# Ideally filtering with `workflow_run_id` is enough. However,
# due to the index of `WorkflowNodeExecution` table, we have to
# add a filter condition of `triggered_from` to
# ensure that we can utilize the index.
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=workflow_run.id,
)
def _generate() -> Generator[Mapping[str, Any] | str, None, None]:
# send a PING event immediately to prevent the connection staying in pending state for a long time.
#
# This simplify the debugging process as the DevTools in Chrome does not
# provide complete curl command for pending connections.
yield StreamEvent.PING.value
last_msg_time = time.time()
last_ping_time = last_msg_time
with topic.subscribe() as sub:
buffer_state = _start_buffering(sub)
try:
task_id = _resolve_task_id(resumption_context, buffer_state, workflow_run.id)
snapshot_events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=node_snapshots,
task_id=task_id,
message_context=message_context,
pause_entity=pause_entity,
resumption_context=resumption_context,
)
for event in snapshot_events:
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event, include_paused=True):
return
while True:
if buffer_state.done_event.is_set() and buffer_state.queue.empty():
return
try:
event = buffer_state.queue.get(timeout=0.1)
except queue.Empty:
current_time = time.time()
if current_time - last_msg_time > idle_timeout:
logger.debug(
"No workflow events received for %s seconds, keeping stream open",
idle_timeout,
)
last_msg_time = current_time
if current_time - last_ping_time >= ping_interval:
yield StreamEvent.PING.value
last_ping_time = current_time
continue
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event, include_paused=True):
return
finally:
buffer_state.stop_event.set()
return _generate()
def _get_message_context(session_maker: sessionmaker[Session], workflow_run_id: str) -> MessageContext | None:
with session_maker() as session:
stmt = select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(desc(Message.created_at))
message = session.scalar(stmt)
if message is None:
return None
created_at = int(message.created_at.timestamp()) if message.created_at else 0
return MessageContext(
conversation_id=message.conversation_id,
message_id=message.id,
created_at=created_at,
answer=message.answer,
)
def _load_resumption_context(pause_entity: WorkflowPauseEntity | None) -> WorkflowResumptionContext | None:
if pause_entity is None:
return None
try:
raw_state = pause_entity.get_state().decode()
return WorkflowResumptionContext.loads(raw_state)
except Exception:
logger.exception("Failed to load resumption context")
return None
def _resolve_task_id(
resumption_context: WorkflowResumptionContext | None,
buffer_state: BufferState | None,
workflow_run_id: str,
wait_timeout: float = 0.2,
) -> str:
if resumption_context is not None:
generate_entity = resumption_context.get_generate_entity()
if generate_entity.task_id:
return generate_entity.task_id
if buffer_state is None:
return workflow_run_id
if buffer_state.task_id_hint is None:
buffer_state.task_id_ready.wait(timeout=wait_timeout)
if buffer_state.task_id_hint:
return buffer_state.task_id_hint
return workflow_run_id
def _build_snapshot_events(
*,
workflow_run: WorkflowRun,
node_snapshots: Sequence[WorkflowNodeExecutionSnapshot],
task_id: str,
message_context: MessageContext | None,
pause_entity: WorkflowPauseEntity | None,
resumption_context: WorkflowResumptionContext | None,
) -> list[Mapping[str, Any]]:
events: list[Mapping[str, Any]] = []
workflow_started = _build_workflow_started_event(
workflow_run=workflow_run,
task_id=task_id,
)
_apply_message_context(workflow_started, message_context)
events.append(workflow_started)
if message_context is not None and message_context.answer is not None:
message_replace = _build_message_replace_event(task_id=task_id, answer=message_context.answer)
_apply_message_context(message_replace, message_context)
events.append(message_replace)
for snapshot in node_snapshots:
node_started = _build_node_started_event(
workflow_run_id=workflow_run.id,
task_id=task_id,
snapshot=snapshot,
)
_apply_message_context(node_started, message_context)
events.append(node_started)
if snapshot.status != WorkflowNodeExecutionStatus.RUNNING.value:
node_finished = _build_node_finished_event(
workflow_run_id=workflow_run.id,
task_id=task_id,
snapshot=snapshot,
)
_apply_message_context(node_finished, message_context)
events.append(node_finished)
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
pause_event = _build_pause_event(
workflow_run=workflow_run,
workflow_run_id=workflow_run.id,
task_id=task_id,
pause_entity=pause_entity,
resumption_context=resumption_context,
)
if pause_event is not None:
_apply_message_context(pause_event, message_context)
events.append(pause_event)
return events
def _build_workflow_started_event(
*,
workflow_run: WorkflowRun,
task_id: str,
) -> dict[str, Any]:
response = WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=WorkflowStartStreamResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
inputs=workflow_run.inputs_dict or {},
created_at=int(workflow_run.created_at.timestamp()),
reason=WorkflowStartReason.INITIAL,
),
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
return payload
def _build_message_replace_event(*, task_id: str, answer: str) -> dict[str, Any]:
response = MessageReplaceStreamResponse(
task_id=task_id,
answer=answer,
reason="",
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
return payload
def _build_node_started_event(
*,
workflow_run_id: str,
task_id: str,
snapshot: WorkflowNodeExecutionSnapshot,
) -> dict[str, Any]:
created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=NodeStartStreamResponse.Data(
id=snapshot.execution_id,
node_id=snapshot.node_id,
node_type=snapshot.node_type,
title=snapshot.title,
index=snapshot.index,
predecessor_node_id=None,
inputs=None,
created_at=created_at,
extras={},
iteration_id=snapshot.iteration_id,
loop_id=snapshot.loop_id,
),
)
return response.to_ignore_detail_dict()
def _build_node_finished_event(
*,
workflow_run_id: str,
task_id: str,
snapshot: WorkflowNodeExecutionSnapshot,
) -> dict[str, Any]:
created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0
finished_at = int(snapshot.finished_at.timestamp()) if snapshot.finished_at else created_at
response = NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=NodeFinishStreamResponse.Data(
id=snapshot.execution_id,
node_id=snapshot.node_id,
node_type=snapshot.node_type,
title=snapshot.title,
index=snapshot.index,
predecessor_node_id=None,
inputs=None,
process_data=None,
outputs=None,
status=WorkflowNodeExecutionStatus(snapshot.status),
error=None,
elapsed_time=snapshot.elapsed_time,
execution_metadata=None,
created_at=created_at,
finished_at=finished_at,
files=[],
iteration_id=snapshot.iteration_id,
loop_id=snapshot.loop_id,
),
)
return response.to_ignore_detail_dict()
def _build_pause_event(
*,
workflow_run: WorkflowRun,
workflow_run_id: str,
task_id: str,
pause_entity: WorkflowPauseEntity,
resumption_context: WorkflowResumptionContext | None,
) -> dict[str, Any] | None:
paused_nodes: list[str] = []
outputs: dict[str, Any] = {}
if resumption_context is not None:
state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
paused_nodes = state.get_paused_nodes()
outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {}))
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
response = WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=WorkflowPauseStreamResponse.Data(
workflow_run_id=workflow_run_id,
paused_nodes=paused_nodes,
outputs=outputs,
reasons=reasons,
status=workflow_run.status,
created_at=int(workflow_run.created_at.timestamp()),
elapsed_time=float(workflow_run.elapsed_time or 0.0),
total_tokens=int(workflow_run.total_tokens or 0),
total_steps=int(workflow_run.total_steps or 0),
),
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
return payload
def _apply_message_context(payload: dict[str, Any], message_context: MessageContext | None) -> None:
if message_context is None:
return
payload["conversation_id"] = message_context.conversation_id
payload["message_id"] = message_context.message_id
payload["created_at"] = message_context.created_at
def _start_buffering(subscription) -> BufferState:
buffer_state = BufferState(
queue=queue.Queue(maxsize=2048),
stop_event=threading.Event(),
done_event=threading.Event(),
task_id_ready=threading.Event(),
)
def _worker() -> None:
dropped_count = 0
try:
while not buffer_state.stop_event.is_set():
msg = subscription.receive(timeout=0.1)
if msg is None:
continue
event = _parse_event_message(msg)
if event is None:
continue
task_id = event.get("task_id")
if task_id and buffer_state.task_id_hint is None:
buffer_state.task_id_hint = str(task_id)
buffer_state.task_id_ready.set()
try:
buffer_state.queue.put_nowait(event)
except queue.Full:
dropped_count += 1
try:
buffer_state.queue.get_nowait()
except queue.Empty:
pass
try:
buffer_state.queue.put_nowait(event)
except queue.Full:
continue
logger.warning("Dropped buffered workflow event, total_dropped=%s", dropped_count)
except Exception:
logger.exception("Failed while buffering workflow events")
finally:
buffer_state.done_event.set()
thread = threading.Thread(target=_worker, name=f"workflow-event-buffer-{id(subscription)}", daemon=True)
thread.start()
return buffer_state
def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
try:
event = json.loads(message)
except json.JSONDecodeError:
logger.warning("Failed to decode workflow event payload")
return None
if not isinstance(event, dict):
return None
return event
def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool:
if not isinstance(event, Mapping):
return False
event_type = event.get("event")
if event_type == StreamEvent.WORKFLOW_FINISHED.value:
return True
if include_paused:
return event_type == StreamEvent.WORKFLOW_PAUSED.value
return False

View File

@ -1,4 +1,5 @@
import json
import logging
import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
@ -11,21 +12,34 @@ from configs import dify_config
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.variables import VariableBase
from core.variables.variables import Variable
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.entities import GraphInitParams, WorkflowNodeExecution
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.human_input.entities import (
DeliveryChannelConfig,
HumanInputNodeData,
apply_debug_email_recipient,
validate_human_input_submission,
)
from core.workflow.nodes.human_input.enums import HumanInputFormKind
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import VariablePool
from core.workflow.repositories.human_input_form_repository import FormCreateParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from enums.cloud_plan import CloudPlan
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@ -34,6 +48,8 @@ from extensions.ext_storage import storage
from factories.file_factory import build_from_mapping, build_from_mappings
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import UserFrom
from models.human_input import HumanInputFormRecipient, RecipientType
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
@ -44,6 +60,13 @@ from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededEr
from services.workflow.workflow_converter import WorkflowConverter
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
from .human_input_delivery_test_service import (
DeliveryTestContext,
DeliveryTestEmailRecipient,
DeliveryTestError,
DeliveryTestUnsupportedError,
HumanInputDeliveryTestService,
)
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
@ -744,6 +767,344 @@ class WorkflowService:
return workflow_node_execution
def get_human_input_form_preview(
self,
*,
app_model: App,
account: Account,
node_id: str,
inputs: Mapping[str, Any] | None = None,
) -> Mapping[str, Any]:
"""
Build a human input form preview for a draft workflow.
Args:
app_model: Target application model.
account: Current account.
node_id: Human input node ID.
inputs: Values used to fill missing upstream variables referenced in form_content.
"""
draft_workflow = self.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError("Workflow not initialized")
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
if node_type is not NodeType.HUMAN_INPUT:
raise ValueError("Node type must be human-input.")
# inputs: values used to fill missing upstream variables referenced in form_content.
variable_pool = self._build_human_input_variable_pool(
app_model=app_model,
workflow=draft_workflow,
node_config=node_config,
manual_inputs=inputs or {},
)
node = self._build_human_input_node(
workflow=draft_workflow,
account=account,
node_config=node_config,
variable_pool=variable_pool,
)
rendered_content = node.render_form_content_before_submission()
resolved_default_values = node.resolve_default_values()
node_data = node.node_data
human_input_required = HumanInputRequired(
form_id=node_id,
form_content=rendered_content,
inputs=node_data.inputs,
actions=node_data.user_actions,
node_id=node_id,
node_title=node.title,
resolved_default_values=resolved_default_values,
form_token=None,
)
return human_input_required.model_dump(mode="json")
def submit_human_input_form_preview(
self,
*,
app_model: App,
account: Account,
node_id: str,
form_inputs: Mapping[str, Any],
inputs: Mapping[str, Any] | None = None,
action: str,
) -> Mapping[str, Any]:
"""
Submit a human input form preview for a draft workflow.
Args:
app_model: Target application model.
account: Current account.
node_id: Human input node ID.
form_inputs: Values the user provides for the form's own fields.
inputs: Values used to fill missing upstream variables referenced in form_content.
action: Selected action ID.
"""
draft_workflow = self.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError("Workflow not initialized")
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
if node_type is not NodeType.HUMAN_INPUT:
raise ValueError("Node type must be human-input.")
# inputs: values used to fill missing upstream variables referenced in form_content.
# form_inputs: values the user provides for the form's own fields.
variable_pool = self._build_human_input_variable_pool(
app_model=app_model,
workflow=draft_workflow,
node_config=node_config,
manual_inputs=inputs or {},
)
node = self._build_human_input_node(
workflow=draft_workflow,
account=account,
node_config=node_config,
variable_pool=variable_pool,
)
node_data = node.node_data
validate_human_input_submission(
inputs=node_data.inputs,
user_actions=node_data.user_actions,
selected_action_id=action,
form_data=form_inputs,
)
rendered_content = node.render_form_content_before_submission()
outputs: dict[str, Any] = dict(form_inputs)
outputs["__action_id"] = action
outputs["__rendered_content"] = node.render_form_content_with_outputs(
rendered_content, outputs, node_data.outputs_field_names()
)
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
session=session,
app_id=app_model.id,
node_id=node_id,
node_type=NodeType.HUMAN_INPUT,
node_execution_id=str(uuid.uuid4()),
user=account,
enclosing_node_id=enclosing_node_id,
)
draft_var_saver.save(outputs=outputs, process_data={})
session.commit()
return outputs
def test_human_input_delivery(
self,
*,
app_model: App,
account: Account,
node_id: str,
delivery_method_id: str,
inputs: Mapping[str, Any] | None = None,
) -> None:
draft_workflow = self.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError("Workflow not initialized")
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
if node_type is not NodeType.HUMAN_INPUT:
raise ValueError("Node type must be human-input.")
node_data = HumanInputNodeData.model_validate(node_config.get("data", {}))
delivery_method = self._resolve_human_input_delivery_method(
node_data=node_data,
delivery_method_id=delivery_method_id,
)
if delivery_method is None:
raise ValueError("Delivery method not found.")
delivery_method = apply_debug_email_recipient(
delivery_method,
enabled=True,
user_id=account.id or "",
)
variable_pool = self._build_human_input_variable_pool(
app_model=app_model,
workflow=draft_workflow,
node_config=node_config,
manual_inputs=inputs or {},
)
node = self._build_human_input_node(
workflow=draft_workflow,
account=account,
node_config=node_config,
variable_pool=variable_pool,
)
rendered_content = node.render_form_content_before_submission()
resolved_default_values = node.resolve_default_values()
form_id, recipients = self._create_human_input_delivery_test_form(
app_model=app_model,
node_id=node_id,
node_data=node_data,
delivery_method=delivery_method,
rendered_content=rendered_content,
resolved_default_values=resolved_default_values,
)
test_service = HumanInputDeliveryTestService()
context = DeliveryTestContext(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
node_id=node_id,
node_title=node_data.title,
rendered_content=rendered_content,
template_vars={"form_id": form_id},
recipients=recipients,
variable_pool=variable_pool,
)
try:
test_service.send_test(context=context, method=delivery_method)
except DeliveryTestUnsupportedError as exc:
raise ValueError("Delivery method does not support test send.") from exc
except DeliveryTestError as exc:
raise ValueError(str(exc)) from exc
@staticmethod
def _resolve_human_input_delivery_method(
*,
node_data: HumanInputNodeData,
delivery_method_id: str,
) -> DeliveryChannelConfig | None:
for method in node_data.delivery_methods:
if str(method.id) == delivery_method_id:
return method
return None
def _create_human_input_delivery_test_form(
self,
*,
app_model: App,
node_id: str,
node_data: HumanInputNodeData,
delivery_method: DeliveryChannelConfig,
rendered_content: str,
resolved_default_values: Mapping[str, Any],
) -> tuple[str, list[DeliveryTestEmailRecipient]]:
repo = HumanInputFormRepositoryImpl(session_factory=db.engine, tenant_id=app_model.tenant_id)
params = FormCreateParams(
app_id=app_model.id,
workflow_execution_id=None,
node_id=node_id,
form_config=node_data,
rendered_content=rendered_content,
delivery_methods=[delivery_method],
display_in_ui=False,
resolved_default_values=resolved_default_values,
form_kind=HumanInputFormKind.DELIVERY_TEST,
)
form_entity = repo.create_form(params)
return form_entity.id, self._load_email_recipients(form_entity.id)
@staticmethod
def _load_email_recipients(form_id: str) -> list[DeliveryTestEmailRecipient]:
logger = logging.getLogger(__name__)
with Session(bind=db.engine) as session:
recipients = session.scalars(
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id)
).all()
recipients_data: list[DeliveryTestEmailRecipient] = []
for recipient in recipients:
if recipient.recipient_type not in {RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL}:
continue
if not recipient.access_token:
continue
try:
payload = json.loads(recipient.recipient_payload)
except Exception:
logger.exception("Failed to parse human input recipient payload for delivery test.")
continue
email = payload.get("email")
if email:
recipients_data.append(DeliveryTestEmailRecipient(email=email, form_token=recipient.access_token))
return recipients_data
def _build_human_input_node(
self,
*,
workflow: Workflow,
account: Account,
node_config: Mapping[str, Any],
variable_pool: VariablePool,
) -> HumanInputNode:
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
graph_config=workflow.graph_dict,
user_id=account.id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.DEBUGGER.value,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
node = HumanInputNode(
id=node_config.get("id", str(uuid.uuid4())),
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
return node
def _build_human_input_variable_pool(
self,
*,
app_model: App,
workflow: Workflow,
node_config: Mapping[str, Any],
manual_inputs: Mapping[str, Any],
) -> VariablePool:
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
variable_loader = DraftVarLoader(
engine=db.engine,
app_id=app_model.id,
tenant_id=app_model.tenant_id,
)
variable_mapping = HumanInputNode.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict,
config=node_config,
)
normalized_user_inputs: dict[str, Any] = dict(manual_inputs)
load_into_variable_pool(
variable_loader=variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=normalized_user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=normalized_user_inputs,
variable_pool=variable_pool,
tenant_id=app_model.tenant_id,
)
return variable_pool
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
@ -945,6 +1306,13 @@ class WorkflowService:
if any(nt.is_trigger_node for nt in node_types):
raise ValueError("Start node and trigger nodes cannot coexist in the same workflow")
for node in node_configs:
node_data = node.get("data", {})
node_type = node_data.get("type")
if node_type == NodeType.HUMAN_INPUT:
self._validate_human_input_node_data(node_data)
def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(
@ -957,6 +1325,23 @@ class WorkflowService:
else:
raise ValueError(f"Invalid app mode: {app_model.mode}")
def _validate_human_input_node_data(self, node_data: dict) -> None:
"""
Validate HumanInput node data format.
Args:
node_data: The node data dictionary
Raises:
ValueError: If the node data format is invalid
"""
from core.workflow.nodes.human_input.entities import HumanInputNodeData
try:
HumanInputNodeData.model_validate(node_data)
except Exception as e:
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Workflow | None: