diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index 8d66984447..6850ec1c69 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -1,3 +1,4 @@ +import json import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any @@ -216,7 +217,11 @@ class HumanInputNode(Node[HumanInputNodeData]): submitted_data = form.submitted_data or {} outputs: dict[str, Any] = dict(submitted_data) outputs["__action_id"] = selected_action_id - outputs["__rendered_content"] = form.rendered_content + outputs["__rendered_content"] = self._render_form_content_with_outputs( + form.rendered_content, + outputs, + self._node_data.outputs_field_names(), + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, @@ -224,7 +229,13 @@ class HumanInputNode(Node[HumanInputNodeData]): ) if form.status == HumanInputFormStatus.TIMEOUT or form.expiration_time <= naive_utc_now(): - outputs: dict[str, Any] = {"__rendered_content": form.rendered_content} + outputs: dict[str, Any] = { + "__rendered_content": self._render_form_content_with_outputs( + form.rendered_content, + {}, + self._node_data.outputs_field_names(), + ) + } return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, @@ -243,13 +254,35 @@ class HumanInputNode(Node[HumanInputNodeData]): This method should: 1. Parse the form_content markdown 2. Substitute {{#node_name.var_name#}} with actual values - 3. Keep {{#$output.field_name#}} placeholders for form inputs + 3. Keep {{#$outputs.field_name#}} placeholders for form inputs """ rendered_form_content = self.graph_runtime_state.variable_pool.convert_template( self._node_data.form_content, ) return rendered_form_content.markdown + @staticmethod + def _render_form_content_with_outputs( + form_content: str, + outputs: Mapping[str, Any], + field_names: Sequence[str], + ) -> str: + """ + Replace {{#$outputs.xxx#}} placeholders with submitted values. + """ + rendered_content = form_content + for field_name in field_names: + placeholder = "{{#$outputs." + field_name + "#}}" + value = outputs.get(field_name) + if value is None: + replacement = "" + elif isinstance(value, (dict, list)): + replacement = json.dumps(value, ensure_ascii=False) + else: + replacement = str(value) + rendered_content = rendered_content.replace(placeholder, replacement) + return rendered_content + @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 5abf5cdf81..be54113f63 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -9,7 +9,7 @@ import pytest from pydantic import ValidationError from core.workflow.entities import GraphInitParams -from core.workflow.node_events import PauseRequestedEvent +from core.workflow.node_events import NodeRunResult, PauseRequestedEvent from core.workflow.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, @@ -35,6 +35,7 @@ from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository class TestDeliveryMethod: @@ -337,7 +338,6 @@ class TestHumanInputNodeVariableResolution: graph_runtime_state=runtime_state, form_repository=mock_repo, ) - node.init_node_data(config["data"]) run_result = node._run() pause_event = next(run_result) @@ -377,3 +377,64 @@ class TestValidation: title="Test", timeout_unit="invalid-unit", # Invalid unit ) + + +class TestHumanInputNodeRenderedContent: + """Tests for rendering submitted content.""" + + def test_replaces_outputs_placeholders_after_submission(self): + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-1", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + node_data = HumanInputNodeData( + title="Human Input", + form_content="Name: {{#$outputs.name#}}", + inputs=[ + FormInput( + type=FormInputType.TEXT_INPUT, + output_variable_name="name", + ) + ], + user_actions=[UserAction(id="approve", title="Approve")], + ) + config = {"id": "human", "data": node_data.model_dump()} + + form_repository = InMemoryHumanInputFormRepository() + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + form_repository=form_repository, + ) + + pause_gen = node._run() + pause_event = next(pause_gen) + assert isinstance(pause_event, PauseRequestedEvent) + with pytest.raises(StopIteration): + next(pause_gen) + + form_repository.set_submission(action_id="approve", form_data={"name": "Alice"}) + + result = node._run() + assert isinstance(result, NodeRunResult) + assert result.outputs["__rendered_content"] == "Name: Alice" diff --git a/api/tests/unit_tests/tasks/test_workflow_execute_task.py b/api/tests/unit_tests/tasks/test_workflow_execute_task.py index 3f3054c008..161151305d 100644 --- a/api/tests/unit_tests/tasks/test_workflow_execute_task.py +++ b/api/tests/unit_tests/tasks/test_workflow_execute_task.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock import pytest +from models.model import AppMode from tasks.app_generate.workflow_execute_task import _publish_streaming_response @@ -13,7 +14,7 @@ from tasks.app_generate.workflow_execute_task import _publish_streaming_response def mock_topic(mocker) -> MagicMock: topic = MagicMock() mocker.patch( - "tasks.app_generate.workflow_execute_task.AdvancedChatAppGenerator.get_response_topic", + "tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic", return_value=topic, ) return topic @@ -23,7 +24,7 @@ def test_publish_streaming_response_with_uuid(mock_topic: MagicMock): workflow_run_id = uuid.uuid4() response_stream = iter([{"event": "foo"}, "ping"]) - _publish_streaming_response(response_stream, workflow_run_id) + _publish_streaming_response(response_stream, workflow_run_id, app_mode=AppMode.ADVANCED_CHAT) payloads = [call.args[0] for call in mock_topic.publish.call_args_list] assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()] @@ -33,6 +34,6 @@ def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock): workflow_run_id = uuid.uuid4() response_stream = iter([{"event": "bar"}]) - _publish_streaming_response(response_stream, str(workflow_run_id)) + _publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT) mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode())