From e6fbf3a198dff2c1c0d69946fb3a129a5ac88f20 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 8 Dec 2025 02:52:34 +0800 Subject: [PATCH] WIP: unify Form And FormSubmission --- .../repositories/human_input_reposotiry.py | 37 +++------ .../nodes/human_input/human_input_node.py | 15 ++-- .../human_input_form_repository.py | 40 +++++----- .../execution_extra_content_repository.py | 3 +- .../models/test_message_extra_content.py | 75 ------------------- .../services/test_app_generate_service.py | 10 ++- .../test_workflow_pause_integration.py | 5 -- .../test_human_input_form_repository_impl.py | 32 ++++---- .../graph_engine/human_input_test_utils.py | 38 +++++----- .../test_human_input_pause_multi_branch.py | 22 +++--- .../test_human_input_pause_single_branch.py | 22 +++--- .../graph_engine/test_pause_resume_state.py | 14 ++-- .../nodes/human_input/test_entities.py | 2 +- ..._sqlalchemy_api_workflow_run_repository.py | 4 +- 14 files changed, 113 insertions(+), 206 deletions(-) delete mode 100644 api/tests/test_containers_integration_tests/models/test_message_extra_content.py diff --git a/api/core/repositories/human_input_reposotiry.py b/api/core/repositories/human_input_reposotiry.py index f99bc877bc..3df02e79ce 100644 --- a/api/core/repositories/human_input_reposotiry.py +++ b/api/core/repositories/human_input_reposotiry.py @@ -20,7 +20,6 @@ from core.workflow.nodes.human_input.entities import ( from core.workflow.repositories.human_input_form_repository import ( FormCreateParams, FormNotFoundError, - FormSubmission, HumanInputFormEntity, HumanInputFormRecipientEntity, ) @@ -73,6 +72,9 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): (recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.WEBAPP), None, ) + self._submitted_data: Mapping[str, Any] | None = ( + json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None + ) @property def id(self) -> str: @@ -92,23 +94,17 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): def rendered_content(self) -> str: return self._form_model.rendered_content - -class _FormSubmissionImpl(FormSubmission): - def __init__(self, form_model: HumanInputForm): - self._form_model = form_model + @property + def selected_action_id(self) -> str | None: + return self._form_model.selected_action_id @property - def selected_action_id(self) -> str: - selected_action_id = self._form_model.selected_action_id - if selected_action_id is None: - raise AssertionError(f"selected_action_id should not be None, form_id={self._form_model.id}") - return selected_action_id + def submitted_data(self) -> Mapping[str, Any] | None: + return self._submitted_data - def form_data(self) -> Mapping[str, Any]: - submitted_data = self._form_model.submitted_data - if submitted_data is None: - raise AssertionError(f"submitted_data should not be None, form_id={self._form_model.id}") - return json.loads(submitted_data) + @property + def submitted(self) -> bool: + return self._form_model.submitted_at is not None @dataclasses.dataclass(frozen=True) @@ -348,17 +344,6 @@ class HumanInputFormRepositoryImpl: recipient_models = session.scalars(recipient_query).all() return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - def get_form_submission(self, form_id: str) -> FormSubmission | None: - with self._session_factory(expire_on_commit=False) as session: - form_model: HumanInputForm | None = session.get(HumanInputForm, form_id) - if form_model is None or form_model.tenant_id != self._tenant_id: - raise FormNotFoundError(f"form not found, form_id={form_id}") - - if form_model.submitted_at is None: - return None - - return _FormSubmissionImpl(form_model=form_model) - class HumanInputFormSubmissionRepository: """Repository for fetching and submitting human input forms.""" diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index d2403d780d..c8827f8209 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -207,15 +207,18 @@ class HumanInputNode(Node[HumanInputNodeData]): if form is None: return self._create_form() - submission_result = repo.get_form_submission(form.id) - if submission_result: - outputs: dict[str, Any] = dict(submission_result.form_data()) - outputs["action_id"] = submission_result.selected_action_id - outputs["__action_id"] = submission_result.selected_action_id + if form.submitted: + selected_action_id = form.selected_action_id + if selected_action_id is None: + raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") + submitted_data = form.submitted_data or {} + outputs: dict[str, Any] = dict(submitted_data) + outputs["__action_id"] = selected_action_id + outputs["__rendered_content"] = form.rendered_content return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, - edge_source_handle=submission_result.selected_action_id, + edge_source_handle=selected_action_id, ) return self._pause_with_form(form) diff --git a/api/core/workflow/repositories/human_input_form_repository.py b/api/core/workflow/repositories/human_input_form_repository.py index 6050e86c8d..b401cace11 100644 --- a/api/core/workflow/repositories/human_input_form_repository.py +++ b/api/core/workflow/repositories/human_input_form_repository.py @@ -64,6 +64,24 @@ class HumanInputFormEntity(abc.ABC): """Rendered markdown content associated with the form.""" ... + @property + @abc.abstractmethod + def selected_action_id(self) -> str | None: + """Identifier of the selected user action if the form has been submitted.""" + ... + + @property + @abc.abstractmethod + def submitted_data(self) -> Mapping[str, Any] | None: + """Submitted form data if available.""" + ... + + @property + @abc.abstractmethod + def submitted(self) -> bool: + """Whether the form has been submitted.""" + ... + class HumanInputFormRecipientEntity(abc.ABC): @property @@ -79,19 +97,6 @@ class HumanInputFormRecipientEntity(abc.ABC): ... -class FormSubmission(abc.ABC): - @property - @abc.abstractmethod - def selected_action_id(self) -> str: - """The identifier of action user has selected, correspond to `UserAction.id`.""" - pass - - @abc.abstractmethod - def form_data(self) -> Mapping[str, Any]: - """The data submitted for this form""" - pass - - class HumanInputFormRepository(Protocol): """ Repository interface for HumanInputForm. @@ -115,12 +120,3 @@ class HumanInputFormRepository(Protocol): Create a human input form from form definition. """ ... - - def get_form_submission(self, form_id: str) -> FormSubmission | None: - """Retrieve the submission for a specific human input node. - - Returns `FormSubmission` if the form has been submitted, or `None` if not. - - Raises `FormNotFoundError` if correspond form record is not found. - """ - ... diff --git a/api/repositories/execution_extra_content_repository.py b/api/repositories/execution_extra_content_repository.py index b741fcb56c..72b5443d2c 100644 --- a/api/repositories/execution_extra_content_repository.py +++ b/api/repositories/execution_extra_content_repository.py @@ -7,8 +7,7 @@ from core.entities.execution_extra_content import ExecutionExtraContentDomainMod class ExecutionExtraContentRepository(Protocol): - def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: - ... + def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: ... __all__ = ["ExecutionExtraContentRepository"] diff --git a/api/tests/test_containers_integration_tests/models/test_message_extra_content.py b/api/tests/test_containers_integration_tests/models/test_message_extra_content.py deleted file mode 100644 index eb6ae45b33..0000000000 --- a/api/tests/test_containers_integration_tests/models/test_message_extra_content.py +++ /dev/null @@ -1,75 +0,0 @@ -import uuid -from decimal import Decimal - -from sqlalchemy import select -from sqlalchemy.orm import selectinload - -from libs.uuid_utils import uuidv7 -from models.enums import CreatorUserRole -from models.model import AppMode, Conversation, Message - - -def _create_conversation(session) -> Conversation: - conversation = Conversation( - app_id=str(uuid.uuid4()), - mode=AppMode.CHAT, - name="Test Conversation", - status="normal", - from_source=CreatorUserRole.ACCOUNT, - from_account_id=str(uuid.uuid4()), - ) - conversation.inputs = {} - session.add(conversation) - session.commit() - return conversation - - -def _create_message(session, conversation: Conversation) -> Message: - message = Message( - app_id=conversation.app_id, - conversation_id=conversation.id, - query="Need manual approval", - message={"type": "text", "content": "Need manual approval"}, - answer="Acknowledged", - message_tokens=10, - answer_tokens=20, - message_unit_price=Decimal("0.001"), - answer_unit_price=Decimal("0.001"), - message_price_unit=Decimal("0.001"), - answer_price_unit=Decimal("0.001"), - currency="USD", - status="normal", - from_source=CreatorUserRole.ACCOUNT, - ) - message.inputs = {} - session.add(message) - session.commit() - return message - - -def test_message_auto_loads_multiple_extra_variants(db_session_with_containers): - conversation = _create_conversation(db_session_with_containers) - message = _create_message(db_session_with_containers, conversation) - - human_input_result_content_id = str(uuidv7()) - human_input_result_content = HumanInputResultRelation( - id=human_input_result_content_id, - message_id=message.id, - form_id=None, - ) - db_session_with_containers.add(human_input_result_content) - db_session_with_containers.commit() - - # polymorphic_extra = with_polymorphic( - # MessageExtraContent, - # [HumanInputResultRelation], - # ) - - stmt = select(Message).options(selectinload(Message.extra_content)).where(Message.id == message.id) - loaded_message = db_session_with_containers.execute(stmt).scalar_one() - - assert len(loaded_message.extra_content) == 1 - assert human_input_result_content_id in {extra.id for extra in loaded_message.extra_content} - - loaded_types = {type(extra) for extra in loaded_message.extra_content} - assert HumanInputResultRelation in loaded_types diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 476f58585d..2e62796fec 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -38,9 +38,13 @@ class TestAppGenerateService: # Setup default mock returns for workflow service mock_workflow_service_instance = mock_workflow_service.return_value - mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow) - mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow) - mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow) + mock_published_workflow = MagicMock(spec=Workflow) + mock_published_workflow.id = str(uuid.uuid4()) + mock_workflow_service_instance.get_published_workflow.return_value = mock_published_workflow + mock_draft_workflow = MagicMock(spec=Workflow) + mock_draft_workflow.id = str(uuid.uuid4()) + mock_workflow_service_instance.get_draft_workflow.return_value = mock_draft_workflow + mock_workflow_service_instance.get_published_workflow_by_id.return_value = mock_published_workflow # Setup default mock returns for rate limiting mock_rate_limit_instance = mock_rate_limit.return_value diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 889e3d1d83..5f4f28cf4f 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -94,11 +94,6 @@ class PrunePausesTestCase: def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]: """Create test cases for pause workflow failure scenarios.""" return [ - PauseWorkflowFailureCase( - name="pause_already_paused_workflow", - initial_status=WorkflowExecutionStatus.PAUSED, - description="Should fail to pause an already paused workflow", - ), PauseWorkflowFailureCase( name="pause_completed_workflow", initial_status=WorkflowExecutionStatus.SUCCEEDED, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 784bebfb71..96d5894164 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -22,7 +22,6 @@ from core.workflow.nodes.human_input.entities import ( TimeoutUnit, UserAction, ) -from core.workflow.repositories.human_input_form_repository import FormNotFoundError from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, @@ -309,7 +308,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert repo.get_form("run-1", "node-1") is None - def test_get_form_submission_returns_none_when_pending(self): + def test_get_form_returns_unsubmitted_state(self): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -319,12 +318,17 @@ class TestHumanInputFormRepositoryImplPublicMethods: rendered_content="

hello

", expiration_time=naive_utc_now(), ) - session = _FakeSession(forms={form.id: form}) + session = _FakeSession(scalars_results=[form, []]) repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - assert repo.get_form_submission(form.id) is None + entity = repo.get_form(form.workflow_run_id, form.node_id) - def test_get_form_submission_returns_submission_when_completed(self): + assert entity is not None + assert entity.submitted is False + assert entity.selected_action_id is None + assert entity.submitted_data is None + + def test_get_form_returns_submission_when_completed(self): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -337,21 +341,15 @@ class TestHumanInputFormRepositoryImplPublicMethods: submitted_data='{"field": "value"}', submitted_at=naive_utc_now(), ) - session = _FakeSession(forms={form.id: form}) + session = _FakeSession(scalars_results=[form, []]) repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - submission = repo.get_form_submission(form.id) + entity = repo.get_form(form.workflow_run_id, form.node_id) - assert submission is not None - assert submission.selected_action_id == "approve" - assert submission.form_data() == {"field": "value"} - - def test_get_form_submission_raises_when_form_missing(self): - session = _FakeSession(forms={}) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") - - with pytest.raises(FormNotFoundError): - repo.get_form_submission("form-unknown") + assert entity is not None + assert entity.submitted is True + assert entity.selected_action_id == "approve" + assert entity.submitted_data == {"field": "value"} class TestHumanInputFormSubmissionRepository: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py index 4155e5b9c0..082cc4ca12 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py @@ -8,7 +8,6 @@ from typing import Any from core.workflow.repositories.human_input_form_repository import ( FormCreateParams, - FormSubmission, HumanInputFormEntity, HumanInputFormRecipientEntity, HumanInputFormRepository, @@ -36,6 +35,9 @@ class _InMemoryFormEntity(HumanInputFormEntity): form_id: str rendered: str token: str | None = None + action_id: str | None = None + data: Mapping[str, Any] | None = None + is_submitted: bool = False @property def id(self) -> str: @@ -53,18 +55,17 @@ class _InMemoryFormEntity(HumanInputFormEntity): def rendered_content(self) -> str: return self.rendered - -class _InMemoryFormSubmission(FormSubmission): - def __init__(self, selected_action_id: str, form_data: Mapping[str, Any]) -> None: - self._selected_action_id = selected_action_id - self._form_data = form_data + @property + def selected_action_id(self) -> str | None: + return self.action_id @property - def selected_action_id(self) -> str: - return self._selected_action_id + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data - def form_data(self) -> Mapping[str, Any]: - return self._form_data + @property + def submitted(self) -> bool: + return self.is_submitted class InMemoryHumanInputFormRepository(HumanInputFormRepository): @@ -75,7 +76,6 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): self.created_params: list[FormCreateParams] = [] self.created_forms: list[_InMemoryFormEntity] = [] self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} - self._submissions: dict[str, FormSubmission] = {} def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: self.created_params.append(params) @@ -89,9 +89,6 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_key.get((workflow_execution_id, node_id)) - def get_form_submission(self, form_id: str) -> FormSubmission | None: - return self._submissions.get(form_id) - # Convenience helpers for tests ------------------------------------- def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: @@ -99,8 +96,15 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): if not self.created_forms: raise AssertionError("no form has been created to attach submission data") - target_form_id = self.created_forms[-1].id - self._submissions[target_form_id] = _InMemoryFormSubmission(action_id, form_data or {}) + entity = self.created_forms[-1] + entity.action_id = action_id + entity.data = form_data or {} + entity.is_submitted = True def clear_submission(self) -> None: - self._submissions.clear() + if not self.created_forms: + return + for form in self.created_forms: + form.action_id = None + form.data = None + form.is_submitted = False diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 8058432d8c..a9b00386bb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -29,11 +29,7 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( - FormSubmission, - HumanInputFormEntity, - HumanInputFormRepository, -) +from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable @@ -242,13 +238,13 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: runner = TableTestRunner() mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form_submission.return_value = None mock_create_repo.get_form.return_value = None mock_form_entity = MagicMock(spec=HumanInputFormEntity) mock_form_entity.id = "test_form_id" mock_form_entity.web_app_token = "test_web_app_token" mock_form_entity.recipients = [] mock_form_entity.rendered_content = "rendered" + mock_form_entity.submitted = False mock_create_repo.create_form.return_value = mock_form_entity def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: @@ -297,11 +293,15 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: ) mock_get_repo = MagicMock(spec=HumanInputFormRepository) - mock_form_submission = MagicMock(spec=FormSubmission) - mock_form_submission.selected_action_id = scenario["handle"] - mock_form_submission.form_data.return_value = {} - mock_get_repo.get_form_submission.return_value = mock_form_submission - mock_get_repo.get_form.return_value = mock_form_entity + submitted_form = MagicMock(spec=HumanInputFormEntity) + submitted_form.id = mock_form_entity.id + submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.recipients = [] + submitted_form.rendered_content = mock_form_entity.rendered_content + submitted_form.submitted = True + submitted_form.selected_action_id = scenario["handle"] + submitted_form.submitted_data = {} + mock_get_repo.get_form.return_value = submitted_form def resume_graph_factory( initial_result=initial_result, mock_get_repo=mock_get_repo diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index f3fa1c0ffe..d52f43dacf 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -28,11 +28,7 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( - FormSubmission, - HumanInputFormEntity, - HumanInputFormRepository, -) +from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable @@ -187,13 +183,13 @@ def test_human_input_llm_streaming_order_across_pause() -> None: ] mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form_submission.return_value = None mock_create_repo.get_form.return_value = None mock_form_entity = MagicMock(spec=HumanInputFormEntity) mock_form_entity.id = "test_form_id" mock_form_entity.web_app_token = "test_web_app_token" mock_form_entity.recipients = [] mock_form_entity.rendered_content = "rendered" + mock_form_entity.submitted = False mock_create_repo.create_form.return_value = mock_form_entity def graph_factory() -> tuple[Graph, GraphRuntimeState]: @@ -255,11 +251,15 @@ def test_human_input_llm_streaming_order_across_pause() -> None: ] mock_get_repo = MagicMock(spec=HumanInputFormRepository) - mock_form_submission = MagicMock(spec=FormSubmission) - mock_form_submission.selected_action_id = "accept" - mock_form_submission.form_data.return_value = {} - mock_get_repo.get_form_submission.return_value = mock_form_submission - mock_get_repo.get_form.return_value = mock_form_entity + submitted_form = MagicMock(spec=HumanInputFormEntity) + submitted_form.id = mock_form_entity.id + submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.recipients = [] + submitted_form.rendered_content = mock_form_entity.rendered_content + submitted_form.submitted = True + submitted_form.selected_action_id = "accept" + submitted_form.submitted_data = {} + mock_get_repo.get_form.return_value = submitted_form def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: # restruct the graph runtime state diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 6ec9e9420c..bd9da85ac5 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -20,7 +20,6 @@ from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode from core.workflow.repositories.human_input_form_repository import ( - FormSubmission, HumanInputFormEntity, HumanInputFormRepository, ) @@ -43,28 +42,27 @@ def _build_runtime_state() -> GraphRuntimeState: def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: - submission = MagicMock(spec=FormSubmission) - submission.selected_action_id = action_id - submission.form_data.return_value = {} repo = MagicMock(spec=HumanInputFormRepository) - repo.get_form_submission.return_value = submission form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" form_entity.web_app_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" + form_entity.submitted = True + form_entity.selected_action_id = action_id + form_entity.submitted_data = {} repo.get_form.return_value = form_entity return repo def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) - repo.get_form_submission.return_value = None form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" form_entity.web_app_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" + form_entity.submitted = False repo.create_form.return_value = form_entity repo.get_form.return_value = None return repo @@ -184,8 +182,8 @@ def test_engine_resume_restores_state_and_completion(): assert combined_success_nodes == baseline_success_nodes assert baseline_state.outputs == resumed_state.outputs - assert _segment_value(baseline_state.variable_pool, ("human", "action_id")) == _segment_value( - resumed_state.variable_pool, ("human", "action_id") + assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value( + resumed_state.variable_pool, ("human", "__action_id") ) assert baseline_state.graph_execution.completed assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index ba24c6a679..74e4bb7e63 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -320,12 +320,12 @@ class TestHumanInputNodeVariableResolution: mock_repo = MagicMock(spec=HumanInputFormRepository) mock_repo.get_form.return_value = None - mock_repo.get_form_submission.return_value = None mock_repo.create_form.return_value = SimpleNamespace( id="form-1", rendered_content="Provide your name", web_app_token="token", recipients=[], + submitted=False, ) node = HumanInputNode( diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 0c34676252..3fdf266bbe 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -168,11 +168,11 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): ): """Test workflow pause creation when workflow not in RUNNING status.""" # Arrange - sample_workflow_run.status = WorkflowExecutionStatus.PAUSED + sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED mock_session.get.return_value = sample_workflow_run # Act & Assert - with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"): + with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"): repository.create_workflow_pause( workflow_run_id="workflow-run-123", state_owner_user_id="user-123",