mirror of
https://github.com/langgenius/dify.git
synced 2026-03-27 17:19:55 +08:00
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
209 lines
7.5 KiB
Python
209 lines
7.5 KiB
Python
"""Human Input node entities.
|
|
|
|
The graph package owns the workflow-facing form schema and keeps it transportable
|
|
across runtimes. Dify-specific delivery surface and recipient translation stay
|
|
outside `dify_graph`.
|
|
"""
|
|
|
|
import re
|
|
from collections.abc import Mapping, Sequence
|
|
from datetime import datetime, timedelta
|
|
from typing import Any, Self
|
|
|
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
|
|
from dify_graph.entities.base_node_data import BaseNodeData
|
|
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
|
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
|
from dify_graph.variables.consts import SELECTORS_LENGTH
|
|
|
|
from .enums import ButtonStyle, FormInputType, PlaceholderType, TimeoutUnit
|
|
|
|
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
|
|
|
|
|
|
class FormInputDefault(BaseModel):
|
|
"""Default configuration for form inputs."""
|
|
|
|
# NOTE: Ideally, a discriminated union would be used to model
|
|
# FormInputDefault. However, the UI requires preserving the previous
|
|
# value when switching between `VARIABLE` and `CONSTANT` types. This
|
|
# necessitates retaining all fields, making a discriminated union unsuitable.
|
|
|
|
type: PlaceholderType
|
|
|
|
# The selector of default variable, used when `type` is `VARIABLE`.
|
|
selector: Sequence[str] = Field(default_factory=tuple) #
|
|
|
|
# The value of the default, used when `type` is `CONSTANT`.
|
|
# TODO: How should we express JSON values?
|
|
value: str = ""
|
|
|
|
@model_validator(mode="after")
|
|
def _validate_selector(self) -> Self:
|
|
if self.type == PlaceholderType.CONSTANT:
|
|
return self
|
|
if len(self.selector) < SELECTORS_LENGTH:
|
|
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
|
|
return self
|
|
|
|
|
|
class FormInput(BaseModel):
|
|
"""Form input definition."""
|
|
|
|
type: FormInputType
|
|
output_variable_name: str
|
|
default: FormInputDefault | None = None
|
|
|
|
|
|
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
|
|
|
|
|
class UserAction(BaseModel):
|
|
"""User action configuration."""
|
|
|
|
# id is the identifier for this action.
|
|
# It also serves as the identifiers of output handle.
|
|
#
|
|
# The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.)
|
|
id: str = Field(max_length=20)
|
|
title: str = Field(max_length=20)
|
|
button_style: ButtonStyle = ButtonStyle.DEFAULT
|
|
|
|
@field_validator("id")
|
|
@classmethod
|
|
def _validate_id(cls, value: str) -> str:
|
|
if not _IDENTIFIER_PATTERN.match(value):
|
|
raise ValueError(
|
|
f"'{value}' is not a valid identifier. It must start with a letter or underscore, "
|
|
f"and contain only letters, numbers, or underscores."
|
|
)
|
|
return value
|
|
|
|
|
|
class HumanInputNodeData(BaseNodeData):
|
|
"""Human Input node data."""
|
|
|
|
type: NodeType = BuiltinNodeTypes.HUMAN_INPUT
|
|
form_content: str = ""
|
|
inputs: list[FormInput] = Field(default_factory=list)
|
|
user_actions: list[UserAction] = Field(default_factory=list)
|
|
timeout: int = 36
|
|
timeout_unit: TimeoutUnit = TimeoutUnit.HOUR
|
|
|
|
@field_validator("inputs")
|
|
@classmethod
|
|
def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]:
|
|
seen_names: set[str] = set()
|
|
for form_input in inputs:
|
|
name = form_input.output_variable_name
|
|
if name in seen_names:
|
|
raise ValueError(f"duplicated output_variable_name '{name}' in inputs")
|
|
seen_names.add(name)
|
|
return inputs
|
|
|
|
@field_validator("user_actions")
|
|
@classmethod
|
|
def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]:
|
|
seen_ids: set[str] = set()
|
|
for action in user_actions:
|
|
action_id = action.id
|
|
if action_id in seen_ids:
|
|
raise ValueError(f"duplicated user action id '{action_id}'")
|
|
seen_ids.add(action_id)
|
|
return user_actions
|
|
|
|
def expiration_time(self, start_time: datetime) -> datetime:
|
|
if self.timeout_unit == TimeoutUnit.HOUR:
|
|
return start_time + timedelta(hours=self.timeout)
|
|
elif self.timeout_unit == TimeoutUnit.DAY:
|
|
return start_time + timedelta(days=self.timeout)
|
|
else:
|
|
raise AssertionError("unknown timeout unit.")
|
|
|
|
def outputs_field_names(self) -> Sequence[str]:
|
|
field_names = []
|
|
for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content):
|
|
field_names.append(match.group("field_name"))
|
|
return field_names
|
|
|
|
def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]:
|
|
variable_mappings: dict[str, Sequence[str]] = {}
|
|
|
|
def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None:
|
|
for selector in selectors:
|
|
if len(selector) < SELECTORS_LENGTH:
|
|
continue
|
|
qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#"
|
|
variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH])
|
|
|
|
form_template_parser = VariableTemplateParser(template=self.form_content)
|
|
_add_variable_selectors(
|
|
[selector.value_selector for selector in form_template_parser.extract_variable_selectors()]
|
|
)
|
|
|
|
for input in self.inputs:
|
|
default_value = input.default
|
|
if default_value is None:
|
|
continue
|
|
if default_value.type == PlaceholderType.CONSTANT:
|
|
continue
|
|
default_value_key = ".".join(default_value.selector)
|
|
qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#"
|
|
variable_mappings[qualified_variable_mapping_key] = default_value.selector
|
|
|
|
return variable_mappings
|
|
|
|
def find_action_text(self, action_id: str) -> str:
|
|
"""
|
|
Resolve action display text by id.
|
|
"""
|
|
for action in self.user_actions:
|
|
if action.id == action_id:
|
|
return action.title
|
|
return action_id
|
|
|
|
|
|
class FormDefinition(BaseModel):
|
|
form_content: str
|
|
inputs: list[FormInput] = Field(default_factory=list)
|
|
user_actions: list[UserAction] = Field(default_factory=list)
|
|
rendered_content: str
|
|
expiration_time: datetime
|
|
|
|
# this is used to store the resolved default values
|
|
default_values: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
# node_title records the title of the HumanInput node.
|
|
node_title: str | None = None
|
|
|
|
# display_in_ui controls whether the form should be displayed in UI surfaces.
|
|
display_in_ui: bool | None = None
|
|
|
|
|
|
class HumanInputSubmissionValidationError(ValueError):
|
|
pass
|
|
|
|
|
|
def validate_human_input_submission(
|
|
*,
|
|
inputs: Sequence[FormInput],
|
|
user_actions: Sequence[UserAction],
|
|
selected_action_id: str,
|
|
form_data: Mapping[str, Any],
|
|
) -> None:
|
|
available_actions = {action.id for action in user_actions}
|
|
if selected_action_id not in available_actions:
|
|
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
|
|
|
|
provided_inputs = set(form_data.keys())
|
|
missing_inputs = [
|
|
form_input.output_variable_name
|
|
for form_input in inputs
|
|
if form_input.output_variable_name not in provided_inputs
|
|
]
|
|
|
|
if missing_inputs:
|
|
missing_list = ", ".join(missing_inputs)
|
|
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")
|