diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 8882d51c00..8a7cf4d274 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -4,56 +4,79 @@ Console/Studio Human Input Form APIs. import json import logging +from collections.abc import Generator -from flask import g, jsonify +from flask import Response, jsonify from flask_restx import Resource, reqparse +from pydantic import BaseModel +from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import NotFoundError +from core.workflow.nodes.human_input.entities import FormDefinition from extensions.ext_database import db -from libs.login import login_required -from models.human_input import HumanInputSubmissionType -from services.human_input_form_service import ( - HumanInputFormAlreadySubmittedError, - HumanInputFormExpiredError, - HumanInputFormNotFoundError, - HumanInputFormService, - InvalidFormDataError, -) +from libs.login import current_account_with_tenant, login_required +from models.human_input import HumanInputForm as HumanInputFormModel +from services.human_input_service import HumanInputService logger = logging.getLogger(__name__) +class _FormDefinitionWithSite(FormDefinition): + site: None + + +def _jsonify_pydantic_model(model: BaseModel) -> Response: + return Response(model.model_dump_json(), mimetype="application/json") + + class ConsoleHumanInputFormApi(Resource): """Console API for getting human input form definition.""" - @account_initialization_required + @setup_required @login_required + @account_initialization_required def get(self, form_id: str): """ Get human input form definition by form ID. GET /console/api/form/human_input/ """ - try: - service = HumanInputFormService(db.session()) - form_definition = service.get_form_definition(identifier=form_id, is_token=False, include_site_info=False) - return form_definition, 200 + service = HumanInputService(db.engine) + form = service.get_form_definition_by_id( + form_id=form_id, + ) + if form is None: + raise NotFoundError(f"form not found, id={form_id}") - except HumanInputFormNotFoundError: - raise NotFoundError("Form not found") - except HumanInputFormExpiredError: - return jsonify( - {"error_code": "human_input_form_expired", "description": "Human input form has expired"} - ), 400 - except HumanInputFormAlreadySubmittedError: - return jsonify( - { - "error_code": "human_input_form_submitted", - "description": "Human input form has already been submitted", - } - ), 400 + current_user, current_tenant_id = current_account_with_tenant() + form_model = db.session.get(HumanInputFormModel, form_id) + if form_model is None or form_model.tenant_id != current_tenant_id: + raise NotFoundError(f"form not found, id={form_id}") + + from models import App + from models.workflow import Workflow, WorkflowRun + + workflow_run = db.session.get(WorkflowRun, form_model.workflow_run_id) + if workflow_run is None or workflow_run.tenant_id != current_tenant_id: + raise NotFoundError("Workflow run not found") + + if workflow_run.app_id: + app = db.session.get(App, workflow_run.app_id) + if app is None or app.tenant_id != current_tenant_id: + raise NotFoundError("App not found") + owner_account_id = app.created_by + else: + workflow = db.session.get(Workflow, workflow_run.workflow_id) + if workflow is None or workflow.tenant_id != current_tenant_id: + raise NotFoundError("Workflow not found") + owner_account_id = workflow.created_by + + if owner_account_id != current_user.id: + raise Forbidden("You do not have permission to access this human input form.") + + return _jsonify_pydantic_model(form.get_definition()) class ConsoleHumanInputFormSubmissionApi(Resource): @@ -80,75 +103,15 @@ class ConsoleHumanInputFormSubmissionApi(Resource): parser.add_argument("action", type=str, required=True, location="json") args = parser.parse_args() - try: - # Submit the form - service = HumanInputFormService(db.session()) - service.submit_form( - identifier=form_id, - form_data=args["inputs"], - action=args["action"], - is_token=False, - submission_type=HumanInputSubmissionType.web_form, - submission_user_id=g.current_user.id, - ) + # Submit the form + service = HumanInputService(db.engine) + service.submit_form_by_id( + form_id=form_id, + selected_action_id=args["action"], + form_data=args["inputs"], + ) - return {}, 200 - - except HumanInputFormNotFoundError: - raise NotFoundError("Form not found") - except HumanInputFormExpiredError: - return jsonify( - {"error_code": "human_input_form_expired", "description": "Human input form has expired"} - ), 400 - except HumanInputFormAlreadySubmittedError: - return jsonify( - { - "error_code": "human_input_form_submitted", - "description": "Human input form has already been submitted", - } - ), 400 - except InvalidFormDataError as e: - return jsonify({"error_code": "invalid_form_data", "description": e.message}), 400 - - -class ConsoleWorkflowResumeWaitApi(Resource): - """Console API for long-polling workflow resume wait.""" - - @account_initialization_required - @login_required - def get(self, task_id: str): - """ - Get workflow execution resume notification. - - GET /console/api/workflow//resume-wait - - This is a long-polling API that waits for workflow to resume from paused state. - """ - import time - - # TODO: Implement actual workflow status checking - # For now, return a basic response - - timeout = 30 # 30 seconds timeout for demo - start_time = time.time() - - while time.time() - start_time < timeout: - # TODO: Check workflow status from database/cache - # workflow_status = workflow_service.get_status(task_id) - - # For demo purposes, simulate different states - # In real implementation, this would check the actual workflow state - workflow_status = "paused" # or "running" or "ended" - - if workflow_status == "running": - return {"status": "running"}, 200 - elif workflow_status == "ended": - return {"status": "ended"}, 200 - - time.sleep(1) # Poll every second - - # Return paused status if timeout reached - return {"status": "paused"}, 200 + return jsonify({}) class ConsoleWorkflowEventsApi(Resource): @@ -156,7 +119,7 @@ class ConsoleWorkflowEventsApi(Resource): @account_initialization_required @login_required - def get(self, task_id: str): + def get(self, workflow_run_id: str): """ Get workflow execution events stream after resume. @@ -164,9 +127,8 @@ class ConsoleWorkflowEventsApi(Resource): Returns Server-Sent Events stream. """ - from collections.abc import Generator - from flask import Response + events = def generate_events() -> Generator[str, None, None]: """Generate SSE events for workflow execution.""" @@ -263,6 +225,5 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Register the APIs api.add_resource(ConsoleHumanInputFormApi, "/form/human_input/") api.add_resource(ConsoleHumanInputFormSubmissionApi, "/form/human_input/", methods=["POST"]) -api.add_resource(ConsoleWorkflowResumeWaitApi, "/workflow//resume-wait") -api.add_resource(ConsoleWorkflowEventsApi, "/workflow//events") +api.add_resource(ConsoleWorkflowEventsApi, "/workflow//events") api.add_resource(ConsoleWorkflowPauseDetailsApi, "/workflow//pause-details") diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 38ecec5d30..e1a9e38166 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime -from typing import Any, NewType, Union +from typing import Any, NamedTuple, NewType, Union, final from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -53,6 +53,7 @@ from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser +from models.workflow import WorkflowRun from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator NodeExecutionId = NewType("NodeExecutionId", str) @@ -264,6 +265,54 @@ class WorkflowResponseConverter: ), ) + @classmethod + def workflow_run_result_to_finish_response( + cls, + *, + workflow_run: WorkflowRun, + creator_user: Account | EndUser, + ) -> WorkflowFinishStreamResponse: + run_id = workflow_run.id + elapsed_time = workflow_run.elapsed_time + + encoded_outputs = workflow_run.outputs_dict + finished_at = workflow_run.finished_at + assert finished_at is not None + + created_by: Mapping[str, object] + user = creator_user + if isinstance(user, Account): + created_by = { + "id": user.id, + "name": user.name, + "email": user.email, + } + else: + created_by = { + "id": user.id, + "user": user.session_id, + } + + return WorkflowFinishStreamResponse( + task_id=task_id, # TODO + workflow_run_id=run_id, + data=WorkflowFinishStreamResponse.Data( + id=run_id, + workflow_id=workflow_run.workflow_id, + status=workflow_run.status.value, + outputs=encoded_outputs, + error=workflow_run.error, + elapsed_time=elapsed_time, + total_tokens=workflow_run.total_tokens, + total_steps=workflow_run.total_steps, + created_by=created_by, + created_at=int(workflow_run.created_at.timestamp()), + finished_at=int(finished_at.timestamp()), + files=cls.fetch_files_from_node_outputs(encoded_outputs), + exceptions_count=workflow_run.exceptions_count, + ), + ) + def workflow_node_start_to_stream_response( self, *, @@ -592,7 +641,8 @@ class WorkflowResponseConverter: ), ) - def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: + @classmethod + def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -601,7 +651,7 @@ class WorkflowResponseConverter: if not outputs_dict: return [] - files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] + files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] # Remove None files = [file for file in files if file] # Flatten list diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 6d8e03efa0..f27819c34c 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,9 +1,13 @@ import json import logging +import time import uuid from collections.abc import Generator -from typing import Mapping, Union, cast +from typing import Any, Mapping, Union, cast +from libs.broadcast_channel.channel import Subscription, Topic +from libs.broadcast_channel.exc import SubscriptionClosedError +from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from sqlalchemy import select from sqlalchemy.orm import Session @@ -29,8 +33,7 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db -from libs.broadcast_channel.channel import Topic -from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel +from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models import Account from models.enums import CreatorUserRole @@ -38,7 +41,6 @@ from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Me from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError -from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -302,15 +304,34 @@ class MessageBasedAppGenerator(BaseAppGenerator): return topic @classmethod - def retrieve_events(cls, app_mode: AppMode, workflow_run_id: uuid.UUID) -> Generator[Mapping | str, None, None]: + def retrieve_events( + cls, app_mode: AppMode, workflow_run_id: uuid.UUID, idle_timeout=300 + ) -> Generator[Mapping | str, None, None]: topic = cls.get_response_topic(app_mode, workflow_run_id) - with topic.subscribe() as sub: - for payload in sub: - event = json.loads(payload) - yield event - if not isinstance(event, dict): - continue + return _topic_msg_generator(topic, idle_timeout) - event_type = event.get("event") - if event_type == StreamEvent.WORKFLOW_FINISHED: + +def _topic_msg_generator(topic: Topic, idle_timeout: float) -> Generator[Mapping[str, Any], None, None]: + last_msg_time = time.time() + with topic.subscribe() as sub: + while True: + try: + msg = sub.receive() + except SubscriptionClosedError: + return + if msg is None: + current_time = time.time() + if current_time - last_msg_time > idle_timeout: return + # skip the `None` message + continue + + last_msg_time = time.time() + event = json.loads(msg) + yield event + if not isinstance(event, dict): + continue + + event_type = event.get("event") + if event_type in (StreamEvent.WORKFLOW_FINISHED, StreamEvent.WORKFLOW_PAUSED): + return diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 79a5e657b3..89a8d5720e 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -69,6 +69,7 @@ class StreamEvent(StrEnum): AGENT_THOUGHT = "agent_thought" AGENT_MESSAGE = "agent_message" WORKFLOW_STARTED = "workflow_started" + WORKFLOW_PAUSED = "workflow_paused" WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" diff --git a/api/core/repositories/human_input_reposotiry.py b/api/core/repositories/human_input_reposotiry.py new file mode 100644 index 0000000000..1ab09b91cd --- /dev/null +++ b/api/core/repositories/human_input_reposotiry.py @@ -0,0 +1,232 @@ +import abc +import dataclasses +import json +import uuid +from collections.abc import Sequence +from typing import Any, Mapping + +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session, sessionmaker + +from core.workflow.nodes.human_input.entities import ( + DeliveryChannelConfig, + EmailDeliveryMethod, + EmailRecipient, + ExternalRecipient, + FormDefinition, + HumanInputNodeData, + MemberRecipient, + WebAppDeliveryMethod, +) +from core.workflow.repositories.human_input_form_repository import ( + FormCreateParams, + FormNotFoundError, + FormSubmissionEntity, + HumanInputFormEntity, +) +from libs.datetime_utils import naive_utc_now +from libs.uuid_utils import uuidv7 +from models.human_input import ( + EmailExternalRecipientPayload, + EmailMemberRecipientPayload, + HumanInputDelivery, + HumanInputForm, + HumanInputRecipient, + RecipientType, + WebAppRecipientPayload, +) + + +@dataclasses.dataclass(frozen=True) +class _DeliveryAndRecipients: + delivery: HumanInputDelivery + recipients: Sequence[HumanInputRecipient] + + def webapp_recipient(self) -> HumanInputRecipient | None: + return next((i for i in self.recipients if i.recipient_type == RecipientType.WEBAPP), None) + + +class _HumanInputFormEntityImpl(HumanInputFormEntity): + def __init__(self, form_model: HumanInputForm, web_app_recipient: HumanInputRecipient | None): + self._form_model = form_model + self._web_app_recipient = web_app_recipient + + @property + def id(self) -> str: + return self._form_model.id + + @property + def web_app_token(self): + if self._web_app_recipient is None: + return None + return self._web_app_recipient.access_token + + +class _FormSubmissionEntityImpl(FormSubmissionEntity): + def __init__(self, form_model: HumanInputForm): + self._form_model = form_model + + @property + def selected_action_id(self) -> str: + selected_action_id = self._form_model.selected_action_id + if selected_action_id is None: + raise AssertionError(f"selected_action_id should not be None, form_id={self._form_model.id}") + return selected_action_id + + def form_data(self) -> Mapping[str, Any]: + submitted_data = self._form_model.submitted_data + if submitted_data is None: + raise AssertionError(f"submitted_data should not be None, form_id={self._form_model.id}") + return json.loads(submitted_data) + + +class WorkspaceMember: + def user_id(self) -> str: + pass + + def email(self) -> str: + pass + + +class WorkspaceMemberQueirer: + def get_all_workspace_members(self) -> Sequence[WorkspaceMember]: + # TOOD: need a way to query all members in the current workspace. + pass + + def get_members_by_ids(self, user_ids: Sequence[str]) -> Sequence[WorkspaceMember]: + pass + + +class HumanInputFormRepositoryImpl: + def __init__( + self, + session_factory: sessionmaker | Engine, + tenant_id: str, + member_quierer: WorkspaceMemberQueirer, + ): + if isinstance(session_factory, Engine): + session_factory = sessionmaker(bind=session_factory) + self._session_factory = session_factory + self._tenant_id = tenant_id + self._member_queirer = member_quierer + + def _delivery_method_to_model(self, form_id, delivery_method: DeliveryChannelConfig) -> _DeliveryAndRecipients: + delivery_id = str(uuidv7()) + delivery_model = HumanInputDelivery( + id=delivery_id, + form_id=form_id, + delivery_method_type=delivery_method.type, + delivery_config_id=delivery_method.id, + channel_payload=delivery_method.model_dump_json(), + ) + recipients: list[HumanInputRecipient] = [] + if isinstance(delivery_method, WebAppDeliveryMethod): + recipient_model = HumanInputRecipient( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=RecipientType.WEBAPP, + recipient_payload=WebAppRecipientPayload().model_dump_json(), + ) + recipients.append(recipient_model) + elif isinstance(delivery_method, EmailDeliveryMethod): + email_recipients_config = delivery_method.config.recipients + if email_recipients_config.whole_workspace: + recipients.extend(self._create_whole_workspace_recipients(form_id=form_id, delivery_id=delivery_id)) + else: + recipients.extend( + self._create_email_recipients( + form_id=form_id, delivery_id=delivery_id, recipients=email_recipients_config.items + ) + ) + + return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients) + + def _create_email_recipients( + self, + form_id: str, + delivery_id: str, + recipients: Sequence[EmailRecipient], + ) -> list[HumanInputRecipient]: + recipient_models: list[HumanInputRecipient] = [] + member_user_ids: list[str] = [] + for r in recipients: + if isinstance(r, MemberRecipient): + member_user_ids.append(r.user_id) + elif isinstance(r, ExternalRecipient): + recipient_model = HumanInputRecipient.new( + form_id=form_id, delivery_id=delivery_id, payload=EmailExternalRecipientPayload(email=r.email) + ) + recipient_models.append(recipient_model) + else: + raise AssertionError(f"unknown recipient type: recipient={r}") + + members = self._member_queirer.get_members_by_ids(member_user_ids) + for member in members: + payload = EmailMemberRecipientPayload(user_id=member.user_id(), email=member.email()) + recipient_model = HumanInputRecipient.new( + form_id=form_id, + delivery_id=delivery_id, + payload=payload, + ) + recipient_models.append(recipient_model) + return recipient_models + + def _create_whole_workspace_recipients(self, form_id: str, delivery_id: str) -> list[HumanInputRecipient]: + recipeint_models = [] + members = self._member_queirer.get_all_workspace_members() + for member in members: + payload = EmailMemberRecipientPayload(user_id=member.user_id(), email=member.email()) + recipient_model = HumanInputRecipient.new( + form_id=form_id, + delivery_id=delivery_id, + payload=payload, + ) + recipeint_models.append(recipient_model) + + return recipeint_models + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + form_config: HumanInputNodeData = params.form_config + + with self._session_factory(expire_on_commit=False) as session, session.begin(): + # Generate unique form ID + form_id = str(uuidv7()) + form_definition = FormDefinition( + inputs=form_config.inputs, + user_actions=form_config.user_actions, + ) + form_model = HumanInputForm( + id=form_id, + tenant_id=self._tenant_id, + workflow_run_id=params.workflow_execution_id, + node_id=params.node_id, + form_definition=form_definition.model_dump_json(), + rendered_content=params.rendered_content, + expiration_time=form_config.expiration_time(naive_utc_now()), + ) + session.add(form_model) + web_app_recipient: HumanInputRecipient | None = None + for delivery in form_config.delivery_methods: + delivery_and_recipients = self._delivery_method_to_model(form_id, delivery) + session.add(delivery_and_recipients.delivery) + session.add_all(delivery_and_recipients.recipients) + if web_app_recipient is None: + web_app_recipient = delivery_and_recipients.webapp_recipient() + session.flush() + + return _HumanInputFormEntityImpl(form_model=form_model, web_app_recipient=web_app_recipient) + + def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmissionEntity | None: + query = select(HumanInputForm).where( + HumanInputForm.workflow_run_id == workflow_execution_id, + HumanInputForm.node_id == node_id, + ) + with self._session_factory(expire_on_commit=False) as session: + form_model: HumanInputForm | None = session.scalars(query).first() + if form_model is None: + raise FormNotFoundError(f"form not found for node, {workflow_execution_id=}, {node_id=}") + + if form_model.submitted_at is None: + return None + + return _FormSubmissionEntityImpl(form_model=form_model) diff --git a/api/core/repositories/sqlalchemy_human_input_form_repository.py b/api/core/repositories/sqlalchemy_human_input_form_repository.py deleted file mode 100644 index 12e5dac90b..0000000000 --- a/api/core/repositories/sqlalchemy_human_input_form_repository.py +++ /dev/null @@ -1,272 +0,0 @@ -""" -SQLAlchemy implementation of the HumanInputFormRepository. -""" - -import json -import logging -from datetime import datetime, timedelta -from typing import Union - -from sqlalchemy.engine import Engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy import and_, select - -from core.workflow.entities.human_input_form import ( - HumanInputForm, - HumanInputFormStatus, - HumanInputSubmissionType, - FormSubmission, -) -from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository -from libs.helper import extract_tenant_id -from models.human_input import ( - HumanInputForm as HumanInputFormModel, - HumanInputFormStatus as HumanInputFormStatusModel, - HumanInputSubmissionType as HumanInputSubmissionTypeModel, -) -from models import Account, EndUser - -logger = logging.getLogger(__name__) - - -class SQLAlchemyHumanInputFormRepository(HumanInputFormRepository): - """ - SQLAlchemy implementation of the HumanInputFormRepository interface. - - This implementation supports multi-tenancy by filtering operations based on tenant_id. - Each method creates its own session, handles the transaction, and commits changes - to the database. This prevents long-running connections in the workflow core. - """ - - def __init__( - self, - session_factory: Union[sessionmaker, Engine], - user: Union[Account, EndUser], - app_id: str | None, - ): - """ - Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. - - Args: - session_factory: SQLAlchemy sessionmaker or engine for creating sessions - user: Account or EndUser object containing tenant_id, user ID, and role information - app_id: App ID for filtering by application (can be None) - """ - # If an engine is provided, create a sessionmaker from it - if isinstance(session_factory, Engine): - self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) - elif isinstance(session_factory, sessionmaker): - self._session_factory = session_factory - else: - raise ValueError( - f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" - ) - - # Extract tenant_id from user - tenant_id = extract_tenant_id(user) - if not tenant_id: - raise ValueError("User must have a tenant_id or current_tenant_id") - self._tenant_id = tenant_id - - # Store app context - self._app_id = app_id - - def _to_domain_model(self, db_model: HumanInputFormModel) -> HumanInputForm: - """ - Convert a database model to a domain model. - - Args: - db_model: The database model to convert - - Returns: - The domain model - """ - # Parse JSON fields - form_definition = json.loads(db_model.form_definition) if db_model.form_definition else {} - - # Create submission if present - submission = None - if db_model.status == HumanInputFormStatusModel.SUBMITTED and db_model.submitted_data: - submission = FormSubmission( - data=json.loads(db_model.submitted_data) if db_model.submitted_data else {}, - action="", # Action is not stored separately in DB model, would need to be stored in submitted_data - submitted_at=db_model.submitted_at or datetime.utcnow(), - submission_type=HumanInputSubmissionType(db_model.submission_type.value) - if db_model.submission_type - else HumanInputSubmissionType.web_form, - submission_user_id=db_model.submission_user_id, - submission_end_user_id=db_model.submission_end_user_id, - submitter_email=db_model.submitter_email, - ) - - return HumanInputForm( - id_=db_model.id, - workflow_run_id=db_model.workflow_run_id, - form_definition=form_definition, - rendered_content=db_model.rendered_content, - status=HumanInputFormStatus(db_model.status.value), - web_app_token=db_model.web_app_token, - submission=submission, - created_at=db_model.created_at, - ) - - def _to_db_model(self, domain_model: HumanInputForm) -> HumanInputFormModel: - """ - Convert a domain model to a database model. - - Args: - domain_model: The domain model to convert - - Returns: - The database model - """ - db_model = HumanInputFormModel() - db_model.id = domain_model.id_ - db_model.tenant_id = self._tenant_id - if self._app_id is not None: - db_model.app_id = self._app_id - db_model.workflow_run_id = domain_model.workflow_run_id - db_model.form_definition = json.dumps(domain_model.form_definition) if domain_model.form_definition else None - db_model.rendered_content = domain_model.rendered_content - db_model.status = HumanInputFormStatusModel(domain_model.status.value) - db_model.web_app_token = domain_model.web_app_token - db_model.created_at = domain_model.created_at - - # Handle submission data - if domain_model.submission: - db_model.submitted_data = json.dumps(domain_model.submission.data) if domain_model.submission.data else None - db_model.submitted_at = domain_model.submission.submitted_at - db_model.submission_type = HumanInputSubmissionTypeModel(domain_model.submission.submission_type.value) - db_model.submission_user_id = domain_model.submission.submission_user_id - db_model.submission_end_user_id = domain_model.submission.submission_end_user_id - db_model.submitter_email = domain_model.submission.submitter_email - - return db_model - - def save(self, form: HumanInputForm) -> None: - """ - Save or update a HumanInputForm domain entity to the database. - - This method serves as a domain-to-database adapter that: - 1. Converts the domain entity to its database representation - 2. Persists the database model using SQLAlchemy's merge operation - 3. Maintains proper multi-tenancy by including tenant context during conversion - - The method handles both creating new records and updating existing ones through - SQLAlchemy's merge operation. - - Args: - form: The HumanInputForm domain entity to persist - """ - db_model = self._to_db_model(form) - - with self._session_factory() as session: - session.merge(db_model) - session.commit() - logger.info("Saved human input form %s", form.id_) - - def get_by_id(self, form_id: str) -> HumanInputForm: - """Get a form by its ID.""" - with self._session_factory() as session: - stmt = select(HumanInputFormModel).where( - and_( - HumanInputFormModel.id == form_id, - HumanInputFormModel.tenant_id == self._tenant_id, - ) - ) - if self._app_id is not None: - stmt = stmt.where(HumanInputFormModel.app_id == self._app_id) - - db_model = session.scalar(stmt) - if not db_model: - raise ValueError(f"Human input form not found: {form_id}") - - return self._to_domain_model(db_model) - - def get_by_web_app_token(self, web_app_token: str) -> HumanInputForm: - """Get a form by its web app token.""" - with self._session_factory() as session: - stmt = select(HumanInputFormModel).where( - and_( - HumanInputFormModel.web_app_token == web_app_token, - HumanInputFormModel.tenant_id == self._tenant_id, - ) - ) - if self._app_id is not None: - stmt = stmt.where(HumanInputFormModel.app_id == self._app_id) - - db_model = session.scalar(stmt) - if not db_model: - raise ValueError(f"Human input form not found with token: {web_app_token}") - - return self._to_domain_model(db_model) - - def get_pending_forms_for_workflow_run(self, workflow_run_id: str) -> list[HumanInputForm]: - """Get all pending human input forms for a workflow run.""" - with self._session_factory() as session: - stmt = select(HumanInputFormModel).where( - and_( - HumanInputFormModel.workflow_run_id == workflow_run_id, - HumanInputFormModel.status == HumanInputFormStatusModel.WAITING, - HumanInputFormModel.tenant_id == self._tenant_id, - ) - ) - if self._app_id is not None: - stmt = stmt.where(HumanInputFormModel.app_id == self._app_id) - - db_models = list(session.scalars(stmt).all()) - return [self._to_domain_model(db_model) for db_model in db_models] - - def mark_expired_forms(self, expiry_hours: int = 48) -> int: - """Mark expired forms as expired.""" - with self._session_factory() as session: - expiry_time = datetime.utcnow() - timedelta(hours=expiry_hours) - - stmt = select(HumanInputFormModel).where( - and_( - HumanInputFormModel.status == HumanInputFormStatusModel.WAITING, - HumanInputFormModel.created_at < expiry_time, - HumanInputFormModel.tenant_id == self._tenant_id, - ) - ) - if self._app_id is not None: - stmt = stmt.where(HumanInputFormModel.app_id == self._app_id) - - expired_forms = list(session.scalars(stmt).all()) - - count = 0 - for form in expired_forms: - form.status = HumanInputFormStatusModel.EXPIRED - count += 1 - - session.commit() - logger.info("Marked %d forms as expired", count) - return count - - def exists_by_id(self, form_id: str) -> bool: - """Check if a form exists by ID.""" - with self._session_factory() as session: - stmt = select(HumanInputFormModel).where( - and_( - HumanInputFormModel.id == form_id, - HumanInputFormModel.tenant_id == self._tenant_id, - ) - ) - if self._app_id is not None: - stmt = stmt.where(HumanInputFormModel.app_id == self._app_id) - - return session.scalar(stmt) is not None - - def exists_by_web_app_token(self, web_app_token: str) -> bool: - """Check if a form exists by web app token.""" - with self._session_factory() as session: - stmt = select(HumanInputFormModel).where( - and_( - HumanInputFormModel.web_app_token == web_app_token, - HumanInputFormModel.tenant_id == self._tenant_id, - ) - ) - if self._app_id is not None: - stmt = stmt.where(HumanInputFormModel.app_id == self._app_id) - - return session.scalar(stmt) is not None diff --git a/api/core/workflow/entities/human_input_form.py b/api/core/workflow/entities/human_input_form.py deleted file mode 100644 index 46f611ee28..0000000000 --- a/api/core/workflow/entities/human_input_form.py +++ /dev/null @@ -1,197 +0,0 @@ -""" -Domain entities for human input forms. - -Models are independent of the storage mechanism and don't contain -implementation details like tenant_id, app_id, etc. -""" - -from datetime import datetime -from enum import StrEnum -from typing import Any, Optional - -from pydantic import BaseModel, Field - -from libs.datetime_utils import naive_utc_now - - -def naive_utc_from_now() -> datetime: - """Get current UTC datetime.""" - return naive_utc_now() - - -class HumanInputFormStatus(StrEnum): - """Status of a human input form.""" - - WAITING = "waiting" - EXPIRED = "expired" - SUBMITTED = "submitted" - TIMEOUT = "timeout" - - -class HumanInputSubmissionType(StrEnum): - """Type of submission for human input forms.""" - - web_form = "web_form" - web_app = "web_app" - email = "email" - - -class FormSubmission(BaseModel): - """Represents a form submission.""" - - data: dict[str, Any] = Field(default_factory=dict) - action: str = "" - submitted_at: datetime = Field(default_factory=naive_utc_now) - submission_type: HumanInputSubmissionType = HumanInputSubmissionType.web_form - submission_user_id: Optional[str] = None - submission_end_user_id: Optional[str] = None - submitter_email: Optional[str] = None - - -class HumanInputForm(BaseModel): - """ - Domain model for human input forms. - - This model represents the business concept of a human input form without - infrastructure concerns like tenant_id, app_id, etc. - """ - - id_: str = Field(...) - workflow_run_id: str = Field(...) - form_definition: dict[str, Any] = Field(default_factory=dict) - rendered_content: str = "" - status: HumanInputFormStatus = HumanInputFormStatus.WAITING - web_app_token: Optional[str] = None - submission: Optional[FormSubmission] = None - created_at: datetime = Field(default_factory=naive_utc_from_now) - - @property - def is_submitted(self) -> bool: - """Check if the form has been submitted.""" - return self.status == HumanInputFormStatus.SUBMITTED - - @property - def is_expired(self) -> bool: - """Check if the form has expired.""" - return self.status == HumanInputFormStatus.EXPIRED - - @property - def is_waiting(self) -> bool: - """Check if the form is waiting for submission.""" - return self.status == HumanInputFormStatus.WAITING - - @property - def can_be_submitted(self) -> bool: - """Check if the form can still be submitted.""" - return self.status == HumanInputFormStatus.WAITING - - def submit( - self, - data: dict[str, Any], - action: str, - submission_type: HumanInputSubmissionType = HumanInputSubmissionType.web_form, - submission_user_id: Optional[str] = None, - submission_end_user_id: Optional[str] = None, - submitter_email: Optional[str] = None, - ) -> None: - """ - Submit the form with the given data and action. - - Args: - data: The form data submitted by the user - action: The action taken by the user - submission_type: Type of submission - submission_user_id: ID of the user who submitted (console submissions) - submission_end_user_id: ID of the end user who submitted (webapp submissions) - submitter_email: Email of the submitter (if applicable) - - Raises: - ValueError: If the form cannot be submitted - """ - if not self.can_be_submitted: - raise ValueError(f"Form cannot be submitted in status: {self.status}") - - # Validate that the action is valid based on form definition - valid_actions = {act.get("id") for act in self.form_definition.get("user_actions", [])} - if action not in valid_actions: - raise ValueError(f"Invalid action: {action}") - - self.submission = FormSubmission( - data=data, - action=action, - submission_type=submission_type, - submission_user_id=submission_user_id, - submission_end_user_id=submission_end_user_id, - submitter_email=submitter_email, - ) - self.status = HumanInputFormStatus.SUBMITTED - - def expire(self) -> None: - """Mark the form as expired.""" - if self.status != HumanInputFormStatus.WAITING: - raise ValueError(f"Form cannot be expired in status: {self.status}") - - self.status = HumanInputFormStatus.EXPIRED - - def get_form_definition_for_display(self, include_site_info: bool = False) -> dict[str, Any]: - """ - Get form definition for display purposes. - - Args: - include_site_info: Whether to include site information in the response - - Returns: - Form definition dictionary for display - """ - if self.status == HumanInputFormStatus.EXPIRED: - raise ValueError("Form has expired") - - if self.status == HumanInputFormStatus.SUBMITTED: - raise ValueError("Form has already been submitted") - - response = { - "form_content": self.rendered_content, - "inputs": self.form_definition.get("inputs", []), - "user_actions": self.form_definition.get("user_actions", []), - } - - if include_site_info: - # Note: In domain model, we don't have app_id - # This would be added at the application layer - response["site"] = { - "title": "Workflow Form", - } - - return response - - @classmethod - def create( - cls, - *, - id_: str, - workflow_run_id: str, - form_definition: dict[str, Any], - rendered_content: str, - web_app_token: Optional[str] = None, - ) -> "HumanInputForm": - """ - Create a new human input form. - - Args: - id_: Unique identifier for the form - workflow_run_id: ID of the associated workflow run - form_definition: Form definition as a dictionary - rendered_content: Rendered HTML content of the form - web_app_token: Optional token for web app access - - Returns: - New HumanInputForm instance - """ - return cls( - id_=id_, - workflow_run_id=workflow_run_id, - form_definition=form_definition, - rendered_content=rendered_content, - status=HumanInputFormStatus.WAITING, - web_app_token=web_app_token, - ) diff --git a/api/core/workflow/entities/pause_reason.py b/api/core/workflow/entities/pause_reason.py index c6655b7eab..05b6dc98c6 100644 --- a/api/core/workflow/entities/pause_reason.py +++ b/api/core/workflow/entities/pause_reason.py @@ -3,6 +3,8 @@ from typing import Annotated, Literal, TypeAlias from pydantic import BaseModel, Field +from core.workflow.nodes.human_input.entities import FormInput + class PauseReasonType(StrEnum): HUMAN_INPUT_REQUIRED = auto() @@ -11,10 +13,10 @@ class PauseReasonType(StrEnum): class HumanInputRequired(BaseModel): TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED - form_id: str - # The identifier of the human input node causing the pause. - node_id: str + form_content: str + inputs: list[FormInput] = Field(default_factory=list) + web_app_form_token: str | None = None class SchedulingPause(BaseModel): diff --git a/api/core/workflow/nodes/human_input/__init__.py b/api/core/workflow/nodes/human_input/__init__.py index 1ea167aa15..d3861d2c99 100644 --- a/api/core/workflow/nodes/human_input/__init__.py +++ b/api/core/workflow/nodes/human_input/__init__.py @@ -4,6 +4,5 @@ Human Input node implementation. from .entities import HumanInputNodeData from .human_input_node import HumanInputNode -from .node import HumanInputNode __all__ = ["HumanInputNode", "HumanInputNodeData"] diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/core/workflow/nodes/human_input/entities.py index 5af7ec4aa2..06bf83bcfd 100644 --- a/api/core/workflow/nodes/human_input/entities.py +++ b/api/core/workflow/nodes/human_input/entities.py @@ -2,59 +2,77 @@ Human Input node entities. """ +import enum +import re +import uuid +from collections.abc import Mapping, Sequence +from datetime import datetime, timedelta from enum import StrEnum -from typing import Any, Literal, Optional, Union +from typing import Annotated, Literal, Optional -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator +from core.variables.consts import SELECTORS_LENGTH from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser + +_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$outputs\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") + + +class HumanInputFormStatus(StrEnum): + """Status of a human input form.""" + + WAITING = enum.auto() + EXPIRED = enum.auto() + SUBMITTED = enum.auto() + TIMEOUT = enum.auto() class DeliveryMethodType(StrEnum): """Delivery method types for human input forms.""" - WEBAPP = "webapp" - EMAIL = "email" + WEBAPP = enum.auto() + EMAIL = enum.auto() class ButtonStyle(StrEnum): """Button styles for user actions.""" - PRIMARY = "primary" - DEFAULT = "default" - ACCENT = "accent" - GHOST = "ghost" + PRIMARY = enum.auto() + DEFAULT = enum.auto() + ACCENT = enum.auto() + GHOST = enum.auto() class TimeoutUnit(StrEnum): """Timeout unit for form expiration.""" - HOUR = "hour" - DAY = "day" + HOUR = enum.auto() + DAY = enum.auto() class FormInputType(StrEnum): """Form input types.""" - TEXT_INPUT = "text-input" - PARAGRAPH = "paragraph" + TEXT_INPUT = enum.auto() + PARAGRAPH = enum.auto() class PlaceholderType(StrEnum): """Placeholder types for form inputs.""" - VARIABLE = "variable" - CONSTANT = "constant" + VARIABLE = enum.auto() + CONSTANT = enum.auto() -class RecipientType(StrEnum): +class EmailRecipientType(StrEnum): """Email recipient types.""" - MEMBER = "member" - EXTERNAL = "external" + MEMBER = enum.auto() + EXTERNAL = enum.auto() -class WebAppDeliveryConfig(BaseModel): +class _WebAppDeliveryConfig(BaseModel): """Configuration for webapp delivery method.""" pass # Empty for webapp delivery @@ -63,25 +81,25 @@ class WebAppDeliveryConfig(BaseModel): class MemberRecipient(BaseModel): """Member recipient for email delivery.""" - type: Literal[RecipientType.MEMBER] + type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER user_id: str class ExternalRecipient(BaseModel): """External recipient for email delivery.""" - type: Literal[RecipientType.EXTERNAL] + type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL email: str -Recipient = Union[MemberRecipient, ExternalRecipient] +EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")] class EmailRecipients(BaseModel): """Email recipients configuration.""" whole_workspace: bool = False - items: list[Recipient] = Field(default_factory=list) + items: list[EmailRecipient] = Field(default_factory=list) class EmailDeliveryConfig(BaseModel): @@ -92,51 +110,54 @@ class EmailDeliveryConfig(BaseModel): body: str -DeliveryConfig = Union[WebAppDeliveryConfig, EmailDeliveryConfig] +class _DeliveryMethodBase(BaseModel): + """Base delivery method configuration.""" - -class DeliveryMethod(BaseModel): - """Delivery method configuration.""" - - type: DeliveryMethodType enabled: bool = True - config: Optional[DeliveryConfig] = None + id: uuid.UUID = Field(default_factory=uuid.uuid4) - @model_validator(mode="after") - def validate_config_type(self): - """Validate that config matches the delivery method type.""" - if self.config is None: - return self - if self.type == DeliveryMethodType.EMAIL: - if isinstance(self.config, dict): - # Try to parse as EmailDeliveryConfig - this will raise validation errors - try: - self.config = EmailDeliveryConfig.model_validate(self.config) - except Exception as e: - # Re-raise with more specific context - raise ValueError(f"Invalid email delivery configuration: {str(e)}") - elif not isinstance(self.config, EmailDeliveryConfig): - raise ValueError("Config must be EmailDeliveryConfig for email delivery method") - elif self.type == DeliveryMethodType.WEBAPP: - if isinstance(self.config, dict): - # Try to parse as WebAppDeliveryConfig - try: - self.config = WebAppDeliveryConfig.model_validate(self.config) - except Exception as e: - raise ValueError(f"Invalid webapp delivery configuration: {str(e)}") - elif not isinstance(self.config, WebAppDeliveryConfig): - raise ValueError("Config must be WebAppDeliveryConfig for webapp delivery method") +class WebAppDeliveryMethod(_DeliveryMethodBase): + """Webapp delivery method configuration.""" - return self + type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP + # The config field is not used currently. + config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig) + + +class EmailDeliveryMethod(_DeliveryMethodBase): + """Email delivery method configuration.""" + + type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL + config: EmailDeliveryConfig + + +DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] class FormInputPlaceholder(BaseModel): """Placeholder configuration for form inputs.""" + # NOTE: Ideally, a discriminated union would be used to model + # FormInputPlaceholder. However, the UI requires preserving the previous + # value when switching between `VARIABLE` and `CONSTANT` types. This + # necessitates retaining all fields, making a discriminated union unsuitable. + type: PlaceholderType - selector: list[str] = Field(default_factory=list) # Used when type is VARIABLE - value: str = "" # Used when type is CONSTANT + + # The selector of placeholder variable, used when `type` is `VARIABLE` + selector: Sequence[str] = Field(default_factory=tuple) # + + # The value of the placeholder, used when `type` is `CONSTANT`. + # TODO: How should we express JSON values? + value: str = "" + + @field_validator("selector") + @classmethod + def _validate_selector(cls, selector: Sequence[str]) -> Sequence[str]: + if len(selector) < SELECTORS_LENGTH: + raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={selector}") + return selector class FormInput(BaseModel): @@ -147,24 +168,106 @@ class FormInput(BaseModel): placeholder: Optional[FormInputPlaceholder] = None +_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + class UserAction(BaseModel): """User action configuration.""" + # id is the identifier for this action. + # It also serves as the identifiers of output handle. + # + # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) id: str title: str button_style: ButtonStyle = ButtonStyle.DEFAULT + @field_validator("id") + @classmethod + def _validate_id(cls, value: str) -> str: + if not _IDENTIFIER_PATTERN.match(value): + raise ValueError( + f"'{value}' is not a valid identifier. It must start with a letter or underscore, " + f"and contain only letters, numbers, or underscores." + ) + return value + class HumanInputNodeData(BaseNodeData): """Human Input node data.""" - delivery_methods: list[DeliveryMethod] = Field(default_factory=list) + delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) form_content: str = "" inputs: list[FormInput] = Field(default_factory=list) user_actions: list[UserAction] = Field(default_factory=list) timeout: int = 36 timeout_unit: TimeoutUnit = TimeoutUnit.HOUR + @field_validator("inputs") + @classmethod + def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: + seen_names: set[str] = set() + for form_input in inputs: + name = form_input.output_variable_name + if name in seen_names: + raise ValueError(f"duplicated output_variable_name '{name}' in inputs") + seen_names.add(name) + return inputs + + @field_validator("user_actions") + @classmethod + def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: + seen_ids: set[str] = set() + for action in user_actions: + action_id = action.id + if action_id in seen_ids: + raise ValueError(f"duplicated user action id '{action_id}'") + seen_ids.add(action_id) + return user_actions + + def is_webapp_enabled(self) -> bool: + for dm in self.delivery_methods: + if not dm.enabled: + continue + if dm.type == DeliveryMethodType.WEBAPP: + return True + return False + + def expiration_time(self, start_time: datetime) -> datetime: + if self.timeout_unit == TimeoutUnit.HOUR: + return start_time + timedelta(hours=self.timeout) + elif self.timeout_unit == TimeoutUnit.DAY: + return start_time + timedelta(days=self.timeout) + else: + raise AssertionError("unknown timeout unit.") + + def outputs_field_names(self) -> Sequence[str]: + field_names = [] + for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): + field_names.append(match.group("field_name")) + return field_names + + def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: + variable_selectors = [] + variable_template_parser = VariableTemplateParser(template=self.form_content) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + variable_mappings = {} + for variable_selector in variable_selectors: + qualified_variable_mapping_key = f"{node_id}.{variable_selector.variable}" + variable_mappings[qualified_variable_mapping_key] = variable_selector.value_selector + + for input in self.inputs: + placeholder = input.placeholder + if placeholder is None: + continue + if placeholder.type == PlaceholderType.CONSTANT: + continue + placeholder_key = ".".join(placeholder.selector) + qualified_variable_mapping_key = f"{node_id}.#{placeholder_key}#" + variable_mappings[qualified_variable_mapping_key] = placeholder.selector + + return variable_mappings + class HumanInputRequired(BaseModel): """Event data for human input required.""" @@ -176,72 +279,11 @@ class HumanInputRequired(BaseModel): web_app_form_token: Optional[str] = None -class WorkflowSuspended(BaseModel): - """Event data for workflow suspended.""" - - suspended_at_node_ids: list[str] - - -class PauseTypeHumanInput(BaseModel): - """Pause type for human input.""" - - type: Literal["human_input"] - form_id: str - - -class PauseTypeBreakpoint(BaseModel): - """Pause type for breakpoint (debugging).""" - - type: Literal["breakpoint"] - - -PauseType = Union[PauseTypeHumanInput, PauseTypeBreakpoint] - - -class PausedNode(BaseModel): - """Information about a paused node.""" - - node_id: str - node_title: str - pause_type: PauseType - - -class WorkflowPauseDetails(BaseModel): - """Details about workflow pause.""" - - paused_at: str # ISO datetime - paused_nodes: list[PausedNode] - - -class FormSubmissionRequest(BaseModel): - """Form submission request data.""" - - inputs: dict[str, str] # mapping of output_variable_name to user input - action: str # UserAction id - - -class FormGetResponse(BaseModel): - """Response for form get API.""" - - site: Optional[dict[str, Any]] = None # Site information for webapp +class FormDefinition(BaseModel): form_content: str - inputs: list[FormInput] + inputs: list[FormInput] = Field(default_factory=list) + user_actions: list[UserAction] = Field(default_factory=list) + rendered_content: str - -class FormSubmissionResponse(BaseModel): - """Response for successful form submission.""" - - pass # Empty response for success - - -class FormErrorResponse(BaseModel): - """Response for form submission errors.""" - - error_code: str - description: str - - -class ResumeWaitResponse(BaseModel): - """Response for resume wait API.""" - - status: Literal["paused", "running", "ended"] + timeout: int + timeout_unit: TimeoutUnit diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index 6c8bf36fab..02bfcf6b32 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -1,13 +1,27 @@ -from collections.abc import Mapping +import dataclasses +import logging +from collections.abc import Generator, Mapping, Sequence from typing import Any from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult, PauseRequestedEvent +from core.workflow.node_events.base import NodeEventBase +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node +from core.workflow.repositories.human_input_form_repository import FormCreateParams, HumanInputFormRepository from .entities import HumanInputNodeData +_SELECTED_BRANCH_KEY = "selected_branch" + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class _FormSubmissionResult: + action_id: str + class HumanInputNode(Node[HumanInputNodeData]): node_type = NodeType.HUMAN_INPUT @@ -17,7 +31,7 @@ class HumanInputNode(Node[HumanInputNodeData]): "edge_source_handle", "edgeSourceHandle", "source_handle", - "selected_branch", + _SELECTED_BRANCH_KEY, "selectedBranch", "branch", "branch_id", @@ -25,43 +39,12 @@ class HumanInputNode(Node[HumanInputNodeData]): "handle", ) + _node_data: HumanInputNodeData + @classmethod def version(cls) -> str: return "1" - def _run(self): # type: ignore[override] - if self._is_completion_ready(): - branch_handle = self._resolve_branch_selection() - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={}, - edge_source_handle=branch_handle or "source", - ) - - return self._pause_generator() - - def _pause_generator(self): - # TODO(QuantumGhost): yield a real form id. - yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id)) - - def _is_completion_ready(self) -> bool: - """Determine whether all required inputs are satisfied.""" - - if not self.node_data.required_variables: - return False - - variable_pool = self.graph_runtime_state.variable_pool - - for selector_str in self.node_data.required_variables: - parts = selector_str.split(".") - if len(parts) != 2: - return False - segment = variable_pool.get(parts) - if segment is None: - return False - - return True - def _resolve_branch_selection(self) -> str | None: """Determine the branch handle selected by human input if available.""" @@ -108,3 +91,106 @@ class HumanInputNode(Node[HumanInputNodeData]): return candidate return None + + def _create_form_repository(self) -> HumanInputFormRepository: + pass + + @staticmethod + def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]: + yield event + + @property + def _workflow_execution_id(self) -> str: + workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id + assert workflow_exec_id is not None + return workflow_exec_id + + def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: + """ + Execute the human input node. + + This method will: + 1. Generate a unique form ID + 2. Create form content with variable substitution + 3. Create form in database + 4. Send form via configured delivery methods + 5. Suspend workflow execution + 6. Wait for form submission to resume + """ + repo = self._create_form_repository() + submission_result = repo.get_form_submission(self._workflow_execution_id, self.app_id) + if submission_result: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "action_id": submission_result.selected_action_id, + }, + edge_source_handle=submission_result.selected_action_id, + ) + try: + repo = self._create_form_repository() + params = FormCreateParams( + workflow_execution_id=self._workflow_execution_id, + node_id=self.id, + form_config=self._node_data, + rendered_content=self._render_form_content(), + ) + result = repo.create_form(params) + # Create human input required event + + required_event = HumanInputRequired( + form_id=result.id, + form_content=self._node_data.form_content, + inputs=self._node_data.inputs, + web_app_form_token=result.web_app_token, + ) + pause_requested_event = PauseRequestedEvent(reason=required_event) + + # Create workflow suspended event + + logger.info( + "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", + self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, + self.id, + result.id, + ) + except Exception as e: + logger.exception("Human Input node failed to execute, node_id=%s", self.id) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + error_type="HumanInputNodeError", + ) + return self._pause_generator(pause_requested_event) + + def _render_form_content(self) -> str: + """ + Process form content by substituting variables. + + This method should: + 1. Parse the form_content markdown + 2. Substitute {{#node_name.var_name#}} with actual values + 3. Keep {{#$output.field_name#}} placeholders for form inputs + """ + rendered_form_content = self.graph_runtime_state.variable_pool.convert_template( + self._node_data.form_content, + ) + return rendered_form_content.markdown + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: Mapping[str, Any], + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selectors referenced in form content and input placeholders. + + This method should parse: + 1. Variables referenced in form_content ({{#node_name.var_name#}}) + 2. Variables referenced in input placeholders + """ + validated_node_data = HumanInputNodeData.model_validate(node_data) + return validated_node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/core/workflow/nodes/human_input/node.py b/api/core/workflow/nodes/human_input/node.py deleted file mode 100644 index 4e0cd03bbe..0000000000 --- a/api/core/workflow/nodes/human_input/node.py +++ /dev/null @@ -1,259 +0,0 @@ -""" -Human Input node implementation. -""" - -import json -import logging -import uuid -from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, Union - -from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.graph_engine.entities.event import InNodeEvent -from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig -from core.workflow.nodes.enums import ErrorStrategy, NodeType -from core.workflow.nodes.event import NodeEvent -from core.workflow.nodes.human_input.entities import ( - HumanInputNodeData, - HumanInputRequired, - WorkflowSuspended, -) -from extensions.ext_database import db -from services.human_input_form_service import HumanInputFormService - -logger = logging.getLogger(__name__) - - -class HumanInputNode(BaseNode): - """ - Human Input Node implementation. - - This node pauses workflow execution and waits for human input through - configured delivery methods (webapp or email). The workflow resumes - once the form is submitted. - """ - - _node_type: NodeType = NodeType.HUMAN_INPUT - _node_data_cls = HumanInputNodeData - node_data: HumanInputNodeData - - def init_node_data(self, data: Mapping[str, Any]) -> None: - """Initialize node data from configuration.""" - self.node_data = self._node_data_cls.model_validate(data) - - def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, InNodeEvent], None, None]: - """ - Execute the human input node. - - This method will: - 1. Generate a unique form ID - 2. Create form content with variable substitution - 3. Create form in database - 4. Send form via configured delivery methods - 5. Suspend workflow execution - 6. Wait for form submission to resume - """ - try: - # Generate unique form ID - form_id = str(uuid.uuid4()) - - # Create form content with variable substitution - form_content = self._process_form_content() - - # Generate webapp token if webapp delivery is enabled - web_app_form_token = None - webapp_enabled = any(dm.enabled and dm.type.value == "webapp" for dm in self.node_data.delivery_methods) - if webapp_enabled: - web_app_form_token = str(uuid.uuid4()).replace("-", "") - - # Create form definition for database storage - form_definition = { - "node_id": self.node_id, - "title": self.node_data.title, - "inputs": [inp.model_dump() for inp in self.node_data.inputs], - "user_actions": [action.model_dump() for action in self.node_data.user_actions], - "timeout": self.node_data.timeout, - "timeout_unit": self.node_data.timeout_unit.value, - "delivery_methods": [dm.model_dump() for dm in self.node_data.delivery_methods], - } - - # Create form in database - service = HumanInputFormService(db.session()) - service.create_form( - form_id=form_id, - workflow_run_id=self.graph_runtime_state.workflow_run_id, - tenant_id=self.graph_init_params.tenant_id, - app_id=self.graph_init_params.app_id, - form_definition=json.dumps(form_definition), - rendered_content=form_content, - web_app_token=web_app_form_token, - ) - - # Create human input required event - human_input_event = HumanInputRequired( - form_id=form_id, - node_id=self.node_id, - form_content=form_content, - inputs=self.node_data.inputs, - web_app_form_token=web_app_form_token, - ) - - # Create workflow suspended event - suspended_event = WorkflowSuspended(suspended_at_node_ids=[self.node_id]) - - logger.info(f"Human Input node {self.node_id} suspended workflow for form {form_id}") - - # Return suspension result - # The workflow engine should handle the suspension and resume logic - return NodeRunResult( - status=WorkflowNodeExecutionStatus.RUNNING, # Node is still running, waiting for input - inputs={}, - outputs={}, - metadata={ - "form_id": form_id, - "web_app_form_token": web_app_form_token, - "human_input_event": human_input_event.model_dump(), - "suspended_event": suspended_event.model_dump(), - "suspended": True, # Flag to indicate this node caused suspension - }, - ) - - except Exception as e: - logger.exception(f"Human Input node {self.node_id} failed to execute") - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - error_type="HumanInputNodeError", - ) - - def _process_form_content(self) -> str: - """ - Process form content by substituting variables. - - This method should: - 1. Parse the form_content markdown - 2. Substitute {{#node_name.var_name#}} with actual values - 3. Keep {{#$output.field_name#}} placeholders for form inputs - """ - # TODO: Implement variable substitution logic - # For now, return the raw form content - # This should integrate with the existing variable template parser - return self.node_data.form_content - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: Mapping[str, Any], - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selectors referenced in form content and input placeholders. - - This method should parse: - 1. Variables referenced in form_content ({{#node_name.var_name#}}) - 2. Variables referenced in input placeholders - """ - # TODO: Implement variable extraction logic - # This should parse the form_content and placeholder configurations - # to extract all referenced variables - return {} - - @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: - """Get default configuration for human input node.""" - return { - "type": "human_input", - "config": { - "delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}], - "form_content": "# Human Input\n\nPlease provide your input:\n\n{{#$output.input#}}", - "inputs": [ - { - "type": "text-input", - "output_variable_name": "input", - "placeholder": {"type": "constant", "value": "Enter your response here..."}, - } - ], - "user_actions": [{"id": "submit", "title": "Submit", "button_style": "primary"}], - "timeout": 24, - "timeout_unit": "hour", - }, - } - - @classmethod - def version(cls) -> str: - """Return the version of the human input node.""" - return "1" - - def _get_error_strategy(self) -> Optional[ErrorStrategy]: - """Get the error strategy for this node.""" - return self.node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self.node_data.retry_config - - def _get_title(self) -> str: - """Get the node title.""" - return self.node_data.title - - def _get_description(self) -> Optional[str]: - """Get the node description.""" - return self.node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self.node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - """Get the BaseNodeData object for this node.""" - return self.node_data - - def resume_from_human_input(self, form_submission_data: dict[str, Any]) -> NodeRunResult: - """ - Resume node execution after human input form is submitted. - - Args: - form_submission_data: Dict containing: - - inputs: Dict of input field values - - action: The user action taken - - Returns: - NodeRunResult with the form inputs as outputs - """ - try: - inputs = form_submission_data.get("inputs", {}) - action = form_submission_data.get("action", "") - - # Create output dictionary with form inputs - outputs = {} - for input_field in self.node_data.inputs: - field_name = input_field.output_variable_name - if field_name in inputs: - outputs[field_name] = inputs[field_name] - - # Add the action to outputs - outputs["_action"] = action - - logger.info(f"Human Input node {self.node_id} resumed with action {action}") - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={}, - outputs=outputs, - metadata={ - "form_submitted": True, - "submitted_action": action, - }, - ) - - except Exception as e: - logger.exception(f"Human Input node {self.node_id} failed to resume") - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - error_type="HumanInputResumeError", - ) diff --git a/api/core/workflow/repositories/human_input_form_repository.py b/api/core/workflow/repositories/human_input_form_repository.py index ab96ca664f..d98a4a034c 100644 --- a/api/core/workflow/repositories/human_input_form_repository.py +++ b/api/core/workflow/repositories/human_input_form_repository.py @@ -1,6 +1,65 @@ -from typing import Protocol +import abc +import dataclasses +from collections.abc import Mapping +from typing import Any, Protocol -from core.workflow.entities.human_input_form import HumanInputForm +from core.workflow.nodes.human_input.entities import HumanInputNodeData + + +class HumanInputError(Exception): + pass + + +class FormNotFoundError(HumanInputError): + pass + + +@dataclasses.dataclass +class FormCreateParams: + workflow_execution_id: str + + # node_id is the identifier for a specific + # node in the graph. + # + # TODO: for node inside loop / iteration, this would + # cause problems, as a single node may be executed multiple times. + node_id: str + + form_config: HumanInputNodeData + rendered_content: str + + +class HumanInputFormEntity(abc.ABC): + @property + @abc.abstractmethod + def id(self) -> str: + """id returns the identifer of the form.""" + pass + + @property + @abc.abstractmethod + def web_app_token(self) -> str | None: + """web_app_token returns the token for submission inside webapp. + + If web app delivery is not enabled, this method would return `None`. + """ + + # TODO: what if the users are allowed to add multiple + # webapp delivery? + pass + + +class FormSubmissionEntity(abc.ABC): + @property + @abc.abstractmethod + def selected_action_id(self) -> str: + """The identifier of action user has selected, correspond to `UserAction.id`.""" + pass + + @abc.abstractmethod + def form_data(self) -> Mapping[str, Any]: + """The data submitted for this form""" + pass class HumanInputFormRepository(Protocol): @@ -16,93 +75,17 @@ class HumanInputFormRepository(Protocol): application domains or deployment scenarios. """ - def save(self, form: HumanInputForm) -> None: + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: """ - Save or update a HumanInputForm instance. - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the form's ID or other identifying fields. - - Args: - form: The HumanInputForm instance to save or update + Create a human input form from form definition. """ ... - def get_by_id(self, form_id: str) -> HumanInputForm: - """ - Get a form by its ID. + def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmissionEntity | None: + """Retrieve the submission for a specific human input node. - Args: - form_id: The ID of the form to retrieve + Returns `FormSubmission` if the form has been submitted, or `None` if not. - Returns: - The HumanInputForm instance - - Raises: - NotFoundError: If the form is not found - """ - ... - - def get_by_web_app_token(self, web_app_token: str) -> HumanInputForm: - """ - Get a form by its web app token. - - Args: - web_app_token: The web app token to search for - - Returns: - The HumanInputForm instance - - Raises: - NotFoundError: If the form is not found - """ - ... - - def get_pending_forms_for_workflow_run(self, workflow_run_id: str) -> list[HumanInputForm]: - """ - Get all pending human input forms for a workflow run. - - Args: - workflow_run_id: The workflow run ID to filter by - - Returns: - List of pending HumanInputForm instances - """ - ... - - def mark_expired_forms(self, expiry_hours: int = 48) -> int: - """ - Mark expired forms as expired. - - Args: - expiry_hours: Number of hours after which forms should be expired - - Returns: - Number of forms marked as expired - """ - ... - - def exists_by_id(self, form_id: str) -> bool: - """ - Check if a form exists by ID. - - Args: - form_id: The ID of the form to check - - Returns: - True if the form exists, False otherwise - """ - ... - - def exists_by_web_app_token(self, web_app_token: str) -> bool: - """ - Check if a form exists by web app token. - - Args: - web_app_token: The web app token to check - - Returns: - True if the form exists, False otherwise + Raises `FormNotFoundError` if correspond form record is not found. """ ... diff --git a/api/models/base.py b/api/models/base.py index c8a5e20f25..aa93d31199 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -41,7 +41,7 @@ class DefaultFieldsMixin: ) updated_at: Mapped[datetime] = mapped_column( - __name_pos=DateTime, + DateTime, nullable=False, default=naive_utc_now, server_default=func.current_timestamp(), diff --git a/api/models/human_input.py b/api/models/human_input.py index ec09cf8d4a..7bfae33050 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -1,28 +1,21 @@ from datetime import datetime from enum import StrEnum +from typing import Annotated, Any, ClassVar, Literal, Self, final import sqlalchemy as sa -from sqlalchemy.orm import Mapped, mapped_column +from pydantic import BaseModel, Field +from sqlalchemy.orm import Mapped, mapped_column, relationship +from core.workflow.nodes.human_input.entities import ( + DeliveryMethodType, + EmailRecipientType, + HumanInputFormStatus, +) from libs.helper import generate_string -from .base import Base, ModelMixin +from .base import Base, DefaultFieldsMixin from .types import EnumText, StringUUID - -class HumanInputFormStatus(StrEnum): - WAITING = "waiting" - EXPIRED = "expired" - SUBMITTED = "submitted" - TIMEOUT = "timeout" - - -class HumanInputSubmissionType(StrEnum): - web_form = "web_form" - web_app = "web_app" - email = "email" - - _token_length = 22 # A 32-character string can store a base64-encoded value with 192 bits of entropy # or a base62-encoded value with over 180 bits of entropy, providing sufficient @@ -31,36 +24,18 @@ _token_field_length = 32 _email_field_length = 330 -def _generate_token(): +def _generate_token() -> str: return generate_string(_token_length) -class HumanInputForm(ModelMixin, Base): +class HumanInputForm(DefaultFieldsMixin, Base): __tablename__ = "human_input_forms" - # `tenant_id` identifies the tenant associated with this suspension, - # corresponding to the `id` field in the `Tenant` model. - tenant_id: Mapped[str] = mapped_column( - StringUUID, - nullable=False, - ) - - # `app_id` represents the application identifier associated with this state. - # It corresponds to the `id` field in the `App` model. - # - # While this field is technically redundant (as the corresponding app can be - # determined by querying the `Workflow`), it is retained to simplify data - # cleanup and management processes. - app_id: Mapped[str] = mapped_column( - StringUUID, - nullable=False, - ) - - workflow_run_id: Mapped[str] = mapped_column( - StringUUID, - nullable=False, - ) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # The human input node the current form corresponds to. + node_id: Mapped[str] = mapped_column(sa.String(60), nullable=False) form_definition: Mapped[str] = mapped_column(sa.Text, nullable=False) rendered_content: Mapped[str] = mapped_column(sa.Text, nullable=False) status: Mapped[HumanInputFormStatus] = mapped_column( @@ -69,56 +44,137 @@ class HumanInputForm(ModelMixin, Base): default=HumanInputFormStatus.WAITING, ) - web_app_token: Mapped[str] = mapped_column( + expiration_time: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + ) + + # Submission-related fields (nullable until a submission happens). + selected_action_id: Mapped[str | None] = mapped_column(sa.String(200), nullable=True) + submitted_data: Mapped[str | None] = mapped_column(sa.Text, nullable=True) + submitted_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True) + submission_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + submission_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + + completed_by_recipient_id: Mapped[str | None] = mapped_column( + StringUUID, + sa.ForeignKey("human_input_recipients.id"), + nullable=True, + ) + + deliveries: Mapped[list["HumanInputDelivery"]] = relationship( + "HumanInputDelivery", + back_populates="form", + lazy="raise", + ) + completed_by_recipient: Mapped["HumanInputRecipient | None"] = relationship( + "HumanInputRecipient", + primaryjoin="HumanInputForm.completed_by_recipient_id == HumanInputRecipient.id", + lazy="raise", + viewonly=True, + ) + + +class HumanInputDelivery(DefaultFieldsMixin, Base): + __tablename__ = "human_input_deliveries" + + form_id: Mapped[str] = mapped_column( + StringUUID, + sa.ForeignKey("human_input_forms.id"), + nullable=False, + ) + delivery_method_type: Mapped[DeliveryMethodType] = mapped_column( + EnumText(DeliveryMethodType), + nullable=False, + ) + delivery_config_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + channel_payload: Mapped[None] = mapped_column(sa.Text, nullable=True) + + form: Mapped[HumanInputForm] = relationship( + "HumanInputForm", + back_populates="deliveries", + lazy="raise", + ) + recipients: Mapped[list["HumanInputRecipient"]] = relationship( + "HumanInputRecipient", + back_populates="delivery", + cascade="all, delete-orphan", + lazy="raise", + ) + + +class RecipientType(StrEnum): + # EMAIL_MEMBER member means that the + EMAIL_MEMBER = "email_member" + EMAIL_EXTERNAL = "email_external" + WEBAPP = "webapp" + + +@final +class EmailMemberRecipientPayload(BaseModel): + TYPE: Literal[RecipientType.EMAIL_MEMBER] = RecipientType.EMAIL_MEMBER + user_id: str + + # The `email` field here is only used for mail sending. + email: str + + +@final +class EmailExternalRecipientPayload(BaseModel): + TYPE: Literal[RecipientType.EMAIL_EXTERNAL] = RecipientType.EMAIL_EXTERNAL + email: str + + +@final +class WebAppRecipientPayload(BaseModel): + TYPE: Literal[RecipientType.WEBAPP] = RecipientType.WEBAPP + + +RecipientPayload = Annotated[ + EmailMemberRecipientPayload | EmailExternalRecipientPayload | WebAppRecipientPayload, + Field(discriminator="TYPE"), +] + + +class HumanInputRecipient(DefaultFieldsMixin, Base): + __tablename__ = "human_input_recipients" + + form_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + delivery_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + recipient_type: Mapped["RecipientType"] = mapped_column(EnumText(RecipientType), nullable=False) + recipient_payload: Mapped[str] = mapped_column(sa.Text, nullable=False) + + # Token primarily used for authenticated resume links (email, etc.). + access_token: Mapped[str | None] = mapped_column( sa.VARCHAR(_token_field_length), nullable=True, + default=_generate_token, ) - # The following fields are not null if the form is submitted. - - # The inputs provided by the user when resuming the suspended workflow. - # These inputs are serialized as a JSON-formatted string (e.g., `{}`). - # - # This field is `NULL` if no inputs were submitted by the user. - submitted_data: Mapped[str] = mapped_column(sa.Text, nullable=True) - - submitted_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=True) - - submission_type: Mapped[HumanInputSubmissionType] = mapped_column( - EnumText(HumanInputSubmissionType), - nullable=True, + delivery: Mapped[HumanInputDelivery] = relationship( + "HumanInputDelivery", + back_populates="recipients", + lazy="raise", ) - # If the submission happens in dashboard (Studio for orchestrate the workflow, or - # Explore for using published apps), which requires user to login before submission. - # Then the `submission_user_id` records the user id - # of submitter, else `None`. - submission_user_id: Mapped[str] = mapped_column(StringUUID, nullable=True) - - # If the submission happens in WebApp (which does not requires user to login before submission) - # Then the `submission_user_id` records the end_user_id of submitter, else `None`. - submission_end_user_id: Mapped[str] = mapped_column(StringUUID, nullable=True) - - # IF the submitter receives a email and - submitter_email: Mapped[str] = mapped_column(sa.VARCHAR(_email_field_length), nullable=True) - - -# class HumanInputEmailDelivery(ModelMixin, Base): -# # form_id refers to `HumanInputForm.id` -# form_id: Mapped[str] = mapped_column( -# StringUUID, -# nullable=False, -# ) - -# # IF the submitter receives a email and -# email: Mapped[str] = mapped_column(__name_pos=sa.VARCHAR(_email_field_length), nullable=False) -# user_id: Mapped[str] = mapped_column( -# StringUUID, -# nullable=True, -# ) - -# email_link_token: Mapped[str] = mapped_column( -# sa.VARCHAR(_token_field_length), -# nullable=False, -# default=_generate_token, -# ) + @classmethod + def new( + cls, + form_id: str, + delivery_id: str, + payload: RecipientPayload, + ) -> Self: + recipient_model = cls( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=payload.TYPE, + recipient_payload=payload.model_dump_json(), + access_token=_generate_token(), + ) + return recipient_model diff --git a/api/models/workflow.py b/api/models/workflow.py index 853d5afefc..34095cf99b 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -30,7 +30,7 @@ from core.workflow.constants import ( SYSTEM_VARIABLE_NODE_ID, ) from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from core.workflow.enums import NodeType +from core.workflow.enums import NodeType, WorkflowExecutionStatus from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -607,7 +607,10 @@ class WorkflowRun(Base): version: Mapped[str] = mapped_column(String(255)) graph: Mapped[str | None] = mapped_column(LongText) inputs: Mapped[str | None] = mapped_column(LongText) - status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded + status: Mapped[WorkflowExecutionStatus] = mapped_column( + EnumText(WorkflowExecutionStatus, length=255), + nullable=False, + ) outputs: Mapped[str | None] = mapped_column(LongText, default="{}") error: Mapped[str | None] = mapped_column(LongText) elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 90413878ce..764f9b3ab6 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -16,6 +16,10 @@ from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError +from models.model import Account, App, AppMode, EndUser +from models.workflow import Workflow, WorkflowRun +from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError +from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import ChatflowExecutionParams, chatflow_execute_task @@ -246,3 +250,19 @@ class AppGenerateService: raise ValueError("Workflow not published") return workflow + + @classmethod + def get_response_generator( + cls, + app_model: App, + workflow_run: WorkflowRun, + ): + if workflow_run.status.is_ended(): + # TODO(QuantumGhost): handled the ended scenario. + return + + generator = AdvancedChatAppGenerator() + + return generator.convert_to_event_stream( + generator.retrieve_events(app_model.mode, workflow_run.id), + ) diff --git a/api/services/human_input_form_domain_service.py b/api/services/human_input_form_domain_service.py deleted file mode 100644 index bcb10c8064..0000000000 --- a/api/services/human_input_form_domain_service.py +++ /dev/null @@ -1,272 +0,0 @@ -""" -Service for managing human input forms using domain models. - -This service layer operates on domain models and uses repositories for persistence, -keeping the business logic clean and independent of database concerns. -""" - -import logging -from typing import Any, Optional - -from core.workflow.entities.human_input_form import HumanInputForm, HumanInputFormStatus, HumanInputSubmissionType -from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository -from services.errors.base import BaseServiceError - -logger = logging.getLogger(__name__) - - -class HumanInputFormNotFoundError(BaseServiceError): - """Raised when a human input form is not found.""" - - def __init__(self, identifier: str): - super().__init__(f"Human input form not found: {identifier}") - self.identifier = identifier - - -class HumanInputFormExpiredError(BaseServiceError): - """Raised when a human input form has expired.""" - - def __init__(self): - super().__init__("Human input form has expired") - - -class HumanInputFormAlreadySubmittedError(BaseServiceError): - """Raised when trying to operate on an already submitted form.""" - - def __init__(self): - super().__init__("Human input form has already been submitted") - - -class InvalidFormDataError(BaseServiceError): - """Raised when form submission data is invalid.""" - - def __init__(self, message: str): - super().__init__(f"Invalid form data: {message}") - self.message = message - - -class HumanInputFormDomainService: - """Service for managing human input forms using domain models.""" - - def __init__(self, repository: HumanInputFormRepository): - """ - Initialize the service with a repository. - - Args: - repository: Repository for human input form persistence - """ - self._repository = repository - - def create_form( - self, - *, - form_id: str, - workflow_run_id: str, - form_definition: dict[str, Any], - rendered_content: str, - web_app_token: Optional[str] = None, - ) -> HumanInputForm: - """ - Create a new human input form. - - Args: - form_id: Unique identifier for the form - workflow_run_id: ID of the associated workflow run - form_definition: Form definition as a dictionary - rendered_content: Rendered HTML content of the form - web_app_token: Optional token for web app access - - Returns: - Created HumanInputForm domain model - """ - form = HumanInputForm.create( - id_=form_id, - workflow_run_id=workflow_run_id, - form_definition=form_definition, - rendered_content=rendered_content, - web_app_token=web_app_token, - ) - - self._repository.save(form) - logger.info("Created human input form %s", form_id) - return form - - def get_form_by_id(self, form_id: str) -> HumanInputForm: - """ - Get a form by its ID. - - Args: - form_id: The ID of the form to retrieve - - Returns: - HumanInputForm domain model - - Raises: - HumanInputFormNotFoundError: If the form is not found - """ - try: - return self._repository.get_by_id(form_id) - except ValueError as e: - raise HumanInputFormNotFoundError(form_id) from e - - def get_form_by_token(self, web_app_token: str) -> HumanInputForm: - """ - Get a form by its web app token. - - Args: - web_app_token: The web app token to search for - - Returns: - HumanInputForm domain model - - Raises: - HumanInputFormNotFoundError: If the form is not found - """ - try: - return self._repository.get_by_web_app_token(web_app_token) - except ValueError as e: - raise HumanInputFormNotFoundError(web_app_token) from e - - def get_form_definition( - self, - identifier: str, - is_token: bool = False, - include_site_info: bool = False, - app_id: Optional[str] = None, - ) -> dict[str, Any]: - """ - Get form definition for display. - - Args: - identifier: Form ID or web app token - is_token: True if identifier is a web app token, False if it's a form ID - include_site_info: Whether to include site information in the response - app_id: App ID for site information (if include_site_info is True) - - Returns: - Form definition dictionary for display - - Raises: - HumanInputFormNotFoundError: If the form is not found - HumanInputFormExpiredError: If the form has expired - HumanInputFormAlreadySubmittedError: If the form has already been submitted - """ - if is_token: - form = self.get_form_by_token(identifier) - else: - form = self.get_form_by_id(identifier) - - try: - form_definition = form.get_form_definition_for_display(include_site_info=include_site_info) - except ValueError as e: - if "expired" in str(e).lower(): - raise HumanInputFormExpiredError() from e - elif "submitted" in str(e).lower(): - raise HumanInputFormAlreadySubmittedError() from e - else: - raise InvalidFormDataError(str(e)) from e - - # Add site info if requested and app_id is provided - if include_site_info and app_id and "site" in form_definition: - form_definition["site"]["app_id"] = app_id - - return form_definition - - def submit_form( - self, - identifier: str, - form_data: dict[str, Any], - action: str, - is_token: bool = False, - submission_type: HumanInputSubmissionType = HumanInputSubmissionType.web_form, - submission_user_id: Optional[str] = None, - submission_end_user_id: Optional[str] = None, - ) -> HumanInputForm: - """ - Submit a form. - - Args: - identifier: Form ID or web app token - form_data: The submitted form data - action: The action taken by the user - is_token: True if identifier is a web app token, False if it's a form ID - submission_type: Type of submission (web_form, web_app, email) - submission_user_id: ID of the user who submitted (for console submissions) - submission_end_user_id: ID of the end user who submitted (for webapp submissions) - - Returns: - Updated HumanInputForm domain model - - Raises: - HumanInputFormNotFoundError: If the form is not found - HumanInputFormExpiredError: If the form has expired - HumanInputFormAlreadySubmittedError: If the form has already been submitted - InvalidFormDataError: If the submission data is invalid - """ - if is_token: - form = self.get_form_by_token(identifier) - else: - form = self.get_form_by_id(identifier) - - if form.is_expired: - raise HumanInputFormExpiredError() - - if form.is_submitted: - raise HumanInputFormAlreadySubmittedError() - - try: - form.submit( - data=form_data, - action=action, - submission_type=submission_type, - submission_user_id=submission_user_id, - submission_end_user_id=submission_end_user_id, - ) - except ValueError as e: - raise InvalidFormDataError(str(e)) from e - - self._repository.save(form) - logger.info("Form %s submitted with action %s", form.id_, action) - return form - - def cleanup_expired_forms(self, expiry_hours: int = 48) -> int: - """ - Clean up expired forms. - - Args: - expiry_hours: Number of hours after which forms should be expired - - Returns: - Number of forms cleaned up - """ - count = self._repository.mark_expired_forms(expiry_hours) - logger.info("Cleaned up %d expired forms", count) - return count - - def get_pending_forms_for_workflow_run(self, workflow_run_id: str) -> list[HumanInputForm]: - """ - Get all pending human input forms for a workflow run. - - Args: - workflow_run_id: The workflow run ID to filter by - - Returns: - List of pending HumanInputForm domain models - """ - return self._repository.get_pending_forms_for_workflow_run(workflow_run_id) - - def form_exists(self, identifier: str, is_token: bool = False) -> bool: - """ - Check if a form exists. - - Args: - identifier: Form ID or web app token - is_token: True if identifier is a web app token, False if it's a form ID - - Returns: - True if the form exists, False otherwise - """ - if is_token: - return self._repository.exists_by_web_app_token(identifier) - else: - return self._repository.exists_by_id(identifier) diff --git a/api/services/human_input_form_service.py b/api/services/human_input_form_service.py deleted file mode 100644 index 1b6d4926be..0000000000 --- a/api/services/human_input_form_service.py +++ /dev/null @@ -1,352 +0,0 @@ -""" -Service for managing human input forms. - -This service maintains backward compatibility while internally using domain models -and repositories for better architecture. -""" - -import json -import logging -from datetime import datetime, timedelta -from typing import Any, Optional - -from sqlalchemy import and_, select -from sqlalchemy.orm import Session - -from core.repositories.factory import DifyCoreRepositoryFactory -from core.workflow.entities.human_input_form import HumanInputForm as DomainHumanInputForm, HumanInputSubmissionType -from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository -from models.human_input import ( - HumanInputForm, - HumanInputFormStatus, - HumanInputSubmissionType as DBHumanInputSubmissionType, -) -from services.errors.base import BaseServiceError -from services.human_input_form_domain_service import ( - HumanInputFormDomainService, - HumanInputFormNotFoundError as DomainNotFoundError, - HumanInputFormExpiredError as DomainExpiredError, - HumanInputFormAlreadySubmittedError as DomainAlreadySubmittedError, - InvalidFormDataError as DomainInvalidFormDataError, -) - -logger = logging.getLogger(__name__) - - -class HumanInputFormNotFoundError(BaseServiceError): - """Raised when a human input form is not found.""" - - def __init__(self, identifier: str): - super().__init__(f"Human input form not found: {identifier}") - self.identifier = identifier - - -class HumanInputFormExpiredError(BaseServiceError): - """Raised when a human input form has expired.""" - - def __init__(self): - super().__init__("Human input form has expired") - - -class HumanInputFormAlreadySubmittedError(BaseServiceError): - """Raised when trying to operate on an already submitted form.""" - - def __init__(self): - super().__init__("Human input form has already been submitted") - - -class InvalidFormDataError(BaseServiceError): - """Raised when form submission data is invalid.""" - - def __init__(self, message: str): - super().__init__(f"Invalid form data: {message}") - self.message = message - - -class HumanInputFormService: - """Service for managing human input forms using domain models internally.""" - - def __init__(self, session: Session): - """ - Initialize the service with a database session. - - Args: - session: SQLAlchemy session - """ - self._session = session - # For backward compatibility, we need user and app_id context - # These would typically be available from the request context - # For now, we'll extract them from the session or use defaults - self._user = None # This should be set from request context - self._app_id = None # This should be set from request context - self._domain_service = None - - def _get_domain_service(self) -> HumanInputFormDomainService: - """ - Get the domain service instance. - - Note: This requires user and app_id context to be properly set. - In a real implementation, these would be extracted from the request context. - """ - if self._domain_service is None: - if not self._user: - # For backward compatibility, we need to handle this case - # In practice, the user should be available from the request context - raise ValueError("User context is required for domain operations") - - repository = DifyCoreRepositoryFactory.create_human_input_form_repository( - session_factory=self._session, - user=self._user, - app_id=self._app_id or "", - ) - self._domain_service = HumanInputFormDomainService(repository) - - return self._domain_service - - def _domain_to_db_model(self, domain_form: DomainHumanInputForm) -> HumanInputForm: - """Convert domain model to database model for backward compatibility.""" - # Find existing DB model or create new one - db_model = self._session.get(HumanInputForm, domain_form.id_) - if db_model is None: - db_model = HumanInputForm() - db_model.id = domain_form.id_ - # Set tenant_id and app_id from context - if hasattr(self._user, "current_tenant_id"): - db_model.tenant_id = self._user.current_tenant_id - elif hasattr(self._user, "tenant_id"): - db_model.tenant_id = self._user.tenant_id - if self._app_id: - db_model.app_id = self._app_id - - # Update fields - db_model.workflow_run_id = domain_form.workflow_run_id - db_model.form_definition = json.dumps(domain_form.form_definition) - db_model.rendered_content = domain_form.rendered_content - db_model.status = HumanInputFormStatus(domain_form.status.value) - db_model.web_app_token = domain_form.web_app_token - db_model.created_at = domain_form.created_at - - # Handle submission data - if domain_form.submission: - db_model.submitted_data = json.dumps(domain_form.submission.data) - db_model.submitted_at = domain_form.submission.submitted_at - db_model.submission_type = DBHumanInputSubmissionType(domain_form.submission.submission_type.value) - db_model.submission_user_id = domain_form.submission.submission_user_id - db_model.submission_end_user_id = domain_form.submission.submission_end_user_id - # Note: submitter_email is not in the current DB model schema - - return db_model - - def set_context(self, user, app_id: Optional[str] = None) -> None: - """ - Set user and app context for the service. - - Args: - user: User object (Account or EndUser) - app_id: Application ID - """ - self._user = user - self._app_id = app_id - self._domain_service = None # Reset to force recreation with new context - - def create_form( - self, - *, - form_id: str, - workflow_run_id: str, - tenant_id: str, - app_id: str, - form_definition: str, - rendered_content: str, - web_app_token: Optional[str] = None, - ) -> HumanInputForm: - """Create a new human input form.""" - # Set context for this operation - self._app_id = app_id - - try: - domain_service = self._get_domain_service() - domain_form = domain_service.create_form( - form_id=form_id, - workflow_run_id=workflow_run_id, - form_definition=json.loads(form_definition), - rendered_content=rendered_content, - web_app_token=web_app_token, - ) - - # Convert back to DB model for backward compatibility - db_model = self._domain_to_db_model(domain_form) - self._session.add(db_model) - self._session.commit() - - logger.info("Created human input form %s", form_id) - return db_model - - except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e: - # Convert domain exceptions to service exceptions - if isinstance(e, DomainNotFoundError): - raise HumanInputFormNotFoundError(e.identifier) from e - elif isinstance(e, DomainExpiredError): - raise HumanInputFormExpiredError() from e - elif isinstance(e, DomainAlreadySubmittedError): - raise HumanInputFormAlreadySubmittedError() from e - elif isinstance(e, DomainInvalidFormDataError): - raise InvalidFormDataError(e.message) from e - - def get_form_by_id(self, form_id: str) -> HumanInputForm: - """Get a form by its ID.""" - try: - domain_service = self._get_domain_service() - domain_form = domain_service.get_form_by_id(form_id) - return self._domain_to_db_model(domain_form) - except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e: - if isinstance(e, DomainNotFoundError): - raise HumanInputFormNotFoundError(e.identifier) from e - elif isinstance(e, DomainExpiredError): - raise HumanInputFormExpiredError() from e - elif isinstance(e, DomainAlreadySubmittedError): - raise HumanInputFormAlreadySubmittedError() from e - elif isinstance(e, DomainInvalidFormDataError): - raise InvalidFormDataError(e.message) from e - - def get_form_by_token(self, web_app_token: str) -> HumanInputForm: - """Get a form by its web app token.""" - try: - domain_service = self._get_domain_service() - domain_form = domain_service.get_form_by_token(web_app_token) - return self._domain_to_db_model(domain_form) - except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e: - if isinstance(e, DomainNotFoundError): - raise HumanInputFormNotFoundError(e.identifier) from e - elif isinstance(e, DomainExpiredError): - raise HumanInputFormExpiredError() from e - elif isinstance(e, DomainAlreadySubmittedError): - raise HumanInputFormAlreadySubmittedError() from e - elif isinstance(e, DomainInvalidFormDataError): - raise InvalidFormDataError(e.message) from e - - def get_form_definition( - self, - identifier: str, - is_token: bool = False, - include_site_info: bool = False, - ) -> dict[str, Any]: - """ - Get form definition for display. - - Args: - identifier: Form ID or web app token - is_token: True if identifier is a web app token, False if it's a form ID - include_site_info: Whether to include site information in the response - """ - try: - domain_service = self._get_domain_service() - return domain_service.get_form_definition( - identifier=identifier, - is_token=is_token, - include_site_info=include_site_info, - app_id=self._app_id, - ) - except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e: - if isinstance(e, DomainNotFoundError): - raise HumanInputFormNotFoundError(e.identifier) from e - elif isinstance(e, DomainExpiredError): - raise HumanInputFormExpiredError() from e - elif isinstance(e, DomainAlreadySubmittedError): - raise HumanInputFormAlreadySubmittedError() from e - elif isinstance(e, DomainInvalidFormDataError): - raise InvalidFormDataError(e.message) from e - - def submit_form( - self, - identifier: str, - form_data: dict[str, Any], - action: str, - is_token: bool = False, - submission_type: HumanInputSubmissionType = HumanInputSubmissionType.web_form, - submission_user_id: Optional[str] = None, - submission_end_user_id: Optional[str] = None, - ) -> HumanInputForm: - """ - Submit a form. - - Args: - identifier: Form ID or web app token - form_data: The submitted form data - action: The action taken by the user - is_token: True if identifier is a web app token, False if it's a form ID - submission_type: Type of submission (web_form, web_app, email) - submission_user_id: ID of the user who submitted (for console submissions) - submission_end_user_id: ID of the end user who submitted (for webapp submissions) - """ - try: - domain_service = self._get_domain_service() - domain_form = domain_service.submit_form( - identifier=identifier, - form_data=form_data, - action=action, - is_token=is_token, - submission_type=submission_type, - submission_user_id=submission_user_id, - submission_end_user_id=submission_end_user_id, - ) - - # Convert back to DB model for backward compatibility - db_model = self._domain_to_db_model(domain_form) - self._session.merge(db_model) - self._session.commit() - - return db_model - - except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e: - if isinstance(e, DomainNotFoundError): - raise HumanInputFormNotFoundError(e.identifier) from e - elif isinstance(e, DomainExpiredError): - raise HumanInputFormExpiredError() from e - elif isinstance(e, DomainAlreadySubmittedError): - raise HumanInputFormAlreadySubmittedError() from e - elif isinstance(e, DomainInvalidFormDataError): - raise InvalidFormDataError(e.message) from e - - def _validate_submission(self, form: HumanInputForm, form_data: dict[str, Any], action: str) -> None: - """Validate form submission data.""" - form_definition = json.loads(form.form_definition) - - # Check that the action is valid - valid_actions = {act.get("id") for act in form_definition.get("user_actions", [])} - if action not in valid_actions: - raise InvalidFormDataError(f"Invalid action: {action}") - - # Note: We don't validate required inputs here as the original implementation - # allows extra inputs and doesn't strictly enforce all inputs to be present - - def cleanup_expired_forms(self) -> int: - """Clean up expired forms. Returns the number of forms cleaned up.""" - try: - domain_service = self._get_domain_service() - return domain_service.cleanup_expired_forms() - except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e: - if isinstance(e, DomainNotFoundError): - raise HumanInputFormNotFoundError(e.identifier) from e - elif isinstance(e, DomainExpiredError): - raise HumanInputFormExpiredError() from e - elif isinstance(e, DomainAlreadySubmittedError): - raise HumanInputFormAlreadySubmittedError() from e - elif isinstance(e, DomainInvalidFormDataError): - raise InvalidFormDataError(e.message) from e - - def get_pending_forms_for_workflow_run(self, workflow_run_id: str) -> list[HumanInputForm]: - """Get all pending human input forms for a workflow run.""" - try: - domain_service = self._get_domain_service() - domain_forms = domain_service.get_pending_forms_for_workflow_run(workflow_run_id) - return [self._domain_to_db_model(domain_form) for domain_form in domain_forms] - except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e: - if isinstance(e, DomainNotFoundError): - raise HumanInputFormNotFoundError(e.identifier) from e - elif isinstance(e, DomainExpiredError): - raise HumanInputFormExpiredError() from e - elif isinstance(e, DomainAlreadySubmittedError): - raise HumanInputFormAlreadySubmittedError() from e - elif isinstance(e, DomainInvalidFormDataError): - raise InvalidFormDataError(e.message) from e diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py new file mode 100644 index 0000000000..aca1058cf9 --- /dev/null +++ b/api/services/human_input_service.py @@ -0,0 +1,63 @@ +import abc +from collections.abc import Mapping +from typing import Any + +from sqlalchemy import Engine +from sqlalchemy.orm import sessionmaker + +from core.workflow.nodes.human_input.entities import FormDefinition +from libs.exception import BaseHTTPException +from models.human_input import RecipientType + + +class Form(abc.ABC): + @abc.abstractmethod + def get_definition(self) -> FormDefinition: + pass + + @abc.abstractmethod + def submitted(self) -> bool: + pass + + +class HumanInputError(Exception): + pass + + +class FormSubmittedError(HumanInputError, BaseHTTPException): + error_code = "human_input_form_submitted" + description = "This form has already been submitted by another user, form_id={form_id}" + code = 412 + + def __init__(self, form_id: str): + description = self.description.format(form_id=form_id) + super().__init__(description=description) + + +class FormNotFoundError(HumanInputError, BaseException): + error_code = "human_input_form_not_found" + code = 404 + + +class HumanInputService: + def __init__( + self, + session_factory: sessionmaker | Engine, + ): + if isinstance(session_factory, Engine): + session_factory = sessionmaker(bind=session_factory) + self._session_factory = session_factory + + def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form: + pass + + def get_form_definition_by_id(self, form_id: str) -> Form | None: + pass + + def submit_form_by_id(self, form_id: str, selected_action_id: str, form_data: Mapping[str, Any]): + pass + + def submit_form_by_token( + self, recipient_type: RecipientType, form_token: str, selected_action_id: str, form_data: Mapping[str, Any] + ): + pass diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index f94d219dcf..0a3b5edfbd 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -136,7 +136,7 @@ class _ChatflowRunner: chat_generator = AdvancedChatAppGenerator() - workflow_run_id = uuid.uuid4() + workflow_run_id = exec_params.workflow_run_id with self._setup_flask_context(user): response = chat_generator.generate( diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py new file mode 100644 index 0000000000..7784bbdd2f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py @@ -0,0 +1,75 @@ +""" +Tests for PauseReason discriminated union serialization/deserialization. +""" + +import pytest +from pydantic import BaseModel, ValidationError + +from core.workflow.entities.pause_reason import ( + HumanInputRequired, + PauseReason, + SchedulingPause, +) + + +class _Holder(BaseModel): + """Helper model that embeds PauseReason for union tests.""" + + reason: PauseReason + + +class TestPauseReasonDiscriminator: + """Test suite for PauseReason union discriminator.""" + + @pytest.mark.parametrize( + ("dict_value", "expected"), + [ + pytest.param( + { + "reason": { + "TYPE": "human_input_required", + "human_input_form_id": "form_id", + }, + }, + HumanInputRequired(human_input_form_id="form_id"), + id="HumanInputRequired", + ), + pytest.param( + { + "reason": { + "TYPE": "scheduled_pause", + "message": "Hold on", + } + }, + SchedulingPause(message="Hold on"), + id="SchedulingPause", + ), + ], + ) + def test_model_validate(self, dict_value, expected): + """Ensure scheduled pause payloads with lowercase TYPE deserialize.""" + holder = _Holder.model_validate(dict_value) + + assert type(holder.reason) == type(expected) + assert holder.reason == expected + + @pytest.mark.parametrize( + "reason", + [ + HumanInputRequired(human_input_form_id="form_id"), + SchedulingPause(message="Hold on"), + ], + ids=lambda x: type(x).__name__, + ) + def test_model_construct(self, reason): + holder = _Holder(reason=reason) + assert holder.reason == reason + + def test_model_construct_with_invalid_type(self): + with pytest.raises(ValidationError): + holder = _Holder(reason=object()) # type: ignore + + def test_unknown_type_fails_validation(self): + """Unknown TYPE values should raise a validation error.""" + with pytest.raises(ValidationError): + _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 653d5dda59..73d109f033 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -7,10 +7,11 @@ from pydantic import ValidationError from core.workflow.nodes.human_input.entities import ( ButtonStyle, - DeliveryMethod, DeliveryMethodType, EmailDeliveryConfig, + EmailDeliveryMethod, EmailRecipients, + EmailRecipientType, ExternalRecipient, FormInput, FormInputPlaceholder, @@ -18,10 +19,10 @@ from core.workflow.nodes.human_input.entities import ( HumanInputNodeData, MemberRecipient, PlaceholderType, - RecipientType, TimeoutUnit, UserAction, - WebAppDeliveryConfig, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, ) @@ -30,19 +31,19 @@ class TestDeliveryMethod: def test_webapp_delivery_method(self): """Test webapp delivery method creation.""" - delivery_method = DeliveryMethod(type=DeliveryMethodType.WEBAPP, enabled=True, config=WebAppDeliveryConfig()) + delivery_method = WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()) assert delivery_method.type == DeliveryMethodType.WEBAPP assert delivery_method.enabled is True - assert isinstance(delivery_method.config, WebAppDeliveryConfig) + assert isinstance(delivery_method.config, _WebAppDeliveryConfig) def test_email_delivery_method(self): """Test email delivery method creation.""" recipients = EmailRecipients( whole_workspace=False, items=[ - MemberRecipient(type=RecipientType.MEMBER, user_id="test-user-123"), - ExternalRecipient(type=RecipientType.EXTERNAL, email="test@example.com"), + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), + ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), ], ) @@ -50,7 +51,7 @@ class TestDeliveryMethod: recipients=recipients, subject="Test Subject", body="Test body with {{#url#}} placeholder" ) - delivery_method = DeliveryMethod(type=DeliveryMethodType.EMAIL, enabled=True, config=config) + delivery_method = EmailDeliveryMethod(enabled=True, config=config) assert delivery_method.type == DeliveryMethodType.EMAIL assert delivery_method.enabled is True @@ -118,7 +119,7 @@ class TestHumanInputNodeData: def test_valid_node_data_creation(self): """Test creating valid human input node data.""" - delivery_methods = [DeliveryMethod(type=DeliveryMethodType.WEBAPP, enabled=True, config=WebAppDeliveryConfig())] + delivery_methods = [WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())] inputs = [ FormInput( @@ -153,11 +154,12 @@ class TestHumanInputNodeData: def test_node_data_with_multiple_delivery_methods(self): """Test node data with multiple delivery methods.""" delivery_methods = [ - DeliveryMethod(type=DeliveryMethodType.WEBAPP, enabled=True, config=WebAppDeliveryConfig()), - DeliveryMethod( - type=DeliveryMethodType.EMAIL, + WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()), + EmailDeliveryMethod( enabled=False, # Disabled method should be fine - config=None, + config=EmailDeliveryConfig( + subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) + ), ), ] @@ -182,28 +184,64 @@ class TestHumanInputNodeData: assert node_data.timeout == 36 assert node_data.timeout_unit == TimeoutUnit.HOUR + def test_duplicate_input_output_variable_name_raises_validation_error(self): + """Duplicate form input output_variable_name should raise validation error.""" + duplicate_inputs = [ + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"), + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"), + ] + + with pytest.raises(ValidationError, match="duplicated output_variable_name 'content'"): + HumanInputNodeData(title="Test Node", inputs=duplicate_inputs) + + def test_duplicate_user_action_ids_raise_validation_error(self): + """Duplicate user action ids should raise validation error.""" + duplicate_actions = [ + UserAction(id="submit", title="Submit"), + UserAction(id="submit", title="Submit Again"), + ] + + with pytest.raises(ValidationError, match="duplicated user action id 'submit'"): + HumanInputNodeData(title="Test Node", user_actions=duplicate_actions) + + def test_extract_outputs_field_names(self): + content = r"""This is titile {{#start.title#}} + + A content is required: + + {{#$outputs.content#}} + + A ending is required: + + {{#$outputs.ending#}} + """ + + node_data = HumanInputNodeData(title="Human Input", form_content=content) + field_names = node_data.outputs_field_names() + assert field_names == ["content", "ending"] + class TestRecipients: """Test email recipient entities.""" def test_member_recipient(self): """Test member recipient creation.""" - recipient = MemberRecipient(type=RecipientType.MEMBER, user_id="user-123") + recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") - assert recipient.type == RecipientType.MEMBER + assert recipient.type == EmailRecipientType.MEMBER assert recipient.user_id == "user-123" def test_external_recipient(self): """Test external recipient creation.""" - recipient = ExternalRecipient(type=RecipientType.EXTERNAL, email="test@example.com") + recipient = ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com") - assert recipient.type == RecipientType.EXTERNAL + assert recipient.type == EmailRecipientType.EXTERNAL assert recipient.email == "test@example.com" def test_email_recipients_whole_workspace(self): """Test email recipients with whole workspace enabled.""" recipients = EmailRecipients( - whole_workspace=True, items=[MemberRecipient(type=RecipientType.MEMBER, user_id="user-123")] + whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] ) assert recipients.whole_workspace is True @@ -214,8 +252,8 @@ class TestRecipients: recipients = EmailRecipients( whole_workspace=False, items=[ - MemberRecipient(type=RecipientType.MEMBER, user_id="user-123"), - ExternalRecipient(type=RecipientType.EXTERNAL, email="external@example.com"), + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), + ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), ], )