WIP: unify Form And FormSubmission

This commit is contained in:
QuantumGhost
2025-12-08 02:52:34 +08:00
parent 1f64281ce5
commit e6fbf3a198
14 changed files with 113 additions and 206 deletions

View File

@ -20,7 +20,6 @@ from core.workflow.nodes.human_input.entities import (
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
FormNotFoundError,
FormSubmission,
HumanInputFormEntity,
HumanInputFormRecipientEntity,
)
@ -73,6 +72,9 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
(recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.WEBAPP),
None,
)
self._submitted_data: Mapping[str, Any] | None = (
json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None
)
@property
def id(self) -> str:
@ -92,23 +94,17 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
def rendered_content(self) -> str:
return self._form_model.rendered_content
class _FormSubmissionImpl(FormSubmission):
def __init__(self, form_model: HumanInputForm):
self._form_model = form_model
@property
def selected_action_id(self) -> str | None:
return self._form_model.selected_action_id
@property
def selected_action_id(self) -> str:
selected_action_id = self._form_model.selected_action_id
if selected_action_id is None:
raise AssertionError(f"selected_action_id should not be None, form_id={self._form_model.id}")
return selected_action_id
def submitted_data(self) -> Mapping[str, Any] | None:
return self._submitted_data
def form_data(self) -> Mapping[str, Any]:
submitted_data = self._form_model.submitted_data
if submitted_data is None:
raise AssertionError(f"submitted_data should not be None, form_id={self._form_model.id}")
return json.loads(submitted_data)
@property
def submitted(self) -> bool:
return self._form_model.submitted_at is not None
@dataclasses.dataclass(frozen=True)
@ -348,17 +344,6 @@ class HumanInputFormRepositoryImpl:
recipient_models = session.scalars(recipient_query).all()
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
def get_form_submission(self, form_id: str) -> FormSubmission | None:
with self._session_factory(expire_on_commit=False) as session:
form_model: HumanInputForm | None = session.get(HumanInputForm, form_id)
if form_model is None or form_model.tenant_id != self._tenant_id:
raise FormNotFoundError(f"form not found, form_id={form_id}")
if form_model.submitted_at is None:
return None
return _FormSubmissionImpl(form_model=form_model)
class HumanInputFormSubmissionRepository:
"""Repository for fetching and submitting human input forms."""

View File

@ -207,15 +207,18 @@ class HumanInputNode(Node[HumanInputNodeData]):
if form is None:
return self._create_form()
submission_result = repo.get_form_submission(form.id)
if submission_result:
outputs: dict[str, Any] = dict(submission_result.form_data())
outputs["action_id"] = submission_result.selected_action_id
outputs["__action_id"] = submission_result.selected_action_id
if form.submitted:
selected_action_id = form.selected_action_id
if selected_action_id is None:
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
submitted_data = form.submitted_data or {}
outputs: dict[str, Any] = dict(submitted_data)
outputs["__action_id"] = selected_action_id
outputs["__rendered_content"] = form.rendered_content
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
edge_source_handle=submission_result.selected_action_id,
edge_source_handle=selected_action_id,
)
return self._pause_with_form(form)

View File

@ -64,6 +64,24 @@ class HumanInputFormEntity(abc.ABC):
"""Rendered markdown content associated with the form."""
...
@property
@abc.abstractmethod
def selected_action_id(self) -> str | None:
"""Identifier of the selected user action if the form has been submitted."""
...
@property
@abc.abstractmethod
def submitted_data(self) -> Mapping[str, Any] | None:
"""Submitted form data if available."""
...
@property
@abc.abstractmethod
def submitted(self) -> bool:
"""Whether the form has been submitted."""
...
class HumanInputFormRecipientEntity(abc.ABC):
@property
@ -79,19 +97,6 @@ class HumanInputFormRecipientEntity(abc.ABC):
...
class FormSubmission(abc.ABC):
@property
@abc.abstractmethod
def selected_action_id(self) -> str:
"""The identifier of action user has selected, correspond to `UserAction.id`."""
pass
@abc.abstractmethod
def form_data(self) -> Mapping[str, Any]:
"""The data submitted for this form"""
pass
class HumanInputFormRepository(Protocol):
"""
Repository interface for HumanInputForm.
@ -115,12 +120,3 @@ class HumanInputFormRepository(Protocol):
Create a human input form from form definition.
"""
...
def get_form_submission(self, form_id: str) -> FormSubmission | None:
"""Retrieve the submission for a specific human input node.
Returns `FormSubmission` if the form has been submitted, or `None` if not.
Raises `FormNotFoundError` if correspond form record is not found.
"""
...

View File

@ -7,8 +7,7 @@ from core.entities.execution_extra_content import ExecutionExtraContentDomainMod
class ExecutionExtraContentRepository(Protocol):
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
...
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: ...
__all__ = ["ExecutionExtraContentRepository"]

View File

@ -1,75 +0,0 @@
import uuid
from decimal import Decimal
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from libs.uuid_utils import uuidv7
from models.enums import CreatorUserRole
from models.model import AppMode, Conversation, Message
def _create_conversation(session) -> Conversation:
conversation = Conversation(
app_id=str(uuid.uuid4()),
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source=CreatorUserRole.ACCOUNT,
from_account_id=str(uuid.uuid4()),
)
conversation.inputs = {}
session.add(conversation)
session.commit()
return conversation
def _create_message(session, conversation: Conversation) -> Message:
message = Message(
app_id=conversation.app_id,
conversation_id=conversation.id,
query="Need manual approval",
message={"type": "text", "content": "Need manual approval"},
answer="Acknowledged",
message_tokens=10,
answer_tokens=20,
message_unit_price=Decimal("0.001"),
answer_unit_price=Decimal("0.001"),
message_price_unit=Decimal("0.001"),
answer_price_unit=Decimal("0.001"),
currency="USD",
status="normal",
from_source=CreatorUserRole.ACCOUNT,
)
message.inputs = {}
session.add(message)
session.commit()
return message
def test_message_auto_loads_multiple_extra_variants(db_session_with_containers):
conversation = _create_conversation(db_session_with_containers)
message = _create_message(db_session_with_containers, conversation)
human_input_result_content_id = str(uuidv7())
human_input_result_content = HumanInputResultRelation(
id=human_input_result_content_id,
message_id=message.id,
form_id=None,
)
db_session_with_containers.add(human_input_result_content)
db_session_with_containers.commit()
# polymorphic_extra = with_polymorphic(
# MessageExtraContent,
# [HumanInputResultRelation],
# )
stmt = select(Message).options(selectinload(Message.extra_content)).where(Message.id == message.id)
loaded_message = db_session_with_containers.execute(stmt).scalar_one()
assert len(loaded_message.extra_content) == 1
assert human_input_result_content_id in {extra.id for extra in loaded_message.extra_content}
loaded_types = {type(extra) for extra in loaded_message.extra_content}
assert HumanInputResultRelation in loaded_types

View File

@ -38,9 +38,13 @@ class TestAppGenerateService:
# Setup default mock returns for workflow service
mock_workflow_service_instance = mock_workflow_service.return_value
mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow)
mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow)
mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow)
mock_published_workflow = MagicMock(spec=Workflow)
mock_published_workflow.id = str(uuid.uuid4())
mock_workflow_service_instance.get_published_workflow.return_value = mock_published_workflow
mock_draft_workflow = MagicMock(spec=Workflow)
mock_draft_workflow.id = str(uuid.uuid4())
mock_workflow_service_instance.get_draft_workflow.return_value = mock_draft_workflow
mock_workflow_service_instance.get_published_workflow_by_id.return_value = mock_published_workflow
# Setup default mock returns for rate limiting
mock_rate_limit_instance = mock_rate_limit.return_value

View File

@ -94,11 +94,6 @@ class PrunePausesTestCase:
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
"""Create test cases for pause workflow failure scenarios."""
return [
PauseWorkflowFailureCase(
name="pause_already_paused_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
description="Should fail to pause an already paused workflow",
),
PauseWorkflowFailureCase(
name="pause_completed_workflow",
initial_status=WorkflowExecutionStatus.SUCCEEDED,

View File

@ -22,7 +22,6 @@ from core.workflow.nodes.human_input.entities import (
TimeoutUnit,
UserAction,
)
from core.workflow.repositories.human_input_form_repository import FormNotFoundError
from libs.datetime_utils import naive_utc_now
from models.human_input import (
EmailExternalRecipientPayload,
@ -309,7 +308,7 @@ class TestHumanInputFormRepositoryImplPublicMethods:
assert repo.get_form("run-1", "node-1") is None
def test_get_form_submission_returns_none_when_pending(self):
def test_get_form_returns_unsubmitted_state(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
@ -319,12 +318,17 @@ class TestHumanInputFormRepositoryImplPublicMethods:
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
session = _FakeSession(forms={form.id: form})
session = _FakeSession(scalars_results=[form, []])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
assert repo.get_form_submission(form.id) is None
entity = repo.get_form(form.workflow_run_id, form.node_id)
def test_get_form_submission_returns_submission_when_completed(self):
assert entity is not None
assert entity.submitted is False
assert entity.selected_action_id is None
assert entity.submitted_data is None
def test_get_form_returns_submission_when_completed(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
@ -337,21 +341,15 @@ class TestHumanInputFormRepositoryImplPublicMethods:
submitted_data='{"field": "value"}',
submitted_at=naive_utc_now(),
)
session = _FakeSession(forms={form.id: form})
session = _FakeSession(scalars_results=[form, []])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
submission = repo.get_form_submission(form.id)
entity = repo.get_form(form.workflow_run_id, form.node_id)
assert submission is not None
assert submission.selected_action_id == "approve"
assert submission.form_data() == {"field": "value"}
def test_get_form_submission_raises_when_form_missing(self):
session = _FakeSession(forms={})
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
with pytest.raises(FormNotFoundError):
repo.get_form_submission("form-unknown")
assert entity is not None
assert entity.submitted is True
assert entity.selected_action_id == "approve"
assert entity.submitted_data == {"field": "value"}
class TestHumanInputFormSubmissionRepository:

View File

@ -8,7 +8,6 @@ from typing import Any
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
FormSubmission,
HumanInputFormEntity,
HumanInputFormRecipientEntity,
HumanInputFormRepository,
@ -36,6 +35,9 @@ class _InMemoryFormEntity(HumanInputFormEntity):
form_id: str
rendered: str
token: str | None = None
action_id: str | None = None
data: Mapping[str, Any] | None = None
is_submitted: bool = False
@property
def id(self) -> str:
@ -53,18 +55,17 @@ class _InMemoryFormEntity(HumanInputFormEntity):
def rendered_content(self) -> str:
return self.rendered
class _InMemoryFormSubmission(FormSubmission):
def __init__(self, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
self._selected_action_id = selected_action_id
self._form_data = form_data
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def selected_action_id(self) -> str:
return self._selected_action_id
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
def form_data(self) -> Mapping[str, Any]:
return self._form_data
@property
def submitted(self) -> bool:
return self.is_submitted
class InMemoryHumanInputFormRepository(HumanInputFormRepository):
@ -75,7 +76,6 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
self.created_params: list[FormCreateParams] = []
self.created_forms: list[_InMemoryFormEntity] = []
self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {}
self._submissions: dict[str, FormSubmission] = {}
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
self.created_params.append(params)
@ -89,9 +89,6 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
return self._forms_by_key.get((workflow_execution_id, node_id))
def get_form_submission(self, form_id: str) -> FormSubmission | None:
return self._submissions.get(form_id)
# Convenience helpers for tests -------------------------------------
def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None:
@ -99,8 +96,15 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
if not self.created_forms:
raise AssertionError("no form has been created to attach submission data")
target_form_id = self.created_forms[-1].id
self._submissions[target_form_id] = _InMemoryFormSubmission(action_id, form_data or {})
entity = self.created_forms[-1]
entity.action_id = action_id
entity.data = form_data or {}
entity.is_submitted = True
def clear_submission(self) -> None:
self._submissions.clear()
if not self.created_forms:
return
for form in self.created_forms:
form.action_id = None
form.data = None
form.is_submitted = False

View File

@ -29,11 +29,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormSubmission,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@ -242,13 +238,13 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
runner = TableTestRunner()
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
mock_create_repo.get_form_submission.return_value = None
mock_create_repo.get_form.return_value = None
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
mock_form_entity.id = "test_form_id"
mock_form_entity.web_app_token = "test_web_app_token"
mock_form_entity.recipients = []
mock_form_entity.rendered_content = "rendered"
mock_form_entity.submitted = False
mock_create_repo.create_form.return_value = mock_form_entity
def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]:
@ -297,11 +293,15 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
)
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
mock_form_submission = MagicMock(spec=FormSubmission)
mock_form_submission.selected_action_id = scenario["handle"]
mock_form_submission.form_data.return_value = {}
mock_get_repo.get_form_submission.return_value = mock_form_submission
mock_get_repo.get_form.return_value = mock_form_entity
submitted_form = MagicMock(spec=HumanInputFormEntity)
submitted_form.id = mock_form_entity.id
submitted_form.web_app_token = mock_form_entity.web_app_token
submitted_form.recipients = []
submitted_form.rendered_content = mock_form_entity.rendered_content
submitted_form.submitted = True
submitted_form.selected_action_id = scenario["handle"]
submitted_form.submitted_data = {}
mock_get_repo.get_form.return_value = submitted_form
def resume_graph_factory(
initial_result=initial_result, mock_get_repo=mock_get_repo

View File

@ -28,11 +28,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormSubmission,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@ -187,13 +183,13 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
]
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
mock_create_repo.get_form_submission.return_value = None
mock_create_repo.get_form.return_value = None
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
mock_form_entity.id = "test_form_id"
mock_form_entity.web_app_token = "test_web_app_token"
mock_form_entity.recipients = []
mock_form_entity.rendered_content = "rendered"
mock_form_entity.submitted = False
mock_create_repo.create_form.return_value = mock_form_entity
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
@ -255,11 +251,15 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
]
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
mock_form_submission = MagicMock(spec=FormSubmission)
mock_form_submission.selected_action_id = "accept"
mock_form_submission.form_data.return_value = {}
mock_get_repo.get_form_submission.return_value = mock_form_submission
mock_get_repo.get_form.return_value = mock_form_entity
submitted_form = MagicMock(spec=HumanInputFormEntity)
submitted_form.id = mock_form_entity.id
submitted_form.web_app_token = mock_form_entity.web_app_token
submitted_form.recipients = []
submitted_form.rendered_content = mock_form_entity.rendered_content
submitted_form.submitted = True
submitted_form.selected_action_id = "accept"
submitted_form.submitted_data = {}
mock_get_repo.get_form.return_value = submitted_form
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
# restruct the graph runtime state

View File

@ -20,7 +20,6 @@ from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormSubmission,
HumanInputFormEntity,
HumanInputFormRepository,
)
@ -43,28 +42,27 @@ def _build_runtime_state() -> GraphRuntimeState:
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
submission = MagicMock(spec=FormSubmission)
submission.selected_action_id = action_id
submission.form_data.return_value = {}
repo = MagicMock(spec=HumanInputFormRepository)
repo.get_form_submission.return_value = submission
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = True
form_entity.selected_action_id = action_id
form_entity.submitted_data = {}
repo.get_form.return_value = form_entity
return repo
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
repo = MagicMock(spec=HumanInputFormRepository)
repo.get_form_submission.return_value = None
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = False
repo.create_form.return_value = form_entity
repo.get_form.return_value = None
return repo
@ -184,8 +182,8 @@ def test_engine_resume_restores_state_and_completion():
assert combined_success_nodes == baseline_success_nodes
assert baseline_state.outputs == resumed_state.outputs
assert _segment_value(baseline_state.variable_pool, ("human", "action_id")) == _segment_value(
resumed_state.variable_pool, ("human", "action_id")
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(
resumed_state.variable_pool, ("human", "__action_id")
)
assert baseline_state.graph_execution.completed
assert resumed_state.graph_execution.completed

View File

@ -320,12 +320,12 @@ class TestHumanInputNodeVariableResolution:
mock_repo = MagicMock(spec=HumanInputFormRepository)
mock_repo.get_form.return_value = None
mock_repo.get_form_submission.return_value = None
mock_repo.create_form.return_value = SimpleNamespace(
id="form-1",
rendered_content="Provide your name",
web_app_token="token",
recipients=[],
submitted=False,
)
node = HumanInputNode(

View File

@ -168,11 +168,11 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
):
"""Test workflow pause creation when workflow not in RUNNING status."""
# Arrange
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
mock_session.get.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",