mirror of
https://github.com/langgenius/dify.git
synced 2026-03-14 03:18:36 +08:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@ -18,6 +18,8 @@ from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
@ -34,6 +36,8 @@ from core.workflow.graph_events import (
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
IterationFailedEvent,
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
@ -61,6 +65,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node(Generic[NodeDataT]):
|
||||
"""BaseNode serves as the foundational class for all node implementations.
|
||||
|
||||
Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output`
|
||||
attribute to track files generated by the LLM). However, these states are not persisted
|
||||
when the workflow is suspended or resumed. If a node needs its state to be preserved
|
||||
across workflow suspension and resumption, it should include the relevant state data
|
||||
in its output.
|
||||
"""
|
||||
|
||||
node_type: ClassVar[NodeType]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
@ -251,10 +264,33 @@ class Node(Generic[NodeDataT]):
|
||||
return self._node_execution_id
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
if self._node_execution_id:
|
||||
return self._node_execution_id
|
||||
|
||||
resumed_execution_id = self._restore_execution_id_from_runtime_state()
|
||||
if resumed_execution_id:
|
||||
self._node_execution_id = resumed_execution_id
|
||||
return self._node_execution_id
|
||||
|
||||
self._node_execution_id = str(uuid4())
|
||||
return self._node_execution_id
|
||||
|
||||
def _restore_execution_id_from_runtime_state(self) -> str | None:
|
||||
graph_execution = self.graph_runtime_state.graph_execution
|
||||
try:
|
||||
node_executions = graph_execution.node_executions
|
||||
except AttributeError:
|
||||
return None
|
||||
if not isinstance(node_executions, dict):
|
||||
return None
|
||||
node_execution = node_executions.get(self._node_id)
|
||||
if node_execution is None:
|
||||
return None
|
||||
execution_id = node_execution.execution_id
|
||||
if not execution_id:
|
||||
return None
|
||||
return str(execution_id)
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@ -620,6 +656,28 @@ class Node(Generic[NodeDataT]):
|
||||
metadata=event.metadata,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: HumanInputFormFilledEvent):
|
||||
return NodeRunHumanInputFormFilledEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: HumanInputFormTimeoutEvent):
|
||||
return NodeRunHumanInputFormTimeoutEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=event.node_title,
|
||||
expiration_time=event.expiration_time,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
return NodeRunLoopStartedEvent(
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from .human_input_node import HumanInputNode
|
||||
|
||||
__all__ = ["HumanInputNode"]
|
||||
"""
|
||||
Human Input node implementation.
|
||||
"""
|
||||
|
||||
@ -1,10 +1,350 @@
|
||||
from pydantic import Field
|
||||
"""
|
||||
Human Input node entities.
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Any, ClassVar, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit
|
||||
|
||||
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
|
||||
|
||||
|
||||
class _WebAppDeliveryConfig(BaseModel):
|
||||
"""Configuration for webapp delivery method."""
|
||||
|
||||
pass # Empty for webapp delivery
|
||||
|
||||
|
||||
class MemberRecipient(BaseModel):
|
||||
"""Member recipient for email delivery."""
|
||||
|
||||
type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER
|
||||
user_id: str
|
||||
|
||||
|
||||
class ExternalRecipient(BaseModel):
|
||||
"""External recipient for email delivery."""
|
||||
|
||||
type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL
|
||||
email: str
|
||||
|
||||
|
||||
EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")]
|
||||
|
||||
|
||||
class EmailRecipients(BaseModel):
|
||||
"""Email recipients configuration."""
|
||||
|
||||
# When true, recipients are the union of all workspace members and external items.
|
||||
# Member items are ignored because they are already covered by the workspace scope.
|
||||
# De-duplication is applied by email, with member recipients taking precedence.
|
||||
whole_workspace: bool = False
|
||||
items: list[EmailRecipient] = Field(default_factory=list)
|
||||
|
||||
|
||||
class EmailDeliveryConfig(BaseModel):
|
||||
"""Configuration for email delivery method."""
|
||||
|
||||
URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}"
|
||||
|
||||
recipients: EmailRecipients
|
||||
|
||||
# the subject of email
|
||||
subject: str
|
||||
|
||||
# Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which
|
||||
# represent the url to submit the form.
|
||||
#
|
||||
# It may also reference the output variable of the previous node with the syntax
|
||||
# `{{#<node_id>.<field_name>#}}`.
|
||||
body: str
|
||||
debug_mode: bool = False
|
||||
|
||||
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
|
||||
if not user_id:
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
|
||||
return self.model_copy(update={"recipients": debug_recipients})
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
|
||||
return self.model_copy(update={"recipients": debug_recipients})
|
||||
|
||||
@classmethod
|
||||
def replace_url_placeholder(cls, body: str, url: str | None) -> str:
|
||||
"""Replace the url placeholder with provided value."""
|
||||
return body.replace(cls.URL_PLACEHOLDER, url or "")
|
||||
|
||||
@classmethod
|
||||
def render_body_template(
|
||||
cls,
|
||||
*,
|
||||
body: str,
|
||||
url: str | None,
|
||||
variable_pool: VariablePool | None = None,
|
||||
) -> str:
|
||||
"""Render email body by replacing placeholders with runtime values."""
|
||||
templated_body = cls.replace_url_placeholder(body, url)
|
||||
if variable_pool is None:
|
||||
return templated_body
|
||||
return variable_pool.convert_template(templated_body).text
|
||||
|
||||
|
||||
class _DeliveryMethodBase(BaseModel):
|
||||
"""Base delivery method configuration."""
|
||||
|
||||
enabled: bool = True
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
||||
|
||||
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
|
||||
return ()
|
||||
|
||||
|
||||
class WebAppDeliveryMethod(_DeliveryMethodBase):
|
||||
"""Webapp delivery method configuration."""
|
||||
|
||||
type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP
|
||||
# The config field is not used currently.
|
||||
config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig)
|
||||
|
||||
|
||||
class EmailDeliveryMethod(_DeliveryMethodBase):
|
||||
"""Email delivery method configuration."""
|
||||
|
||||
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
|
||||
config: EmailDeliveryConfig
|
||||
|
||||
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
|
||||
variable_template_parser = VariableTemplateParser(template=self.config.body)
|
||||
selectors: list[Sequence[str]] = []
|
||||
for variable_selector in variable_template_parser.extract_variable_selectors():
|
||||
value_selector = list(variable_selector.value_selector)
|
||||
if len(value_selector) < SELECTORS_LENGTH:
|
||||
continue
|
||||
selectors.append(value_selector[:SELECTORS_LENGTH])
|
||||
return selectors
|
||||
|
||||
|
||||
DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")]
|
||||
|
||||
|
||||
def apply_debug_email_recipient(
|
||||
method: DeliveryChannelConfig,
|
||||
*,
|
||||
enabled: bool,
|
||||
user_id: str,
|
||||
) -> DeliveryChannelConfig:
|
||||
if not enabled:
|
||||
return method
|
||||
if not isinstance(method, EmailDeliveryMethod):
|
||||
return method
|
||||
if not method.config.debug_mode:
|
||||
return method
|
||||
debug_config = method.config.with_debug_recipient(user_id or "")
|
||||
return method.model_copy(update={"config": debug_config})
|
||||
|
||||
|
||||
class FormInputDefault(BaseModel):
|
||||
"""Default configuration for form inputs."""
|
||||
|
||||
# NOTE: Ideally, a discriminated union would be used to model
|
||||
# FormInputDefault. However, the UI requires preserving the previous
|
||||
# value when switching between `VARIABLE` and `CONSTANT` types. This
|
||||
# necessitates retaining all fields, making a discriminated union unsuitable.
|
||||
|
||||
type: PlaceholderType
|
||||
|
||||
# The selector of default variable, used when `type` is `VARIABLE`.
|
||||
selector: Sequence[str] = Field(default_factory=tuple) #
|
||||
|
||||
# The value of the default, used when `type` is `CONSTANT`.
|
||||
# TODO: How should we express JSON values?
|
||||
value: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_selector(self) -> Self:
|
||||
if self.type == PlaceholderType.CONSTANT:
|
||||
return self
|
||||
if len(self.selector) < SELECTORS_LENGTH:
|
||||
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
|
||||
return self
|
||||
|
||||
|
||||
class FormInput(BaseModel):
|
||||
"""Form input definition."""
|
||||
|
||||
type: FormInputType
|
||||
output_variable_name: str
|
||||
default: FormInputDefault | None = None
|
||||
|
||||
|
||||
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
class UserAction(BaseModel):
|
||||
"""User action configuration."""
|
||||
|
||||
# id is the identifier for this action.
|
||||
# It also serves as the identifiers of output handle.
|
||||
#
|
||||
# The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.)
|
||||
id: str = Field(max_length=20)
|
||||
title: str = Field(max_length=20)
|
||||
button_style: ButtonStyle = ButtonStyle.DEFAULT
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def _validate_id(cls, value: str) -> str:
|
||||
if not _IDENTIFIER_PATTERN.match(value):
|
||||
raise ValueError(
|
||||
f"'{value}' is not a valid identifier. It must start with a letter or underscore, "
|
||||
f"and contain only letters, numbers, or underscores."
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Configuration schema for the HumanInput node."""
|
||||
"""Human Input node data."""
|
||||
|
||||
required_variables: list[str] = Field(default_factory=list)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
|
||||
form_content: str = ""
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
user_actions: list[UserAction] = Field(default_factory=list)
|
||||
timeout: int = 36
|
||||
timeout_unit: TimeoutUnit = TimeoutUnit.HOUR
|
||||
|
||||
@field_validator("inputs")
|
||||
@classmethod
|
||||
def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]:
|
||||
seen_names: set[str] = set()
|
||||
for form_input in inputs:
|
||||
name = form_input.output_variable_name
|
||||
if name in seen_names:
|
||||
raise ValueError(f"duplicated output_variable_name '{name}' in inputs")
|
||||
seen_names.add(name)
|
||||
return inputs
|
||||
|
||||
@field_validator("user_actions")
|
||||
@classmethod
|
||||
def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]:
|
||||
seen_ids: set[str] = set()
|
||||
for action in user_actions:
|
||||
action_id = action.id
|
||||
if action_id in seen_ids:
|
||||
raise ValueError(f"duplicated user action id '{action_id}'")
|
||||
seen_ids.add(action_id)
|
||||
return user_actions
|
||||
|
||||
def is_webapp_enabled(self) -> bool:
|
||||
for dm in self.delivery_methods:
|
||||
if not dm.enabled:
|
||||
continue
|
||||
if dm.type == DeliveryMethodType.WEBAPP:
|
||||
return True
|
||||
return False
|
||||
|
||||
def expiration_time(self, start_time: datetime) -> datetime:
|
||||
if self.timeout_unit == TimeoutUnit.HOUR:
|
||||
return start_time + timedelta(hours=self.timeout)
|
||||
elif self.timeout_unit == TimeoutUnit.DAY:
|
||||
return start_time + timedelta(days=self.timeout)
|
||||
else:
|
||||
raise AssertionError("unknown timeout unit.")
|
||||
|
||||
def outputs_field_names(self) -> Sequence[str]:
|
||||
field_names = []
|
||||
for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content):
|
||||
field_names.append(match.group("field_name"))
|
||||
return field_names
|
||||
|
||||
def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]:
|
||||
variable_mappings: dict[str, Sequence[str]] = {}
|
||||
|
||||
def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None:
|
||||
for selector in selectors:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
continue
|
||||
qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#"
|
||||
variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH])
|
||||
|
||||
form_template_parser = VariableTemplateParser(template=self.form_content)
|
||||
_add_variable_selectors(
|
||||
[selector.value_selector for selector in form_template_parser.extract_variable_selectors()]
|
||||
)
|
||||
for delivery_method in self.delivery_methods:
|
||||
if not delivery_method.enabled:
|
||||
continue
|
||||
_add_variable_selectors(delivery_method.extract_variable_selectors())
|
||||
|
||||
for input in self.inputs:
|
||||
default_value = input.default
|
||||
if default_value is None:
|
||||
continue
|
||||
if default_value.type == PlaceholderType.CONSTANT:
|
||||
continue
|
||||
default_value_key = ".".join(default_value.selector)
|
||||
qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#"
|
||||
variable_mappings[qualified_variable_mapping_key] = default_value.selector
|
||||
|
||||
return variable_mappings
|
||||
|
||||
def find_action_text(self, action_id: str) -> str:
|
||||
"""
|
||||
Resolve action display text by id.
|
||||
"""
|
||||
for action in self.user_actions:
|
||||
if action.id == action_id:
|
||||
return action.title
|
||||
return action_id
|
||||
|
||||
|
||||
class FormDefinition(BaseModel):
|
||||
form_content: str
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
user_actions: list[UserAction] = Field(default_factory=list)
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
|
||||
# this is used to store the resolved default values
|
||||
default_values: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# node_title records the title of the HumanInput node.
|
||||
node_title: str | None = None
|
||||
|
||||
# display_in_ui controls whether the form should be displayed in UI surfaces.
|
||||
display_in_ui: bool | None = None
|
||||
|
||||
|
||||
class HumanInputSubmissionValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def validate_human_input_submission(
|
||||
*,
|
||||
inputs: Sequence[FormInput],
|
||||
user_actions: Sequence[UserAction],
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
) -> None:
|
||||
available_actions = {action.id for action in user_actions}
|
||||
if selected_action_id not in available_actions:
|
||||
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
|
||||
|
||||
provided_inputs = set(form_data.keys())
|
||||
missing_inputs = [
|
||||
form_input.output_variable_name
|
||||
for form_input in inputs
|
||||
if form_input.output_variable_name not in provided_inputs
|
||||
]
|
||||
|
||||
if missing_inputs:
|
||||
missing_list = ", ".join(missing_inputs)
|
||||
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")
|
||||
|
||||
72
api/core/workflow/nodes/human_input/enums.py
Normal file
72
api/core/workflow/nodes/human_input/enums.py
Normal file
@ -0,0 +1,72 @@
|
||||
import enum
|
||||
|
||||
|
||||
class HumanInputFormStatus(enum.StrEnum):
|
||||
"""Status of a human input form."""
|
||||
|
||||
# Awaiting submission from any recipient. Forms stay in this state until
|
||||
# submitted or a timeout rule applies.
|
||||
WAITING = enum.auto()
|
||||
# Global timeout reached. The workflow run is stopped and will not resume.
|
||||
# This is distinct from node-level timeout.
|
||||
EXPIRED = enum.auto()
|
||||
# Submitted by a recipient; form data is available and execution resumes
|
||||
# along the selected action edge.
|
||||
SUBMITTED = enum.auto()
|
||||
# Node-level timeout reached. The human input node should emit a timeout
|
||||
# event and the workflow should resume along the timeout edge.
|
||||
TIMEOUT = enum.auto()
|
||||
|
||||
|
||||
class HumanInputFormKind(enum.StrEnum):
|
||||
"""Kind of a human input form."""
|
||||
|
||||
RUNTIME = enum.auto() # Form created during workflow execution.
|
||||
DELIVERY_TEST = enum.auto() # Form created for delivery tests.
|
||||
|
||||
|
||||
class DeliveryMethodType(enum.StrEnum):
|
||||
"""Delivery method types for human input forms."""
|
||||
|
||||
# WEBAPP controls whether the form is delivered to the web app. It not only controls
|
||||
# the standalone web app, but also controls the installed apps in the console.
|
||||
WEBAPP = enum.auto()
|
||||
|
||||
EMAIL = enum.auto()
|
||||
|
||||
|
||||
class ButtonStyle(enum.StrEnum):
|
||||
"""Button styles for user actions."""
|
||||
|
||||
PRIMARY = enum.auto()
|
||||
DEFAULT = enum.auto()
|
||||
ACCENT = enum.auto()
|
||||
GHOST = enum.auto()
|
||||
|
||||
|
||||
class TimeoutUnit(enum.StrEnum):
|
||||
"""Timeout unit for form expiration."""
|
||||
|
||||
HOUR = enum.auto()
|
||||
DAY = enum.auto()
|
||||
|
||||
|
||||
class FormInputType(enum.StrEnum):
|
||||
"""Form input types."""
|
||||
|
||||
TEXT_INPUT = enum.auto()
|
||||
PARAGRAPH = enum.auto()
|
||||
|
||||
|
||||
class PlaceholderType(enum.StrEnum):
|
||||
"""Default value types for form inputs."""
|
||||
|
||||
VARIABLE = enum.auto()
|
||||
CONSTANT = enum.auto()
|
||||
|
||||
|
||||
class EmailRecipientType(enum.StrEnum):
|
||||
"""Email recipient types."""
|
||||
|
||||
MEMBER = enum.auto()
|
||||
EXTERNAL = enum.auto()
|
||||
@ -1,12 +1,42 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.node_events import (
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
NodeRunResult,
|
||||
PauseRequestedEvent,
|
||||
)
|
||||
from core.workflow.node_events.base import NodeEventBase
|
||||
from core.workflow.node_events.node import StreamCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
|
||||
from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
|
||||
_SELECTED_BRANCH_KEY = "selected_branch"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputNode(Node[HumanInputNodeData]):
|
||||
@ -17,7 +47,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
"edge_source_handle",
|
||||
"edgeSourceHandle",
|
||||
"source_handle",
|
||||
"selected_branch",
|
||||
_SELECTED_BRANCH_KEY,
|
||||
"selectedBranch",
|
||||
"branch",
|
||||
"branch_id",
|
||||
@ -25,43 +55,37 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
"handle",
|
||||
)
|
||||
|
||||
_node_data: HumanInputNodeData
|
||||
_form_repository: HumanInputFormRepository
|
||||
_OUTPUT_FIELD_ACTION_ID = "__action_id"
|
||||
_OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content"
|
||||
_TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if form_repository is None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
self._form_repository = form_repository
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self): # type: ignore[override]
|
||||
if self._is_completion_ready():
|
||||
branch_handle = self._resolve_branch_selection()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
edge_source_handle=branch_handle or "source",
|
||||
)
|
||||
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
# TODO(QuantumGhost): yield a real form id.
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
if not self.node_data.required_variables:
|
||||
return False
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for selector_str in self.node_data.required_variables:
|
||||
parts = selector_str.split(".")
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
segment = variable_pool.get(parts)
|
||||
if segment is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_branch_selection(self) -> str | None:
|
||||
"""Determine the branch handle selected by human input if available."""
|
||||
|
||||
@ -108,3 +132,224 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def _workflow_execution_id(self) -> str:
|
||||
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
assert workflow_exec_id is not None
|
||||
return workflow_exec_id
|
||||
|
||||
def _form_to_pause_event(self, form_entity: HumanInputFormEntity):
|
||||
required_event = self._human_input_required_event(form_entity)
|
||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
||||
return pause_requested_event
|
||||
|
||||
def resolve_default_values(self) -> Mapping[str, Any]:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
resolved_defaults: dict[str, Any] = {}
|
||||
for input in self._node_data.inputs:
|
||||
if (default_value := input.default) is None:
|
||||
continue
|
||||
if default_value.type == PlaceholderType.CONSTANT:
|
||||
continue
|
||||
resolved_value = variable_pool.get(default_value.selector)
|
||||
if resolved_value is None:
|
||||
# TODO: How should we handle this?
|
||||
continue
|
||||
resolved_defaults[input.output_variable_name] = (
|
||||
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value)
|
||||
)
|
||||
|
||||
return resolved_defaults
|
||||
|
||||
def _should_require_console_recipient(self) -> bool:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
return True
|
||||
if self.invoke_from == InvokeFrom.EXPLORE:
|
||||
return self._node_data.is_webapp_enabled()
|
||||
return False
|
||||
|
||||
def _display_in_ui(self) -> bool:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
return True
|
||||
return self._node_data.is_webapp_enabled()
|
||||
|
||||
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
|
||||
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
|
||||
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
|
||||
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
|
||||
return [
|
||||
apply_debug_email_recipient(
|
||||
method,
|
||||
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
|
||||
user_id=self.user_id or "",
|
||||
)
|
||||
for method in enabled_methods
|
||||
]
|
||||
|
||||
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
|
||||
node_data = self._node_data
|
||||
resolved_default_values = self.resolve_default_values()
|
||||
display_in_ui = self._display_in_ui()
|
||||
form_token = form_entity.web_app_token
|
||||
if display_in_ui and form_token is None:
|
||||
raise AssertionError("Form token should be available for UI execution.")
|
||||
return HumanInputRequired(
|
||||
form_id=form_entity.id,
|
||||
form_content=form_entity.rendered_content,
|
||||
inputs=node_data.inputs,
|
||||
actions=node_data.user_actions,
|
||||
display_in_ui=display_in_ui,
|
||||
node_id=self.id,
|
||||
node_title=node_data.title,
|
||||
form_token=form_token,
|
||||
resolved_default_values=resolved_default_values,
|
||||
)
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Execute the human input node.
|
||||
|
||||
This method will:
|
||||
1. Generate a unique form ID
|
||||
2. Create form content with variable substitution
|
||||
3. Create form in database
|
||||
4. Send form via configured delivery methods
|
||||
5. Suspend workflow execution
|
||||
6. Wait for form submission to resume
|
||||
"""
|
||||
repo = self._form_repository
|
||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
||||
if form is None:
|
||||
display_in_ui = self._display_in_ui()
|
||||
params = FormCreateParams(
|
||||
app_id=self.app_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self.render_form_content_before_submission(),
|
||||
delivery_methods=self._effective_delivery_methods(),
|
||||
display_in_ui=display_in_ui,
|
||||
resolved_default_values=self.resolve_default_values(),
|
||||
console_recipient_required=self._should_require_console_recipient(),
|
||||
console_creator_account_id=(
|
||||
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
|
||||
),
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
form_entity = self._form_repository.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
logger.info(
|
||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
self.id,
|
||||
form_entity.id,
|
||||
)
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
return
|
||||
|
||||
if (
|
||||
form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}
|
||||
or form.expiration_time <= naive_utc_now()
|
||||
):
|
||||
yield HumanInputFormTimeoutEvent(
|
||||
node_title=self._node_data.title,
|
||||
expiration_time=form.expiration_time,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={self._OUTPUT_FIELD_ACTION_ID: ""},
|
||||
edge_source_handle=self._TIMEOUT_HANDLE,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not form.submitted:
|
||||
yield self._form_to_pause_event(form)
|
||||
return
|
||||
|
||||
selected_action_id = form.selected_action_id
|
||||
if selected_action_id is None:
|
||||
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
|
||||
submitted_data = form.submitted_data or {}
|
||||
outputs: dict[str, Any] = dict(submitted_data)
|
||||
outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id
|
||||
rendered_content = self.render_form_content_with_outputs(
|
||||
form.rendered_content,
|
||||
outputs,
|
||||
self._node_data.outputs_field_names(),
|
||||
)
|
||||
outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content
|
||||
|
||||
action_text = self._node_data.find_action_text(selected_action_id)
|
||||
|
||||
yield HumanInputFormFilledEvent(
|
||||
node_title=self._node_data.title,
|
||||
rendered_content=rendered_content,
|
||||
action_id=selected_action_id,
|
||||
action_text=action_text,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
edge_source_handle=selected_action_id,
|
||||
)
|
||||
)
|
||||
|
||||
def render_form_content_before_submission(self) -> str:
|
||||
"""
|
||||
Process form content by substituting variables.
|
||||
|
||||
This method should:
|
||||
1. Parse the form_content markdown
|
||||
2. Substitute {{#node_name.var_name#}} with actual values
|
||||
3. Keep {{#$output.field_name#}} placeholders for form inputs
|
||||
"""
|
||||
rendered_form_content = self.graph_runtime_state.variable_pool.convert_template(
|
||||
self._node_data.form_content,
|
||||
)
|
||||
return rendered_form_content.markdown
|
||||
|
||||
@staticmethod
|
||||
def render_form_content_with_outputs(
|
||||
form_content: str,
|
||||
outputs: Mapping[str, Any],
|
||||
field_names: Sequence[str],
|
||||
) -> str:
|
||||
"""
|
||||
Replace {{#$output.xxx#}} placeholders with submitted values.
|
||||
"""
|
||||
rendered_content = form_content
|
||||
for field_name in field_names:
|
||||
placeholder = "{{#$output." + field_name + "#}}"
|
||||
value = outputs.get(field_name)
|
||||
if value is None:
|
||||
replacement = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
replacement = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
replacement = str(value)
|
||||
rendered_content = rendered_content.replace(placeholder, replacement)
|
||||
return rendered_content
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selectors referenced in form content and input default values.
|
||||
|
||||
This method should parse:
|
||||
1. Variables referenced in form_content ({{#node_name.var_name#}})
|
||||
2. Variables referenced in input default values
|
||||
"""
|
||||
validated_node_data = HumanInputNodeData.model_validate(node_data)
|
||||
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
|
||||
Reference in New Issue
Block a user