diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/core/workflow/nodes/human_input/entities.py index 81b03d1958..65bf32cd61 100644 --- a/api/core/workflow/nodes/human_input/entities.py +++ b/api/core/workflow/nodes/human_input/entities.py @@ -259,3 +259,30 @@ class FormDefinition(BaseModel): # display_in_ui controls whether the form should be displayed in UI surfaces. display_in_ui: bool | None = None + + +class HumanInputSubmissionValidationError(ValueError): + pass + + +def validate_human_input_submission( + *, + inputs: Sequence[FormInput], + user_actions: Sequence[UserAction], + selected_action_id: str, + form_data: Mapping[str, Any], +) -> None: + available_actions = {action.id for action in user_actions} + if selected_action_id not in available_actions: + raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") + + provided_inputs = set(form_data.keys()) + missing_inputs = [ + form_input.output_variable_name + for form_input in inputs + if form_input.output_variable_name not in provided_inputs + ] + + if missing_inputs: + missing_list = ", ".join(missing_inputs) + raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index bf182fce48..d1d31a7ab6 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -9,7 +9,11 @@ from core.repositories.human_input_reposotiry import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from core.workflow.nodes.human_input.entities import FormDefinition +from core.workflow.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) from core.workflow.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.exception import BaseHTTPException @@ -171,20 +175,15 @@ class HumanInputService: def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None: definition = form.get_definition() - - available_actions = {action.id for action in definition.user_actions} - if selected_action_id not in available_actions: - raise InvalidFormDataError(f"Invalid action: {selected_action_id}") - - provided_inputs = set(form_data.keys()) - missing_inputs = [ - form_input.output_variable_name - for form_input in definition.inputs - if form_input.output_variable_name not in provided_inputs - ] - - if missing_inputs: - raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}") + try: + validate_human_input_submission( + inputs=definition.inputs, + user_actions=definition.user_actions, + selected_action_id=selected_action_id, + form_data=form_data, + ) + except HumanInputSubmissionValidationError as exc: + raise InvalidFormDataError(str(exc)) from exc def _enqueue_resume(self, workflow_run_id: str) -> None: workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6b19cf4642..37679f89f6 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -24,7 +24,11 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, N from core.workflow.node_events import NodeRunResult from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node -from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData +from core.workflow.nodes.human_input.entities import ( + DeliveryChannelConfig, + HumanInputNodeData, + validate_human_input_submission, +) from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData @@ -833,21 +837,18 @@ class WorkflowService: ) node_data = node.node_data - available_actions = {user_action.id for user_action in node_data.user_actions} - if action not in available_actions: - raise ValueError(f"Invalid action: {action}") - - expected_inputs = {form_input.output_variable_name for form_input in node_data.inputs} - missing_inputs = [name for name in expected_inputs if name not in form_inputs] - if missing_inputs: - missing_list = ", ".join(sorted(missing_inputs)) - raise ValueError(f"Missing inputs: {missing_list}") + validate_human_input_submission( + inputs=node_data.inputs, + user_actions=node_data.user_actions, + selected_action_id=action, + form_data=form_inputs, + ) rendered_content = node._render_form_content_before_submission() outputs: dict[str, Any] = dict(form_inputs) outputs["__action_id"] = action outputs["__rendered_content"] = node._render_form_content_with_outputs( - node_data.form_content, outputs, node_data.outputs_field_names() + rendered_content, outputs, node_data.outputs_field_names() ) enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 9700cbaf0e..9833c44543 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -1,9 +1,15 @@ +from contextlib import nullcontext +from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from core.workflow.enums import NodeType +from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import FormInputType from models.model import App from models.workflow import Workflow +from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService @@ -161,3 +167,111 @@ class TestWorkflowService: assert workflows == [] assert has_more is False mock_session.scalars.assert_called_once() + + def test_submit_human_input_form_preview_uses_rendered_content( + self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch + ) -> None: + service = workflow_service + node_data = HumanInputNodeData( + title="Human Input", + form_content="
{{#$output.name#}}
", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + ) + node = MagicMock() + node.node_data = node_data + node._render_form_content_before_submission.return_value = "preview
" + node._render_form_content_with_outputs.return_value = "rendered
" + + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] + + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + workflow.get_enclosing_node_type_and_id.return_value = None + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + + saved_outputs: dict[str, object] = {} + + class DummySession: + def __init__(self, *args, **kwargs): + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def begin(self): + return nullcontext() + + class DummySaver: + def __init__(self, *args, **kwargs): + pass + + def save(self, outputs, process_data): + saved_outputs.update(outputs) + + monkeypatch.setattr(workflow_service_module, "Session", DummySession) + monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) + monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) + + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + account = SimpleNamespace(id="account-1") + + result = service.submit_human_input_form_preview( + app_model=app_model, + account=account, + node_id="node-1", + form_inputs={"name": "Ada", "extra": "ignored"}, + action="approve", + ) + + node._render_form_content_with_outputs.assert_called_once() + called_args = node._render_form_content_with_outputs.call_args.args + assert called_args[0] == "preview
" + assert called_args[2] == node_data.outputs_field_names() + rendered_outputs = called_args[1] + assert rendered_outputs["name"] == "Ada" + assert rendered_outputs["extra"] == "ignored" + assert "extra" in saved_outputs + assert "extra" in result + assert saved_outputs["name"] == "Ada" + assert result["name"] == "Ada" + assert result["__action_id"] == "approve" + assert "__rendered_content" in result + + def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None: + service = workflow_service + node_data = HumanInputNodeData( + title="Human Input", + form_content="{{#$output.name#}}
", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + ) + node = MagicMock() + node.node_data = node_data + node._render_form_content_before_submission.return_value = "preview
" + node._render_form_content_with_outputs.return_value = "rendered
" + + service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] + + workflow = MagicMock() + workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] + + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + account = SimpleNamespace(id="account-1") + + with pytest.raises(ValueError) as exc_info: + service.submit_human_input_form_preview( + app_model=app_model, + account=account, + node_id="node-1", + form_inputs={}, + action="approve", + ) + + assert "Missing required inputs" in str(exc_info.value)