mirror of
https://github.com/langgenius/dify.git
synced 2026-02-23 03:17:57 +08:00
resume test
This commit is contained in:
@ -765,12 +765,6 @@ class RepositoryConfig(BaseSettings):
|
||||
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
)
|
||||
|
||||
CORE_HUMAN_INPUT_FORM_REPOSITORY: str = Field(
|
||||
description="Repository implementation for HumanInputForm. Options: "
|
||||
"'core.repositories.sqlalchemy_human_input_form_repository.SQLAlchemyHumanInputFormRepository' (default)",
|
||||
default="core.repositories.sqlalchemy_human_input_form_repository.SQLAlchemyHumanInputFormRepository",
|
||||
)
|
||||
|
||||
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. "
|
||||
"Specify as a module path",
|
||||
|
||||
@ -131,7 +131,7 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
GET /console/api/workflow/<task_id>/events
|
||||
GET /console/api/workflow/<workflow_run_id>/events
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
|
||||
@ -56,5 +56,4 @@ class WorkflowEventsApi(WebApiResource):
|
||||
|
||||
|
||||
# Register the APIs
|
||||
api.add_resource(WorkflowResumeWaitApi, "/workflow/<string:task_id>/resume-wait")
|
||||
api.add_resource(WorkflowEventsApi, "/workflow/<string:task_id>/events")
|
||||
|
||||
@ -11,7 +11,6 @@ from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from libs.module_loading import import_string
|
||||
@ -107,36 +106,3 @@ class DifyCoreRepositoryFactory:
|
||||
raise RepositoryImportError(
|
||||
f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def create_human_input_form_repository(
|
||||
cls,
|
||||
session_factory: Union[sessionmaker, Engine],
|
||||
user: Union[Account, EndUser],
|
||||
app_id: str,
|
||||
) -> HumanInputFormRepository:
|
||||
"""
|
||||
Create a HumanInputFormRepository instance based on configuration.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine
|
||||
user: Account or EndUser object
|
||||
app_id: Application ID
|
||||
|
||||
Returns:
|
||||
Configured HumanInputFormRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be created
|
||||
"""
|
||||
class_path = dify_config.CORE_HUMAN_INPUT_FORM_REPOSITORY
|
||||
|
||||
try:
|
||||
repository_class = import_string(class_path)
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
)
|
||||
except (ImportError, Exception) as e:
|
||||
raise RepositoryImportError(f"Failed to create HumanInputFormRepository from '{class_path}': {e}") from e
|
||||
|
||||
@ -22,6 +22,7 @@ from core.workflow.repositories.human_input_form_repository import (
|
||||
FormNotFoundError,
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
@ -42,9 +43,6 @@ class _DeliveryAndRecipients:
|
||||
delivery: HumanInputDelivery
|
||||
recipients: Sequence[HumanInputFormRecipient]
|
||||
|
||||
def webapp_recipient(self) -> HumanInputFormRecipient | None:
|
||||
return next((i for i in self.recipients if i.recipient_type == RecipientType.WEBAPP), None)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _WorkspaceMemberInfo:
|
||||
@ -52,10 +50,31 @@ class _WorkspaceMemberInfo:
|
||||
email: str
|
||||
|
||||
|
||||
class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
|
||||
def __init__(self, recipient_model: HumanInputFormRecipient):
|
||||
self._recipient_model = recipient_model
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._recipient_model.id
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
if self._recipient_model.access_token is None:
|
||||
raise AssertionError(
|
||||
f"access_token should not be None for recipient {self._recipient_model.id}"
|
||||
)
|
||||
return self._recipient_model.access_token
|
||||
|
||||
|
||||
class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
def __init__(self, form_model: HumanInputForm, web_app_recipient: HumanInputFormRecipient | None):
|
||||
def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]):
|
||||
self._form_model = form_model
|
||||
self._web_app_recipient = web_app_recipient
|
||||
self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models]
|
||||
self._web_app_recipient = next(
|
||||
(recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.WEBAPP),
|
||||
None,
|
||||
)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
@ -67,6 +86,10 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
return None
|
||||
return self._web_app_recipient.access_token
|
||||
|
||||
@property
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
||||
return list(self._recipients)
|
||||
|
||||
|
||||
class _FormSubmissionImpl(FormSubmission):
|
||||
def __init__(self, form_model: HumanInputForm):
|
||||
@ -293,7 +316,7 @@ class HumanInputFormRepositoryImpl:
|
||||
expiration_time=form_config.expiration_time(naive_utc_now()),
|
||||
)
|
||||
session.add(form_model)
|
||||
web_app_recipient: HumanInputFormRecipient | None = None
|
||||
recipient_models: list[HumanInputFormRecipient] = []
|
||||
for delivery in form_config.delivery_methods:
|
||||
delivery_and_recipients = self._delivery_method_to_model(
|
||||
session=session,
|
||||
@ -302,21 +325,33 @@ class HumanInputFormRepositoryImpl:
|
||||
)
|
||||
session.add(delivery_and_recipients.delivery)
|
||||
session.add_all(delivery_and_recipients.recipients)
|
||||
if web_app_recipient is None:
|
||||
web_app_recipient = delivery_and_recipients.webapp_recipient()
|
||||
recipient_models.extend(delivery_and_recipients.recipients)
|
||||
session.flush()
|
||||
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, web_app_recipient=web_app_recipient)
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None:
|
||||
query = select(HumanInputForm).where(
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
form_query = select(HumanInputForm).where(
|
||||
HumanInputForm.workflow_run_id == workflow_execution_id,
|
||||
HumanInputForm.node_id == node_id,
|
||||
HumanInputForm.tenant_id == self._tenant_id,
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_model: HumanInputForm | None = session.scalars(query).first()
|
||||
form_model: HumanInputForm | None = session.scalars(form_query).first()
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found for node, {workflow_execution_id=}, {node_id=}")
|
||||
return None
|
||||
|
||||
recipient_query = select(HumanInputFormRecipient).where(
|
||||
HumanInputFormRecipient.form_id == form_model.id
|
||||
)
|
||||
recipient_models = session.scalars(recipient_query).all()
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
def get_form_submission(self, form_id: str) -> FormSubmission | None:
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_model: HumanInputForm | None = session.get(HumanInputForm, form_id)
|
||||
if form_model is None or form_model.tenant_id != self._tenant_id:
|
||||
raise FormNotFoundError(f"form not found, form_id={form_id}")
|
||||
|
||||
if form_model.submitted_at is None:
|
||||
return None
|
||||
@ -324,8 +359,8 @@ class HumanInputFormRepositoryImpl:
|
||||
return _FormSubmissionImpl(form_model=form_model)
|
||||
|
||||
|
||||
class HumanInputFormReadRepository:
|
||||
"""Read/write repository for fetching and submitting human input forms."""
|
||||
class HumanInputFormSubmissionRepository:
|
||||
"""Repository for fetching and submitting human input forms."""
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine):
|
||||
if isinstance(session_factory, Engine):
|
||||
|
||||
@ -1,19 +1,29 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.node_events.base import NodeEventBase
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.repositories.human_input_form_repository import FormCreateParams, HumanInputFormRepository
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
|
||||
_SELECTED_BRANCH_KEY = "selected_branch"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -34,6 +44,28 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
)
|
||||
|
||||
_node_data: HumanInputNodeData
|
||||
_form_repository: HumanInputFormRepository
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if form_repository is None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
self._form_repository = form_repository
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@ -86,19 +118,44 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
|
||||
return None
|
||||
|
||||
def _create_form_repository(self) -> HumanInputFormRepository:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]:
|
||||
yield event
|
||||
|
||||
@property
|
||||
def _workflow_execution_id(self) -> str:
|
||||
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
assert workflow_exec_id is not None
|
||||
return workflow_exec_id
|
||||
|
||||
def _form_to_pause_event(self, form_entity: HumanInputFormEntity):
|
||||
required_event = self._human_input_required_event(form_entity)
|
||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
||||
return pause_requested_event
|
||||
|
||||
def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult:
|
||||
try:
|
||||
params = FormCreateParams(
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self._render_form_content(),
|
||||
)
|
||||
form_entity = self._form_repository.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
logger.info(
|
||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
self.id,
|
||||
form_entity.id,
|
||||
)
|
||||
yield self._human
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
except Exception as e:
|
||||
logger.exception("Human Input node failed to execute, node_id=%s", self.id)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="HumanInputNodeError",
|
||||
)
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Execute the human input node.
|
||||
@ -111,51 +168,26 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
5. Suspend workflow execution
|
||||
6. Wait for form submission to resume
|
||||
"""
|
||||
repo = self._create_form_repository()
|
||||
submission_result = repo.get_form_submission(self._workflow_execution_id, self.app_id)
|
||||
repo = self._form_repository
|
||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
||||
if form is None:
|
||||
return self._create_form()
|
||||
|
||||
submission_result = repo.get_form_submission(form.id)
|
||||
if submission_result:
|
||||
outputs: dict[str, Any] = dict(submission_result.form_data())
|
||||
outputs["action_id"] = submission_result.selected_action_id
|
||||
outputs["__action_id"] = submission_result.selected_action_id
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"action_id": submission_result.selected_action_id,
|
||||
},
|
||||
outputs=outputs,
|
||||
edge_source_handle=submission_result.selected_action_id,
|
||||
)
|
||||
try:
|
||||
repo = self._create_form_repository()
|
||||
params = FormCreateParams(
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self._render_form_content(),
|
||||
)
|
||||
result = repo.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
required_event = HumanInputRequired(
|
||||
form_id=result.id,
|
||||
form_content=self._node_data.form_content,
|
||||
inputs=self._node_data.inputs,
|
||||
web_app_form_token=result.web_app_token,
|
||||
)
|
||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
||||
return self._pause_with_form(form)
|
||||
|
||||
# Create workflow suspended event
|
||||
|
||||
logger.info(
|
||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
self.id,
|
||||
result.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Human Input node failed to execute, node_id=%s", self.id)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="HumanInputNodeError",
|
||||
)
|
||||
return self._pause_generator(pause_requested_event)
|
||||
def _pause_with_form(self, form_entity: HumanInputFormEntity) -> Generator[NodeEventBase, None, None]:
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
|
||||
def _render_form_content(self) -> str:
|
||||
"""
|
||||
|
||||
@ -93,13 +93,18 @@ class HumanInputFormRepository(Protocol):
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
"""Get the form created for a given human input node in a workflow execution. Returns
|
||||
`None` if the form has not been created yet."""
|
||||
...
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
"""
|
||||
Create a human input form from form definition.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None:
|
||||
def get_form_submission(self, form_id: str) -> FormSubmission | None:
|
||||
"""Retrieve the submission for a specific human input node.
|
||||
|
||||
Returns `FormSubmission` if the form has been submitted, or `None` if not.
|
||||
|
||||
@ -0,0 +1,72 @@
|
||||
"""Add human input related models
|
||||
|
||||
Revision ID: d411af417245
|
||||
Revises: 669ffd70119c
|
||||
Create Date: 2025-11-24 03:36:50.565145
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d411af417245"
|
||||
down_revision = "669ffd70119c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"human_input_form_deliveries",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("form_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("delivery_method_type", sa.String(20), nullable=False),
|
||||
sa.Column("delivery_config_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("channel_payload", sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_deliveries_pkey")),
|
||||
)
|
||||
op.create_table(
|
||||
"human_input_form_recipients",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("form_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("delivery_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("recipient_type", sa.String(20), nullable=False),
|
||||
sa.Column("recipient_payload", sa.Text(), nullable=False),
|
||||
sa.Column("access_token", sa.VARCHAR(length=32), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_recipients_pkey")),
|
||||
)
|
||||
op.create_table(
|
||||
"human_input_forms",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("node_id", sa.String(length=60), nullable=False),
|
||||
sa.Column("form_definition", sa.Text(), nullable=False),
|
||||
sa.Column("rendered_content", sa.Text(), nullable=False),
|
||||
sa.Column("status", sa.String(20), nullable=False),
|
||||
sa.Column("expiration_time", sa.DateTime(), nullable=False),
|
||||
sa.Column("selected_action_id", sa.String(length=200), nullable=True),
|
||||
sa.Column("submitted_data", sa.Text(), nullable=True),
|
||||
sa.Column("submitted_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("submission_user_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("submission_end_user_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("completed_by_recipient_id", models.types.StringUUID(), nullable=True),
|
||||
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("human_input_forms_pkey")),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("human_input_forms")
|
||||
op.drop_table("human_input_form_recipients")
|
||||
op.drop_table("human_input_form_deliveries")
|
||||
@ -54,6 +54,7 @@ from .model import (
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageExtraContent,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
OperationLog,
|
||||
@ -161,6 +162,7 @@ __all__ = [
|
||||
"MessageAgentThought",
|
||||
"MessageAnnotation",
|
||||
"MessageChain",
|
||||
"MessageExtraContent",
|
||||
"MessageFeedback",
|
||||
"MessageFile",
|
||||
"OperationLog",
|
||||
|
||||
@ -87,7 +87,7 @@ class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
nullable=False,
|
||||
)
|
||||
delivery_config_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
channel_payload: Mapped[None] = mapped_column(sa.Text, nullable=True)
|
||||
channel_payload: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
form: Mapped[HumanInputForm] = relationship(
|
||||
"HumanInputForm",
|
||||
|
||||
@ -24,11 +24,11 @@ from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .base import Base, TypeBase
|
||||
from .base import Base, DefaultFieldsMixin, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import LongText, StringUUID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .workflow import Workflow
|
||||
@ -2065,3 +2065,19 @@ class TraceAppConfig(TypeBase):
|
||||
"created_at": str(self.created_at) if self.created_at else None,
|
||||
"updated_at": str(self.updated_at) if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class MessageExtraContentType(StrEnum):
|
||||
human_input_result = auto()
|
||||
|
||||
|
||||
class MessageExtraContent(DefaultFieldsMixin):
|
||||
__tablename__ = "message_extra_contents"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="message_extra_content_pkey"),
|
||||
sa.Index("message_extra_content_message_id_idx", "message_id"),
|
||||
)
|
||||
|
||||
message_id = mapped_column(StringUUID, nullable=False, index=True)
|
||||
type: Mapped[MessageExtraContentType] = mapped_column(EnumText(MessageExtraContentType, length=30), nullable=False)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
@ -5,7 +5,10 @@ from typing import Any
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord
|
||||
from core.repositories.human_input_reposotiry import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormSubmissionRepository,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from libs.exception import BaseHTTPException
|
||||
from models.account import Account
|
||||
@ -65,6 +68,14 @@ class FormNotFoundError(HumanInputError, BaseHTTPException):
|
||||
code = 404
|
||||
|
||||
|
||||
class InvalidFormDataError(HumanInputError, BaseHTTPException):
|
||||
error_code = "invalid_form_data"
|
||||
code = 400
|
||||
|
||||
def __init__(self, description: str):
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
|
||||
pass
|
||||
|
||||
@ -76,12 +87,12 @@ class HumanInputService:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker[Session] | Engine,
|
||||
form_repository: HumanInputFormReadRepository | None = None,
|
||||
form_repository: HumanInputFormSubmissionRepository | None = None,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._form_repository = form_repository or HumanInputFormReadRepository(session_factory)
|
||||
self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory)
|
||||
|
||||
def get_form_by_token(self, form_token: str) -> Form | None:
|
||||
record = self._form_repository.get_by_token(form_token)
|
||||
@ -124,6 +135,7 @@ class HumanInputService:
|
||||
raise WebAppDeliveryNotEnabledError()
|
||||
|
||||
self._ensure_not_submitted(form)
|
||||
self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data)
|
||||
|
||||
result = self._form_repository.mark_submitted(
|
||||
form_id=form.id,
|
||||
@ -149,6 +161,7 @@ class HumanInputService:
|
||||
raise WebAppDeliveryNotEnabledError()
|
||||
|
||||
self._ensure_not_submitted(form)
|
||||
self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data)
|
||||
|
||||
result = self._form_repository.mark_submitted(
|
||||
form_id=form.id,
|
||||
@ -165,6 +178,23 @@ class HumanInputService:
|
||||
if form.submitted:
|
||||
raise FormSubmittedError(form.id)
|
||||
|
||||
def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
|
||||
definition = form.get_definition()
|
||||
|
||||
available_actions = {action.id for action in definition.user_actions}
|
||||
if selected_action_id not in available_actions:
|
||||
raise InvalidFormDataError(f"Invalid action: {selected_action_id}")
|
||||
|
||||
provided_inputs = set(form_data.keys())
|
||||
missing_inputs = [
|
||||
form_input.output_variable_name
|
||||
for form_input in definition.inputs
|
||||
if form_input.output_variable_name not in provided_inputs
|
||||
]
|
||||
|
||||
if missing_inputs:
|
||||
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
|
||||
|
||||
def _enqueue_resume(self, workflow_run_id: str) -> None:
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
|
||||
@ -10,9 +10,9 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_reposotiry import (
|
||||
HumanInputFormReadRepository,
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormRepositoryImpl,
|
||||
HumanInputFormSubmissionRepository,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
@ -22,6 +22,7 @@ from core.workflow.nodes.human_input.entities import (
|
||||
TimeoutUnit,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import FormNotFoundError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
@ -197,23 +198,42 @@ class _FakeScalarResult:
|
||||
self._obj = obj
|
||||
|
||||
def first(self):
|
||||
if isinstance(self._obj, list):
|
||||
return self._obj[0] if self._obj else None
|
||||
return self._obj
|
||||
|
||||
def all(self):
|
||||
if isinstance(self._obj, list):
|
||||
return list(self._obj)
|
||||
if self._obj is None:
|
||||
return []
|
||||
return [self._obj]
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scalars_result=None,
|
||||
scalars_results: list[object] | None = None,
|
||||
forms: dict[str, _DummyForm] | None = None,
|
||||
recipients: dict[str, _DummyRecipient] | None = None,
|
||||
):
|
||||
self._scalars_result = scalars_result
|
||||
if scalars_results is not None:
|
||||
self._scalars_queue = list(scalars_results)
|
||||
elif scalars_result is not None:
|
||||
self._scalars_queue = [scalars_result]
|
||||
else:
|
||||
self._scalars_queue = []
|
||||
self.forms = forms or {}
|
||||
self.recipients = recipients or {}
|
||||
|
||||
def scalars(self, _query):
|
||||
return _FakeScalarResult(self._scalars_result)
|
||||
if self._scalars_queue:
|
||||
result = self._scalars_queue.pop(0)
|
||||
else:
|
||||
result = None
|
||||
return _FakeScalarResult(result)
|
||||
|
||||
def get(self, model_cls, obj_id): # type: ignore[no-untyped-def]
|
||||
if getattr(model_cls, "__name__", None) == "HumanInputForm":
|
||||
@ -255,7 +275,86 @@ def _session_factory(session: _FakeSession):
|
||||
return _factory
|
||||
|
||||
|
||||
class TestHumanInputFormReadRepository:
|
||||
class TestHumanInputFormRepositoryImplPublicMethods:
|
||||
def test_get_form_returns_entity_and_recipients(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-id",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
access_token="token-123",
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, [recipient]])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
entity = repo.get_form(form.workflow_run_id, form.node_id)
|
||||
|
||||
assert entity is not None
|
||||
assert entity.id == form.id
|
||||
assert entity.web_app_token == "token-123"
|
||||
assert len(entity.recipients) == 1
|
||||
assert entity.recipients[0].token == "token-123"
|
||||
|
||||
def test_get_form_returns_none_when_missing(self):
|
||||
session = _FakeSession(scalars_results=[None])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
assert repo.get_form("run-1", "node-1") is None
|
||||
|
||||
def test_get_form_submission_returns_none_when_pending(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-id",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
assert repo.get_form_submission(form.id) is None
|
||||
|
||||
def test_get_form_submission_returns_submission_when_completed(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-id",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
selected_action_id="approve",
|
||||
submitted_data='{"field": "value"}',
|
||||
submitted_at=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(forms={form.id: form})
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
submission = repo.get_form_submission(form.id)
|
||||
|
||||
assert submission is not None
|
||||
assert submission.selected_action_id == "approve"
|
||||
assert submission.form_data() == {"field": "value"}
|
||||
|
||||
def test_get_form_submission_raises_when_form_missing(self):
|
||||
session = _FakeSession(forms={})
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
with pytest.raises(FormNotFoundError):
|
||||
repo.get_form_submission("form-unknown")
|
||||
|
||||
|
||||
class TestHumanInputFormSubmissionRepository:
|
||||
def test_get_by_token_returns_record(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
@ -274,7 +373,7 @@ class TestHumanInputFormReadRepository:
|
||||
form=form,
|
||||
)
|
||||
session = _FakeSession(scalars_result=recipient)
|
||||
repo = HumanInputFormReadRepository(_session_factory(session))
|
||||
repo = HumanInputFormSubmissionRepository(_session_factory(session))
|
||||
|
||||
record = repo.get_by_token("token-123")
|
||||
|
||||
@ -301,7 +400,7 @@ class TestHumanInputFormReadRepository:
|
||||
form=form,
|
||||
)
|
||||
session = _FakeSession(scalars_result=recipient)
|
||||
repo = HumanInputFormReadRepository(_session_factory(session))
|
||||
repo = HumanInputFormSubmissionRepository(_session_factory(session))
|
||||
|
||||
record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.WEBAPP)
|
||||
|
||||
@ -332,7 +431,7 @@ class TestHumanInputFormReadRepository:
|
||||
forms={form.id: form},
|
||||
recipients={recipient.id: recipient},
|
||||
)
|
||||
repo = HumanInputFormReadRepository(_session_factory(session))
|
||||
repo = HumanInputFormSubmissionRepository(_session_factory(session))
|
||||
|
||||
record: HumanInputFormRecord = repo.mark_submitted(
|
||||
form_id=form.id,
|
||||
|
||||
@ -0,0 +1,101 @@
|
||||
"""Utilities for testing HumanInputNode without database dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
|
||||
|
||||
class _InMemoryFormRecipient(HumanInputFormRecipientEntity):
|
||||
"""Minimal recipient entity required by the repository interface."""
|
||||
|
||||
def __init__(self, recipient_id: str, token: str) -> None:
|
||||
self._id = recipient_id
|
||||
self._token = token
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
return self._token
|
||||
|
||||
|
||||
@dataclass
|
||||
class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
form_id: str
|
||||
token: str | None = None
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.form_id
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
return self.token
|
||||
|
||||
@property
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
||||
return []
|
||||
|
||||
|
||||
class _InMemoryFormSubmission(FormSubmission):
|
||||
def __init__(self, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
|
||||
self._selected_action_id = selected_action_id
|
||||
self._form_data = form_data
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str:
|
||||
return self._selected_action_id
|
||||
|
||||
def form_data(self) -> Mapping[str, Any]:
|
||||
return self._form_data
|
||||
|
||||
|
||||
class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
"""Pure in-memory repository used by workflow graph engine tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._form_counter = 0
|
||||
self.created_params: list[FormCreateParams] = []
|
||||
self.created_forms: list[_InMemoryFormEntity] = []
|
||||
self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {}
|
||||
self._submissions: dict[str, FormSubmission] = {}
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
self.created_params.append(params)
|
||||
self._form_counter += 1
|
||||
form_id = f"form-{self._form_counter}"
|
||||
entity = _InMemoryFormEntity(form_id=form_id, token=f"token-{form_id}")
|
||||
self.created_forms.append(entity)
|
||||
self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity
|
||||
return entity
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
return self._forms_by_key.get((workflow_execution_id, node_id))
|
||||
|
||||
def get_form_submission(self, form_id: str) -> FormSubmission | None:
|
||||
return self._submissions.get(form_id)
|
||||
|
||||
# Convenience helpers for tests -------------------------------------
|
||||
|
||||
def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None:
|
||||
"""Simulate a human submission for the next repository lookup."""
|
||||
|
||||
if not self.created_forms:
|
||||
raise AssertionError("no form has been created to attach submission data")
|
||||
target_form_id = self.created_forms[-1].id
|
||||
self._submissions[target_form_id] = _InMemoryFormSubmission(action_id, form_data or {})
|
||||
|
||||
def clear_submission(self) -> None:
|
||||
self._submissions.clear()
|
||||
@ -1,5 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@ -17,7 +18,7 @@ from core.workflow.graph_events import (
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
@ -28,6 +29,11 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
@ -36,7 +42,11 @@ from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
def _build_branching_graph(
|
||||
mock_config: MockConfig,
|
||||
form_repository: HumanInputFormRepository,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
@ -49,12 +59,18 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
if graph_runtime_state is None:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="test-execution-id",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
@ -93,15 +109,21 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
required_variables=["human.input_ready"],
|
||||
pause_reason="Awaiting human input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="primary", title="Primary"),
|
||||
UserAction(id="secondary", title="Secondary"),
|
||||
],
|
||||
)
|
||||
|
||||
human_config = {"id": "human", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
@ -219,8 +241,17 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
for scenario in branch_scenarios:
|
||||
runner = TableTestRunner()
|
||||
|
||||
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_create_repo.get_form_submission.return_value = None
|
||||
mock_create_repo.get_form.return_value = None
|
||||
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
mock_form_entity.id = "test_form_id"
|
||||
mock_form_entity.web_app_token = "test_web_app_token"
|
||||
mock_form_entity.recipients = []
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_branching_graph(mock_config)
|
||||
return _build_branching_graph(mock_config, mock_create_repo)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
description="HumanInput pause before branching decision",
|
||||
@ -242,15 +273,6 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
assert initial_result.success, initial_result.event_mismatch_details
|
||||
assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events)
|
||||
|
||||
graph_runtime_state = initial_result.graph_runtime_state
|
||||
graph = initial_result.graph
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
|
||||
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
|
||||
graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"])
|
||||
graph_runtime_state.graph_execution.pause_reason = None
|
||||
|
||||
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
|
||||
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
|
||||
|
||||
@ -273,11 +295,18 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
]
|
||||
)
|
||||
|
||||
def resume_graph_factory(
|
||||
graph_snapshot: Graph = graph,
|
||||
state_snapshot: GraphRuntimeState = graph_runtime_state,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
return graph_snapshot, state_snapshot
|
||||
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_form_submission = MagicMock(spec=FormSubmission)
|
||||
mock_form_submission.selected_action_id = scenario["handle"]
|
||||
mock_form_submission.form_data.return_value = {}
|
||||
mock_get_repo.get_form_submission.return_value = mock_form_submission
|
||||
mock_get_repo.get_form.return_value = mock_form_entity
|
||||
|
||||
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
assert initial_result.graph_runtime_state is not None
|
||||
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
|
||||
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
|
||||
return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state)
|
||||
|
||||
resume_case = WorkflowTestCase(
|
||||
description=f"HumanInput resumes via {scenario['handle']} branch",
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@ -16,7 +17,7 @@ from core.workflow.graph_events import (
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
@ -27,6 +28,11 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
@ -35,7 +41,11 @@ from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
def _build_llm_human_llm_graph(
|
||||
mock_config: MockConfig,
|
||||
form_repository: HumanInputFormRepository,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
@ -48,12 +58,15 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
if graph_runtime_state is None:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id,"
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
@ -92,15 +105,21 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
required_variables=["human.input_ready"],
|
||||
pause_reason="Awaiting human input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="accept", title="Accept"),
|
||||
UserAction(id="reject", title="Reject"),
|
||||
],
|
||||
)
|
||||
|
||||
human_config = {"id": "human", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
|
||||
@ -130,7 +149,7 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
.add_root(start_node)
|
||||
.add_node(llm_first)
|
||||
.add_node(human_node)
|
||||
.add_node(llm_second)
|
||||
.add_node(llm_second, source_handle="accept")
|
||||
.add_node(end_node)
|
||||
.build()
|
||||
)
|
||||
@ -167,8 +186,17 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
GraphRunPausedEvent, # graph run pauses awaiting resume
|
||||
]
|
||||
|
||||
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_create_repo.get_form_submission.return_value = None
|
||||
mock_create_repo.get_form.return_value = None
|
||||
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
mock_form_entity.id = "test_form_id"
|
||||
mock_form_entity.web_app_token = "test_web_app_token"
|
||||
mock_form_entity.recipients = []
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_llm_human_llm_graph(mock_config)
|
||||
return _build_llm_human_llm_graph(mock_config, mock_create_repo)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
description="HumanInput pause preserves LLM streaming order",
|
||||
@ -225,12 +253,22 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
GraphRunSucceededEvent, # graph run succeeds after resume
|
||||
]
|
||||
|
||||
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_form_submission = MagicMock(spec=FormSubmission)
|
||||
mock_form_submission.selected_action_id = "accept"
|
||||
mock_form_submission.form_data.return_value = {}
|
||||
mock_get_repo.get_form_submission.return_value = mock_form_submission
|
||||
mock_get_repo.get_form.return_value = mock_form_entity
|
||||
|
||||
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
|
||||
graph_runtime_state.graph_execution.pause_reason = None
|
||||
return graph, graph_runtime_state
|
||||
# restruct the graph runtime state
|
||||
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
|
||||
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
|
||||
return _build_llm_human_llm_graph(
|
||||
mock_config,
|
||||
mock_get_repo,
|
||||
resume_runtime_state,
|
||||
)
|
||||
|
||||
resume_case = WorkflowTestCase(
|
||||
description="HumanInput resume continues LLM streaming order",
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import core.workflow.nodes.human_input.entities # noqa: F401
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
@ -15,72 +12,66 @@ from core.workflow.graph_events import (
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeEventBase, NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormSubmission,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
class _PausingNodeData(BaseNodeData):
|
||||
pass
|
||||
|
||||
|
||||
class _PausingNode(Node):
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = _PausingNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@staticmethod
|
||||
def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]:
|
||||
yield event
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
resumed_flag = self.graph_runtime_state.variable_pool.get((self.id, "resumed"))
|
||||
if resumed_flag is None:
|
||||
# mark as resumed and request pause
|
||||
self.graph_runtime_state.variable_pool.add((self.id, "resumed"), True)
|
||||
return self._pause_generator(PauseRequestedEvent(reason=SchedulingPause(message="test pause")))
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"value": "completed"},
|
||||
)
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="test-execution-id",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_pausing_graph(runtime_state: GraphRuntimeState) -> Graph:
|
||||
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
|
||||
submission = MagicMock(spec=FormSubmission)
|
||||
submission.selected_action_id = action_id
|
||||
submission.form_data.return_value = {}
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
repo.get_form_submission.return_value = submission
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
repo.get_form.return_value = form_entity
|
||||
return repo
|
||||
|
||||
|
||||
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
repo.get_form_submission.return_value = None
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
repo.create_form.return_value = form_entity
|
||||
repo.get_form.return_value = None
|
||||
return repo
|
||||
|
||||
|
||||
def _build_human_input_graph(
|
||||
runtime_state: GraphRuntimeState,
|
||||
form_repository: HumanInputFormRepository,
|
||||
) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
@ -102,19 +93,27 @@ def _build_pausing_graph(runtime_state: GraphRuntimeState) -> Graph:
|
||||
)
|
||||
start_node.init_node_data(start_data.model_dump())
|
||||
|
||||
pause_data = _PausingNodeData(title="pausing")
|
||||
pause_node = _PausingNode(
|
||||
id="pausing",
|
||||
config={"id": "pausing", "data": pause_data.model_dump()},
|
||||
human_data = HumanInputNodeData(
|
||||
title="human",
|
||||
form_content="Awaiting human input",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="continue", title="Continue"),
|
||||
],
|
||||
)
|
||||
human_node = HumanInputNode(
|
||||
id="human",
|
||||
config={"id": "human", "data": human_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
pause_node.init_node_data(pause_data.model_dump())
|
||||
human_node.init_node_data(human_data.model_dump())
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[
|
||||
VariableSelector(variable="result", value_selector=["pausing", "value"]),
|
||||
VariableSelector(variable="result", value_selector=["human", "action_id"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
@ -126,7 +125,13 @@ def _build_pausing_graph(runtime_state: GraphRuntimeState) -> Graph:
|
||||
)
|
||||
end_node.init_node_data(end_data.model_dump())
|
||||
|
||||
return Graph.new().add_root(start_node).add_node(pause_node).add_node(end_node).build()
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(human_node)
|
||||
.add_node(end_node, from_node_id="human", source_handle="continue")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]:
|
||||
@ -152,22 +157,24 @@ def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> An
|
||||
def test_engine_resume_restores_state_and_completion():
|
||||
# Baseline run without pausing
|
||||
baseline_state = _build_runtime_state()
|
||||
baseline_graph = _build_pausing_graph(baseline_state)
|
||||
baseline_state.variable_pool.add(("pausing", "resumed"), True)
|
||||
baseline_repo = _mock_form_repository_with_submission(action_id="continue")
|
||||
baseline_graph = _build_human_input_graph(baseline_state, baseline_repo)
|
||||
baseline_events = _run_graph(baseline_graph, baseline_state)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_success_nodes = _node_successes(baseline_events)
|
||||
|
||||
# Run with pause
|
||||
paused_state = _build_runtime_state()
|
||||
paused_graph = _build_pausing_graph(paused_state)
|
||||
pause_repo = _mock_form_repository_without_submission()
|
||||
paused_graph = _build_human_input_graph(paused_state, pause_repo)
|
||||
paused_events = _run_graph(paused_graph, paused_state)
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
# Resume from snapshot
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resumed_graph = _build_pausing_graph(resumed_state)
|
||||
resume_repo = _mock_form_repository_with_submission(action_id="continue")
|
||||
resumed_graph = _build_human_input_graph(resumed_state, resume_repo)
|
||||
resumed_events = _run_graph(resumed_graph, resumed_state)
|
||||
assert isinstance(resumed_events[-1], GraphRunSucceededEvent)
|
||||
|
||||
@ -175,11 +182,8 @@ def test_engine_resume_restores_state_and_completion():
|
||||
assert combined_success_nodes == baseline_success_nodes
|
||||
|
||||
assert baseline_state.outputs == resumed_state.outputs
|
||||
assert _segment_value(baseline_state.variable_pool, ("pausing", "resumed")) == _segment_value(
|
||||
resumed_state.variable_pool, ("pausing", "resumed")
|
||||
)
|
||||
assert _segment_value(baseline_state.variable_pool, ("pausing", "value")) == _segment_value(
|
||||
resumed_state.variable_pool, ("pausing", "value")
|
||||
assert _segment_value(baseline_state.variable_pool, ("human", "action_id")) == _segment_value(
|
||||
resumed_state.variable_pool, ("human", "action_id")
|
||||
)
|
||||
assert baseline_state.graph_execution.completed
|
||||
assert resumed_state.graph_execution.completed
|
||||
|
||||
@ -7,6 +7,7 @@ from core.workflow.nodes.base.node import Node
|
||||
# Ensures that all node classes are imported.
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed.
|
||||
_ = NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
|
||||
@ -45,7 +46,9 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
||||
assert isinstance(cls.node_type, NodeType)
|
||||
assert isinstance(node_version, str)
|
||||
node_type_and_version = (node_type, node_version)
|
||||
assert node_type_and_version not in type_version_set
|
||||
assert node_type_and_version not in type_version_set, (
|
||||
f"Duplicate node type and version for class: {cls=} {node_type_and_version=}"
|
||||
)
|
||||
type_version_set.add(node_type_and_version)
|
||||
|
||||
|
||||
|
||||
@ -4,11 +4,14 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition, TimeoutUnit, UserAction
|
||||
from core.repositories.human_input_reposotiry import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormSubmissionRepository,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, FormInputType, TimeoutUnit, UserAction
|
||||
from models.account import Account
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import FormSubmittedError, HumanInputService
|
||||
from services.human_input_service import FormSubmittedError, HumanInputService, InvalidFormDataError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -129,7 +132,7 @@ def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
|
||||
|
||||
def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormReadRepository)
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_form_id_and_recipient_type.return_value = sample_form_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
@ -146,7 +149,7 @@ def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_sess
|
||||
def test_get_form_definition_by_id_raises_on_submitted(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime(2024, 1, 1))
|
||||
repo = MagicMock(spec=HumanInputFormReadRepository)
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_form_id_and_recipient_type.return_value = submitted_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
@ -157,7 +160,7 @@ def test_get_form_definition_by_id_raises_on_submitted(sample_form_record, mock_
|
||||
|
||||
def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormReadRepository)
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = sample_form_record
|
||||
repo.mark_submitted.return_value = sample_form_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
@ -166,7 +169,7 @@ def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, m
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
form_token="token",
|
||||
selected_action_id="approve",
|
||||
selected_action_id="submit",
|
||||
form_data={"field": "value"},
|
||||
submission_end_user_id="end-user-id",
|
||||
)
|
||||
@ -176,7 +179,7 @@ def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, m
|
||||
call_kwargs = repo.mark_submitted.call_args.kwargs
|
||||
assert call_kwargs["form_id"] == sample_form_record.form_id
|
||||
assert call_kwargs["recipient_id"] == sample_form_record.recipient_id
|
||||
assert call_kwargs["selected_action_id"] == "approve"
|
||||
assert call_kwargs["selected_action_id"] == "submit"
|
||||
assert call_kwargs["form_data"] == {"field": "value"}
|
||||
assert call_kwargs["submission_end_user_id"] == "end-user-id"
|
||||
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
|
||||
@ -184,7 +187,7 @@ def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, m
|
||||
|
||||
def test_submit_form_by_id_passes_account(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormReadRepository)
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_form_id_and_recipient_type.return_value = sample_form_record
|
||||
repo.mark_submitted.return_value = sample_form_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
@ -194,7 +197,7 @@ def test_submit_form_by_id_passes_account(sample_form_record, mock_session_facto
|
||||
|
||||
service.submit_form_by_id(
|
||||
form_id="form-id",
|
||||
selected_action_id="approve",
|
||||
selected_action_id="submit",
|
||||
form_data={"x": 1},
|
||||
user=account,
|
||||
)
|
||||
@ -203,3 +206,49 @@ def test_submit_form_by_id_passes_account(sample_form_record, mock_session_facto
|
||||
repo.mark_submitted.assert_called_once()
|
||||
assert repo.mark_submitted.call_args.kwargs["submission_user_id"] == "account-id"
|
||||
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
|
||||
|
||||
|
||||
def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = dataclasses.replace(sample_form_record)
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
form_token="token",
|
||||
selected_action_id="invalid",
|
||||
form_data={},
|
||||
)
|
||||
|
||||
assert "Invalid action" in str(exc_info.value)
|
||||
repo.mark_submitted.assert_not_called()
|
||||
|
||||
|
||||
def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
|
||||
definition_with_input = FormDefinition(
|
||||
form_content="hello",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content")],
|
||||
user_actions=sample_form_record.definition.user_actions,
|
||||
rendered_content="<p>hello</p>",
|
||||
timeout=1,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
)
|
||||
form_with_input = dataclasses.replace(sample_form_record, definition=definition_with_input)
|
||||
repo.get_by_token.return_value = form_with_input
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={},
|
||||
)
|
||||
|
||||
assert "Missing required inputs" in str(exc_info.value)
|
||||
repo.mark_submitted.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user