WIP: unify Form And FormSubmission

This commit is contained in:
QuantumGhost
2025-12-08 02:52:34 +08:00
parent 1f64281ce5
commit e6fbf3a198
14 changed files with 113 additions and 206 deletions

View File

@ -22,7 +22,6 @@ from core.workflow.nodes.human_input.entities import (
TimeoutUnit,
UserAction,
)
from core.workflow.repositories.human_input_form_repository import FormNotFoundError
from libs.datetime_utils import naive_utc_now
from models.human_input import (
EmailExternalRecipientPayload,
@ -309,7 +308,7 @@ class TestHumanInputFormRepositoryImplPublicMethods:
assert repo.get_form("run-1", "node-1") is None
def test_get_form_submission_returns_none_when_pending(self):
def test_get_form_returns_unsubmitted_state(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
@ -319,12 +318,17 @@ class TestHumanInputFormRepositoryImplPublicMethods:
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
session = _FakeSession(forms={form.id: form})
session = _FakeSession(scalars_results=[form, []])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
assert repo.get_form_submission(form.id) is None
entity = repo.get_form(form.workflow_run_id, form.node_id)
def test_get_form_submission_returns_submission_when_completed(self):
assert entity is not None
assert entity.submitted is False
assert entity.selected_action_id is None
assert entity.submitted_data is None
def test_get_form_returns_submission_when_completed(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
@ -337,21 +341,15 @@ class TestHumanInputFormRepositoryImplPublicMethods:
submitted_data='{"field": "value"}',
submitted_at=naive_utc_now(),
)
session = _FakeSession(forms={form.id: form})
session = _FakeSession(scalars_results=[form, []])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
submission = repo.get_form_submission(form.id)
entity = repo.get_form(form.workflow_run_id, form.node_id)
assert submission is not None
assert submission.selected_action_id == "approve"
assert submission.form_data() == {"field": "value"}
def test_get_form_submission_raises_when_form_missing(self):
session = _FakeSession(forms={})
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
with pytest.raises(FormNotFoundError):
repo.get_form_submission("form-unknown")
assert entity is not None
assert entity.submitted is True
assert entity.selected_action_id == "approve"
assert entity.submitted_data == {"field": "value"}
class TestHumanInputFormSubmissionRepository:

View File

@ -8,7 +8,6 @@ from typing import Any
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
FormSubmission,
HumanInputFormEntity,
HumanInputFormRecipientEntity,
HumanInputFormRepository,
@ -36,6 +35,9 @@ class _InMemoryFormEntity(HumanInputFormEntity):
form_id: str
rendered: str
token: str | None = None
action_id: str | None = None
data: Mapping[str, Any] | None = None
is_submitted: bool = False
@property
def id(self) -> str:
@ -53,18 +55,17 @@ class _InMemoryFormEntity(HumanInputFormEntity):
def rendered_content(self) -> str:
return self.rendered
class _InMemoryFormSubmission(FormSubmission):
def __init__(self, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
self._selected_action_id = selected_action_id
self._form_data = form_data
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def selected_action_id(self) -> str:
return self._selected_action_id
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
def form_data(self) -> Mapping[str, Any]:
return self._form_data
@property
def submitted(self) -> bool:
return self.is_submitted
class InMemoryHumanInputFormRepository(HumanInputFormRepository):
@ -75,7 +76,6 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
self.created_params: list[FormCreateParams] = []
self.created_forms: list[_InMemoryFormEntity] = []
self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {}
self._submissions: dict[str, FormSubmission] = {}
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
self.created_params.append(params)
@ -89,9 +89,6 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
return self._forms_by_key.get((workflow_execution_id, node_id))
def get_form_submission(self, form_id: str) -> FormSubmission | None:
return self._submissions.get(form_id)
# Convenience helpers for tests -------------------------------------
def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None:
@ -99,8 +96,15 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
if not self.created_forms:
raise AssertionError("no form has been created to attach submission data")
target_form_id = self.created_forms[-1].id
self._submissions[target_form_id] = _InMemoryFormSubmission(action_id, form_data or {})
entity = self.created_forms[-1]
entity.action_id = action_id
entity.data = form_data or {}
entity.is_submitted = True
def clear_submission(self) -> None:
self._submissions.clear()
if not self.created_forms:
return
for form in self.created_forms:
form.action_id = None
form.data = None
form.is_submitted = False

View File

@ -29,11 +29,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormSubmission,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@ -242,13 +238,13 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
runner = TableTestRunner()
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
mock_create_repo.get_form_submission.return_value = None
mock_create_repo.get_form.return_value = None
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
mock_form_entity.id = "test_form_id"
mock_form_entity.web_app_token = "test_web_app_token"
mock_form_entity.recipients = []
mock_form_entity.rendered_content = "rendered"
mock_form_entity.submitted = False
mock_create_repo.create_form.return_value = mock_form_entity
def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]:
@ -297,11 +293,15 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
)
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
mock_form_submission = MagicMock(spec=FormSubmission)
mock_form_submission.selected_action_id = scenario["handle"]
mock_form_submission.form_data.return_value = {}
mock_get_repo.get_form_submission.return_value = mock_form_submission
mock_get_repo.get_form.return_value = mock_form_entity
submitted_form = MagicMock(spec=HumanInputFormEntity)
submitted_form.id = mock_form_entity.id
submitted_form.web_app_token = mock_form_entity.web_app_token
submitted_form.recipients = []
submitted_form.rendered_content = mock_form_entity.rendered_content
submitted_form.submitted = True
submitted_form.selected_action_id = scenario["handle"]
submitted_form.submitted_data = {}
mock_get_repo.get_form.return_value = submitted_form
def resume_graph_factory(
initial_result=initial_result, mock_get_repo=mock_get_repo

View File

@ -28,11 +28,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormSubmission,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@ -187,13 +183,13 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
]
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
mock_create_repo.get_form_submission.return_value = None
mock_create_repo.get_form.return_value = None
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
mock_form_entity.id = "test_form_id"
mock_form_entity.web_app_token = "test_web_app_token"
mock_form_entity.recipients = []
mock_form_entity.rendered_content = "rendered"
mock_form_entity.submitted = False
mock_create_repo.create_form.return_value = mock_form_entity
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
@ -255,11 +251,15 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
]
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
mock_form_submission = MagicMock(spec=FormSubmission)
mock_form_submission.selected_action_id = "accept"
mock_form_submission.form_data.return_value = {}
mock_get_repo.get_form_submission.return_value = mock_form_submission
mock_get_repo.get_form.return_value = mock_form_entity
submitted_form = MagicMock(spec=HumanInputFormEntity)
submitted_form.id = mock_form_entity.id
submitted_form.web_app_token = mock_form_entity.web_app_token
submitted_form.recipients = []
submitted_form.rendered_content = mock_form_entity.rendered_content
submitted_form.submitted = True
submitted_form.selected_action_id = "accept"
submitted_form.submitted_data = {}
mock_get_repo.get_form.return_value = submitted_form
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
# restruct the graph runtime state

View File

@ -20,7 +20,6 @@ from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormSubmission,
HumanInputFormEntity,
HumanInputFormRepository,
)
@ -43,28 +42,27 @@ def _build_runtime_state() -> GraphRuntimeState:
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
submission = MagicMock(spec=FormSubmission)
submission.selected_action_id = action_id
submission.form_data.return_value = {}
repo = MagicMock(spec=HumanInputFormRepository)
repo.get_form_submission.return_value = submission
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = True
form_entity.selected_action_id = action_id
form_entity.submitted_data = {}
repo.get_form.return_value = form_entity
return repo
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
repo = MagicMock(spec=HumanInputFormRepository)
repo.get_form_submission.return_value = None
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = False
repo.create_form.return_value = form_entity
repo.get_form.return_value = None
return repo
@ -184,8 +182,8 @@ def test_engine_resume_restores_state_and_completion():
assert combined_success_nodes == baseline_success_nodes
assert baseline_state.outputs == resumed_state.outputs
assert _segment_value(baseline_state.variable_pool, ("human", "action_id")) == _segment_value(
resumed_state.variable_pool, ("human", "action_id")
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(
resumed_state.variable_pool, ("human", "__action_id")
)
assert baseline_state.graph_execution.completed
assert resumed_state.graph_execution.completed

View File

@ -320,12 +320,12 @@ class TestHumanInputNodeVariableResolution:
mock_repo = MagicMock(spec=HumanInputFormRepository)
mock_repo.get_form.return_value = None
mock_repo.get_form_submission.return_value = None
mock_repo.create_form.return_value = SimpleNamespace(
id="form-1",
rendered_content="Provide your name",
web_app_token="token",
recipients=[],
submitted=False,
)
node = HumanInputNode(

View File

@ -168,11 +168,11 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
):
"""Test workflow pause creation when workflow not in RUNNING status."""
# Arrange
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
mock_session.get.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",