Files
dify/api/graphon/nodes/human_input/human_input_node.py

300 lines
11 KiB
Python

import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast
from graphon.entities.graph_config import NodeConfigDict
from graphon.entities.pause_reason import HumanInputRequired
from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
from graphon.node_events import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
NodeRunResult,
PauseRequestedEvent,
)
from graphon.node_events.base import NodeEventBase
from graphon.node_events.node import StreamCompletedEvent
from graphon.nodes.base.node import Node
from graphon.nodes.runtime import HumanInputFormStateProtocol, HumanInputNodeRuntimeProtocol
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
from .entities import HumanInputNodeData
from .enums import HumanInputFormStatus, PlaceholderType
if TYPE_CHECKING:
from graphon.entities.graph_init_params import GraphInitParams
from graphon.runtime.graph_runtime_state import GraphRuntimeState
_SELECTED_BRANCH_KEY = "selected_branch"
logger = logging.getLogger(__name__)
class HumanInputNode(Node[HumanInputNodeData]):
node_type = BuiltinNodeTypes.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
"edge_source_handle",
"edgeSourceHandle",
"source_handle",
_SELECTED_BRANCH_KEY,
"selectedBranch",
"branch",
"branch_id",
"branchId",
"handle",
)
_node_data: HumanInputNodeData
_OUTPUT_FIELD_ACTION_ID = "__action_id"
_OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content"
_TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout"
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
runtime: HumanInputNodeRuntimeProtocol | None = None,
form_repository: object | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
resolved_runtime = runtime
if resolved_runtime is None:
raise ValueError("runtime is required")
if form_repository is not None:
with_form_repository = getattr(resolved_runtime, "with_form_repository", None)
if callable(with_form_repository):
resolved_runtime = cast(HumanInputNodeRuntimeProtocol, with_form_repository(form_repository))
self._runtime: HumanInputNodeRuntimeProtocol = resolved_runtime
@classmethod
def version(cls) -> str:
return "1"
def _resolve_branch_selection(self) -> str | None:
"""Determine the branch handle selected by human input if available."""
variable_pool = self.graph_runtime_state.variable_pool
for key in self._BRANCH_SELECTION_KEYS:
handle = self._extract_branch_handle(variable_pool.get((self.id, key)))
if handle:
return handle
default_values = self.node_data.default_value_dict
for key in self._BRANCH_SELECTION_KEYS:
handle = self._normalize_branch_value(default_values.get(key))
if handle:
return handle
return None
@staticmethod
def _extract_branch_handle(segment: Any) -> str | None:
if segment is None:
return None
candidate = getattr(segment, "to_object", None)
raw_value = candidate() if callable(candidate) else getattr(segment, "value", None)
if raw_value is None:
return None
return HumanInputNode._normalize_branch_value(raw_value)
@staticmethod
def _normalize_branch_value(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
if isinstance(value, Mapping):
for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
return None
def _form_to_pause_event(self, form_entity: HumanInputFormStateProtocol):
required_event = self._human_input_required_event(form_entity)
pause_requested_event = PauseRequestedEvent(reason=required_event)
return pause_requested_event
def resolve_default_values(self) -> Mapping[str, Any]:
variable_pool = self.graph_runtime_state.variable_pool
resolved_defaults: dict[str, Any] = {}
for input in self._node_data.inputs:
if (default_value := input.default) is None:
continue
if default_value.type == PlaceholderType.CONSTANT:
continue
resolved_value = variable_pool.get(default_value.selector)
if resolved_value is None:
# TODO: How should we handle this?
continue
resolved_defaults[input.output_variable_name] = (
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value)
)
return resolved_defaults
def _human_input_required_event(self, form_entity: HumanInputFormStateProtocol) -> HumanInputRequired:
node_data = self._node_data
resolved_default_values = self.resolve_default_values()
return HumanInputRequired(
form_id=form_entity.id,
form_content=form_entity.rendered_content,
inputs=node_data.inputs,
actions=node_data.user_actions,
node_id=self.id,
node_title=node_data.title,
resolved_default_values=resolved_default_values,
)
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Execute the human input node.
This method will:
1. Generate a unique form ID
2. Create form content with variable substitution
3. Persist the form through the configured repository
4. Send form via configured delivery methods
5. Suspend workflow execution
6. Wait for form submission to resume
"""
form = self._runtime.get_form(node_id=self.id)
if form is None:
form_entity = self._runtime.create_form(
node_id=self.id,
node_data=self._node_data,
rendered_content=self.render_form_content_before_submission(),
resolved_default_values=self.resolve_default_values(),
)
logger.info(
"Human Input node suspended workflow for form. node_id=%s, form_id=%s",
self.id,
form_entity.id,
)
yield self._form_to_pause_event(form_entity)
return
if form.status in {
HumanInputFormStatus.TIMEOUT,
HumanInputFormStatus.EXPIRED,
} or form.expiration_time <= datetime.now(UTC).replace(tzinfo=None):
yield HumanInputFormTimeoutEvent(
node_title=self._node_data.title,
expiration_time=form.expiration_time,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={self._OUTPUT_FIELD_ACTION_ID: ""},
edge_source_handle=self._TIMEOUT_HANDLE,
)
)
return
if not form.submitted:
yield self._form_to_pause_event(form)
return
selected_action_id = form.selected_action_id
if selected_action_id is None:
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
submitted_data = form.submitted_data or {}
outputs: dict[str, Any] = dict(submitted_data)
outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id
rendered_content = self.render_form_content_with_outputs(
form.rendered_content,
outputs,
self._node_data.outputs_field_names(),
)
outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content
action_text = self._node_data.find_action_text(selected_action_id)
yield HumanInputFormFilledEvent(
node_title=self._node_data.title,
rendered_content=rendered_content,
action_id=selected_action_id,
action_text=action_text,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
edge_source_handle=selected_action_id,
)
)
def render_form_content_before_submission(self) -> str:
"""
Process form content by substituting variables.
This method should:
1. Parse the form_content markdown
2. Substitute {{#node_name.var_name#}} with actual values
3. Keep {{#$output.field_name#}} placeholders for form inputs
"""
rendered_form_content = self.graph_runtime_state.variable_pool.convert_template(
self._node_data.form_content,
)
return rendered_form_content.markdown
@staticmethod
def render_form_content_with_outputs(
form_content: str,
outputs: Mapping[str, Any],
field_names: Sequence[str],
) -> str:
"""
Replace {{#$output.xxx#}} placeholders with submitted values.
"""
rendered_content = form_content
for field_name in field_names:
placeholder = "{{#$output." + field_name + "#}}"
value = outputs.get(field_name)
if value is None:
replacement = ""
elif isinstance(value, (dict, list)):
replacement = json.dumps(value, ensure_ascii=False)
else:
replacement = str(value)
rendered_content = rendered_content.replace(placeholder, replacement)
return rendered_content
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: HumanInputNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selectors referenced in form content and input default values.
This method should parse:
1. Variables referenced in form_content ({{#node_name.var_name#}})
2. Variables referenced in input default values
"""
return node_data.extract_variable_selector_to_variable_mapping(node_id)