mirror of
https://github.com/langgenius/dify.git
synced 2026-02-22 19:15:47 +08:00
refactor(api): Unify Human Input handling logic
This commit is contained in:
@ -259,3 +259,30 @@ class FormDefinition(BaseModel):
|
||||
|
||||
# display_in_ui controls whether the form should be displayed in UI surfaces.
|
||||
display_in_ui: bool | None = None
|
||||
|
||||
|
||||
class HumanInputSubmissionValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def validate_human_input_submission(
|
||||
*,
|
||||
inputs: Sequence[FormInput],
|
||||
user_actions: Sequence[UserAction],
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
) -> None:
|
||||
available_actions = {action.id for action in user_actions}
|
||||
if selected_action_id not in available_actions:
|
||||
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
|
||||
|
||||
provided_inputs = set(form_data.keys())
|
||||
missing_inputs = [
|
||||
form_input.output_variable_name
|
||||
for form_input in inputs
|
||||
if form_input.output_variable_name not in provided_inputs
|
||||
]
|
||||
|
||||
if missing_inputs:
|
||||
missing_list = ", ".join(missing_inputs)
|
||||
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")
|
||||
|
||||
@ -9,7 +9,11 @@ from core.repositories.human_input_reposotiry import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormSubmissionRepository,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormDefinition,
|
||||
HumanInputSubmissionValidationError,
|
||||
validate_human_input_submission,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.exception import BaseHTTPException
|
||||
@ -171,20 +175,15 @@ class HumanInputService:
|
||||
|
||||
def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
|
||||
definition = form.get_definition()
|
||||
|
||||
available_actions = {action.id for action in definition.user_actions}
|
||||
if selected_action_id not in available_actions:
|
||||
raise InvalidFormDataError(f"Invalid action: {selected_action_id}")
|
||||
|
||||
provided_inputs = set(form_data.keys())
|
||||
missing_inputs = [
|
||||
form_input.output_variable_name
|
||||
for form_input in definition.inputs
|
||||
if form_input.output_variable_name not in provided_inputs
|
||||
]
|
||||
|
||||
if missing_inputs:
|
||||
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
|
||||
try:
|
||||
validate_human_input_submission(
|
||||
inputs=definition.inputs,
|
||||
user_actions=definition.user_actions,
|
||||
selected_action_id=selected_action_id,
|
||||
form_data=form_data,
|
||||
)
|
||||
except HumanInputSubmissionValidationError as exc:
|
||||
raise InvalidFormDataError(str(exc)) from exc
|
||||
|
||||
def _enqueue_resume(self, workflow_run_id: str) -> None:
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)
|
||||
|
||||
@ -24,7 +24,11 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, N
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
HumanInputNodeData,
|
||||
validate_human_input_submission,
|
||||
)
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
@ -833,21 +837,18 @@ class WorkflowService:
|
||||
)
|
||||
node_data = node.node_data
|
||||
|
||||
available_actions = {user_action.id for user_action in node_data.user_actions}
|
||||
if action not in available_actions:
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
|
||||
expected_inputs = {form_input.output_variable_name for form_input in node_data.inputs}
|
||||
missing_inputs = [name for name in expected_inputs if name not in form_inputs]
|
||||
if missing_inputs:
|
||||
missing_list = ", ".join(sorted(missing_inputs))
|
||||
raise ValueError(f"Missing inputs: {missing_list}")
|
||||
validate_human_input_submission(
|
||||
inputs=node_data.inputs,
|
||||
user_actions=node_data.user_actions,
|
||||
selected_action_id=action,
|
||||
form_data=form_inputs,
|
||||
)
|
||||
|
||||
rendered_content = node._render_form_content_before_submission()
|
||||
outputs: dict[str, Any] = dict(form_inputs)
|
||||
outputs["__action_id"] = action
|
||||
outputs["__rendered_content"] = node._render_form_content_with_outputs(
|
||||
node_data.form_content, outputs, node_data.outputs_field_names()
|
||||
rendered_content, outputs, node_data.outputs_field_names()
|
||||
)
|
||||
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
from contextlib import nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import FormInputType
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services import workflow_service as workflow_service_module
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
@ -161,3 +167,111 @@ class TestWorkflowService:
|
||||
assert workflows == []
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_submit_human_input_form_preview_uses_rendered_content(
|
||||
self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
service = workflow_service
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="<p>{{#$output.name#}}</p>",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
node = MagicMock()
|
||||
node.node_data = node_data
|
||||
node._render_form_content_before_submission.return_value = "<p>preview</p>"
|
||||
node._render_form_content_with_outputs.return_value = "<p>rendered</p>"
|
||||
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
|
||||
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
|
||||
workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
|
||||
saved_outputs: dict[str, object] = {}
|
||||
|
||||
class DummySession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.commit = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def begin(self):
|
||||
return nullcontext()
|
||||
|
||||
class DummySaver:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def save(self, outputs, process_data):
|
||||
saved_outputs.update(outputs)
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "Session", DummySession)
|
||||
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver)
|
||||
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
result = service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
form_inputs={"name": "Ada", "extra": "ignored"},
|
||||
action="approve",
|
||||
)
|
||||
|
||||
node._render_form_content_with_outputs.assert_called_once()
|
||||
called_args = node._render_form_content_with_outputs.call_args.args
|
||||
assert called_args[0] == "<p>preview</p>"
|
||||
assert called_args[2] == node_data.outputs_field_names()
|
||||
rendered_outputs = called_args[1]
|
||||
assert rendered_outputs["name"] == "Ada"
|
||||
assert rendered_outputs["extra"] == "ignored"
|
||||
assert "extra" in saved_outputs
|
||||
assert "extra" in result
|
||||
assert saved_outputs["name"] == "Ada"
|
||||
assert result["name"] == "Ada"
|
||||
assert result["__action_id"] == "approve"
|
||||
assert "__rendered_content" in result
|
||||
|
||||
def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None:
|
||||
service = workflow_service
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="<p>{{#$output.name#}}</p>",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
node = MagicMock()
|
||||
node.node_data = node_data
|
||||
node._render_form_content_before_submission.return_value = "<p>preview</p>"
|
||||
node._render_form_content_with_outputs.return_value = "<p>rendered</p>"
|
||||
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
|
||||
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
form_inputs={},
|
||||
action="approve",
|
||||
)
|
||||
|
||||
assert "Missing required inputs" in str(exc_info.value)
|
||||
|
||||
Reference in New Issue
Block a user