Files
dify/api/dify_graph/nodes/human_input/entities.py

351 lines
12 KiB
Python

"""
Human Input node entities.
"""
import re
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from typing import Annotated, Any, ClassVar, Literal, Self
from pydantic import BaseModel, Field, field_validator, model_validator
from dify_graph.nodes.base import BaseNodeData
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.runtime import VariablePool
from dify_graph.variables.consts import SELECTORS_LENGTH
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
class _WebAppDeliveryConfig(BaseModel):
"""Configuration for webapp delivery method."""
pass # Empty for webapp delivery
class MemberRecipient(BaseModel):
"""Member recipient for email delivery."""
type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER
user_id: str
class ExternalRecipient(BaseModel):
"""External recipient for email delivery."""
type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL
email: str
EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")]
class EmailRecipients(BaseModel):
"""Email recipients configuration."""
# When true, recipients are the union of all workspace members and external items.
# Member items are ignored because they are already covered by the workspace scope.
# De-duplication is applied by email, with member recipients taking precedence.
whole_workspace: bool = False
items: list[EmailRecipient] = Field(default_factory=list)
class EmailDeliveryConfig(BaseModel):
"""Configuration for email delivery method."""
URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}"
recipients: EmailRecipients
# the subject of email
subject: str
# Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which
# represent the url to submit the form.
#
# It may also reference the output variable of the previous node with the syntax
# `{{#<node_id>.<field_name>#}}`.
body: str
debug_mode: bool = False
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
if not user_id:
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
return self.model_copy(update={"recipients": debug_recipients})
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
return self.model_copy(update={"recipients": debug_recipients})
@classmethod
def replace_url_placeholder(cls, body: str, url: str | None) -> str:
"""Replace the url placeholder with provided value."""
return body.replace(cls.URL_PLACEHOLDER, url or "")
@classmethod
def render_body_template(
cls,
*,
body: str,
url: str | None,
variable_pool: VariablePool | None = None,
) -> str:
"""Render email body by replacing placeholders with runtime values."""
templated_body = cls.replace_url_placeholder(body, url)
if variable_pool is None:
return templated_body
return variable_pool.convert_template(templated_body).text
class _DeliveryMethodBase(BaseModel):
"""Base delivery method configuration."""
enabled: bool = True
id: uuid.UUID = Field(default_factory=uuid.uuid4)
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
return ()
class WebAppDeliveryMethod(_DeliveryMethodBase):
"""Webapp delivery method configuration."""
type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP
# The config field is not used currently.
config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig)
class EmailDeliveryMethod(_DeliveryMethodBase):
"""Email delivery method configuration."""
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
config: EmailDeliveryConfig
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
variable_template_parser = VariableTemplateParser(template=self.config.body)
selectors: list[Sequence[str]] = []
for variable_selector in variable_template_parser.extract_variable_selectors():
value_selector = list(variable_selector.value_selector)
if len(value_selector) < SELECTORS_LENGTH:
continue
selectors.append(value_selector[:SELECTORS_LENGTH])
return selectors
DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")]
def apply_debug_email_recipient(
method: DeliveryChannelConfig,
*,
enabled: bool,
user_id: str,
) -> DeliveryChannelConfig:
if not enabled:
return method
if not isinstance(method, EmailDeliveryMethod):
return method
if not method.config.debug_mode:
return method
debug_config = method.config.with_debug_recipient(user_id or "")
return method.model_copy(update={"config": debug_config})
class FormInputDefault(BaseModel):
"""Default configuration for form inputs."""
# NOTE: Ideally, a discriminated union would be used to model
# FormInputDefault. However, the UI requires preserving the previous
# value when switching between `VARIABLE` and `CONSTANT` types. This
# necessitates retaining all fields, making a discriminated union unsuitable.
type: PlaceholderType
# The selector of default variable, used when `type` is `VARIABLE`.
selector: Sequence[str] = Field(default_factory=tuple) #
# The value of the default, used when `type` is `CONSTANT`.
# TODO: How should we express JSON values?
value: str = ""
@model_validator(mode="after")
def _validate_selector(self) -> Self:
if self.type == PlaceholderType.CONSTANT:
return self
if len(self.selector) < SELECTORS_LENGTH:
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
return self
class FormInput(BaseModel):
"""Form input definition."""
type: FormInputType
output_variable_name: str
default: FormInputDefault | None = None
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
class UserAction(BaseModel):
"""User action configuration."""
# id is the identifier for this action.
# It also serves as the identifiers of output handle.
#
# The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.)
id: str = Field(max_length=20)
title: str = Field(max_length=20)
button_style: ButtonStyle = ButtonStyle.DEFAULT
@field_validator("id")
@classmethod
def _validate_id(cls, value: str) -> str:
if not _IDENTIFIER_PATTERN.match(value):
raise ValueError(
f"'{value}' is not a valid identifier. It must start with a letter or underscore, "
f"and contain only letters, numbers, or underscores."
)
return value
class HumanInputNodeData(BaseNodeData):
"""Human Input node data."""
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
form_content: str = ""
inputs: list[FormInput] = Field(default_factory=list)
user_actions: list[UserAction] = Field(default_factory=list)
timeout: int = 36
timeout_unit: TimeoutUnit = TimeoutUnit.HOUR
@field_validator("inputs")
@classmethod
def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]:
seen_names: set[str] = set()
for form_input in inputs:
name = form_input.output_variable_name
if name in seen_names:
raise ValueError(f"duplicated output_variable_name '{name}' in inputs")
seen_names.add(name)
return inputs
@field_validator("user_actions")
@classmethod
def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]:
seen_ids: set[str] = set()
for action in user_actions:
action_id = action.id
if action_id in seen_ids:
raise ValueError(f"duplicated user action id '{action_id}'")
seen_ids.add(action_id)
return user_actions
def is_webapp_enabled(self) -> bool:
for dm in self.delivery_methods:
if not dm.enabled:
continue
if dm.type == DeliveryMethodType.WEBAPP:
return True
return False
def expiration_time(self, start_time: datetime) -> datetime:
if self.timeout_unit == TimeoutUnit.HOUR:
return start_time + timedelta(hours=self.timeout)
elif self.timeout_unit == TimeoutUnit.DAY:
return start_time + timedelta(days=self.timeout)
else:
raise AssertionError("unknown timeout unit.")
def outputs_field_names(self) -> Sequence[str]:
field_names = []
for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content):
field_names.append(match.group("field_name"))
return field_names
def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]:
variable_mappings: dict[str, Sequence[str]] = {}
def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None:
for selector in selectors:
if len(selector) < SELECTORS_LENGTH:
continue
qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#"
variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH])
form_template_parser = VariableTemplateParser(template=self.form_content)
_add_variable_selectors(
[selector.value_selector for selector in form_template_parser.extract_variable_selectors()]
)
for delivery_method in self.delivery_methods:
if not delivery_method.enabled:
continue
_add_variable_selectors(delivery_method.extract_variable_selectors())
for input in self.inputs:
default_value = input.default
if default_value is None:
continue
if default_value.type == PlaceholderType.CONSTANT:
continue
default_value_key = ".".join(default_value.selector)
qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#"
variable_mappings[qualified_variable_mapping_key] = default_value.selector
return variable_mappings
def find_action_text(self, action_id: str) -> str:
"""
Resolve action display text by id.
"""
for action in self.user_actions:
if action.id == action_id:
return action.title
return action_id
class FormDefinition(BaseModel):
form_content: str
inputs: list[FormInput] = Field(default_factory=list)
user_actions: list[UserAction] = Field(default_factory=list)
rendered_content: str
expiration_time: datetime
# this is used to store the resolved default values
default_values: dict[str, Any] = Field(default_factory=dict)
# node_title records the title of the HumanInput node.
node_title: str | None = None
# display_in_ui controls whether the form should be displayed in UI surfaces.
display_in_ui: bool | None = None
class HumanInputSubmissionValidationError(ValueError):
pass
def validate_human_input_submission(
*,
inputs: Sequence[FormInput],
user_actions: Sequence[UserAction],
selected_action_id: str,
form_data: Mapping[str, Any],
) -> None:
available_actions = {action.id for action in user_actions}
if selected_action_id not in available_actions:
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
provided_inputs = set(form_data.keys())
missing_inputs = [
form_input.output_variable_name
for form_input in inputs
if form_input.output_variable_name not in provided_inputs
]
if missing_inputs:
missing_list = ", ".join(missing_inputs)
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")