Files
dify/api/dify_graph/nodes/human_input/human_input_node.py

356 lines
13 KiB
Python

import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
NodeRunResult,
PauseRequestedEvent,
)
from dify_graph.node_events.base import NodeEventBase
from dify_graph.node_events.node import StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType
if TYPE_CHECKING:
from dify_graph.entities.graph_init_params import GraphInitParams
from dify_graph.runtime.graph_runtime_state import GraphRuntimeState
_SELECTED_BRANCH_KEY = "selected_branch"
logger = logging.getLogger(__name__)
class HumanInputNode(Node[HumanInputNodeData]):
node_type = NodeType.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
"edge_source_handle",
"edgeSourceHandle",
"source_handle",
_SELECTED_BRANCH_KEY,
"selectedBranch",
"branch",
"branch_id",
"branchId",
"handle",
)
_node_data: HumanInputNodeData
_form_repository: HumanInputFormRepository
_OUTPUT_FIELD_ACTION_ID = "__action_id"
_OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content"
_TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout"
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if form_repository is None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self.tenant_id,
)
self._form_repository = form_repository
@classmethod
def version(cls) -> str:
return "1"
def _resolve_branch_selection(self) -> str | None:
"""Determine the branch handle selected by human input if available."""
variable_pool = self.graph_runtime_state.variable_pool
for key in self._BRANCH_SELECTION_KEYS:
handle = self._extract_branch_handle(variable_pool.get((self.id, key)))
if handle:
return handle
default_values = self.node_data.default_value_dict
for key in self._BRANCH_SELECTION_KEYS:
handle = self._normalize_branch_value(default_values.get(key))
if handle:
return handle
return None
@staticmethod
def _extract_branch_handle(segment: Any) -> str | None:
if segment is None:
return None
candidate = getattr(segment, "to_object", None)
raw_value = candidate() if callable(candidate) else getattr(segment, "value", None)
if raw_value is None:
return None
return HumanInputNode._normalize_branch_value(raw_value)
@staticmethod
def _normalize_branch_value(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
if isinstance(value, Mapping):
for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
return None
@property
def _workflow_execution_id(self) -> str:
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
assert workflow_exec_id is not None
return workflow_exec_id
def _form_to_pause_event(self, form_entity: HumanInputFormEntity):
required_event = self._human_input_required_event(form_entity)
pause_requested_event = PauseRequestedEvent(reason=required_event)
return pause_requested_event
def resolve_default_values(self) -> Mapping[str, Any]:
variable_pool = self.graph_runtime_state.variable_pool
resolved_defaults: dict[str, Any] = {}
for input in self._node_data.inputs:
if (default_value := input.default) is None:
continue
if default_value.type == PlaceholderType.CONSTANT:
continue
resolved_value = variable_pool.get(default_value.selector)
if resolved_value is None:
# TODO: How should we handle this?
continue
resolved_defaults[input.output_variable_name] = (
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value)
)
return resolved_defaults
def _should_require_console_recipient(self) -> bool:
if self.invoke_from == InvokeFrom.DEBUGGER:
return True
if self.invoke_from == InvokeFrom.EXPLORE:
return self._node_data.is_webapp_enabled()
return False
def _display_in_ui(self) -> bool:
if self.invoke_from == InvokeFrom.DEBUGGER:
return True
return self._node_data.is_webapp_enabled()
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
return [
apply_debug_email_recipient(
method,
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
user_id=self.user_id or "",
)
for method in enabled_methods
]
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
node_data = self._node_data
resolved_default_values = self.resolve_default_values()
display_in_ui = self._display_in_ui()
form_token = form_entity.web_app_token
if display_in_ui and form_token is None:
raise AssertionError("Form token should be available for UI execution.")
return HumanInputRequired(
form_id=form_entity.id,
form_content=form_entity.rendered_content,
inputs=node_data.inputs,
actions=node_data.user_actions,
display_in_ui=display_in_ui,
node_id=self.id,
node_title=node_data.title,
form_token=form_token,
resolved_default_values=resolved_default_values,
)
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Execute the human input node.
This method will:
1. Generate a unique form ID
2. Create form content with variable substitution
3. Create form in database
4. Send form via configured delivery methods
5. Suspend workflow execution
6. Wait for form submission to resume
"""
repo = self._form_repository
form = repo.get_form(self._workflow_execution_id, self.id)
if form is None:
display_in_ui = self._display_in_ui()
params = FormCreateParams(
app_id=self.app_id,
workflow_execution_id=self._workflow_execution_id,
node_id=self.id,
form_config=self._node_data,
rendered_content=self.render_form_content_before_submission(),
delivery_methods=self._effective_delivery_methods(),
display_in_ui=display_in_ui,
resolved_default_values=self.resolve_default_values(),
console_recipient_required=self._should_require_console_recipient(),
console_creator_account_id=(
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
),
backstage_recipient_required=True,
)
form_entity = self._form_repository.create_form(params)
# Create human input required event
logger.info(
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
self.id,
form_entity.id,
)
yield self._form_to_pause_event(form_entity)
return
if (
form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}
or form.expiration_time <= naive_utc_now()
):
yield HumanInputFormTimeoutEvent(
node_title=self._node_data.title,
expiration_time=form.expiration_time,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={self._OUTPUT_FIELD_ACTION_ID: ""},
edge_source_handle=self._TIMEOUT_HANDLE,
)
)
return
if not form.submitted:
yield self._form_to_pause_event(form)
return
selected_action_id = form.selected_action_id
if selected_action_id is None:
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
submitted_data = form.submitted_data or {}
outputs: dict[str, Any] = dict(submitted_data)
outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id
rendered_content = self.render_form_content_with_outputs(
form.rendered_content,
outputs,
self._node_data.outputs_field_names(),
)
outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content
action_text = self._node_data.find_action_text(selected_action_id)
yield HumanInputFormFilledEvent(
node_title=self._node_data.title,
rendered_content=rendered_content,
action_id=selected_action_id,
action_text=action_text,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
edge_source_handle=selected_action_id,
)
)
def render_form_content_before_submission(self) -> str:
"""
Process form content by substituting variables.
This method should:
1. Parse the form_content markdown
2. Substitute {{#node_name.var_name#}} with actual values
3. Keep {{#$output.field_name#}} placeholders for form inputs
"""
rendered_form_content = self.graph_runtime_state.variable_pool.convert_template(
self._node_data.form_content,
)
return rendered_form_content.markdown
@staticmethod
def render_form_content_with_outputs(
form_content: str,
outputs: Mapping[str, Any],
field_names: Sequence[str],
) -> str:
"""
Replace {{#$output.xxx#}} placeholders with submitted values.
"""
rendered_content = form_content
for field_name in field_names:
placeholder = "{{#$output." + field_name + "#}}"
value = outputs.get(field_name)
if value is None:
replacement = ""
elif isinstance(value, (dict, list)):
replacement = json.dumps(value, ensure_ascii=False)
else:
replacement = str(value)
rendered_content = rendered_content.replace(placeholder, replacement)
return rendered_content
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selectors referenced in form content and input default values.
This method should parse:
1. Variables referenced in form_content ({{#node_name.var_name#}})
2. Variables referenced in input default values
"""
validated_node_data = HumanInputNodeData.model_validate(node_data)
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)