mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
WIP: feat(api): always use form_token to submit human input form
This commit is contained in:
@ -4,12 +4,12 @@ from datetime import timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowType
|
||||
|
||||
@ -69,7 +69,7 @@ def test_graph_run_paused_event_emits_queue_pause_event():
|
||||
actions=[],
|
||||
node_id="node-human",
|
||||
node_title="Human Step",
|
||||
web_app_form_token="tok",
|
||||
form_token="tok",
|
||||
)
|
||||
event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"})
|
||||
workflow_entry = SimpleNamespace(
|
||||
@ -128,7 +128,7 @@ def test_queue_workflow_paused_event_to_stream_responses():
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
node_id="node-id",
|
||||
node_title="Human Step",
|
||||
web_app_form_token="token",
|
||||
form_token="token",
|
||||
)
|
||||
queue_event = QueueWorkflowPausedEvent(
|
||||
reasons=[reason],
|
||||
|
||||
@ -290,7 +290,7 @@ class TestHumanInputFormRepositoryImplPublicMethods:
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, [recipient]])
|
||||
@ -368,7 +368,7 @@ class TestHumanInputFormSubmissionRepository:
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
form=form,
|
||||
)
|
||||
@ -379,7 +379,7 @@ class TestHumanInputFormSubmissionRepository:
|
||||
|
||||
assert record is not None
|
||||
assert record.form_id == form.id
|
||||
assert record.recipient_type == RecipientType.WEBAPP
|
||||
assert record.recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
assert record.submitted is False
|
||||
|
||||
def test_get_by_form_id_and_recipient_type_uses_recipient(self):
|
||||
@ -395,14 +395,17 @@ class TestHumanInputFormSubmissionRepository:
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
form=form,
|
||||
)
|
||||
session = _FakeSession(scalars_result=recipient)
|
||||
repo = HumanInputFormSubmissionRepository(_session_factory(session))
|
||||
|
||||
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.WEBAPP)
|
||||
record = repo.get_by_form_id_and_recipient_type(
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
)
|
||||
|
||||
assert record is not None
|
||||
assert record.recipient_id == recipient.id
|
||||
@ -424,7 +427,7 @@ class TestHumanInputFormSubmissionRepository:
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id="form-1",
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
)
|
||||
session = _FakeSession(
|
||||
|
||||
@ -38,6 +38,7 @@ class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
token: str | None = None
|
||||
console_token_value: str | None = None
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
is_submitted: bool = False
|
||||
@ -50,6 +51,8 @@ class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
if self.console_token_value is not None:
|
||||
return self.console_token_value
|
||||
return self.token
|
||||
|
||||
@property
|
||||
@ -94,7 +97,13 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
self.created_params.append(params)
|
||||
self._form_counter += 1
|
||||
form_id = f"form-{self._form_counter}"
|
||||
entity = _InMemoryFormEntity(form_id=form_id, rendered=params.rendered_content, token=f"token-{form_id}")
|
||||
console_token = f"console-{form_id}" if params.console_recipient_required else None
|
||||
entity = _InMemoryFormEntity(
|
||||
form_id=form_id,
|
||||
rendered=params.rendered_content,
|
||||
token=f"token-{form_id}",
|
||||
console_token_value=console_token,
|
||||
)
|
||||
self.created_forms.append(entity)
|
||||
self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity
|
||||
return entity
|
||||
|
||||
@ -350,6 +350,61 @@ class TestHumanInputNodeVariableResolution:
|
||||
params = mock_repo.create_form.call_args.args[0]
|
||||
assert params.resolved_placeholder_values == expected_values
|
||||
|
||||
def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-2",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Provide your name",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
config = {"id": "human", "data": node_data.model_dump()}
|
||||
|
||||
mock_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_repo.get_form.return_value = None
|
||||
mock_repo.create_form.return_value = SimpleNamespace(
|
||||
id="form-2",
|
||||
rendered_content="Provide your name",
|
||||
web_app_token="console-token",
|
||||
recipients=[SimpleNamespace(token="recipient-token")],
|
||||
submitted=False,
|
||||
)
|
||||
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=mock_repo,
|
||||
)
|
||||
|
||||
run_result = node._run()
|
||||
pause_event = next(run_result)
|
||||
|
||||
assert isinstance(pause_event, PauseRequestedEvent)
|
||||
assert pause_event.reason.form_token == "console-token"
|
||||
|
||||
|
||||
class TestValidation:
|
||||
"""Test validation scenarios."""
|
||||
|
||||
@ -48,7 +48,7 @@ class HumanInputForm:
|
||||
user_actions: list[dict[str, Any]]
|
||||
timeout: int
|
||||
timeout_unit: TimeoutUnit
|
||||
web_app_form_token: str | None = None
|
||||
form_token: str | None = None
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = None
|
||||
submitted_at: datetime | None = None
|
||||
@ -141,7 +141,7 @@ class InMemoryFormRepository:
|
||||
|
||||
def get_by_token(self, token: str) -> Optional[HumanInputForm]:
|
||||
for form in self._forms.values():
|
||||
if form.web_app_form_token == token:
|
||||
if form.form_token == token:
|
||||
return form
|
||||
return None
|
||||
|
||||
@ -169,7 +169,7 @@ class FormService:
|
||||
user_actions,
|
||||
timeout: int,
|
||||
timeout_unit: TimeoutUnit,
|
||||
web_app_form_token: str | None = None,
|
||||
form_token: str | None = None,
|
||||
) -> HumanInputForm:
|
||||
form = HumanInputForm(
|
||||
form_id=form_id,
|
||||
@ -182,7 +182,7 @@ class FormService:
|
||||
user_actions=[{"id": action.id, "title": action.title} for action in user_actions],
|
||||
timeout=timeout,
|
||||
timeout_unit=timeout_unit,
|
||||
web_app_form_token=web_app_form_token,
|
||||
form_token=form_token,
|
||||
)
|
||||
form.calculate_expiration()
|
||||
self.repository.save(form)
|
||||
|
||||
@ -14,6 +14,7 @@ from core.workflow.nodes.human_input.enums import (
|
||||
FormInputType,
|
||||
TimeoutUnit,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .support import (
|
||||
FormAlreadySubmittedError,
|
||||
@ -53,7 +54,7 @@ class TestFormService:
|
||||
"user_actions": [UserAction(id="submit", title="Submit")],
|
||||
"timeout": 1,
|
||||
"timeout_unit": TimeoutUnit.HOUR,
|
||||
"web_app_form_token": "token-xyz",
|
||||
"form_token": "token-xyz",
|
||||
}
|
||||
|
||||
def test_create_form(self, form_service, sample_form_data):
|
||||
@ -65,7 +66,7 @@ class TestFormService:
|
||||
assert form.node_id == "node-789"
|
||||
assert form.tenant_id == "tenant-abc"
|
||||
assert form.app_id == "app-def"
|
||||
assert form.web_app_form_token == "token-xyz"
|
||||
assert form.form_token == "token-xyz"
|
||||
assert form.timeout == 1
|
||||
assert form.timeout_unit == TimeoutUnit.HOUR
|
||||
assert form.expires_at is not None
|
||||
@ -99,7 +100,7 @@ class TestFormService:
|
||||
retrieved_form = form_service.get_form_by_token("token-xyz")
|
||||
|
||||
assert retrieved_form.form_id == created_form.form_id
|
||||
assert retrieved_form.web_app_form_token == "token-xyz"
|
||||
assert retrieved_form.form_token == "token-xyz"
|
||||
|
||||
def test_get_form_by_token_not_found(self, form_service):
|
||||
"""Test getting non-existent form by token."""
|
||||
@ -261,13 +262,13 @@ class TestFormService:
|
||||
for i in range(3):
|
||||
data = sample_form_data.copy()
|
||||
data["form_id"] = f"form-{i}"
|
||||
data["web_app_form_token"] = f"token-{i}"
|
||||
data["form_token"] = f"token-{i}"
|
||||
form_service.create_form(**data)
|
||||
|
||||
# Manually expire some forms
|
||||
for i in range(2): # Expire first 2 forms
|
||||
form = form_service.get_form_by_id(f"form-{i}")
|
||||
form.expires_at = datetime.utcnow() - timedelta(hours=1)
|
||||
form.expires_at = naive_utc_now() - timedelta(hours=1)
|
||||
form_service.repository.save(form)
|
||||
|
||||
# Clean up expired forms
|
||||
|
||||
@ -35,7 +35,7 @@ class TestHumanInputForm:
|
||||
"user_actions": [UserAction(id="submit", title="Submit")],
|
||||
"timeout": 2,
|
||||
"timeout_unit": TimeoutUnit.HOUR,
|
||||
"web_app_form_token": "token-xyz",
|
||||
"form_token": "token-xyz",
|
||||
}
|
||||
|
||||
def test_form_creation(self, sample_form_data):
|
||||
@ -47,7 +47,7 @@ class TestHumanInputForm:
|
||||
assert form.node_id == "node-789"
|
||||
assert form.tenant_id == "tenant-abc"
|
||||
assert form.app_id == "app-def"
|
||||
assert form.web_app_form_token == "token-xyz"
|
||||
assert form.form_token == "token-xyz"
|
||||
assert form.timeout == 2
|
||||
assert form.timeout_unit == TimeoutUnit.HOUR
|
||||
assert form.created_at is not None
|
||||
@ -148,11 +148,11 @@ class TestHumanInputForm:
|
||||
|
||||
def test_form_without_web_app_token(self, sample_form_data):
|
||||
"""Test form creation without web app token."""
|
||||
sample_form_data["web_app_form_token"] = None
|
||||
sample_form_data["form_token"] = None
|
||||
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
assert form.web_app_form_token is None
|
||||
assert form.form_token is None
|
||||
assert form.form_id == "form-123" # Other fields should still work
|
||||
|
||||
def test_form_with_explicit_timestamps(self):
|
||||
|
||||
@ -18,9 +18,8 @@ from core.workflow.nodes.human_input.enums import (
|
||||
HumanInputFormStatus,
|
||||
TimeoutUnit,
|
||||
)
|
||||
from models.account import Account
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import FormSubmittedError, HumanInputService, InvalidFormDataError
|
||||
from services.human_input_service import HumanInputService, InvalidFormDataError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -60,7 +59,7 @@ def sample_form_record():
|
||||
submission_end_user_id=None,
|
||||
completed_by_recipient_id=None,
|
||||
recipient_id="recipient-id",
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token",
|
||||
)
|
||||
|
||||
@ -146,32 +145,18 @@ def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory)
|
||||
resume_task.apply_async.assert_not_called()
|
||||
|
||||
|
||||
def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory):
|
||||
def test_get_form_definition_by_token_for_console_uses_repository(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_form_id_and_recipient_type.return_value = sample_form_record
|
||||
console_record = dataclasses.replace(sample_form_record, recipient_type=RecipientType.CONSOLE)
|
||||
repo.get_by_token.return_value = console_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
form = service.get_form_definition_by_id("form-id")
|
||||
form = service.get_form_definition_by_token_for_console("token")
|
||||
|
||||
repo.get_by_form_id_and_recipient_type.assert_called_once_with(
|
||||
form_id="form-id",
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
)
|
||||
repo.get_by_token.assert_called_once_with("token")
|
||||
assert form is not None
|
||||
assert form.get_definition() == sample_form_record.definition
|
||||
|
||||
|
||||
def test_get_form_definition_by_id_raises_on_submitted(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime(2024, 1, 1))
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_form_id_and_recipient_type.return_value = submitted_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
|
||||
with pytest.raises(FormSubmittedError):
|
||||
service.get_form_definition_by_id("form-id")
|
||||
assert form.get_definition() == console_record.definition
|
||||
|
||||
|
||||
def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker):
|
||||
@ -183,7 +168,7 @@ def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, m
|
||||
enqueue_spy = mocker.patch.object(service, "_enqueue_resume")
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={"field": "value"},
|
||||
@ -201,26 +186,25 @@ def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, m
|
||||
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
|
||||
|
||||
|
||||
def test_submit_form_by_id_passes_account(sample_form_record, mock_session_factory, mocker):
|
||||
def test_submit_form_by_token_passes_submission_user_id(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_form_id_and_recipient_type.return_value = sample_form_record
|
||||
repo.get_by_token.return_value = sample_form_record
|
||||
repo.mark_submitted.return_value = sample_form_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
enqueue_spy = mocker.patch.object(service, "_enqueue_resume")
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "account-id"
|
||||
|
||||
service.submit_form_by_id(
|
||||
form_id="form-id",
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={"x": 1},
|
||||
user=account,
|
||||
form_data={"field": "value"},
|
||||
submission_user_id="account-id",
|
||||
)
|
||||
|
||||
repo.get_by_form_id_and_recipient_type.assert_called_once()
|
||||
repo.mark_submitted.assert_called_once()
|
||||
assert repo.mark_submitted.call_args.kwargs["submission_user_id"] == "account-id"
|
||||
call_kwargs = repo.mark_submitted.call_args.kwargs
|
||||
assert call_kwargs["submission_user_id"] == "account-id"
|
||||
assert call_kwargs["submission_end_user_id"] is None
|
||||
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
|
||||
|
||||
|
||||
@ -232,7 +216,7 @@ def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_fa
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="invalid",
|
||||
form_data={},
|
||||
@ -260,7 +244,7 @@ def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_fa
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={},
|
||||
|
||||
Reference in New Issue
Block a user