refactor(api): Unify Human Input handling logic

This commit is contained in:
QuantumGhost
2026-01-13 10:39:55 +08:00
parent 18fd308a81
commit 99937aba2e
4 changed files with 167 additions and 26 deletions

View File

@ -259,3 +259,30 @@ class FormDefinition(BaseModel):
# display_in_ui controls whether the form should be displayed in UI surfaces.
display_in_ui: bool | None = None
class HumanInputSubmissionValidationError(ValueError):
pass
def validate_human_input_submission(
*,
inputs: Sequence[FormInput],
user_actions: Sequence[UserAction],
selected_action_id: str,
form_data: Mapping[str, Any],
) -> None:
available_actions = {action.id for action in user_actions}
if selected_action_id not in available_actions:
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
provided_inputs = set(form_data.keys())
missing_inputs = [
form_input.output_variable_name
for form_input in inputs
if form_input.output_variable_name not in provided_inputs
]
if missing_inputs:
missing_list = ", ".join(missing_inputs)
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")

View File

@ -9,7 +9,11 @@ from core.repositories.human_input_reposotiry import (
HumanInputFormRecord,
HumanInputFormSubmissionRepository,
)
from core.workflow.nodes.human_input.entities import FormDefinition
from core.workflow.nodes.human_input.entities import (
FormDefinition,
HumanInputSubmissionValidationError,
validate_human_input_submission,
)
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from libs.datetime_utils import naive_utc_now
from libs.exception import BaseHTTPException
@ -171,20 +175,15 @@ class HumanInputService:
def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
definition = form.get_definition()
available_actions = {action.id for action in definition.user_actions}
if selected_action_id not in available_actions:
raise InvalidFormDataError(f"Invalid action: {selected_action_id}")
provided_inputs = set(form_data.keys())
missing_inputs = [
form_input.output_variable_name
for form_input in definition.inputs
if form_input.output_variable_name not in provided_inputs
]
if missing_inputs:
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
try:
validate_human_input_submission(
inputs=definition.inputs,
user_actions=definition.user_actions,
selected_action_id=selected_action_id,
form_data=form_data,
)
except HumanInputSubmissionValidationError as exc:
raise InvalidFormDataError(str(exc)) from exc
def _enqueue_resume(self, workflow_run_id: str) -> None:
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)

View File

@ -24,7 +24,11 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, N
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData
from core.workflow.nodes.human_input.entities import (
DeliveryChannelConfig,
HumanInputNodeData,
validate_human_input_submission,
)
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
@ -833,21 +837,18 @@ class WorkflowService:
)
node_data = node.node_data
available_actions = {user_action.id for user_action in node_data.user_actions}
if action not in available_actions:
raise ValueError(f"Invalid action: {action}")
expected_inputs = {form_input.output_variable_name for form_input in node_data.inputs}
missing_inputs = [name for name in expected_inputs if name not in form_inputs]
if missing_inputs:
missing_list = ", ".join(sorted(missing_inputs))
raise ValueError(f"Missing inputs: {missing_list}")
validate_human_input_submission(
inputs=node_data.inputs,
user_actions=node_data.user_actions,
selected_action_id=action,
form_data=form_inputs,
)
rendered_content = node._render_form_content_before_submission()
outputs: dict[str, Any] = dict(form_inputs)
outputs["__action_id"] = action
outputs["__rendered_content"] = node._render_form_content_with_outputs(
node_data.form_content, outputs, node_data.outputs_field_names()
rendered_content, outputs, node_data.outputs_field_names()
)
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)

View File

@ -1,9 +1,15 @@
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.workflow.enums import NodeType
from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import FormInputType
from models.model import App
from models.workflow import Workflow
from services import workflow_service as workflow_service_module
from services.workflow_service import WorkflowService
@ -161,3 +167,111 @@ class TestWorkflowService:
assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()
def test_submit_human_input_form_preview_uses_rendered_content(
self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch
) -> None:
service = workflow_service
node_data = HumanInputNodeData(
title="Human Input",
form_content="<p>{{#$output.name#}}</p>",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
)
node = MagicMock()
node.node_data = node_data
node._render_form_content_before_submission.return_value = "<p>preview</p>"
node._render_form_content_with_outputs.return_value = "<p>rendered</p>"
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
workflow.get_enclosing_node_type_and_id.return_value = None
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
saved_outputs: dict[str, object] = {}
class DummySession:
def __init__(self, *args, **kwargs):
self.commit = MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def begin(self):
return nullcontext()
class DummySaver:
def __init__(self, *args, **kwargs):
pass
def save(self, outputs, process_data):
saved_outputs.update(outputs)
monkeypatch.setattr(workflow_service_module, "Session", DummySession)
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver)
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
account = SimpleNamespace(id="account-1")
result = service.submit_human_input_form_preview(
app_model=app_model,
account=account,
node_id="node-1",
form_inputs={"name": "Ada", "extra": "ignored"},
action="approve",
)
node._render_form_content_with_outputs.assert_called_once()
called_args = node._render_form_content_with_outputs.call_args.args
assert called_args[0] == "<p>preview</p>"
assert called_args[2] == node_data.outputs_field_names()
rendered_outputs = called_args[1]
assert rendered_outputs["name"] == "Ada"
assert rendered_outputs["extra"] == "ignored"
assert "extra" in saved_outputs
assert "extra" in result
assert saved_outputs["name"] == "Ada"
assert result["name"] == "Ada"
assert result["__action_id"] == "approve"
assert "__rendered_content" in result
def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None:
service = workflow_service
node_data = HumanInputNodeData(
title="Human Input",
form_content="<p>{{#$output.name#}}</p>",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
)
node = MagicMock()
node.node_data = node_data
node._render_form_content_before_submission.return_value = "<p>preview</p>"
node._render_form_content_with_outputs.return_value = "<p>rendered</p>"
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
account = SimpleNamespace(id="account-1")
with pytest.raises(ValueError) as exc_info:
service.submit_human_input_form_preview(
app_model=app_model,
account=account,
node_id="node-1",
form_inputs={},
action="approve",
)
assert "Missing required inputs" in str(exc_info.value)