diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 5620cfc1a6..43dddbd011 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -765,12 +765,6 @@ class RepositoryConfig(BaseSettings): default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", ) - CORE_HUMAN_INPUT_FORM_REPOSITORY: str = Field( - description="Repository implementation for HumanInputForm. Options: " - "'core.repositories.sqlalchemy_human_input_form_repository.SQLAlchemyHumanInputFormRepository' (default)", - default="core.repositories.sqlalchemy_human_input_form_repository.SQLAlchemyHumanInputFormRepository", - ) - API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. " "Specify as a module path", diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 21700459bf..2b95e4a865 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -131,7 +131,7 @@ class ConsoleWorkflowEventsApi(Resource): """ Get workflow execution events stream after resume. - GET /console/api/workflow//events + GET /console/api/workflow//events Returns Server-Sent Events stream. """ diff --git a/api/controllers/web/workflow_events.py b/api/controllers/web/workflow_events.py index 0421c4a457..338ced69b8 100644 --- a/api/controllers/web/workflow_events.py +++ b/api/controllers/web/workflow_events.py @@ -56,5 +56,4 @@ class WorkflowEventsApi(WebApiResource): # Register the APIs -api.add_resource(WorkflowResumeWaitApi, "/workflow//resume-wait") api.add_resource(WorkflowEventsApi, "/workflow//events") diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index eacd8e2f0a..02fcabab5d 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -11,7 +11,6 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.module_loading import import_string @@ -107,36 +106,3 @@ class DifyCoreRepositoryFactory: raise RepositoryImportError( f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}" ) from e - - @classmethod - def create_human_input_form_repository( - cls, - session_factory: Union[sessionmaker, Engine], - user: Union[Account, EndUser], - app_id: str, - ) -> HumanInputFormRepository: - """ - Create a HumanInputFormRepository instance based on configuration. - - Args: - session_factory: SQLAlchemy sessionmaker or engine - user: Account or EndUser object - app_id: Application ID - - Returns: - Configured HumanInputFormRepository instance - - Raises: - RepositoryImportError: If the configured repository cannot be created - """ - class_path = dify_config.CORE_HUMAN_INPUT_FORM_REPOSITORY - - try: - repository_class = import_string(class_path) - return repository_class( # type: ignore[no-any-return] - session_factory=session_factory, - user=user, - app_id=app_id, - ) - except (ImportError, Exception) as e: - raise RepositoryImportError(f"Failed to create HumanInputFormRepository from '{class_path}': {e}") from e diff --git a/api/core/repositories/human_input_reposotiry.py b/api/core/repositories/human_input_reposotiry.py index df9442e7d3..c4b1926c21 100644 --- a/api/core/repositories/human_input_reposotiry.py +++ b/api/core/repositories/human_input_reposotiry.py @@ -22,6 +22,7 @@ from core.workflow.repositories.human_input_form_repository import ( FormNotFoundError, FormSubmission, HumanInputFormEntity, + HumanInputFormRecipientEntity, ) from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -42,9 +43,6 @@ class _DeliveryAndRecipients: delivery: HumanInputDelivery recipients: Sequence[HumanInputFormRecipient] - def webapp_recipient(self) -> HumanInputFormRecipient | None: - return next((i for i in self.recipients if i.recipient_type == RecipientType.WEBAPP), None) - @dataclasses.dataclass(frozen=True) class _WorkspaceMemberInfo: @@ -52,10 +50,31 @@ class _WorkspaceMemberInfo: email: str +class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): + def __init__(self, recipient_model: HumanInputFormRecipient): + self._recipient_model = recipient_model + + @property + def id(self) -> str: + return self._recipient_model.id + + @property + def token(self) -> str: + if self._recipient_model.access_token is None: + raise AssertionError( + f"access_token should not be None for recipient {self._recipient_model.id}" + ) + return self._recipient_model.access_token + + class _HumanInputFormEntityImpl(HumanInputFormEntity): - def __init__(self, form_model: HumanInputForm, web_app_recipient: HumanInputFormRecipient | None): + def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]): self._form_model = form_model - self._web_app_recipient = web_app_recipient + self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models] + self._web_app_recipient = next( + (recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.WEBAPP), + None, + ) @property def id(self) -> str: @@ -67,6 +86,10 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): return None return self._web_app_recipient.access_token + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: + return list(self._recipients) + class _FormSubmissionImpl(FormSubmission): def __init__(self, form_model: HumanInputForm): @@ -293,7 +316,7 @@ class HumanInputFormRepositoryImpl: expiration_time=form_config.expiration_time(naive_utc_now()), ) session.add(form_model) - web_app_recipient: HumanInputFormRecipient | None = None + recipient_models: list[HumanInputFormRecipient] = [] for delivery in form_config.delivery_methods: delivery_and_recipients = self._delivery_method_to_model( session=session, @@ -302,21 +325,33 @@ class HumanInputFormRepositoryImpl: ) session.add(delivery_and_recipients.delivery) session.add_all(delivery_and_recipients.recipients) - if web_app_recipient is None: - web_app_recipient = delivery_and_recipients.webapp_recipient() + recipient_models.extend(delivery_and_recipients.recipients) session.flush() - return _HumanInputFormEntityImpl(form_model=form_model, web_app_recipient=web_app_recipient) + return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None: - query = select(HumanInputForm).where( + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + form_query = select(HumanInputForm).where( HumanInputForm.workflow_run_id == workflow_execution_id, HumanInputForm.node_id == node_id, + HumanInputForm.tenant_id == self._tenant_id, ) with self._session_factory(expire_on_commit=False) as session: - form_model: HumanInputForm | None = session.scalars(query).first() + form_model: HumanInputForm | None = session.scalars(form_query).first() if form_model is None: - raise FormNotFoundError(f"form not found for node, {workflow_execution_id=}, {node_id=}") + return None + + recipient_query = select(HumanInputFormRecipient).where( + HumanInputFormRecipient.form_id == form_model.id + ) + recipient_models = session.scalars(recipient_query).all() + return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) + + def get_form_submission(self, form_id: str) -> FormSubmission | None: + with self._session_factory(expire_on_commit=False) as session: + form_model: HumanInputForm | None = session.get(HumanInputForm, form_id) + if form_model is None or form_model.tenant_id != self._tenant_id: + raise FormNotFoundError(f"form not found, form_id={form_id}") if form_model.submitted_at is None: return None @@ -324,8 +359,8 @@ class HumanInputFormRepositoryImpl: return _FormSubmissionImpl(form_model=form_model) -class HumanInputFormReadRepository: - """Read/write repository for fetching and submitting human input forms.""" +class HumanInputFormSubmissionRepository: + """Repository for fetching and submitting human input forms.""" def __init__(self, session_factory: sessionmaker | Engine): if isinstance(session_factory, Engine): diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index b114895958..8f4075b7d2 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -1,19 +1,29 @@ import logging from collections.abc import Generator, Mapping, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any -from core.workflow.entities.pause_reason import HumanInputRequired +from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult, PauseRequestedEvent from core.workflow.node_events.base import NodeEventBase -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node -from core.workflow.repositories.human_input_form_repository import FormCreateParams, HumanInputFormRepository +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from extensions.ext_database import db from .entities import HumanInputNodeData +if TYPE_CHECKING: + from core.workflow.entities.graph_init_params import GraphInitParams + from core.workflow.runtime.graph_runtime_state import GraphRuntimeState + + _SELECTED_BRANCH_KEY = "selected_branch" + logger = logging.getLogger(__name__) @@ -34,6 +44,28 @@ class HumanInputNode(Node[HumanInputNodeData]): ) _node_data: HumanInputNodeData + _form_repository: HumanInputFormRepository + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + form_repository: HumanInputFormRepository | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + if form_repository is None: + form_repository = HumanInputFormRepositoryImpl( + session_factory=db.engine, + tenant_id=self.tenant_id, + ) + self._form_repository = form_repository @classmethod def version(cls) -> str: @@ -86,19 +118,44 @@ class HumanInputNode(Node[HumanInputNodeData]): return None - def _create_form_repository(self) -> HumanInputFormRepository: - pass - - @staticmethod - def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]: - yield event - @property def _workflow_execution_id(self) -> str: workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id assert workflow_exec_id is not None return workflow_exec_id + def _form_to_pause_event(self, form_entity: HumanInputFormEntity): + required_event = self._human_input_required_event(form_entity) + pause_requested_event = PauseRequestedEvent(reason=required_event) + return pause_requested_event + + def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult: + try: + params = FormCreateParams( + workflow_execution_id=self._workflow_execution_id, + node_id=self.id, + form_config=self._node_data, + rendered_content=self._render_form_content(), + ) + form_entity = self._form_repository.create_form(params) + # Create human input required event + + logger.info( + "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", + self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, + self.id, + form_entity.id, + ) + yield self._human + yield self._form_to_pause_event(form_entity) + except Exception as e: + logger.exception("Human Input node failed to execute, node_id=%s", self.id) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + error_type="HumanInputNodeError", + ) + def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: """ Execute the human input node. @@ -111,51 +168,26 @@ class HumanInputNode(Node[HumanInputNodeData]): 5. Suspend workflow execution 6. Wait for form submission to resume """ - repo = self._create_form_repository() - submission_result = repo.get_form_submission(self._workflow_execution_id, self.app_id) + repo = self._form_repository + form = repo.get_form(self._workflow_execution_id, self.id) + if form is None: + return self._create_form() + + submission_result = repo.get_form_submission(form.id) if submission_result: + outputs: dict[str, Any] = dict(submission_result.form_data()) + outputs["action_id"] = submission_result.selected_action_id + outputs["__action_id"] = submission_result.selected_action_id return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "action_id": submission_result.selected_action_id, - }, + outputs=outputs, edge_source_handle=submission_result.selected_action_id, ) - try: - repo = self._create_form_repository() - params = FormCreateParams( - workflow_execution_id=self._workflow_execution_id, - node_id=self.id, - form_config=self._node_data, - rendered_content=self._render_form_content(), - ) - result = repo.create_form(params) - # Create human input required event - required_event = HumanInputRequired( - form_id=result.id, - form_content=self._node_data.form_content, - inputs=self._node_data.inputs, - web_app_form_token=result.web_app_token, - ) - pause_requested_event = PauseRequestedEvent(reason=required_event) + return self._pause_with_form(form) - # Create workflow suspended event - - logger.info( - "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", - self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, - self.id, - result.id, - ) - except Exception as e: - logger.exception("Human Input node failed to execute, node_id=%s", self.id) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - error_type="HumanInputNodeError", - ) - return self._pause_generator(pause_requested_event) + def _pause_with_form(self, form_entity: HumanInputFormEntity) -> Generator[NodeEventBase, None, None]: + yield self._form_to_pause_event(form_entity) def _render_form_content(self) -> str: """ diff --git a/api/core/workflow/repositories/human_input_form_repository.py b/api/core/workflow/repositories/human_input_form_repository.py index 8fd33086f4..24862a3cb1 100644 --- a/api/core/workflow/repositories/human_input_form_repository.py +++ b/api/core/workflow/repositories/human_input_form_repository.py @@ -93,13 +93,18 @@ class HumanInputFormRepository(Protocol): application domains or deployment scenarios. """ + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + """Get the form created for a given human input node in a workflow execution. Returns + `None` if the form has not been created yet.""" + ... + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: """ Create a human input form from form definition. """ ... - def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmission | None: + def get_form_submission(self, form_id: str) -> FormSubmission | None: """Retrieve the submission for a specific human input node. Returns `FormSubmission` if the form has been submitted, or `None` if not. diff --git a/api/migrations/versions/2025_11_24_0336-d411af417245_add_human_input_models.py b/api/migrations/versions/2025_11_24_0336-d411af417245_add_human_input_models.py new file mode 100644 index 0000000000..1db60fe1a2 --- /dev/null +++ b/api/migrations/versions/2025_11_24_0336-d411af417245_add_human_input_models.py @@ -0,0 +1,72 @@ +"""Add human input related models + +Revision ID: d411af417245 +Revises: 669ffd70119c +Create Date: 2025-11-24 03:36:50.565145 + +""" + +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "d411af417245" +down_revision = "669ffd70119c" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "human_input_form_deliveries", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("form_id", models.types.StringUUID(), nullable=False), + sa.Column("delivery_method_type", sa.String(20), nullable=False), + sa.Column("delivery_config_id", models.types.StringUUID(), nullable=True), + sa.Column("channel_payload", sa.Text(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_deliveries_pkey")), + ) + op.create_table( + "human_input_form_recipients", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("form_id", models.types.StringUUID(), nullable=False), + sa.Column("delivery_id", models.types.StringUUID(), nullable=False), + sa.Column("recipient_type", sa.String(20), nullable=False), + sa.Column("recipient_payload", sa.Text(), nullable=False), + sa.Column("access_token", sa.VARCHAR(length=32), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_recipients_pkey")), + ) + op.create_table( + "human_input_forms", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False), + sa.Column("node_id", sa.String(length=60), nullable=False), + sa.Column("form_definition", sa.Text(), nullable=False), + sa.Column("rendered_content", sa.Text(), nullable=False), + sa.Column("status", sa.String(20), nullable=False), + sa.Column("expiration_time", sa.DateTime(), nullable=False), + sa.Column("selected_action_id", sa.String(length=200), nullable=True), + sa.Column("submitted_data", sa.Text(), nullable=True), + sa.Column("submitted_at", sa.DateTime(), nullable=True), + sa.Column("submission_user_id", models.types.StringUUID(), nullable=True), + sa.Column("submission_end_user_id", models.types.StringUUID(), nullable=True), + sa.Column("completed_by_recipient_id", models.types.StringUUID(), nullable=True), + + sa.PrimaryKeyConstraint("id", name=op.f("human_input_forms_pkey")), + ) + + +def downgrade(): + op.drop_table("human_input_forms") + op.drop_table("human_input_form_recipients") + op.drop_table("human_input_form_deliveries") diff --git a/api/models/__init__.py b/api/models/__init__.py index 68129dd6bc..11b2dd42ec 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -54,6 +54,7 @@ from .model import ( MessageAgentThought, MessageAnnotation, MessageChain, + MessageExtraContent, MessageFeedback, MessageFile, OperationLog, @@ -161,6 +162,7 @@ __all__ = [ "MessageAgentThought", "MessageAnnotation", "MessageChain", + "MessageExtraContent", "MessageFeedback", "MessageFile", "OperationLog", diff --git a/api/models/human_input.py b/api/models/human_input.py index b6f5b7911a..3889fc3c50 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -87,7 +87,7 @@ class HumanInputDelivery(DefaultFieldsMixin, Base): nullable=False, ) delivery_config_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - channel_payload: Mapped[None] = mapped_column(sa.Text, nullable=True) + channel_payload: Mapped[str] = mapped_column(sa.Text, nullable=False) form: Mapped[HumanInputForm] = relationship( "HumanInputForm", diff --git a/api/models/model.py b/api/models/model.py index 88cb945b3f..7173e15eae 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -24,11 +24,11 @@ from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 from .account import Account, Tenant -from .base import Base, TypeBase +from .base import Base, DefaultFieldsMixin, TypeBase from .engine import db from .enums import CreatorUserRole from .provider_ids import GenericProviderID -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID if TYPE_CHECKING: from .workflow import Workflow @@ -2065,3 +2065,19 @@ class TraceAppConfig(TypeBase): "created_at": str(self.created_at) if self.created_at else None, "updated_at": str(self.updated_at) if self.updated_at else None, } + + +class MessageExtraContentType(StrEnum): + human_input_result = auto() + + +class MessageExtraContent(DefaultFieldsMixin): + __tablename__ = "message_extra_contents" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="message_extra_content_pkey"), + sa.Index("message_extra_content_message_id_idx", "message_id"), + ) + + message_id = mapped_column(StringUUID, nullable=False, index=True) + type: Mapped[MessageExtraContentType] = mapped_column(EnumText(MessageExtraContentType, length=30), nullable=False) + content = mapped_column(sa.Text, nullable=False) diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 56bb8f2eaf..71a7ee9eff 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -5,7 +5,10 @@ from typing import Any from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord +from core.repositories.human_input_reposotiry import ( + HumanInputFormRecord, + HumanInputFormSubmissionRepository, +) from core.workflow.nodes.human_input.entities import FormDefinition from libs.exception import BaseHTTPException from models.account import Account @@ -65,6 +68,14 @@ class FormNotFoundError(HumanInputError, BaseHTTPException): code = 404 +class InvalidFormDataError(HumanInputError, BaseHTTPException): + error_code = "invalid_form_data" + code = 400 + + def __init__(self, description: str): + super().__init__(description=description) + + class WebAppDeliveryNotEnabledError(HumanInputError, BaseException): pass @@ -76,12 +87,12 @@ class HumanInputService: def __init__( self, session_factory: sessionmaker[Session] | Engine, - form_repository: HumanInputFormReadRepository | None = None, + form_repository: HumanInputFormSubmissionRepository | None = None, ): if isinstance(session_factory, Engine): session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory - self._form_repository = form_repository or HumanInputFormReadRepository(session_factory) + self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory) def get_form_by_token(self, form_token: str) -> Form | None: record = self._form_repository.get_by_token(form_token) @@ -124,6 +135,7 @@ class HumanInputService: raise WebAppDeliveryNotEnabledError() self._ensure_not_submitted(form) + self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data) result = self._form_repository.mark_submitted( form_id=form.id, @@ -149,6 +161,7 @@ class HumanInputService: raise WebAppDeliveryNotEnabledError() self._ensure_not_submitted(form) + self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data) result = self._form_repository.mark_submitted( form_id=form.id, @@ -165,6 +178,23 @@ class HumanInputService: if form.submitted: raise FormSubmittedError(form.id) + def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None: + definition = form.get_definition() + + available_actions = {action.id for action in definition.user_actions} + if selected_action_id not in available_actions: + raise InvalidFormDataError(f"Invalid action: {selected_action_id}") + + provided_inputs = set(form_data.keys()) + missing_inputs = [ + form_input.output_variable_name + for form_input in definition.inputs + if form_input.output_variable_name not in provided_inputs + ] + + if missing_inputs: + raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}") + def _enqueue_resume(self, workflow_run_id: str) -> None: with self._session_factory(expire_on_commit=False) as session: trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 96b6123334..784bebfb71 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -10,9 +10,9 @@ from unittest.mock import MagicMock import pytest from core.repositories.human_input_reposotiry import ( - HumanInputFormReadRepository, HumanInputFormRecord, HumanInputFormRepositoryImpl, + HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) from core.workflow.nodes.human_input.entities import ( @@ -22,6 +22,7 @@ from core.workflow.nodes.human_input.entities import ( TimeoutUnit, UserAction, ) +from core.workflow.repositories.human_input_form_repository import FormNotFoundError from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, @@ -197,23 +198,42 @@ class _FakeScalarResult: self._obj = obj def first(self): + if isinstance(self._obj, list): + return self._obj[0] if self._obj else None return self._obj + def all(self): + if isinstance(self._obj, list): + return list(self._obj) + if self._obj is None: + return [] + return [self._obj] + class _FakeSession: def __init__( self, *, scalars_result=None, + scalars_results: list[object] | None = None, forms: dict[str, _DummyForm] | None = None, recipients: dict[str, _DummyRecipient] | None = None, ): - self._scalars_result = scalars_result + if scalars_results is not None: + self._scalars_queue = list(scalars_results) + elif scalars_result is not None: + self._scalars_queue = [scalars_result] + else: + self._scalars_queue = [] self.forms = forms or {} self.recipients = recipients or {} def scalars(self, _query): - return _FakeScalarResult(self._scalars_result) + if self._scalars_queue: + result = self._scalars_queue.pop(0) + else: + result = None + return _FakeScalarResult(result) def get(self, model_cls, obj_id): # type: ignore[no-untyped-def] if getattr(model_cls, "__name__", None) == "HumanInputForm": @@ -255,7 +275,86 @@ def _session_factory(session: _FakeSession): return _factory -class TestHumanInputFormReadRepository: +class TestHumanInputFormRepositoryImplPublicMethods: + def test_get_form_returns_entity_and_recipients(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-id", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="recipient-1", + form_id=form.id, + recipient_type=RecipientType.WEBAPP, + access_token="token-123", + ) + session = _FakeSession(scalars_results=[form, [recipient]]) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + entity = repo.get_form(form.workflow_run_id, form.node_id) + + assert entity is not None + assert entity.id == form.id + assert entity.web_app_token == "token-123" + assert len(entity.recipients) == 1 + assert entity.recipients[0].token == "token-123" + + def test_get_form_returns_none_when_missing(self): + session = _FakeSession(scalars_results=[None]) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + assert repo.get_form("run-1", "node-1") is None + + def test_get_form_submission_returns_none_when_pending(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-id", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=naive_utc_now(), + ) + session = _FakeSession(forms={form.id: form}) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + assert repo.get_form_submission(form.id) is None + + def test_get_form_submission_returns_submission_when_completed(self): + form = _DummyForm( + id="form-1", + workflow_run_id="run-1", + node_id="node-1", + tenant_id="tenant-id", + form_definition=_make_form_definition(), + rendered_content="

hello

", + expiration_time=naive_utc_now(), + selected_action_id="approve", + submitted_data='{"field": "value"}', + submitted_at=naive_utc_now(), + ) + session = _FakeSession(forms={form.id: form}) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + submission = repo.get_form_submission(form.id) + + assert submission is not None + assert submission.selected_action_id == "approve" + assert submission.form_data() == {"field": "value"} + + def test_get_form_submission_raises_when_form_missing(self): + session = _FakeSession(forms={}) + repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + + with pytest.raises(FormNotFoundError): + repo.get_form_submission("form-unknown") + + +class TestHumanInputFormSubmissionRepository: def test_get_by_token_returns_record(self): form = _DummyForm( id="form-1", @@ -274,7 +373,7 @@ class TestHumanInputFormReadRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormReadRepository(_session_factory(session)) + repo = HumanInputFormSubmissionRepository(_session_factory(session)) record = repo.get_by_token("token-123") @@ -301,7 +400,7 @@ class TestHumanInputFormReadRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormReadRepository(_session_factory(session)) + repo = HumanInputFormSubmissionRepository(_session_factory(session)) record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.WEBAPP) @@ -332,7 +431,7 @@ class TestHumanInputFormReadRepository: forms={form.id: form}, recipients={recipient.id: recipient}, ) - repo = HumanInputFormReadRepository(_session_factory(session)) + repo = HumanInputFormSubmissionRepository(_session_factory(session)) record: HumanInputFormRecord = repo.mark_submitted( form_id=form.id, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py new file mode 100644 index 0000000000..bc8a08621a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py @@ -0,0 +1,101 @@ +"""Utilities for testing HumanInputNode without database dependencies.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + FormSubmission, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) + + +class _InMemoryFormRecipient(HumanInputFormRecipientEntity): + """Minimal recipient entity required by the repository interface.""" + + def __init__(self, recipient_id: str, token: str) -> None: + self._id = recipient_id + self._token = token + + @property + def id(self) -> str: + return self._id + + @property + def token(self) -> str: + return self._token + + +@dataclass +class _InMemoryFormEntity(HumanInputFormEntity): + form_id: str + token: str | None = None + + @property + def id(self) -> str: + return self.form_id + + @property + def web_app_token(self) -> str | None: + return self.token + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: + return [] + + +class _InMemoryFormSubmission(FormSubmission): + def __init__(self, selected_action_id: str, form_data: Mapping[str, Any]) -> None: + self._selected_action_id = selected_action_id + self._form_data = form_data + + @property + def selected_action_id(self) -> str: + return self._selected_action_id + + def form_data(self) -> Mapping[str, Any]: + return self._form_data + + +class InMemoryHumanInputFormRepository(HumanInputFormRepository): + """Pure in-memory repository used by workflow graph engine tests.""" + + def __init__(self) -> None: + self._form_counter = 0 + self.created_params: list[FormCreateParams] = [] + self.created_forms: list[_InMemoryFormEntity] = [] + self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} + self._submissions: dict[str, FormSubmission] = {} + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + self.created_params.append(params) + self._form_counter += 1 + form_id = f"form-{self._form_counter}" + entity = _InMemoryFormEntity(form_id=form_id, token=f"token-{form_id}") + self.created_forms.append(entity) + self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity + return entity + + def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_key.get((workflow_execution_id, node_id)) + + def get_form_submission(self, form_id: str) -> FormSubmission | None: + return self._submissions.get(form_id) + + # Convenience helpers for tests ------------------------------------- + + def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: + """Simulate a human submission for the next repository lookup.""" + + if not self.created_forms: + raise AssertionError("no form has been created to attach submission data") + target_form_id = self.created_forms[-1].id + self._submissions[target_form_id] = _InMemoryFormSubmission(action_id, form_data or {}) + + def clear_submission(self) -> None: + self._submissions.clear() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 8aa04a448c..69cde2cb4f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -1,5 +1,6 @@ import time from collections.abc import Iterable +from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -17,7 +18,7 @@ from core.workflow.graph_events import ( from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.llm.entities import ( ContextConfig, @@ -28,6 +29,11 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + FormSubmission, + HumanInputFormEntity, + HumanInputFormRepository, +) from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable @@ -36,7 +42,11 @@ from .test_mock_nodes import MockLLMNode from .test_table_runner import TableTestRunner, WorkflowTestCase -def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: +def _build_branching_graph( + mock_config: MockConfig, + form_repository: HumanInputFormRepository, + graph_runtime_state: GraphRuntimeState | None = None, +) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} graph_init_params = GraphInitParams( tenant_id="tenant", @@ -49,12 +59,18 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime call_depth=0, ) - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + if graph_runtime_state is None: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="test-execution-id", + ), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( @@ -93,15 +109,21 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime human_data = HumanInputNodeData( title="Human Input", - required_variables=["human.input_ready"], - pause_reason="Awaiting human input", + form_content="Human input required", + inputs=[], + user_actions=[ + UserAction(id="primary", title="Primary"), + UserAction(id="secondary", title="Secondary"), + ], ) + human_config = {"id": "human", "data": human_data.model_dump()} human_node = HumanInputNode( id=human_config["id"], config=human_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + form_repository=form_repository, ) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") @@ -219,8 +241,17 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: for scenario in branch_scenarios: runner = TableTestRunner() + mock_create_repo = MagicMock(spec=HumanInputFormRepository) + mock_create_repo.get_form_submission.return_value = None + mock_create_repo.get_form.return_value = None + mock_form_entity = MagicMock(spec=HumanInputFormEntity) + mock_form_entity.id = "test_form_id" + mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.recipients = [] + mock_create_repo.create_form.return_value = mock_form_entity + def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_branching_graph(mock_config) + return _build_branching_graph(mock_config, mock_create_repo) initial_case = WorkflowTestCase( description="HumanInput pause before branching decision", @@ -242,15 +273,6 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: assert initial_result.success, initial_result.event_mismatch_details assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) - graph_runtime_state = initial_result.graph_runtime_state - graph = initial_result.graph - assert graph_runtime_state is not None - assert graph is not None - - graph_runtime_state.variable_pool.add(("human", "input_ready"), True) - graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"]) - graph_runtime_state.graph_execution.pause_reason = None - pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) @@ -273,11 +295,18 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: ] ) - def resume_graph_factory( - graph_snapshot: Graph = graph, - state_snapshot: GraphRuntimeState = graph_runtime_state, - ) -> tuple[Graph, GraphRuntimeState]: - return graph_snapshot, state_snapshot + mock_get_repo = MagicMock(spec=HumanInputFormRepository) + mock_form_submission = MagicMock(spec=FormSubmission) + mock_form_submission.selected_action_id = scenario["handle"] + mock_form_submission.form_data.return_value = {} + mock_get_repo.get_form_submission.return_value = mock_form_submission + mock_get_repo.get_form.return_value = mock_form_entity + + def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: + assert initial_result.graph_runtime_state is not None + serialized_runtime_state = initial_result.graph_runtime_state.dumps() + resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) + return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state) resume_case = WorkflowTestCase( description=f"HumanInput resumes via {scenario['handle']} branch", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 47e3412b74..ad921b6279 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -1,4 +1,5 @@ import time +from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -16,7 +17,7 @@ from core.workflow.graph_events import ( from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.llm.entities import ( ContextConfig, @@ -27,6 +28,11 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + FormSubmission, + HumanInputFormEntity, + HumanInputFormRepository, +) from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable @@ -35,7 +41,11 @@ from .test_mock_nodes import MockLLMNode from .test_table_runner import TableTestRunner, WorkflowTestCase -def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: +def _build_llm_human_llm_graph( + mock_config: MockConfig, + form_repository: HumanInputFormRepository, + graph_runtime_state: GraphRuntimeState | None = None, +) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} graph_init_params = GraphInitParams( tenant_id="tenant", @@ -48,12 +58,15 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun call_depth=0, ) - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + if graph_runtime_state is None: + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," + ), + user_inputs={}, + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( @@ -92,15 +105,21 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun human_data = HumanInputNodeData( title="Human Input", - required_variables=["human.input_ready"], - pause_reason="Awaiting human input", + form_content="Human input required", + inputs=[], + user_actions=[ + UserAction(id="accept", title="Accept"), + UserAction(id="reject", title="Reject"), + ], ) + human_config = {"id": "human", "data": human_data.model_dump()} human_node = HumanInputNode( id=human_config["id"], config=human_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + form_repository=form_repository, ) llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") @@ -130,7 +149,7 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun .add_root(start_node) .add_node(llm_first) .add_node(human_node) - .add_node(llm_second) + .add_node(llm_second, source_handle="accept") .add_node(end_node) .build() ) @@ -167,8 +186,17 @@ def test_human_input_llm_streaming_order_across_pause() -> None: GraphRunPausedEvent, # graph run pauses awaiting resume ] + mock_create_repo = MagicMock(spec=HumanInputFormRepository) + mock_create_repo.get_form_submission.return_value = None + mock_create_repo.get_form.return_value = None + mock_form_entity = MagicMock(spec=HumanInputFormEntity) + mock_form_entity.id = "test_form_id" + mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.recipients = [] + mock_create_repo.create_form.return_value = mock_form_entity + def graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_llm_human_llm_graph(mock_config) + return _build_llm_human_llm_graph(mock_config, mock_create_repo) initial_case = WorkflowTestCase( description="HumanInput pause preserves LLM streaming order", @@ -225,12 +253,22 @@ def test_human_input_llm_streaming_order_across_pause() -> None: GraphRunSucceededEvent, # graph run succeeds after resume ] + mock_get_repo = MagicMock(spec=HumanInputFormRepository) + mock_form_submission = MagicMock(spec=FormSubmission) + mock_form_submission.selected_action_id = "accept" + mock_form_submission.form_data.return_value = {} + mock_get_repo.get_form_submission.return_value = mock_form_submission + mock_get_repo.get_form.return_value = mock_form_entity + def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: - assert graph_runtime_state is not None - assert graph is not None - graph_runtime_state.variable_pool.add(("human", "input_ready"), True) - graph_runtime_state.graph_execution.pause_reason = None - return graph, graph_runtime_state + # restruct the graph runtime state + serialized_runtime_state = initial_result.graph_runtime_state.dumps() + resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) + return _build_llm_human_llm_graph( + mock_config, + mock_get_repo, + resume_runtime_state, + ) resume_case = WorkflowTestCase( description="HumanInput resume continues LLM streaming order", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 5b3e8ad76a..1f8e5855f5 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -1,11 +1,8 @@ import time -from collections.abc import Generator, Mapping from typing import Any +from unittest.mock import MagicMock -import core.workflow.nodes.human_input.entities # noqa: F401 from core.workflow.entities import GraphInitParams -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel from core.workflow.graph_engine.graph_engine import GraphEngine @@ -15,72 +12,66 @@ from core.workflow.graph_events import ( GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import NodeEventBase, NodeRunResult, PauseRequestedEvent -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector -from core.workflow.nodes.base.node import Node +from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.human_input_node import HumanInputNode from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import ( + FormSubmission, + HumanInputFormEntity, + HumanInputFormRepository, +) from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable -class _PausingNodeData(BaseNodeData): - pass - - -class _PausingNode(Node): - node_type = NodeType.TOOL - - def init_node_data(self, data: Mapping[str, Any]) -> None: - self._node_data = _PausingNodeData.model_validate(data) - - def _get_error_strategy(self): - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - - @staticmethod - def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]: - yield event - - def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: - resumed_flag = self.graph_runtime_state.variable_pool.get((self.id, "resumed")) - if resumed_flag is None: - # mark as resumed and request pause - self.graph_runtime_state.variable_pool.add((self.id, "resumed"), True) - return self._pause_generator(PauseRequestedEvent(reason=SchedulingPause(message="test pause"))) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"value": "completed"}, - ) - - def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=SystemVariable( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="test-execution-id", + ), user_inputs={}, conversation_variables=[], ) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) -def _build_pausing_graph(runtime_state: GraphRuntimeState) -> Graph: +def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: + submission = MagicMock(spec=FormSubmission) + submission.selected_action_id = action_id + submission.form_data.return_value = {} + repo = MagicMock(spec=HumanInputFormRepository) + repo.get_form_submission.return_value = submission + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + repo.get_form.return_value = form_entity + return repo + + +def _mock_form_repository_without_submission() -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + repo.get_form_submission.return_value = None + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + repo.create_form.return_value = form_entity + repo.get_form.return_value = None + return repo + + +def _build_human_input_graph( + runtime_state: GraphRuntimeState, + form_repository: HumanInputFormRepository, +) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} params = GraphInitParams( tenant_id="tenant", @@ -102,19 +93,27 @@ def _build_pausing_graph(runtime_state: GraphRuntimeState) -> Graph: ) start_node.init_node_data(start_data.model_dump()) - pause_data = _PausingNodeData(title="pausing") - pause_node = _PausingNode( - id="pausing", - config={"id": "pausing", "data": pause_data.model_dump()}, + human_data = HumanInputNodeData( + title="human", + form_content="Awaiting human input", + inputs=[], + user_actions=[ + UserAction(id="continue", title="Continue"), + ], + ) + human_node = HumanInputNode( + id="human", + config={"id": "human", "data": human_data.model_dump()}, graph_init_params=params, graph_runtime_state=runtime_state, + form_repository=form_repository, ) - pause_node.init_node_data(pause_data.model_dump()) + human_node.init_node_data(human_data.model_dump()) end_data = EndNodeData( title="end", outputs=[ - VariableSelector(variable="result", value_selector=["pausing", "value"]), + VariableSelector(variable="result", value_selector=["human", "action_id"]), ], desc=None, ) @@ -126,7 +125,13 @@ def _build_pausing_graph(runtime_state: GraphRuntimeState) -> Graph: ) end_node.init_node_data(end_data.model_dump()) - return Graph.new().add_root(start_node).add_node(pause_node).add_node(end_node).build() + return ( + Graph.new() + .add_root(start_node) + .add_node(human_node) + .add_node(end_node, from_node_id="human", source_handle="continue") + .build() + ) def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: @@ -152,22 +157,24 @@ def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> An def test_engine_resume_restores_state_and_completion(): # Baseline run without pausing baseline_state = _build_runtime_state() - baseline_graph = _build_pausing_graph(baseline_state) - baseline_state.variable_pool.add(("pausing", "resumed"), True) + baseline_repo = _mock_form_repository_with_submission(action_id="continue") + baseline_graph = _build_human_input_graph(baseline_state, baseline_repo) baseline_events = _run_graph(baseline_graph, baseline_state) assert isinstance(baseline_events[-1], GraphRunSucceededEvent) baseline_success_nodes = _node_successes(baseline_events) # Run with pause paused_state = _build_runtime_state() - paused_graph = _build_pausing_graph(paused_state) + pause_repo = _mock_form_repository_without_submission() + paused_graph = _build_human_input_graph(paused_state, pause_repo) paused_events = _run_graph(paused_graph, paused_state) assert isinstance(paused_events[-1], GraphRunPausedEvent) snapshot = paused_state.dumps() # Resume from snapshot resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_pausing_graph(resumed_state) + resume_repo = _mock_form_repository_with_submission(action_id="continue") + resumed_graph = _build_human_input_graph(resumed_state, resume_repo) resumed_events = _run_graph(resumed_graph, resumed_state) assert isinstance(resumed_events[-1], GraphRunSucceededEvent) @@ -175,11 +182,8 @@ def test_engine_resume_restores_state_and_completion(): assert combined_success_nodes == baseline_success_nodes assert baseline_state.outputs == resumed_state.outputs - assert _segment_value(baseline_state.variable_pool, ("pausing", "resumed")) == _segment_value( - resumed_state.variable_pool, ("pausing", "resumed") - ) - assert _segment_value(baseline_state.variable_pool, ("pausing", "value")) == _segment_value( - resumed_state.variable_pool, ("pausing", "value") + assert _segment_value(baseline_state.variable_pool, ("human", "action_id")) == _segment_value( + resumed_state.variable_pool, ("human", "action_id") ) assert baseline_state.graph_execution.completed assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 488b47761b..21a642c2f8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -7,6 +7,7 @@ from core.workflow.nodes.base.node import Node # Ensures that all node classes are imported. from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed. _ = NODE_TYPE_CLASSES_MAPPING @@ -45,7 +46,9 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined assert isinstance(cls.node_type, NodeType) assert isinstance(node_version, str) node_type_and_version = (node_type, node_version) - assert node_type_and_version not in type_version_set + assert node_type_and_version not in type_version_set, ( + f"Duplicate node type and version for class: {cls=} {node_type_and_version=}" + ) type_version_set.add(node_type_and_version) diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index ed9569be4b..504f27bc2b 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -4,11 +4,14 @@ from unittest.mock import MagicMock import pytest -from core.repositories.human_input_reposotiry import HumanInputFormReadRepository, HumanInputFormRecord -from core.workflow.nodes.human_input.entities import FormDefinition, TimeoutUnit, UserAction +from core.repositories.human_input_reposotiry import ( + HumanInputFormRecord, + HumanInputFormSubmissionRepository, +) +from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, FormInputType, TimeoutUnit, UserAction from models.account import Account from models.human_input import RecipientType -from services.human_input_service import FormSubmittedError, HumanInputService +from services.human_input_service import FormSubmittedError, HumanInputService, InvalidFormDataError @pytest.fixture @@ -129,7 +132,7 @@ def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory): def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory): session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormReadRepository) + repo = MagicMock(spec=HumanInputFormSubmissionRepository) repo.get_by_form_id_and_recipient_type.return_value = sample_form_record service = HumanInputService(session_factory, form_repository=repo) @@ -146,7 +149,7 @@ def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_sess def test_get_form_definition_by_id_raises_on_submitted(sample_form_record, mock_session_factory): session_factory, _ = mock_session_factory submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime(2024, 1, 1)) - repo = MagicMock(spec=HumanInputFormReadRepository) + repo = MagicMock(spec=HumanInputFormSubmissionRepository) repo.get_by_form_id_and_recipient_type.return_value = submitted_record service = HumanInputService(session_factory, form_repository=repo) @@ -157,7 +160,7 @@ def test_get_form_definition_by_id_raises_on_submitted(sample_form_record, mock_ def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker): session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormReadRepository) + repo = MagicMock(spec=HumanInputFormSubmissionRepository) repo.get_by_token.return_value = sample_form_record repo.mark_submitted.return_value = sample_form_record service = HumanInputService(session_factory, form_repository=repo) @@ -166,7 +169,7 @@ def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, m service.submit_form_by_token( recipient_type=RecipientType.WEBAPP, form_token="token", - selected_action_id="approve", + selected_action_id="submit", form_data={"field": "value"}, submission_end_user_id="end-user-id", ) @@ -176,7 +179,7 @@ def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, m call_kwargs = repo.mark_submitted.call_args.kwargs assert call_kwargs["form_id"] == sample_form_record.form_id assert call_kwargs["recipient_id"] == sample_form_record.recipient_id - assert call_kwargs["selected_action_id"] == "approve" + assert call_kwargs["selected_action_id"] == "submit" assert call_kwargs["form_data"] == {"field": "value"} assert call_kwargs["submission_end_user_id"] == "end-user-id" enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) @@ -184,7 +187,7 @@ def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, m def test_submit_form_by_id_passes_account(sample_form_record, mock_session_factory, mocker): session_factory, _ = mock_session_factory - repo = MagicMock(spec=HumanInputFormReadRepository) + repo = MagicMock(spec=HumanInputFormSubmissionRepository) repo.get_by_form_id_and_recipient_type.return_value = sample_form_record repo.mark_submitted.return_value = sample_form_record service = HumanInputService(session_factory, form_repository=repo) @@ -194,7 +197,7 @@ def test_submit_form_by_id_passes_account(sample_form_record, mock_session_facto service.submit_form_by_id( form_id="form-id", - selected_action_id="approve", + selected_action_id="submit", form_data={"x": 1}, user=account, ) @@ -203,3 +206,49 @@ def test_submit_form_by_id_passes_account(sample_form_record, mock_session_facto repo.mark_submitted.assert_called_once() assert repo.mark_submitted.call_args.kwargs["submission_user_id"] == "account-id" enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id) + + +def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = dataclasses.replace(sample_form_record) + service = HumanInputService(session_factory, form_repository=repo) + + with pytest.raises(InvalidFormDataError) as exc_info: + service.submit_form_by_token( + recipient_type=RecipientType.WEBAPP, + form_token="token", + selected_action_id="invalid", + form_data={}, + ) + + assert "Invalid action" in str(exc_info.value) + repo.mark_submitted.assert_not_called() + + +def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + + definition_with_input = FormDefinition( + form_content="hello", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content")], + user_actions=sample_form_record.definition.user_actions, + rendered_content="

hello

", + timeout=1, + timeout_unit=TimeoutUnit.HOUR, + ) + form_with_input = dataclasses.replace(sample_form_record, definition=definition_with_input) + repo.get_by_token.return_value = form_with_input + service = HumanInputService(session_factory, form_repository=repo) + + with pytest.raises(InvalidFormDataError) as exc_info: + service.submit_form_by_token( + recipient_type=RecipientType.WEBAPP, + form_token="token", + selected_action_id="submit", + form_data={}, + ) + + assert "Missing required inputs" in str(exc_info.value) + repo.mark_submitted.assert_not_called()