mirror of
https://github.com/langgenius/dify.git
synced 2026-03-29 01:49:57 +08:00
WIP: unify Form And FormSubmission
This commit is contained in:
@ -20,7 +20,6 @@ from core.workflow.nodes.human_input.entities import (
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
FormNotFoundError,
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
)
|
||||
@ -73,6 +72,9 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
(recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.WEBAPP),
|
||||
None,
|
||||
)
|
||||
self._submitted_data: Mapping[str, Any] | None = (
|
||||
json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
@ -92,23 +94,17 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
def rendered_content(self) -> str:
|
||||
return self._form_model.rendered_content
|
||||
|
||||
|
||||
class _FormSubmissionImpl(FormSubmission):
|
||||
def __init__(self, form_model: HumanInputForm):
|
||||
self._form_model = form_model
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self._form_model.selected_action_id
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str:
|
||||
selected_action_id = self._form_model.selected_action_id
|
||||
if selected_action_id is None:
|
||||
raise AssertionError(f"selected_action_id should not be None, form_id={self._form_model.id}")
|
||||
return selected_action_id
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self._submitted_data
|
||||
|
||||
def form_data(self) -> Mapping[str, Any]:
|
||||
submitted_data = self._form_model.submitted_data
|
||||
if submitted_data is None:
|
||||
raise AssertionError(f"submitted_data should not be None, form_id={self._form_model.id}")
|
||||
return json.loads(submitted_data)
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self._form_model.submitted_at is not None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -348,17 +344,6 @@ class HumanInputFormRepositoryImpl:
|
||||
recipient_models = session.scalars(recipient_query).all()
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
def get_form_submission(self, form_id: str) -> FormSubmission | None:
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_model: HumanInputForm | None = session.get(HumanInputForm, form_id)
|
||||
if form_model is None or form_model.tenant_id != self._tenant_id:
|
||||
raise FormNotFoundError(f"form not found, form_id={form_id}")
|
||||
|
||||
if form_model.submitted_at is None:
|
||||
return None
|
||||
|
||||
return _FormSubmissionImpl(form_model=form_model)
|
||||
|
||||
|
||||
class HumanInputFormSubmissionRepository:
|
||||
"""Repository for fetching and submitting human input forms."""
|
||||
|
||||
@ -207,15 +207,18 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
if form is None:
|
||||
return self._create_form()
|
||||
|
||||
submission_result = repo.get_form_submission(form.id)
|
||||
if submission_result:
|
||||
outputs: dict[str, Any] = dict(submission_result.form_data())
|
||||
outputs["action_id"] = submission_result.selected_action_id
|
||||
outputs["__action_id"] = submission_result.selected_action_id
|
||||
if form.submitted:
|
||||
selected_action_id = form.selected_action_id
|
||||
if selected_action_id is None:
|
||||
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
|
||||
submitted_data = form.submitted_data or {}
|
||||
outputs: dict[str, Any] = dict(submitted_data)
|
||||
outputs["__action_id"] = selected_action_id
|
||||
outputs["__rendered_content"] = form.rendered_content
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
edge_source_handle=submission_result.selected_action_id,
|
||||
edge_source_handle=selected_action_id,
|
||||
)
|
||||
|
||||
return self._pause_with_form(form)
|
||||
|
||||
@ -64,6 +64,24 @@ class HumanInputFormEntity(abc.ABC):
|
||||
"""Rendered markdown content associated with the form."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def selected_action_id(self) -> str | None:
|
||||
"""Identifier of the selected user action if the form has been submitted."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
"""Submitted form data if available."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def submitted(self) -> bool:
|
||||
"""Whether the form has been submitted."""
|
||||
...
|
||||
|
||||
|
||||
class HumanInputFormRecipientEntity(abc.ABC):
|
||||
@property
|
||||
@ -79,19 +97,6 @@ class HumanInputFormRecipientEntity(abc.ABC):
|
||||
...
|
||||
|
||||
|
||||
class FormSubmission(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def selected_action_id(self) -> str:
|
||||
"""The identifier of action user has selected, correspond to `UserAction.id`."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def form_data(self) -> Mapping[str, Any]:
|
||||
"""The data submitted for this form"""
|
||||
pass
|
||||
|
||||
|
||||
class HumanInputFormRepository(Protocol):
|
||||
"""
|
||||
Repository interface for HumanInputForm.
|
||||
@ -115,12 +120,3 @@ class HumanInputFormRepository(Protocol):
|
||||
Create a human input form from form definition.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_form_submission(self, form_id: str) -> FormSubmission | None:
|
||||
"""Retrieve the submission for a specific human input node.
|
||||
|
||||
Returns `FormSubmission` if the form has been submitted, or `None` if not.
|
||||
|
||||
Raises `FormNotFoundError` if correspond form record is not found.
|
||||
"""
|
||||
...
|
||||
|
||||
@ -7,8 +7,7 @@ from core.entities.execution_extra_content import ExecutionExtraContentDomainMod
|
||||
|
||||
|
||||
class ExecutionExtraContentRepository(Protocol):
|
||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
|
||||
...
|
||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: ...
|
||||
|
||||
|
||||
__all__ = ["ExecutionExtraContentRepository"]
|
||||
|
||||
@ -1,75 +0,0 @@
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
def _create_conversation(session) -> Conversation:
|
||||
conversation = Conversation(
|
||||
app_id=str(uuid.uuid4()),
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
)
|
||||
conversation.inputs = {}
|
||||
session.add(conversation)
|
||||
session.commit()
|
||||
return conversation
|
||||
|
||||
|
||||
def _create_message(session, conversation: Conversation) -> Message:
|
||||
message = Message(
|
||||
app_id=conversation.app_id,
|
||||
conversation_id=conversation.id,
|
||||
query="Need manual approval",
|
||||
message={"type": "text", "content": "Need manual approval"},
|
||||
answer="Acknowledged",
|
||||
message_tokens=10,
|
||||
answer_tokens=20,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_unit_price=Decimal("0.001"),
|
||||
message_price_unit=Decimal("0.001"),
|
||||
answer_price_unit=Decimal("0.001"),
|
||||
currency="USD",
|
||||
status="normal",
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
)
|
||||
message.inputs = {}
|
||||
session.add(message)
|
||||
session.commit()
|
||||
return message
|
||||
|
||||
|
||||
def test_message_auto_loads_multiple_extra_variants(db_session_with_containers):
|
||||
conversation = _create_conversation(db_session_with_containers)
|
||||
message = _create_message(db_session_with_containers, conversation)
|
||||
|
||||
human_input_result_content_id = str(uuidv7())
|
||||
human_input_result_content = HumanInputResultRelation(
|
||||
id=human_input_result_content_id,
|
||||
message_id=message.id,
|
||||
form_id=None,
|
||||
)
|
||||
db_session_with_containers.add(human_input_result_content)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# polymorphic_extra = with_polymorphic(
|
||||
# MessageExtraContent,
|
||||
# [HumanInputResultRelation],
|
||||
# )
|
||||
|
||||
stmt = select(Message).options(selectinload(Message.extra_content)).where(Message.id == message.id)
|
||||
loaded_message = db_session_with_containers.execute(stmt).scalar_one()
|
||||
|
||||
assert len(loaded_message.extra_content) == 1
|
||||
assert human_input_result_content_id in {extra.id for extra in loaded_message.extra_content}
|
||||
|
||||
loaded_types = {type(extra) for extra in loaded_message.extra_content}
|
||||
assert HumanInputResultRelation in loaded_types
|
||||
@ -38,9 +38,13 @@ class TestAppGenerateService:
|
||||
|
||||
# Setup default mock returns for workflow service
|
||||
mock_workflow_service_instance = mock_workflow_service.return_value
|
||||
mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow)
|
||||
mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow)
|
||||
mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow)
|
||||
mock_published_workflow = MagicMock(spec=Workflow)
|
||||
mock_published_workflow.id = str(uuid.uuid4())
|
||||
mock_workflow_service_instance.get_published_workflow.return_value = mock_published_workflow
|
||||
mock_draft_workflow = MagicMock(spec=Workflow)
|
||||
mock_draft_workflow.id = str(uuid.uuid4())
|
||||
mock_workflow_service_instance.get_draft_workflow.return_value = mock_draft_workflow
|
||||
mock_workflow_service_instance.get_published_workflow_by_id.return_value = mock_published_workflow
|
||||
|
||||
# Setup default mock returns for rate limiting
|
||||
mock_rate_limit_instance = mock_rate_limit.return_value
|
||||
|
||||
@ -94,11 +94,6 @@ class PrunePausesTestCase:
|
||||
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
|
||||
"""Create test cases for pause workflow failure scenarios."""
|
||||
return [
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_already_paused_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
description="Should fail to pause an already paused workflow",
|
||||
),
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_completed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
|
||||
@ -22,7 +22,6 @@ from core.workflow.nodes.human_input.entities import (
|
||||
TimeoutUnit,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import FormNotFoundError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
@ -309,7 +308,7 @@ class TestHumanInputFormRepositoryImplPublicMethods:
|
||||
|
||||
assert repo.get_form("run-1", "node-1") is None
|
||||
|
||||
def test_get_form_submission_returns_none_when_pending(self):
|
||||
def test_get_form_returns_unsubmitted_state(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
@ -319,12 +318,17 @@ class TestHumanInputFormRepositoryImplPublicMethods:
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
session = _FakeSession(scalars_results=[form, []])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
assert repo.get_form_submission(form.id) is None
|
||||
entity = repo.get_form(form.workflow_run_id, form.node_id)
|
||||
|
||||
def test_get_form_submission_returns_submission_when_completed(self):
|
||||
assert entity is not None
|
||||
assert entity.submitted is False
|
||||
assert entity.selected_action_id is None
|
||||
assert entity.submitted_data is None
|
||||
|
||||
def test_get_form_returns_submission_when_completed(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
@ -337,21 +341,15 @@ class TestHumanInputFormRepositoryImplPublicMethods:
|
||||
submitted_data='{"field": "value"}',
|
||||
submitted_at=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
session = _FakeSession(scalars_results=[form, []])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
submission = repo.get_form_submission(form.id)
|
||||
entity = repo.get_form(form.workflow_run_id, form.node_id)
|
||||
|
||||
assert submission is not None
|
||||
assert submission.selected_action_id == "approve"
|
||||
assert submission.form_data() == {"field": "value"}
|
||||
|
||||
def test_get_form_submission_raises_when_form_missing(self):
|
||||
session = _FakeSession(forms={})
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
with pytest.raises(FormNotFoundError):
|
||||
repo.get_form_submission("form-unknown")
|
||||
assert entity is not None
|
||||
assert entity.submitted is True
|
||||
assert entity.selected_action_id == "approve"
|
||||
assert entity.submitted_data == {"field": "value"}
|
||||
|
||||
|
||||
class TestHumanInputFormSubmissionRepository:
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Any
|
||||
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
HumanInputFormRepository,
|
||||
@ -36,6 +35,9 @@ class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
token: str | None = None
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
is_submitted: bool = False
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
@ -53,18 +55,17 @@ class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
|
||||
class _InMemoryFormSubmission(FormSubmission):
|
||||
def __init__(self, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
|
||||
self._selected_action_id = selected_action_id
|
||||
self._form_data = form_data
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self.action_id
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str:
|
||||
return self._selected_action_id
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self.data
|
||||
|
||||
def form_data(self) -> Mapping[str, Any]:
|
||||
return self._form_data
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
|
||||
class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
@ -75,7 +76,6 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
self.created_params: list[FormCreateParams] = []
|
||||
self.created_forms: list[_InMemoryFormEntity] = []
|
||||
self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {}
|
||||
self._submissions: dict[str, FormSubmission] = {}
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
self.created_params.append(params)
|
||||
@ -89,9 +89,6 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
return self._forms_by_key.get((workflow_execution_id, node_id))
|
||||
|
||||
def get_form_submission(self, form_id: str) -> FormSubmission | None:
|
||||
return self._submissions.get(form_id)
|
||||
|
||||
# Convenience helpers for tests -------------------------------------
|
||||
|
||||
def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None:
|
||||
@ -99,8 +96,15 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
|
||||
if not self.created_forms:
|
||||
raise AssertionError("no form has been created to attach submission data")
|
||||
target_form_id = self.created_forms[-1].id
|
||||
self._submissions[target_form_id] = _InMemoryFormSubmission(action_id, form_data or {})
|
||||
entity = self.created_forms[-1]
|
||||
entity.action_id = action_id
|
||||
entity.data = form_data or {}
|
||||
entity.is_submitted = True
|
||||
|
||||
def clear_submission(self) -> None:
|
||||
self._submissions.clear()
|
||||
if not self.created_forms:
|
||||
return
|
||||
for form in self.created_forms:
|
||||
form.action_id = None
|
||||
form.data = None
|
||||
form.is_submitted = False
|
||||
|
||||
@ -29,11 +29,7 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
@ -242,13 +238,13 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
runner = TableTestRunner()
|
||||
|
||||
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_create_repo.get_form_submission.return_value = None
|
||||
mock_create_repo.get_form.return_value = None
|
||||
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
mock_form_entity.id = "test_form_id"
|
||||
mock_form_entity.web_app_token = "test_web_app_token"
|
||||
mock_form_entity.recipients = []
|
||||
mock_form_entity.rendered_content = "rendered"
|
||||
mock_form_entity.submitted = False
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]:
|
||||
@ -297,11 +293,15 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
)
|
||||
|
||||
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_form_submission = MagicMock(spec=FormSubmission)
|
||||
mock_form_submission.selected_action_id = scenario["handle"]
|
||||
mock_form_submission.form_data.return_value = {}
|
||||
mock_get_repo.get_form_submission.return_value = mock_form_submission
|
||||
mock_get_repo.get_form.return_value = mock_form_entity
|
||||
submitted_form = MagicMock(spec=HumanInputFormEntity)
|
||||
submitted_form.id = mock_form_entity.id
|
||||
submitted_form.web_app_token = mock_form_entity.web_app_token
|
||||
submitted_form.recipients = []
|
||||
submitted_form.rendered_content = mock_form_entity.rendered_content
|
||||
submitted_form.submitted = True
|
||||
submitted_form.selected_action_id = scenario["handle"]
|
||||
submitted_form.submitted_data = {}
|
||||
mock_get_repo.get_form.return_value = submitted_form
|
||||
|
||||
def resume_graph_factory(
|
||||
initial_result=initial_result, mock_get_repo=mock_get_repo
|
||||
|
||||
@ -28,11 +28,7 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
@ -187,13 +183,13 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
]
|
||||
|
||||
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_create_repo.get_form_submission.return_value = None
|
||||
mock_create_repo.get_form.return_value = None
|
||||
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
mock_form_entity.id = "test_form_id"
|
||||
mock_form_entity.web_app_token = "test_web_app_token"
|
||||
mock_form_entity.recipients = []
|
||||
mock_form_entity.rendered_content = "rendered"
|
||||
mock_form_entity.submitted = False
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
@ -255,11 +251,15 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
]
|
||||
|
||||
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_form_submission = MagicMock(spec=FormSubmission)
|
||||
mock_form_submission.selected_action_id = "accept"
|
||||
mock_form_submission.form_data.return_value = {}
|
||||
mock_get_repo.get_form_submission.return_value = mock_form_submission
|
||||
mock_get_repo.get_form.return_value = mock_form_entity
|
||||
submitted_form = MagicMock(spec=HumanInputFormEntity)
|
||||
submitted_form.id = mock_form_entity.id
|
||||
submitted_form.web_app_token = mock_form_entity.web_app_token
|
||||
submitted_form.recipients = []
|
||||
submitted_form.rendered_content = mock_form_entity.rendered_content
|
||||
submitted_form.submitted = True
|
||||
submitted_form.selected_action_id = "accept"
|
||||
submitted_form.submitted_data = {}
|
||||
mock_get_repo.get_form.return_value = submitted_form
|
||||
|
||||
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
# restruct the graph runtime state
|
||||
|
||||
@ -20,7 +20,6 @@ from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
@ -43,28 +42,27 @@ def _build_runtime_state() -> GraphRuntimeState:
|
||||
|
||||
|
||||
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
|
||||
submission = MagicMock(spec=FormSubmission)
|
||||
submission.selected_action_id = action_id
|
||||
submission.form_data.return_value = {}
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
repo.get_form_submission.return_value = submission
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
form_entity.submitted = True
|
||||
form_entity.selected_action_id = action_id
|
||||
form_entity.submitted_data = {}
|
||||
repo.get_form.return_value = form_entity
|
||||
return repo
|
||||
|
||||
|
||||
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
repo.get_form_submission.return_value = None
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
form_entity.submitted = False
|
||||
repo.create_form.return_value = form_entity
|
||||
repo.get_form.return_value = None
|
||||
return repo
|
||||
@ -184,8 +182,8 @@ def test_engine_resume_restores_state_and_completion():
|
||||
assert combined_success_nodes == baseline_success_nodes
|
||||
|
||||
assert baseline_state.outputs == resumed_state.outputs
|
||||
assert _segment_value(baseline_state.variable_pool, ("human", "action_id")) == _segment_value(
|
||||
resumed_state.variable_pool, ("human", "action_id")
|
||||
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(
|
||||
resumed_state.variable_pool, ("human", "__action_id")
|
||||
)
|
||||
assert baseline_state.graph_execution.completed
|
||||
assert resumed_state.graph_execution.completed
|
||||
|
||||
@ -320,12 +320,12 @@ class TestHumanInputNodeVariableResolution:
|
||||
|
||||
mock_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_repo.get_form.return_value = None
|
||||
mock_repo.get_form_submission.return_value = None
|
||||
mock_repo.create_form.return_value = SimpleNamespace(
|
||||
id="form-1",
|
||||
rendered_content="Provide your name",
|
||||
web_app_token="token",
|
||||
recipients=[],
|
||||
submitted=False,
|
||||
)
|
||||
|
||||
node = HumanInputNode(
|
||||
|
||||
@ -168,11 +168,11 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
):
|
||||
"""Test workflow pause creation when workflow not in RUNNING status."""
|
||||
# Arrange
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
|
||||
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
|
||||
Reference in New Issue
Block a user