mirror of
https://github.com/langgenius/dify.git
synced 2026-02-23 03:17:57 +08:00
WIP: P2
This commit is contained in:
@ -4,56 +4,79 @@ Console/Studio Human Input Form APIs.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import g, jsonify
|
||||
from flask import Response, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.human_input import HumanInputSubmissionType
|
||||
from services.human_input_form_service import (
|
||||
HumanInputFormAlreadySubmittedError,
|
||||
HumanInputFormExpiredError,
|
||||
HumanInputFormNotFoundError,
|
||||
HumanInputFormService,
|
||||
InvalidFormDataError,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.human_input import HumanInputForm as HumanInputFormModel
|
||||
from services.human_input_service import HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _FormDefinitionWithSite(FormDefinition):
|
||||
site: None
|
||||
|
||||
|
||||
def _jsonify_pydantic_model(model: BaseModel) -> Response:
|
||||
return Response(model.model_dump_json(), mimetype="application/json")
|
||||
|
||||
|
||||
class ConsoleHumanInputFormApi(Resource):
|
||||
"""Console API for getting human input form definition."""
|
||||
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, form_id: str):
|
||||
"""
|
||||
Get human input form definition by form ID.
|
||||
|
||||
GET /console/api/form/human_input/<form_id>
|
||||
"""
|
||||
try:
|
||||
service = HumanInputFormService(db.session())
|
||||
form_definition = service.get_form_definition(identifier=form_id, is_token=False, include_site_info=False)
|
||||
return form_definition, 200
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_definition_by_id(
|
||||
form_id=form_id,
|
||||
)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
except HumanInputFormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
except HumanInputFormExpiredError:
|
||||
return jsonify(
|
||||
{"error_code": "human_input_form_expired", "description": "Human input form has expired"}
|
||||
), 400
|
||||
except HumanInputFormAlreadySubmittedError:
|
||||
return jsonify(
|
||||
{
|
||||
"error_code": "human_input_form_submitted",
|
||||
"description": "Human input form has already been submitted",
|
||||
}
|
||||
), 400
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
form_model = db.session.get(HumanInputFormModel, form_id)
|
||||
if form_model is None or form_model.tenant_id != current_tenant_id:
|
||||
raise NotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
from models import App
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
|
||||
workflow_run = db.session.get(WorkflowRun, form_model.workflow_run_id)
|
||||
if workflow_run is None or workflow_run.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
if workflow_run.app_id:
|
||||
app = db.session.get(App, workflow_run.app_id)
|
||||
if app is None or app.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
owner_account_id = app.created_by
|
||||
else:
|
||||
workflow = db.session.get(Workflow, workflow_run.workflow_id)
|
||||
if workflow is None or workflow.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("Workflow not found")
|
||||
owner_account_id = workflow.created_by
|
||||
|
||||
if owner_account_id != current_user.id:
|
||||
raise Forbidden("You do not have permission to access this human input form.")
|
||||
|
||||
return _jsonify_pydantic_model(form.get_definition())
|
||||
|
||||
|
||||
class ConsoleHumanInputFormSubmissionApi(Resource):
|
||||
@ -80,75 +103,15 @@ class ConsoleHumanInputFormSubmissionApi(Resource):
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Submit the form
|
||||
service = HumanInputFormService(db.session())
|
||||
service.submit_form(
|
||||
identifier=form_id,
|
||||
form_data=args["inputs"],
|
||||
action=args["action"],
|
||||
is_token=False,
|
||||
submission_type=HumanInputSubmissionType.web_form,
|
||||
submission_user_id=g.current_user.id,
|
||||
)
|
||||
# Submit the form
|
||||
service = HumanInputService(db.engine)
|
||||
service.submit_form_by_id(
|
||||
form_id=form_id,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
)
|
||||
|
||||
return {}, 200
|
||||
|
||||
except HumanInputFormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
except HumanInputFormExpiredError:
|
||||
return jsonify(
|
||||
{"error_code": "human_input_form_expired", "description": "Human input form has expired"}
|
||||
), 400
|
||||
except HumanInputFormAlreadySubmittedError:
|
||||
return jsonify(
|
||||
{
|
||||
"error_code": "human_input_form_submitted",
|
||||
"description": "Human input form has already been submitted",
|
||||
}
|
||||
), 400
|
||||
except InvalidFormDataError as e:
|
||||
return jsonify({"error_code": "invalid_form_data", "description": e.message}), 400
|
||||
|
||||
|
||||
class ConsoleWorkflowResumeWaitApi(Resource):
|
||||
"""Console API for long-polling workflow resume wait."""
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, task_id: str):
|
||||
"""
|
||||
Get workflow execution resume notification.
|
||||
|
||||
GET /console/api/workflow/<task_id>/resume-wait
|
||||
|
||||
This is a long-polling API that waits for workflow to resume from paused state.
|
||||
"""
|
||||
import time
|
||||
|
||||
# TODO: Implement actual workflow status checking
|
||||
# For now, return a basic response
|
||||
|
||||
timeout = 30 # 30 seconds timeout for demo
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# TODO: Check workflow status from database/cache
|
||||
# workflow_status = workflow_service.get_status(task_id)
|
||||
|
||||
# For demo purposes, simulate different states
|
||||
# In real implementation, this would check the actual workflow state
|
||||
workflow_status = "paused" # or "running" or "ended"
|
||||
|
||||
if workflow_status == "running":
|
||||
return {"status": "running"}, 200
|
||||
elif workflow_status == "ended":
|
||||
return {"status": "ended"}, 200
|
||||
|
||||
time.sleep(1) # Poll every second
|
||||
|
||||
# Return paused status if timeout reached
|
||||
return {"status": "paused"}, 200
|
||||
return jsonify({})
|
||||
|
||||
|
||||
class ConsoleWorkflowEventsApi(Resource):
|
||||
@ -156,7 +119,7 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, task_id: str):
|
||||
def get(self, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
@ -164,9 +127,8 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response
|
||||
events =
|
||||
|
||||
def generate_events() -> Generator[str, None, None]:
|
||||
"""Generate SSE events for workflow execution."""
|
||||
@ -263,6 +225,5 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
# Register the APIs
|
||||
api.add_resource(ConsoleHumanInputFormApi, "/form/human_input/<string:form_id>")
|
||||
api.add_resource(ConsoleHumanInputFormSubmissionApi, "/form/human_input/<string:form_id>", methods=["POST"])
|
||||
api.add_resource(ConsoleWorkflowResumeWaitApi, "/workflow/<string:task_id>/resume-wait")
|
||||
api.add_resource(ConsoleWorkflowEventsApi, "/workflow/<string:task_id>/events")
|
||||
api.add_resource(ConsoleWorkflowEventsApi, "/workflow/<string:workflow_run_id>/events")
|
||||
api.add_resource(ConsoleWorkflowPauseDetailsApi, "/workflow/<string:workflow_run_id>/pause-details")
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
from typing import Any, NamedTuple, NewType, Union, final
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
@ -53,6 +53,7 @@ from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
@ -264,6 +265,54 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def workflow_run_result_to_finish_response(
|
||||
cls,
|
||||
*,
|
||||
workflow_run: WorkflowRun,
|
||||
creator_user: Account | EndUser,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
run_id = workflow_run.id
|
||||
elapsed_time = workflow_run.elapsed_time
|
||||
|
||||
encoded_outputs = workflow_run.outputs_dict
|
||||
finished_at = workflow_run.finished_at
|
||||
assert finished_at is not None
|
||||
|
||||
created_by: Mapping[str, object]
|
||||
user = creator_user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id, # TODO
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=run_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
status=workflow_run.status.value,
|
||||
outputs=encoded_outputs,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(finished_at.timestamp()),
|
||||
files=cls.fetch_files_from_node_outputs(encoded_outputs),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_node_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
@ -592,7 +641,8 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
|
||||
@classmethod
|
||||
def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
@ -601,7 +651,7 @@ class WorkflowResponseConverter:
|
||||
if not outputs_dict:
|
||||
return []
|
||||
|
||||
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Mapping, Union, cast
|
||||
from typing import Any, Mapping, Union, cast
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription, Topic
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -29,8 +33,7 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
@ -38,7 +41,6 @@ from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Me
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -302,15 +304,34 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
return topic
|
||||
|
||||
@classmethod
|
||||
def retrieve_events(cls, app_mode: AppMode, workflow_run_id: uuid.UUID) -> Generator[Mapping | str, None, None]:
|
||||
def retrieve_events(
|
||||
cls, app_mode: AppMode, workflow_run_id: uuid.UUID, idle_timeout=300
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
with topic.subscribe() as sub:
|
||||
for payload in sub:
|
||||
event = json.loads(payload)
|
||||
yield event
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
return _topic_msg_generator(topic, idle_timeout)
|
||||
|
||||
event_type = event.get("event")
|
||||
if event_type == StreamEvent.WORKFLOW_FINISHED:
|
||||
|
||||
def _topic_msg_generator(topic: Topic, idle_timeout: float) -> Generator[Mapping[str, Any], None, None]:
|
||||
last_msg_time = time.time()
|
||||
with topic.subscribe() as sub:
|
||||
while True:
|
||||
try:
|
||||
msg = sub.receive()
|
||||
except SubscriptionClosedError:
|
||||
return
|
||||
if msg is None:
|
||||
current_time = time.time()
|
||||
if current_time - last_msg_time > idle_timeout:
|
||||
return
|
||||
# skip the `None` message
|
||||
continue
|
||||
|
||||
last_msg_time = time.time()
|
||||
event = json.loads(msg)
|
||||
yield event
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
|
||||
event_type = event.get("event")
|
||||
if event_type in (StreamEvent.WORKFLOW_FINISHED, StreamEvent.WORKFLOW_PAUSED):
|
||||
return
|
||||
|
||||
@ -69,6 +69,7 @@ class StreamEvent(StrEnum):
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_PAUSED = "workflow_paused"
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
|
||||
232
api/core/repositories/human_input_reposotiry.py
Normal file
232
api/core/repositories/human_input_reposotiry.py
Normal file
@ -0,0 +1,232 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Mapping
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipient,
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
FormNotFoundError,
|
||||
FormSubmissionEntity,
|
||||
HumanInputFormEntity,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputRecipient,
|
||||
RecipientType,
|
||||
WebAppRecipientPayload,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _DeliveryAndRecipients:
|
||||
delivery: HumanInputDelivery
|
||||
recipients: Sequence[HumanInputRecipient]
|
||||
|
||||
def webapp_recipient(self) -> HumanInputRecipient | None:
|
||||
return next((i for i in self.recipients if i.recipient_type == RecipientType.WEBAPP), None)
|
||||
|
||||
|
||||
class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
def __init__(self, form_model: HumanInputForm, web_app_recipient: HumanInputRecipient | None):
|
||||
self._form_model = form_model
|
||||
self._web_app_recipient = web_app_recipient
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._form_model.id
|
||||
|
||||
@property
|
||||
def web_app_token(self):
|
||||
if self._web_app_recipient is None:
|
||||
return None
|
||||
return self._web_app_recipient.access_token
|
||||
|
||||
|
||||
class _FormSubmissionEntityImpl(FormSubmissionEntity):
|
||||
def __init__(self, form_model: HumanInputForm):
|
||||
self._form_model = form_model
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str:
|
||||
selected_action_id = self._form_model.selected_action_id
|
||||
if selected_action_id is None:
|
||||
raise AssertionError(f"selected_action_id should not be None, form_id={self._form_model.id}")
|
||||
return selected_action_id
|
||||
|
||||
def form_data(self) -> Mapping[str, Any]:
|
||||
submitted_data = self._form_model.submitted_data
|
||||
if submitted_data is None:
|
||||
raise AssertionError(f"submitted_data should not be None, form_id={self._form_model.id}")
|
||||
return json.loads(submitted_data)
|
||||
|
||||
|
||||
class WorkspaceMember:
|
||||
def user_id(self) -> str:
|
||||
pass
|
||||
|
||||
def email(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class WorkspaceMemberQueirer:
|
||||
def get_all_workspace_members(self) -> Sequence[WorkspaceMember]:
|
||||
# TOOD: need a way to query all members in the current workspace.
|
||||
pass
|
||||
|
||||
def get_members_by_ids(self, user_ids: Sequence[str]) -> Sequence[WorkspaceMember]:
|
||||
pass
|
||||
|
||||
|
||||
class HumanInputFormRepositoryImpl:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
tenant_id: str,
|
||||
member_quierer: WorkspaceMemberQueirer,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._tenant_id = tenant_id
|
||||
self._member_queirer = member_quierer
|
||||
|
||||
def _delivery_method_to_model(self, form_id, delivery_method: DeliveryChannelConfig) -> _DeliveryAndRecipients:
|
||||
delivery_id = str(uuidv7())
|
||||
delivery_model = HumanInputDelivery(
|
||||
id=delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=delivery_method.type,
|
||||
delivery_config_id=delivery_method.id,
|
||||
channel_payload=delivery_method.model_dump_json(),
|
||||
)
|
||||
recipients: list[HumanInputRecipient] = []
|
||||
if isinstance(delivery_method, WebAppDeliveryMethod):
|
||||
recipient_model = HumanInputRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=RecipientType.WEBAPP,
|
||||
recipient_payload=WebAppRecipientPayload().model_dump_json(),
|
||||
)
|
||||
recipients.append(recipient_model)
|
||||
elif isinstance(delivery_method, EmailDeliveryMethod):
|
||||
email_recipients_config = delivery_method.config.recipients
|
||||
if email_recipients_config.whole_workspace:
|
||||
recipients.extend(self._create_whole_workspace_recipients(form_id=form_id, delivery_id=delivery_id))
|
||||
else:
|
||||
recipients.extend(
|
||||
self._create_email_recipients(
|
||||
form_id=form_id, delivery_id=delivery_id, recipients=email_recipients_config.items
|
||||
)
|
||||
)
|
||||
|
||||
return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients)
|
||||
|
||||
def _create_email_recipients(
|
||||
self,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
recipients: Sequence[EmailRecipient],
|
||||
) -> list[HumanInputRecipient]:
|
||||
recipient_models: list[HumanInputRecipient] = []
|
||||
member_user_ids: list[str] = []
|
||||
for r in recipients:
|
||||
if isinstance(r, MemberRecipient):
|
||||
member_user_ids.append(r.user_id)
|
||||
elif isinstance(r, ExternalRecipient):
|
||||
recipient_model = HumanInputRecipient.new(
|
||||
form_id=form_id, delivery_id=delivery_id, payload=EmailExternalRecipientPayload(email=r.email)
|
||||
)
|
||||
recipient_models.append(recipient_model)
|
||||
else:
|
||||
raise AssertionError(f"unknown recipient type: recipient={r}")
|
||||
|
||||
members = self._member_queirer.get_members_by_ids(member_user_ids)
|
||||
for member in members:
|
||||
payload = EmailMemberRecipientPayload(user_id=member.user_id(), email=member.email())
|
||||
recipient_model = HumanInputRecipient.new(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
payload=payload,
|
||||
)
|
||||
recipient_models.append(recipient_model)
|
||||
return recipient_models
|
||||
|
||||
def _create_whole_workspace_recipients(self, form_id: str, delivery_id: str) -> list[HumanInputRecipient]:
|
||||
recipeint_models = []
|
||||
members = self._member_queirer.get_all_workspace_members()
|
||||
for member in members:
|
||||
payload = EmailMemberRecipientPayload(user_id=member.user_id(), email=member.email())
|
||||
recipient_model = HumanInputRecipient.new(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
payload=payload,
|
||||
)
|
||||
recipeint_models.append(recipient_model)
|
||||
|
||||
return recipeint_models
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
form_config: HumanInputNodeData = params.form_config
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
# Generate unique form ID
|
||||
form_id = str(uuidv7())
|
||||
form_definition = FormDefinition(
|
||||
inputs=form_config.inputs,
|
||||
user_actions=form_config.user_actions,
|
||||
)
|
||||
form_model = HumanInputForm(
|
||||
id=form_id,
|
||||
tenant_id=self._tenant_id,
|
||||
workflow_run_id=params.workflow_execution_id,
|
||||
node_id=params.node_id,
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content=params.rendered_content,
|
||||
expiration_time=form_config.expiration_time(naive_utc_now()),
|
||||
)
|
||||
session.add(form_model)
|
||||
web_app_recipient: HumanInputRecipient | None = None
|
||||
for delivery in form_config.delivery_methods:
|
||||
delivery_and_recipients = self._delivery_method_to_model(form_id, delivery)
|
||||
session.add(delivery_and_recipients.delivery)
|
||||
session.add_all(delivery_and_recipients.recipients)
|
||||
if web_app_recipient is None:
|
||||
web_app_recipient = delivery_and_recipients.webapp_recipient()
|
||||
session.flush()
|
||||
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, web_app_recipient=web_app_recipient)
|
||||
|
||||
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmissionEntity | None:
|
||||
query = select(HumanInputForm).where(
|
||||
HumanInputForm.workflow_run_id == workflow_execution_id,
|
||||
HumanInputForm.node_id == node_id,
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_model: HumanInputForm | None = session.scalars(query).first()
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found for node, {workflow_execution_id=}, {node_id=}")
|
||||
|
||||
if form_model.submitted_at is None:
|
||||
return None
|
||||
|
||||
return _FormSubmissionEntityImpl(form_model=form_model)
|
||||
@ -1,272 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy implementation of the HumanInputFormRepository.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from core.workflow.entities.human_input_form import (
|
||||
HumanInputForm,
|
||||
HumanInputFormStatus,
|
||||
HumanInputSubmissionType,
|
||||
FormSubmission,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository
|
||||
from libs.helper import extract_tenant_id
|
||||
from models.human_input import (
|
||||
HumanInputForm as HumanInputFormModel,
|
||||
HumanInputFormStatus as HumanInputFormStatusModel,
|
||||
HumanInputSubmissionType as HumanInputSubmissionTypeModel,
|
||||
)
|
||||
from models import Account, EndUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLAlchemyHumanInputFormRepository(HumanInputFormRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of the HumanInputFormRepository interface.
|
||||
|
||||
This implementation supports multi-tenancy by filtering operations based on tenant_id.
|
||||
Each method creates its own session, handles the transaction, and commits changes
|
||||
to the database. This prevents long-running connections in the workflow core.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: Union[sessionmaker, Engine],
|
||||
user: Union[Account, EndUser],
|
||||
app_id: str | None,
|
||||
):
|
||||
"""
|
||||
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
|
||||
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||
app_id: App ID for filtering by application (can be None)
|
||||
"""
|
||||
# If an engine is provided, create a sessionmaker from it
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
def _to_domain_model(self, db_model: HumanInputFormModel) -> HumanInputForm:
|
||||
"""
|
||||
Convert a database model to a domain model.
|
||||
|
||||
Args:
|
||||
db_model: The database model to convert
|
||||
|
||||
Returns:
|
||||
The domain model
|
||||
"""
|
||||
# Parse JSON fields
|
||||
form_definition = json.loads(db_model.form_definition) if db_model.form_definition else {}
|
||||
|
||||
# Create submission if present
|
||||
submission = None
|
||||
if db_model.status == HumanInputFormStatusModel.SUBMITTED and db_model.submitted_data:
|
||||
submission = FormSubmission(
|
||||
data=json.loads(db_model.submitted_data) if db_model.submitted_data else {},
|
||||
action="", # Action is not stored separately in DB model, would need to be stored in submitted_data
|
||||
submitted_at=db_model.submitted_at or datetime.utcnow(),
|
||||
submission_type=HumanInputSubmissionType(db_model.submission_type.value)
|
||||
if db_model.submission_type
|
||||
else HumanInputSubmissionType.web_form,
|
||||
submission_user_id=db_model.submission_user_id,
|
||||
submission_end_user_id=db_model.submission_end_user_id,
|
||||
submitter_email=db_model.submitter_email,
|
||||
)
|
||||
|
||||
return HumanInputForm(
|
||||
id_=db_model.id,
|
||||
workflow_run_id=db_model.workflow_run_id,
|
||||
form_definition=form_definition,
|
||||
rendered_content=db_model.rendered_content,
|
||||
status=HumanInputFormStatus(db_model.status.value),
|
||||
web_app_token=db_model.web_app_token,
|
||||
submission=submission,
|
||||
created_at=db_model.created_at,
|
||||
)
|
||||
|
||||
def _to_db_model(self, domain_model: HumanInputForm) -> HumanInputFormModel:
|
||||
"""
|
||||
Convert a domain model to a database model.
|
||||
|
||||
Args:
|
||||
domain_model: The domain model to convert
|
||||
|
||||
Returns:
|
||||
The database model
|
||||
"""
|
||||
db_model = HumanInputFormModel()
|
||||
db_model.id = domain_model.id_
|
||||
db_model.tenant_id = self._tenant_id
|
||||
if self._app_id is not None:
|
||||
db_model.app_id = self._app_id
|
||||
db_model.workflow_run_id = domain_model.workflow_run_id
|
||||
db_model.form_definition = json.dumps(domain_model.form_definition) if domain_model.form_definition else None
|
||||
db_model.rendered_content = domain_model.rendered_content
|
||||
db_model.status = HumanInputFormStatusModel(domain_model.status.value)
|
||||
db_model.web_app_token = domain_model.web_app_token
|
||||
db_model.created_at = domain_model.created_at
|
||||
|
||||
# Handle submission data
|
||||
if domain_model.submission:
|
||||
db_model.submitted_data = json.dumps(domain_model.submission.data) if domain_model.submission.data else None
|
||||
db_model.submitted_at = domain_model.submission.submitted_at
|
||||
db_model.submission_type = HumanInputSubmissionTypeModel(domain_model.submission.submission_type.value)
|
||||
db_model.submission_user_id = domain_model.submission.submission_user_id
|
||||
db_model.submission_end_user_id = domain_model.submission.submission_end_user_id
|
||||
db_model.submitter_email = domain_model.submission.submitter_email
|
||||
|
||||
return db_model
|
||||
|
||||
def save(self, form: HumanInputForm) -> None:
|
||||
"""
|
||||
Save or update a HumanInputForm domain entity to the database.
|
||||
|
||||
This method serves as a domain-to-database adapter that:
|
||||
1. Converts the domain entity to its database representation
|
||||
2. Persists the database model using SQLAlchemy's merge operation
|
||||
3. Maintains proper multi-tenancy by including tenant context during conversion
|
||||
|
||||
The method handles both creating new records and updating existing ones through
|
||||
SQLAlchemy's merge operation.
|
||||
|
||||
Args:
|
||||
form: The HumanInputForm domain entity to persist
|
||||
"""
|
||||
db_model = self._to_db_model(form)
|
||||
|
||||
with self._session_factory() as session:
|
||||
session.merge(db_model)
|
||||
session.commit()
|
||||
logger.info("Saved human input form %s", form.id_)
|
||||
|
||||
def get_by_id(self, form_id: str) -> HumanInputForm:
|
||||
"""Get a form by its ID."""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(HumanInputFormModel).where(
|
||||
and_(
|
||||
HumanInputFormModel.id == form_id,
|
||||
HumanInputFormModel.tenant_id == self._tenant_id,
|
||||
)
|
||||
)
|
||||
if self._app_id is not None:
|
||||
stmt = stmt.where(HumanInputFormModel.app_id == self._app_id)
|
||||
|
||||
db_model = session.scalar(stmt)
|
||||
if not db_model:
|
||||
raise ValueError(f"Human input form not found: {form_id}")
|
||||
|
||||
return self._to_domain_model(db_model)
|
||||
|
||||
def get_by_web_app_token(self, web_app_token: str) -> HumanInputForm:
|
||||
"""Get a form by its web app token."""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(HumanInputFormModel).where(
|
||||
and_(
|
||||
HumanInputFormModel.web_app_token == web_app_token,
|
||||
HumanInputFormModel.tenant_id == self._tenant_id,
|
||||
)
|
||||
)
|
||||
if self._app_id is not None:
|
||||
stmt = stmt.where(HumanInputFormModel.app_id == self._app_id)
|
||||
|
||||
db_model = session.scalar(stmt)
|
||||
if not db_model:
|
||||
raise ValueError(f"Human input form not found with token: {web_app_token}")
|
||||
|
||||
return self._to_domain_model(db_model)
|
||||
|
||||
def get_pending_forms_for_workflow_run(self, workflow_run_id: str) -> list[HumanInputForm]:
|
||||
"""Get all pending human input forms for a workflow run."""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(HumanInputFormModel).where(
|
||||
and_(
|
||||
HumanInputFormModel.workflow_run_id == workflow_run_id,
|
||||
HumanInputFormModel.status == HumanInputFormStatusModel.WAITING,
|
||||
HumanInputFormModel.tenant_id == self._tenant_id,
|
||||
)
|
||||
)
|
||||
if self._app_id is not None:
|
||||
stmt = stmt.where(HumanInputFormModel.app_id == self._app_id)
|
||||
|
||||
db_models = list(session.scalars(stmt).all())
|
||||
return [self._to_domain_model(db_model) for db_model in db_models]
|
||||
|
||||
def mark_expired_forms(self, expiry_hours: int = 48) -> int:
|
||||
"""Mark expired forms as expired."""
|
||||
with self._session_factory() as session:
|
||||
expiry_time = datetime.utcnow() - timedelta(hours=expiry_hours)
|
||||
|
||||
stmt = select(HumanInputFormModel).where(
|
||||
and_(
|
||||
HumanInputFormModel.status == HumanInputFormStatusModel.WAITING,
|
||||
HumanInputFormModel.created_at < expiry_time,
|
||||
HumanInputFormModel.tenant_id == self._tenant_id,
|
||||
)
|
||||
)
|
||||
if self._app_id is not None:
|
||||
stmt = stmt.where(HumanInputFormModel.app_id == self._app_id)
|
||||
|
||||
expired_forms = list(session.scalars(stmt).all())
|
||||
|
||||
count = 0
|
||||
for form in expired_forms:
|
||||
form.status = HumanInputFormStatusModel.EXPIRED
|
||||
count += 1
|
||||
|
||||
session.commit()
|
||||
logger.info("Marked %d forms as expired", count)
|
||||
return count
|
||||
|
||||
def exists_by_id(self, form_id: str) -> bool:
|
||||
"""Check if a form exists by ID."""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(HumanInputFormModel).where(
|
||||
and_(
|
||||
HumanInputFormModel.id == form_id,
|
||||
HumanInputFormModel.tenant_id == self._tenant_id,
|
||||
)
|
||||
)
|
||||
if self._app_id is not None:
|
||||
stmt = stmt.where(HumanInputFormModel.app_id == self._app_id)
|
||||
|
||||
return session.scalar(stmt) is not None
|
||||
|
||||
def exists_by_web_app_token(self, web_app_token: str) -> bool:
|
||||
"""Check if a form exists by web app token."""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(HumanInputFormModel).where(
|
||||
and_(
|
||||
HumanInputFormModel.web_app_token == web_app_token,
|
||||
HumanInputFormModel.tenant_id == self._tenant_id,
|
||||
)
|
||||
)
|
||||
if self._app_id is not None:
|
||||
stmt = stmt.where(HumanInputFormModel.app_id == self._app_id)
|
||||
|
||||
return session.scalar(stmt) is not None
|
||||
@ -1,197 +0,0 @@
|
||||
"""
|
||||
Domain entities for human input forms.
|
||||
|
||||
Models are independent of the storage mechanism and don't contain
|
||||
implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
def naive_utc_from_now() -> datetime:
|
||||
"""Get current UTC datetime."""
|
||||
return naive_utc_now()
|
||||
|
||||
|
||||
class HumanInputFormStatus(StrEnum):
|
||||
"""Status of a human input form."""
|
||||
|
||||
WAITING = "waiting"
|
||||
EXPIRED = "expired"
|
||||
SUBMITTED = "submitted"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
class HumanInputSubmissionType(StrEnum):
|
||||
"""Type of submission for human input forms."""
|
||||
|
||||
web_form = "web_form"
|
||||
web_app = "web_app"
|
||||
email = "email"
|
||||
|
||||
|
||||
class FormSubmission(BaseModel):
|
||||
"""Represents a form submission."""
|
||||
|
||||
data: dict[str, Any] = Field(default_factory=dict)
|
||||
action: str = ""
|
||||
submitted_at: datetime = Field(default_factory=naive_utc_now)
|
||||
submission_type: HumanInputSubmissionType = HumanInputSubmissionType.web_form
|
||||
submission_user_id: Optional[str] = None
|
||||
submission_end_user_id: Optional[str] = None
|
||||
submitter_email: Optional[str] = None
|
||||
|
||||
|
||||
class HumanInputForm(BaseModel):
|
||||
"""
|
||||
Domain model for human input forms.
|
||||
|
||||
This model represents the business concept of a human input form without
|
||||
infrastructure concerns like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
id_: str = Field(...)
|
||||
workflow_run_id: str = Field(...)
|
||||
form_definition: dict[str, Any] = Field(default_factory=dict)
|
||||
rendered_content: str = ""
|
||||
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
web_app_token: Optional[str] = None
|
||||
submission: Optional[FormSubmission] = None
|
||||
created_at: datetime = Field(default_factory=naive_utc_from_now)
|
||||
|
||||
@property
|
||||
def is_submitted(self) -> bool:
|
||||
"""Check if the form has been submitted."""
|
||||
return self.status == HumanInputFormStatus.SUBMITTED
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the form has expired."""
|
||||
return self.status == HumanInputFormStatus.EXPIRED
|
||||
|
||||
@property
|
||||
def is_waiting(self) -> bool:
|
||||
"""Check if the form is waiting for submission."""
|
||||
return self.status == HumanInputFormStatus.WAITING
|
||||
|
||||
@property
|
||||
def can_be_submitted(self) -> bool:
|
||||
"""Check if the form can still be submitted."""
|
||||
return self.status == HumanInputFormStatus.WAITING
|
||||
|
||||
def submit(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
action: str,
|
||||
submission_type: HumanInputSubmissionType = HumanInputSubmissionType.web_form,
|
||||
submission_user_id: Optional[str] = None,
|
||||
submission_end_user_id: Optional[str] = None,
|
||||
submitter_email: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Submit the form with the given data and action.
|
||||
|
||||
Args:
|
||||
data: The form data submitted by the user
|
||||
action: The action taken by the user
|
||||
submission_type: Type of submission
|
||||
submission_user_id: ID of the user who submitted (console submissions)
|
||||
submission_end_user_id: ID of the end user who submitted (webapp submissions)
|
||||
submitter_email: Email of the submitter (if applicable)
|
||||
|
||||
Raises:
|
||||
ValueError: If the form cannot be submitted
|
||||
"""
|
||||
if not self.can_be_submitted:
|
||||
raise ValueError(f"Form cannot be submitted in status: {self.status}")
|
||||
|
||||
# Validate that the action is valid based on form definition
|
||||
valid_actions = {act.get("id") for act in self.form_definition.get("user_actions", [])}
|
||||
if action not in valid_actions:
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
|
||||
self.submission = FormSubmission(
|
||||
data=data,
|
||||
action=action,
|
||||
submission_type=submission_type,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
submitter_email=submitter_email,
|
||||
)
|
||||
self.status = HumanInputFormStatus.SUBMITTED
|
||||
|
||||
def expire(self) -> None:
|
||||
"""Mark the form as expired."""
|
||||
if self.status != HumanInputFormStatus.WAITING:
|
||||
raise ValueError(f"Form cannot be expired in status: {self.status}")
|
||||
|
||||
self.status = HumanInputFormStatus.EXPIRED
|
||||
|
||||
def get_form_definition_for_display(self, include_site_info: bool = False) -> dict[str, Any]:
|
||||
"""
|
||||
Get form definition for display purposes.
|
||||
|
||||
Args:
|
||||
include_site_info: Whether to include site information in the response
|
||||
|
||||
Returns:
|
||||
Form definition dictionary for display
|
||||
"""
|
||||
if self.status == HumanInputFormStatus.EXPIRED:
|
||||
raise ValueError("Form has expired")
|
||||
|
||||
if self.status == HumanInputFormStatus.SUBMITTED:
|
||||
raise ValueError("Form has already been submitted")
|
||||
|
||||
response = {
|
||||
"form_content": self.rendered_content,
|
||||
"inputs": self.form_definition.get("inputs", []),
|
||||
"user_actions": self.form_definition.get("user_actions", []),
|
||||
}
|
||||
|
||||
if include_site_info:
|
||||
# Note: In domain model, we don't have app_id
|
||||
# This would be added at the application layer
|
||||
response["site"] = {
|
||||
"title": "Workflow Form",
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
id_: str,
|
||||
workflow_run_id: str,
|
||||
form_definition: dict[str, Any],
|
||||
rendered_content: str,
|
||||
web_app_token: Optional[str] = None,
|
||||
) -> "HumanInputForm":
|
||||
"""
|
||||
Create a new human input form.
|
||||
|
||||
Args:
|
||||
id_: Unique identifier for the form
|
||||
workflow_run_id: ID of the associated workflow run
|
||||
form_definition: Form definition as a dictionary
|
||||
rendered_content: Rendered HTML content of the form
|
||||
web_app_token: Optional token for web app access
|
||||
|
||||
Returns:
|
||||
New HumanInputForm instance
|
||||
"""
|
||||
return cls(
|
||||
id_=id_,
|
||||
workflow_run_id=workflow_run_id,
|
||||
form_definition=form_definition,
|
||||
rendered_content=rendered_content,
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
web_app_token=web_app_token,
|
||||
)
|
||||
@ -3,6 +3,8 @@ from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput
|
||||
|
||||
|
||||
class PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
@ -11,10 +13,10 @@ class PauseReasonType(StrEnum):
|
||||
|
||||
class HumanInputRequired(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
form_id: str
|
||||
# The identifier of the human input node causing the pause.
|
||||
node_id: str
|
||||
form_content: str
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
web_app_form_token: str | None = None
|
||||
|
||||
|
||||
class SchedulingPause(BaseModel):
|
||||
|
||||
@ -4,6 +4,5 @@ Human Input node implementation.
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
from .human_input_node import HumanInputNode
|
||||
from .node import HumanInputNode
|
||||
|
||||
__all__ = ["HumanInputNode", "HumanInputNodeData"]
|
||||
|
||||
@ -2,59 +2,77 @@
|
||||
Human Input node entities.
|
||||
"""
|
||||
|
||||
import enum
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$outputs\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
|
||||
|
||||
|
||||
class HumanInputFormStatus(StrEnum):
|
||||
"""Status of a human input form."""
|
||||
|
||||
WAITING = enum.auto()
|
||||
EXPIRED = enum.auto()
|
||||
SUBMITTED = enum.auto()
|
||||
TIMEOUT = enum.auto()
|
||||
|
||||
|
||||
class DeliveryMethodType(StrEnum):
|
||||
"""Delivery method types for human input forms."""
|
||||
|
||||
WEBAPP = "webapp"
|
||||
EMAIL = "email"
|
||||
WEBAPP = enum.auto()
|
||||
EMAIL = enum.auto()
|
||||
|
||||
|
||||
class ButtonStyle(StrEnum):
|
||||
"""Button styles for user actions."""
|
||||
|
||||
PRIMARY = "primary"
|
||||
DEFAULT = "default"
|
||||
ACCENT = "accent"
|
||||
GHOST = "ghost"
|
||||
PRIMARY = enum.auto()
|
||||
DEFAULT = enum.auto()
|
||||
ACCENT = enum.auto()
|
||||
GHOST = enum.auto()
|
||||
|
||||
|
||||
class TimeoutUnit(StrEnum):
|
||||
"""Timeout unit for form expiration."""
|
||||
|
||||
HOUR = "hour"
|
||||
DAY = "day"
|
||||
HOUR = enum.auto()
|
||||
DAY = enum.auto()
|
||||
|
||||
|
||||
class FormInputType(StrEnum):
|
||||
"""Form input types."""
|
||||
|
||||
TEXT_INPUT = "text-input"
|
||||
PARAGRAPH = "paragraph"
|
||||
TEXT_INPUT = enum.auto()
|
||||
PARAGRAPH = enum.auto()
|
||||
|
||||
|
||||
class PlaceholderType(StrEnum):
|
||||
"""Placeholder types for form inputs."""
|
||||
|
||||
VARIABLE = "variable"
|
||||
CONSTANT = "constant"
|
||||
VARIABLE = enum.auto()
|
||||
CONSTANT = enum.auto()
|
||||
|
||||
|
||||
class RecipientType(StrEnum):
|
||||
class EmailRecipientType(StrEnum):
|
||||
"""Email recipient types."""
|
||||
|
||||
MEMBER = "member"
|
||||
EXTERNAL = "external"
|
||||
MEMBER = enum.auto()
|
||||
EXTERNAL = enum.auto()
|
||||
|
||||
|
||||
class WebAppDeliveryConfig(BaseModel):
|
||||
class _WebAppDeliveryConfig(BaseModel):
|
||||
"""Configuration for webapp delivery method."""
|
||||
|
||||
pass # Empty for webapp delivery
|
||||
@ -63,25 +81,25 @@ class WebAppDeliveryConfig(BaseModel):
|
||||
class MemberRecipient(BaseModel):
|
||||
"""Member recipient for email delivery."""
|
||||
|
||||
type: Literal[RecipientType.MEMBER]
|
||||
type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER
|
||||
user_id: str
|
||||
|
||||
|
||||
class ExternalRecipient(BaseModel):
|
||||
"""External recipient for email delivery."""
|
||||
|
||||
type: Literal[RecipientType.EXTERNAL]
|
||||
type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL
|
||||
email: str
|
||||
|
||||
|
||||
Recipient = Union[MemberRecipient, ExternalRecipient]
|
||||
EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")]
|
||||
|
||||
|
||||
class EmailRecipients(BaseModel):
|
||||
"""Email recipients configuration."""
|
||||
|
||||
whole_workspace: bool = False
|
||||
items: list[Recipient] = Field(default_factory=list)
|
||||
items: list[EmailRecipient] = Field(default_factory=list)
|
||||
|
||||
|
||||
class EmailDeliveryConfig(BaseModel):
|
||||
@ -92,51 +110,54 @@ class EmailDeliveryConfig(BaseModel):
|
||||
body: str
|
||||
|
||||
|
||||
DeliveryConfig = Union[WebAppDeliveryConfig, EmailDeliveryConfig]
|
||||
class _DeliveryMethodBase(BaseModel):
|
||||
"""Base delivery method configuration."""
|
||||
|
||||
|
||||
class DeliveryMethod(BaseModel):
|
||||
"""Delivery method configuration."""
|
||||
|
||||
type: DeliveryMethodType
|
||||
enabled: bool = True
|
||||
config: Optional[DeliveryConfig] = None
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_config_type(self):
|
||||
"""Validate that config matches the delivery method type."""
|
||||
if self.config is None:
|
||||
return self
|
||||
|
||||
if self.type == DeliveryMethodType.EMAIL:
|
||||
if isinstance(self.config, dict):
|
||||
# Try to parse as EmailDeliveryConfig - this will raise validation errors
|
||||
try:
|
||||
self.config = EmailDeliveryConfig.model_validate(self.config)
|
||||
except Exception as e:
|
||||
# Re-raise with more specific context
|
||||
raise ValueError(f"Invalid email delivery configuration: {str(e)}")
|
||||
elif not isinstance(self.config, EmailDeliveryConfig):
|
||||
raise ValueError("Config must be EmailDeliveryConfig for email delivery method")
|
||||
elif self.type == DeliveryMethodType.WEBAPP:
|
||||
if isinstance(self.config, dict):
|
||||
# Try to parse as WebAppDeliveryConfig
|
||||
try:
|
||||
self.config = WebAppDeliveryConfig.model_validate(self.config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid webapp delivery configuration: {str(e)}")
|
||||
elif not isinstance(self.config, WebAppDeliveryConfig):
|
||||
raise ValueError("Config must be WebAppDeliveryConfig for webapp delivery method")
|
||||
class WebAppDeliveryMethod(_DeliveryMethodBase):
|
||||
"""Webapp delivery method configuration."""
|
||||
|
||||
return self
|
||||
type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP
|
||||
# The config field is not used currently.
|
||||
config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig)
|
||||
|
||||
|
||||
class EmailDeliveryMethod(_DeliveryMethodBase):
|
||||
"""Email delivery method configuration."""
|
||||
|
||||
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
|
||||
config: EmailDeliveryConfig
|
||||
|
||||
|
||||
DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")]
|
||||
|
||||
|
||||
class FormInputPlaceholder(BaseModel):
|
||||
"""Placeholder configuration for form inputs."""
|
||||
|
||||
# NOTE: Ideally, a discriminated union would be used to model
|
||||
# FormInputPlaceholder. However, the UI requires preserving the previous
|
||||
# value when switching between `VARIABLE` and `CONSTANT` types. This
|
||||
# necessitates retaining all fields, making a discriminated union unsuitable.
|
||||
|
||||
type: PlaceholderType
|
||||
selector: list[str] = Field(default_factory=list) # Used when type is VARIABLE
|
||||
value: str = "" # Used when type is CONSTANT
|
||||
|
||||
# The selector of placeholder variable, used when `type` is `VARIABLE`
|
||||
selector: Sequence[str] = Field(default_factory=tuple) #
|
||||
|
||||
# The value of the placeholder, used when `type` is `CONSTANT`.
|
||||
# TODO: How should we express JSON values?
|
||||
value: str = ""
|
||||
|
||||
@field_validator("selector")
|
||||
@classmethod
|
||||
def _validate_selector(cls, selector: Sequence[str]) -> Sequence[str]:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={selector}")
|
||||
return selector
|
||||
|
||||
|
||||
class FormInput(BaseModel):
|
||||
@ -147,24 +168,106 @@ class FormInput(BaseModel):
|
||||
placeholder: Optional[FormInputPlaceholder] = None
|
||||
|
||||
|
||||
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
class UserAction(BaseModel):
|
||||
"""User action configuration."""
|
||||
|
||||
# id is the identifier for this action.
|
||||
# It also serves as the identifiers of output handle.
|
||||
#
|
||||
# The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.)
|
||||
id: str
|
||||
title: str
|
||||
button_style: ButtonStyle = ButtonStyle.DEFAULT
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def _validate_id(cls, value: str) -> str:
|
||||
if not _IDENTIFIER_PATTERN.match(value):
|
||||
raise ValueError(
|
||||
f"'{value}' is not a valid identifier. It must start with a letter or underscore, "
|
||||
f"and contain only letters, numbers, or underscores."
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Human Input node data."""
|
||||
|
||||
delivery_methods: list[DeliveryMethod] = Field(default_factory=list)
|
||||
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
|
||||
form_content: str = ""
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
user_actions: list[UserAction] = Field(default_factory=list)
|
||||
timeout: int = 36
|
||||
timeout_unit: TimeoutUnit = TimeoutUnit.HOUR
|
||||
|
||||
@field_validator("inputs")
|
||||
@classmethod
|
||||
def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]:
|
||||
seen_names: set[str] = set()
|
||||
for form_input in inputs:
|
||||
name = form_input.output_variable_name
|
||||
if name in seen_names:
|
||||
raise ValueError(f"duplicated output_variable_name '{name}' in inputs")
|
||||
seen_names.add(name)
|
||||
return inputs
|
||||
|
||||
@field_validator("user_actions")
|
||||
@classmethod
|
||||
def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]:
|
||||
seen_ids: set[str] = set()
|
||||
for action in user_actions:
|
||||
action_id = action.id
|
||||
if action_id in seen_ids:
|
||||
raise ValueError(f"duplicated user action id '{action_id}'")
|
||||
seen_ids.add(action_id)
|
||||
return user_actions
|
||||
|
||||
def is_webapp_enabled(self) -> bool:
|
||||
for dm in self.delivery_methods:
|
||||
if not dm.enabled:
|
||||
continue
|
||||
if dm.type == DeliveryMethodType.WEBAPP:
|
||||
return True
|
||||
return False
|
||||
|
||||
def expiration_time(self, start_time: datetime) -> datetime:
|
||||
if self.timeout_unit == TimeoutUnit.HOUR:
|
||||
return start_time + timedelta(hours=self.timeout)
|
||||
elif self.timeout_unit == TimeoutUnit.DAY:
|
||||
return start_time + timedelta(days=self.timeout)
|
||||
else:
|
||||
raise AssertionError("unknown timeout unit.")
|
||||
|
||||
def outputs_field_names(self) -> Sequence[str]:
|
||||
field_names = []
|
||||
for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content):
|
||||
field_names.append(match.group("field_name"))
|
||||
return field_names
|
||||
|
||||
def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]:
|
||||
variable_selectors = []
|
||||
variable_template_parser = VariableTemplateParser(template=self.form_content)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
variable_mappings = {}
|
||||
for variable_selector in variable_selectors:
|
||||
qualified_variable_mapping_key = f"{node_id}.{variable_selector.variable}"
|
||||
variable_mappings[qualified_variable_mapping_key] = variable_selector.value_selector
|
||||
|
||||
for input in self.inputs:
|
||||
placeholder = input.placeholder
|
||||
if placeholder is None:
|
||||
continue
|
||||
if placeholder.type == PlaceholderType.CONSTANT:
|
||||
continue
|
||||
placeholder_key = ".".join(placeholder.selector)
|
||||
qualified_variable_mapping_key = f"{node_id}.#{placeholder_key}#"
|
||||
variable_mappings[qualified_variable_mapping_key] = placeholder.selector
|
||||
|
||||
return variable_mappings
|
||||
|
||||
|
||||
class HumanInputRequired(BaseModel):
|
||||
"""Event data for human input required."""
|
||||
@ -176,72 +279,11 @@ class HumanInputRequired(BaseModel):
|
||||
web_app_form_token: Optional[str] = None
|
||||
|
||||
|
||||
class WorkflowSuspended(BaseModel):
|
||||
"""Event data for workflow suspended."""
|
||||
|
||||
suspended_at_node_ids: list[str]
|
||||
|
||||
|
||||
class PauseTypeHumanInput(BaseModel):
|
||||
"""Pause type for human input."""
|
||||
|
||||
type: Literal["human_input"]
|
||||
form_id: str
|
||||
|
||||
|
||||
class PauseTypeBreakpoint(BaseModel):
|
||||
"""Pause type for breakpoint (debugging)."""
|
||||
|
||||
type: Literal["breakpoint"]
|
||||
|
||||
|
||||
PauseType = Union[PauseTypeHumanInput, PauseTypeBreakpoint]
|
||||
|
||||
|
||||
class PausedNode(BaseModel):
|
||||
"""Information about a paused node."""
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
pause_type: PauseType
|
||||
|
||||
|
||||
class WorkflowPauseDetails(BaseModel):
|
||||
"""Details about workflow pause."""
|
||||
|
||||
paused_at: str # ISO datetime
|
||||
paused_nodes: list[PausedNode]
|
||||
|
||||
|
||||
class FormSubmissionRequest(BaseModel):
|
||||
"""Form submission request data."""
|
||||
|
||||
inputs: dict[str, str] # mapping of output_variable_name to user input
|
||||
action: str # UserAction id
|
||||
|
||||
|
||||
class FormGetResponse(BaseModel):
|
||||
"""Response for form get API."""
|
||||
|
||||
site: Optional[dict[str, Any]] = None # Site information for webapp
|
||||
class FormDefinition(BaseModel):
|
||||
form_content: str
|
||||
inputs: list[FormInput]
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
user_actions: list[UserAction] = Field(default_factory=list)
|
||||
rendered_content: str
|
||||
|
||||
|
||||
class FormSubmissionResponse(BaseModel):
|
||||
"""Response for successful form submission."""
|
||||
|
||||
pass # Empty response for success
|
||||
|
||||
|
||||
class FormErrorResponse(BaseModel):
|
||||
"""Response for form submission errors."""
|
||||
|
||||
error_code: str
|
||||
description: str
|
||||
|
||||
|
||||
class ResumeWaitResponse(BaseModel):
|
||||
"""Response for resume wait API."""
|
||||
|
||||
status: Literal["paused", "running", "ended"]
|
||||
timeout: int
|
||||
timeout_unit: TimeoutUnit
|
||||
|
||||
@ -1,13 +1,27 @@
|
||||
from collections.abc import Mapping
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.node_events.base import NodeEventBase
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.repositories.human_input_form_repository import FormCreateParams, HumanInputFormRepository
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
|
||||
_SELECTED_BRANCH_KEY = "selected_branch"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _FormSubmissionResult:
|
||||
action_id: str
|
||||
|
||||
|
||||
class HumanInputNode(Node[HumanInputNodeData]):
|
||||
node_type = NodeType.HUMAN_INPUT
|
||||
@ -17,7 +31,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
"edge_source_handle",
|
||||
"edgeSourceHandle",
|
||||
"source_handle",
|
||||
"selected_branch",
|
||||
_SELECTED_BRANCH_KEY,
|
||||
"selectedBranch",
|
||||
"branch",
|
||||
"branch_id",
|
||||
@ -25,43 +39,12 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
"handle",
|
||||
)
|
||||
|
||||
_node_data: HumanInputNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self): # type: ignore[override]
|
||||
if self._is_completion_ready():
|
||||
branch_handle = self._resolve_branch_selection()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
edge_source_handle=branch_handle or "source",
|
||||
)
|
||||
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
# TODO(QuantumGhost): yield a real form id.
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
if not self.node_data.required_variables:
|
||||
return False
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for selector_str in self.node_data.required_variables:
|
||||
parts = selector_str.split(".")
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
segment = variable_pool.get(parts)
|
||||
if segment is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_branch_selection(self) -> str | None:
|
||||
"""Determine the branch handle selected by human input if available."""
|
||||
|
||||
@ -108,3 +91,106 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
def _create_form_repository(self) -> HumanInputFormRepository:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _pause_generator(event: PauseRequestedEvent) -> Generator[NodeEventBase, None, None]:
|
||||
yield event
|
||||
|
||||
@property
|
||||
def _workflow_execution_id(self) -> str:
|
||||
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
assert workflow_exec_id is not None
|
||||
return workflow_exec_id
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Execute the human input node.
|
||||
|
||||
This method will:
|
||||
1. Generate a unique form ID
|
||||
2. Create form content with variable substitution
|
||||
3. Create form in database
|
||||
4. Send form via configured delivery methods
|
||||
5. Suspend workflow execution
|
||||
6. Wait for form submission to resume
|
||||
"""
|
||||
repo = self._create_form_repository()
|
||||
submission_result = repo.get_form_submission(self._workflow_execution_id, self.app_id)
|
||||
if submission_result:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"action_id": submission_result.selected_action_id,
|
||||
},
|
||||
edge_source_handle=submission_result.selected_action_id,
|
||||
)
|
||||
try:
|
||||
repo = self._create_form_repository()
|
||||
params = FormCreateParams(
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self._render_form_content(),
|
||||
)
|
||||
result = repo.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
required_event = HumanInputRequired(
|
||||
form_id=result.id,
|
||||
form_content=self._node_data.form_content,
|
||||
inputs=self._node_data.inputs,
|
||||
web_app_form_token=result.web_app_token,
|
||||
)
|
||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
||||
|
||||
# Create workflow suspended event
|
||||
|
||||
logger.info(
|
||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
self.id,
|
||||
result.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Human Input node failed to execute, node_id=%s", self.id)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="HumanInputNodeError",
|
||||
)
|
||||
return self._pause_generator(pause_requested_event)
|
||||
|
||||
def _render_form_content(self) -> str:
|
||||
"""
|
||||
Process form content by substituting variables.
|
||||
|
||||
This method should:
|
||||
1. Parse the form_content markdown
|
||||
2. Substitute {{#node_name.var_name#}} with actual values
|
||||
3. Keep {{#$output.field_name#}} placeholders for form inputs
|
||||
"""
|
||||
rendered_form_content = self.graph_runtime_state.variable_pool.convert_template(
|
||||
self._node_data.form_content,
|
||||
)
|
||||
return rendered_form_content.markdown
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selectors referenced in form content and input placeholders.
|
||||
|
||||
This method should parse:
|
||||
1. Variables referenced in form_content ({{#node_name.var_name#}})
|
||||
2. Variables referenced in input placeholders
|
||||
"""
|
||||
validated_node_data = HumanInputNodeData.model_validate(node_data)
|
||||
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
|
||||
@ -1,259 +0,0 @@
|
||||
"""
|
||||
Human Input node implementation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
HumanInputNodeData,
|
||||
HumanInputRequired,
|
||||
WorkflowSuspended,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from services.human_input_form_service import HumanInputFormService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputNode(BaseNode):
|
||||
"""
|
||||
Human Input Node implementation.
|
||||
|
||||
This node pauses workflow execution and waits for human input through
|
||||
configured delivery methods (webapp or email). The workflow resumes
|
||||
once the form is submitted.
|
||||
"""
|
||||
|
||||
_node_type: NodeType = NodeType.HUMAN_INPUT
|
||||
_node_data_cls = HumanInputNodeData
|
||||
node_data: HumanInputNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
"""Initialize node data from configuration."""
|
||||
self.node_data = self._node_data_cls.model_validate(data)
|
||||
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, InNodeEvent], None, None]:
|
||||
"""
|
||||
Execute the human input node.
|
||||
|
||||
This method will:
|
||||
1. Generate a unique form ID
|
||||
2. Create form content with variable substitution
|
||||
3. Create form in database
|
||||
4. Send form via configured delivery methods
|
||||
5. Suspend workflow execution
|
||||
6. Wait for form submission to resume
|
||||
"""
|
||||
try:
|
||||
# Generate unique form ID
|
||||
form_id = str(uuid.uuid4())
|
||||
|
||||
# Create form content with variable substitution
|
||||
form_content = self._process_form_content()
|
||||
|
||||
# Generate webapp token if webapp delivery is enabled
|
||||
web_app_form_token = None
|
||||
webapp_enabled = any(dm.enabled and dm.type.value == "webapp" for dm in self.node_data.delivery_methods)
|
||||
if webapp_enabled:
|
||||
web_app_form_token = str(uuid.uuid4()).replace("-", "")
|
||||
|
||||
# Create form definition for database storage
|
||||
form_definition = {
|
||||
"node_id": self.node_id,
|
||||
"title": self.node_data.title,
|
||||
"inputs": [inp.model_dump() for inp in self.node_data.inputs],
|
||||
"user_actions": [action.model_dump() for action in self.node_data.user_actions],
|
||||
"timeout": self.node_data.timeout,
|
||||
"timeout_unit": self.node_data.timeout_unit.value,
|
||||
"delivery_methods": [dm.model_dump() for dm in self.node_data.delivery_methods],
|
||||
}
|
||||
|
||||
# Create form in database
|
||||
service = HumanInputFormService(db.session())
|
||||
service.create_form(
|
||||
form_id=form_id,
|
||||
workflow_run_id=self.graph_runtime_state.workflow_run_id,
|
||||
tenant_id=self.graph_init_params.tenant_id,
|
||||
app_id=self.graph_init_params.app_id,
|
||||
form_definition=json.dumps(form_definition),
|
||||
rendered_content=form_content,
|
||||
web_app_token=web_app_form_token,
|
||||
)
|
||||
|
||||
# Create human input required event
|
||||
human_input_event = HumanInputRequired(
|
||||
form_id=form_id,
|
||||
node_id=self.node_id,
|
||||
form_content=form_content,
|
||||
inputs=self.node_data.inputs,
|
||||
web_app_form_token=web_app_form_token,
|
||||
)
|
||||
|
||||
# Create workflow suspended event
|
||||
suspended_event = WorkflowSuspended(suspended_at_node_ids=[self.node_id])
|
||||
|
||||
logger.info(f"Human Input node {self.node_id} suspended workflow for form {form_id}")
|
||||
|
||||
# Return suspension result
|
||||
# The workflow engine should handle the suspension and resume logic
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.RUNNING, # Node is still running, waiting for input
|
||||
inputs={},
|
||||
outputs={},
|
||||
metadata={
|
||||
"form_id": form_id,
|
||||
"web_app_form_token": web_app_form_token,
|
||||
"human_input_event": human_input_event.model_dump(),
|
||||
"suspended_event": suspended_event.model_dump(),
|
||||
"suspended": True, # Flag to indicate this node caused suspension
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Human Input node {self.node_id} failed to execute")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="HumanInputNodeError",
|
||||
)
|
||||
|
||||
def _process_form_content(self) -> str:
|
||||
"""
|
||||
Process form content by substituting variables.
|
||||
|
||||
This method should:
|
||||
1. Parse the form_content markdown
|
||||
2. Substitute {{#node_name.var_name#}} with actual values
|
||||
3. Keep {{#$output.field_name#}} placeholders for form inputs
|
||||
"""
|
||||
# TODO: Implement variable substitution logic
|
||||
# For now, return the raw form content
|
||||
# This should integrate with the existing variable template parser
|
||||
return self.node_data.form_content
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selectors referenced in form content and input placeholders.
|
||||
|
||||
This method should parse:
|
||||
1. Variables referenced in form_content ({{#node_name.var_name#}})
|
||||
2. Variables referenced in input placeholders
|
||||
"""
|
||||
# TODO: Implement variable extraction logic
|
||||
# This should parse the form_content and placeholder configurations
|
||||
# to extract all referenced variables
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""Get default configuration for human input node."""
|
||||
return {
|
||||
"type": "human_input",
|
||||
"config": {
|
||||
"delivery_methods": [{"type": "webapp", "enabled": True, "config": {}}],
|
||||
"form_content": "# Human Input\n\nPlease provide your input:\n\n{{#$output.input#}}",
|
||||
"inputs": [
|
||||
{
|
||||
"type": "text-input",
|
||||
"output_variable_name": "input",
|
||||
"placeholder": {"type": "constant", "value": "Enter your response here..."},
|
||||
}
|
||||
],
|
||||
"user_actions": [{"id": "submit", "title": "Submit", "button_style": "primary"}],
|
||||
"timeout": 24,
|
||||
"timeout_unit": "hour",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of the human input node."""
|
||||
return "1"
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
"""Get the error strategy for this node."""
|
||||
return self.node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
return self.node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
return self.node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
"""Get the node description."""
|
||||
return self.node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self.node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
return self.node_data
|
||||
|
||||
def resume_from_human_input(self, form_submission_data: dict[str, Any]) -> NodeRunResult:
|
||||
"""
|
||||
Resume node execution after human input form is submitted.
|
||||
|
||||
Args:
|
||||
form_submission_data: Dict containing:
|
||||
- inputs: Dict of input field values
|
||||
- action: The user action taken
|
||||
|
||||
Returns:
|
||||
NodeRunResult with the form inputs as outputs
|
||||
"""
|
||||
try:
|
||||
inputs = form_submission_data.get("inputs", {})
|
||||
action = form_submission_data.get("action", "")
|
||||
|
||||
# Create output dictionary with form inputs
|
||||
outputs = {}
|
||||
for input_field in self.node_data.inputs:
|
||||
field_name = input_field.output_variable_name
|
||||
if field_name in inputs:
|
||||
outputs[field_name] = inputs[field_name]
|
||||
|
||||
# Add the action to outputs
|
||||
outputs["_action"] = action
|
||||
|
||||
logger.info(f"Human Input node {self.node_id} resumed with action {action}")
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={},
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
"form_submitted": True,
|
||||
"submitted_action": action,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Human Input node {self.node_id} failed to resume")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="HumanInputResumeError",
|
||||
)
|
||||
@ -1,6 +1,65 @@
|
||||
from typing import Protocol
|
||||
import abc
|
||||
import dataclasses
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.workflow.entities.human_input_form import HumanInputForm
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
|
||||
|
||||
class HumanInputError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FormNotFoundError(HumanInputError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FormCreateParams:
|
||||
workflow_execution_id: str
|
||||
|
||||
# node_id is the identifier for a specific
|
||||
# node in the graph.
|
||||
#
|
||||
# TODO: for node inside loop / iteration, this would
|
||||
# cause problems, as a single node may be executed multiple times.
|
||||
node_id: str
|
||||
|
||||
form_config: HumanInputNodeData
|
||||
rendered_content: str
|
||||
|
||||
|
||||
class HumanInputFormEntity(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def id(self) -> str:
|
||||
"""id returns the identifer of the form."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def web_app_token(self) -> str | None:
|
||||
"""web_app_token returns the token for submission inside webapp.
|
||||
|
||||
If web app delivery is not enabled, this method would return `None`.
|
||||
"""
|
||||
|
||||
# TODO: what if the users are allowed to add multiple
|
||||
# webapp delivery?
|
||||
pass
|
||||
|
||||
|
||||
class FormSubmissionEntity(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def selected_action_id(self) -> str:
|
||||
"""The identifier of action user has selected, correspond to `UserAction.id`."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def form_data(self) -> Mapping[str, Any]:
|
||||
"""The data submitted for this form"""
|
||||
pass
|
||||
|
||||
|
||||
class HumanInputFormRepository(Protocol):
|
||||
@ -16,93 +75,17 @@ class HumanInputFormRepository(Protocol):
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, form: HumanInputForm) -> None:
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
"""
|
||||
Save or update a HumanInputForm instance.
|
||||
|
||||
This method handles both creating new records and updating existing ones.
|
||||
The implementation should determine whether to create or update based on
|
||||
the form's ID or other identifying fields.
|
||||
|
||||
Args:
|
||||
form: The HumanInputForm instance to save or update
|
||||
Create a human input form from form definition.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_id(self, form_id: str) -> HumanInputForm:
|
||||
"""
|
||||
Get a form by its ID.
|
||||
def get_form_submission(self, workflow_execution_id: str, node_id: str) -> FormSubmissionEntity | None:
|
||||
"""Retrieve the submission for a specific human input node.
|
||||
|
||||
Args:
|
||||
form_id: The ID of the form to retrieve
|
||||
Returns `FormSubmission` if the form has been submitted, or `None` if not.
|
||||
|
||||
Returns:
|
||||
The HumanInputForm instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the form is not found
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_web_app_token(self, web_app_token: str) -> HumanInputForm:
|
||||
"""
|
||||
Get a form by its web app token.
|
||||
|
||||
Args:
|
||||
web_app_token: The web app token to search for
|
||||
|
||||
Returns:
|
||||
The HumanInputForm instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the form is not found
|
||||
"""
|
||||
...
|
||||
|
||||
def get_pending_forms_for_workflow_run(self, workflow_run_id: str) -> list[HumanInputForm]:
|
||||
"""
|
||||
Get all pending human input forms for a workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID to filter by
|
||||
|
||||
Returns:
|
||||
List of pending HumanInputForm instances
|
||||
"""
|
||||
...
|
||||
|
||||
def mark_expired_forms(self, expiry_hours: int = 48) -> int:
|
||||
"""
|
||||
Mark expired forms as expired.
|
||||
|
||||
Args:
|
||||
expiry_hours: Number of hours after which forms should be expired
|
||||
|
||||
Returns:
|
||||
Number of forms marked as expired
|
||||
"""
|
||||
...
|
||||
|
||||
def exists_by_id(self, form_id: str) -> bool:
|
||||
"""
|
||||
Check if a form exists by ID.
|
||||
|
||||
Args:
|
||||
form_id: The ID of the form to check
|
||||
|
||||
Returns:
|
||||
True if the form exists, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def exists_by_web_app_token(self, web_app_token: str) -> bool:
|
||||
"""
|
||||
Check if a form exists by web app token.
|
||||
|
||||
Args:
|
||||
web_app_token: The web app token to check
|
||||
|
||||
Returns:
|
||||
True if the form exists, False otherwise
|
||||
Raises `FormNotFoundError` if correspond form record is not found.
|
||||
"""
|
||||
...
|
||||
|
||||
@ -41,7 +41,7 @@ class DefaultFieldsMixin:
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
__name_pos=DateTime,
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
|
||||
@ -1,28 +1,21 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, ClassVar, Literal, Self, final
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryMethodType,
|
||||
EmailRecipientType,
|
||||
HumanInputFormStatus,
|
||||
)
|
||||
from libs.helper import generate_string
|
||||
|
||||
from .base import Base, ModelMixin
|
||||
from .base import Base, DefaultFieldsMixin
|
||||
from .types import EnumText, StringUUID
|
||||
|
||||
|
||||
class HumanInputFormStatus(StrEnum):
|
||||
WAITING = "waiting"
|
||||
EXPIRED = "expired"
|
||||
SUBMITTED = "submitted"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
class HumanInputSubmissionType(StrEnum):
|
||||
web_form = "web_form"
|
||||
web_app = "web_app"
|
||||
email = "email"
|
||||
|
||||
|
||||
_token_length = 22
|
||||
# A 32-character string can store a base64-encoded value with 192 bits of entropy
|
||||
# or a base62-encoded value with over 180 bits of entropy, providing sufficient
|
||||
@ -31,36 +24,18 @@ _token_field_length = 32
|
||||
_email_field_length = 330
|
||||
|
||||
|
||||
def _generate_token():
|
||||
def _generate_token() -> str:
|
||||
return generate_string(_token_length)
|
||||
|
||||
|
||||
class HumanInputForm(ModelMixin, Base):
|
||||
class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_forms"
|
||||
|
||||
# `tenant_id` identifies the tenant associated with this suspension,
|
||||
# corresponding to the `id` field in the `Tenant` model.
|
||||
tenant_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# `app_id` represents the application identifier associated with this state.
|
||||
# It corresponds to the `id` field in the `App` model.
|
||||
#
|
||||
# While this field is technically redundant (as the corresponding app can be
|
||||
# determined by querying the `Workflow`), it is retained to simplify data
|
||||
# cleanup and management processes.
|
||||
app_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
workflow_run_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
# The human input node the current form corresponds to.
|
||||
node_id: Mapped[str] = mapped_column(sa.String(60), nullable=False)
|
||||
form_definition: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
rendered_content: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
status: Mapped[HumanInputFormStatus] = mapped_column(
|
||||
@ -69,56 +44,137 @@ class HumanInputForm(ModelMixin, Base):
|
||||
default=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
|
||||
web_app_token: Mapped[str] = mapped_column(
|
||||
expiration_time: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Submission-related fields (nullable until a submission happens).
|
||||
selected_action_id: Mapped[str | None] = mapped_column(sa.String(200), nullable=True)
|
||||
submitted_data: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
submitted_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
|
||||
submission_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
submission_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
completed_by_recipient_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID,
|
||||
sa.ForeignKey("human_input_recipients.id"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
deliveries: Mapped[list["HumanInputDelivery"]] = relationship(
|
||||
"HumanInputDelivery",
|
||||
back_populates="form",
|
||||
lazy="raise",
|
||||
)
|
||||
completed_by_recipient: Mapped["HumanInputRecipient | None"] = relationship(
|
||||
"HumanInputRecipient",
|
||||
primaryjoin="HumanInputForm.completed_by_recipient_id == HumanInputRecipient.id",
|
||||
lazy="raise",
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
|
||||
class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_deliveries"
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
sa.ForeignKey("human_input_forms.id"),
|
||||
nullable=False,
|
||||
)
|
||||
delivery_method_type: Mapped[DeliveryMethodType] = mapped_column(
|
||||
EnumText(DeliveryMethodType),
|
||||
nullable=False,
|
||||
)
|
||||
delivery_config_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
channel_payload: Mapped[None] = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
form: Mapped[HumanInputForm] = relationship(
|
||||
"HumanInputForm",
|
||||
back_populates="deliveries",
|
||||
lazy="raise",
|
||||
)
|
||||
recipients: Mapped[list["HumanInputRecipient"]] = relationship(
|
||||
"HumanInputRecipient",
|
||||
back_populates="delivery",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
|
||||
class RecipientType(StrEnum):
|
||||
# EMAIL_MEMBER member means that the
|
||||
EMAIL_MEMBER = "email_member"
|
||||
EMAIL_EXTERNAL = "email_external"
|
||||
WEBAPP = "webapp"
|
||||
|
||||
|
||||
@final
|
||||
class EmailMemberRecipientPayload(BaseModel):
|
||||
TYPE: Literal[RecipientType.EMAIL_MEMBER] = RecipientType.EMAIL_MEMBER
|
||||
user_id: str
|
||||
|
||||
# The `email` field here is only used for mail sending.
|
||||
email: str
|
||||
|
||||
|
||||
@final
|
||||
class EmailExternalRecipientPayload(BaseModel):
|
||||
TYPE: Literal[RecipientType.EMAIL_EXTERNAL] = RecipientType.EMAIL_EXTERNAL
|
||||
email: str
|
||||
|
||||
|
||||
@final
|
||||
class WebAppRecipientPayload(BaseModel):
|
||||
TYPE: Literal[RecipientType.WEBAPP] = RecipientType.WEBAPP
|
||||
|
||||
|
||||
RecipientPayload = Annotated[
|
||||
EmailMemberRecipientPayload | EmailExternalRecipientPayload | WebAppRecipientPayload,
|
||||
Field(discriminator="TYPE"),
|
||||
]
|
||||
|
||||
|
||||
class HumanInputRecipient(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_recipients"
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
delivery_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
)
|
||||
recipient_type: Mapped["RecipientType"] = mapped_column(EnumText(RecipientType), nullable=False)
|
||||
recipient_payload: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
# Token primarily used for authenticated resume links (email, etc.).
|
||||
access_token: Mapped[str | None] = mapped_column(
|
||||
sa.VARCHAR(_token_field_length),
|
||||
nullable=True,
|
||||
default=_generate_token,
|
||||
)
|
||||
|
||||
# The following fields are not null if the form is submitted.
|
||||
|
||||
# The inputs provided by the user when resuming the suspended workflow.
|
||||
# These inputs are serialized as a JSON-formatted string (e.g., `{}`).
|
||||
#
|
||||
# This field is `NULL` if no inputs were submitted by the user.
|
||||
submitted_data: Mapped[str] = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
submitted_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=True)
|
||||
|
||||
submission_type: Mapped[HumanInputSubmissionType] = mapped_column(
|
||||
EnumText(HumanInputSubmissionType),
|
||||
nullable=True,
|
||||
delivery: Mapped[HumanInputDelivery] = relationship(
|
||||
"HumanInputDelivery",
|
||||
back_populates="recipients",
|
||||
lazy="raise",
|
||||
)
|
||||
|
||||
# If the submission happens in dashboard (Studio for orchestrate the workflow, or
|
||||
# Explore for using published apps), which requires user to login before submission.
|
||||
# Then the `submission_user_id` records the user id
|
||||
# of submitter, else `None`.
|
||||
submission_user_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
# If the submission happens in WebApp (which does not requires user to login before submission)
|
||||
# Then the `submission_user_id` records the end_user_id of submitter, else `None`.
|
||||
submission_end_user_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
# IF the submitter receives a email and
|
||||
submitter_email: Mapped[str] = mapped_column(sa.VARCHAR(_email_field_length), nullable=True)
|
||||
|
||||
|
||||
# class HumanInputEmailDelivery(ModelMixin, Base):
|
||||
# # form_id refers to `HumanInputForm.id`
|
||||
# form_id: Mapped[str] = mapped_column(
|
||||
# StringUUID,
|
||||
# nullable=False,
|
||||
# )
|
||||
|
||||
# # IF the submitter receives a email and
|
||||
# email: Mapped[str] = mapped_column(__name_pos=sa.VARCHAR(_email_field_length), nullable=False)
|
||||
# user_id: Mapped[str] = mapped_column(
|
||||
# StringUUID,
|
||||
# nullable=True,
|
||||
# )
|
||||
|
||||
# email_link_token: Mapped[str] = mapped_column(
|
||||
# sa.VARCHAR(_token_field_length),
|
||||
# nullable=False,
|
||||
# default=_generate_token,
|
||||
# )
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
payload: RecipientPayload,
|
||||
) -> Self:
|
||||
recipient_model = cls(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=payload.TYPE,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
access_token=_generate_token(),
|
||||
)
|
||||
return recipient_model
|
||||
|
||||
@ -30,7 +30,7 @@ from core.workflow.constants import (
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.enums import NodeType, WorkflowExecutionStatus
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -607,7 +607,10 @@ class WorkflowRun(Base):
|
||||
version: Mapped[str] = mapped_column(String(255))
|
||||
graph: Mapped[str | None] = mapped_column(LongText)
|
||||
inputs: Mapped[str | None] = mapped_column(LongText)
|
||||
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||
status: Mapped[WorkflowExecutionStatus] = mapped_column(
|
||||
EnumText(WorkflowExecutionStatus, length=255),
|
||||
nullable=False,
|
||||
)
|
||||
outputs: Mapped[str | None] = mapped_column(LongText, default="{}")
|
||||
error: Mapped[str | None] = mapped_column(LongText)
|
||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
|
||||
@ -16,6 +16,10 @@ from extensions.otel import AppGenerateHandler, trace_span
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.app_generate.workflow_execute_task import ChatflowExecutionParams, chatflow_execute_task
|
||||
|
||||
@ -246,3 +250,19 @@ class AppGenerateService:
|
||||
raise ValueError("Workflow not published")
|
||||
|
||||
return workflow
|
||||
|
||||
@classmethod
|
||||
def get_response_generator(
|
||||
cls,
|
||||
app_model: App,
|
||||
workflow_run: WorkflowRun,
|
||||
):
|
||||
if workflow_run.status.is_ended():
|
||||
# TODO(QuantumGhost): handled the ended scenario.
|
||||
return
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
return generator.convert_to_event_stream(
|
||||
generator.retrieve_events(app_model.mode, workflow_run.id),
|
||||
)
|
||||
|
||||
@ -1,272 +0,0 @@
|
||||
"""
|
||||
Service for managing human input forms using domain models.
|
||||
|
||||
This service layer operates on domain models and uses repositories for persistence,
|
||||
keeping the business logic clean and independent of database concerns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.human_input_form import HumanInputForm, HumanInputFormStatus, HumanInputSubmissionType
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormNotFoundError(BaseServiceError):
|
||||
"""Raised when a human input form is not found."""
|
||||
|
||||
def __init__(self, identifier: str):
|
||||
super().__init__(f"Human input form not found: {identifier}")
|
||||
self.identifier = identifier
|
||||
|
||||
|
||||
class HumanInputFormExpiredError(BaseServiceError):
|
||||
"""Raised when a human input form has expired."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Human input form has expired")
|
||||
|
||||
|
||||
class HumanInputFormAlreadySubmittedError(BaseServiceError):
|
||||
"""Raised when trying to operate on an already submitted form."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Human input form has already been submitted")
|
||||
|
||||
|
||||
class InvalidFormDataError(BaseServiceError):
|
||||
"""Raised when form submission data is invalid."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(f"Invalid form data: {message}")
|
||||
self.message = message
|
||||
|
||||
|
||||
class HumanInputFormDomainService:
|
||||
"""Service for managing human input forms using domain models."""
|
||||
|
||||
def __init__(self, repository: HumanInputFormRepository):
|
||||
"""
|
||||
Initialize the service with a repository.
|
||||
|
||||
Args:
|
||||
repository: Repository for human input form persistence
|
||||
"""
|
||||
self._repository = repository
|
||||
|
||||
def create_form(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
workflow_run_id: str,
|
||||
form_definition: dict[str, Any],
|
||||
rendered_content: str,
|
||||
web_app_token: Optional[str] = None,
|
||||
) -> HumanInputForm:
|
||||
"""
|
||||
Create a new human input form.
|
||||
|
||||
Args:
|
||||
form_id: Unique identifier for the form
|
||||
workflow_run_id: ID of the associated workflow run
|
||||
form_definition: Form definition as a dictionary
|
||||
rendered_content: Rendered HTML content of the form
|
||||
web_app_token: Optional token for web app access
|
||||
|
||||
Returns:
|
||||
Created HumanInputForm domain model
|
||||
"""
|
||||
form = HumanInputForm.create(
|
||||
id_=form_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
form_definition=form_definition,
|
||||
rendered_content=rendered_content,
|
||||
web_app_token=web_app_token,
|
||||
)
|
||||
|
||||
self._repository.save(form)
|
||||
logger.info("Created human input form %s", form_id)
|
||||
return form
|
||||
|
||||
def get_form_by_id(self, form_id: str) -> HumanInputForm:
|
||||
"""
|
||||
Get a form by its ID.
|
||||
|
||||
Args:
|
||||
form_id: The ID of the form to retrieve
|
||||
|
||||
Returns:
|
||||
HumanInputForm domain model
|
||||
|
||||
Raises:
|
||||
HumanInputFormNotFoundError: If the form is not found
|
||||
"""
|
||||
try:
|
||||
return self._repository.get_by_id(form_id)
|
||||
except ValueError as e:
|
||||
raise HumanInputFormNotFoundError(form_id) from e
|
||||
|
||||
def get_form_by_token(self, web_app_token: str) -> HumanInputForm:
|
||||
"""
|
||||
Get a form by its web app token.
|
||||
|
||||
Args:
|
||||
web_app_token: The web app token to search for
|
||||
|
||||
Returns:
|
||||
HumanInputForm domain model
|
||||
|
||||
Raises:
|
||||
HumanInputFormNotFoundError: If the form is not found
|
||||
"""
|
||||
try:
|
||||
return self._repository.get_by_web_app_token(web_app_token)
|
||||
except ValueError as e:
|
||||
raise HumanInputFormNotFoundError(web_app_token) from e
|
||||
|
||||
def get_form_definition(
|
||||
self,
|
||||
identifier: str,
|
||||
is_token: bool = False,
|
||||
include_site_info: bool = False,
|
||||
app_id: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get form definition for display.
|
||||
|
||||
Args:
|
||||
identifier: Form ID or web app token
|
||||
is_token: True if identifier is a web app token, False if it's a form ID
|
||||
include_site_info: Whether to include site information in the response
|
||||
app_id: App ID for site information (if include_site_info is True)
|
||||
|
||||
Returns:
|
||||
Form definition dictionary for display
|
||||
|
||||
Raises:
|
||||
HumanInputFormNotFoundError: If the form is not found
|
||||
HumanInputFormExpiredError: If the form has expired
|
||||
HumanInputFormAlreadySubmittedError: If the form has already been submitted
|
||||
"""
|
||||
if is_token:
|
||||
form = self.get_form_by_token(identifier)
|
||||
else:
|
||||
form = self.get_form_by_id(identifier)
|
||||
|
||||
try:
|
||||
form_definition = form.get_form_definition_for_display(include_site_info=include_site_info)
|
||||
except ValueError as e:
|
||||
if "expired" in str(e).lower():
|
||||
raise HumanInputFormExpiredError() from e
|
||||
elif "submitted" in str(e).lower():
|
||||
raise HumanInputFormAlreadySubmittedError() from e
|
||||
else:
|
||||
raise InvalidFormDataError(str(e)) from e
|
||||
|
||||
# Add site info if requested and app_id is provided
|
||||
if include_site_info and app_id and "site" in form_definition:
|
||||
form_definition["site"]["app_id"] = app_id
|
||||
|
||||
return form_definition
|
||||
|
||||
def submit_form(
|
||||
self,
|
||||
identifier: str,
|
||||
form_data: dict[str, Any],
|
||||
action: str,
|
||||
is_token: bool = False,
|
||||
submission_type: HumanInputSubmissionType = HumanInputSubmissionType.web_form,
|
||||
submission_user_id: Optional[str] = None,
|
||||
submission_end_user_id: Optional[str] = None,
|
||||
) -> HumanInputForm:
|
||||
"""
|
||||
Submit a form.
|
||||
|
||||
Args:
|
||||
identifier: Form ID or web app token
|
||||
form_data: The submitted form data
|
||||
action: The action taken by the user
|
||||
is_token: True if identifier is a web app token, False if it's a form ID
|
||||
submission_type: Type of submission (web_form, web_app, email)
|
||||
submission_user_id: ID of the user who submitted (for console submissions)
|
||||
submission_end_user_id: ID of the end user who submitted (for webapp submissions)
|
||||
|
||||
Returns:
|
||||
Updated HumanInputForm domain model
|
||||
|
||||
Raises:
|
||||
HumanInputFormNotFoundError: If the form is not found
|
||||
HumanInputFormExpiredError: If the form has expired
|
||||
HumanInputFormAlreadySubmittedError: If the form has already been submitted
|
||||
InvalidFormDataError: If the submission data is invalid
|
||||
"""
|
||||
if is_token:
|
||||
form = self.get_form_by_token(identifier)
|
||||
else:
|
||||
form = self.get_form_by_id(identifier)
|
||||
|
||||
if form.is_expired:
|
||||
raise HumanInputFormExpiredError()
|
||||
|
||||
if form.is_submitted:
|
||||
raise HumanInputFormAlreadySubmittedError()
|
||||
|
||||
try:
|
||||
form.submit(
|
||||
data=form_data,
|
||||
action=action,
|
||||
submission_type=submission_type,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise InvalidFormDataError(str(e)) from e
|
||||
|
||||
self._repository.save(form)
|
||||
logger.info("Form %s submitted with action %s", form.id_, action)
|
||||
return form
|
||||
|
||||
def cleanup_expired_forms(self, expiry_hours: int = 48) -> int:
|
||||
"""
|
||||
Clean up expired forms.
|
||||
|
||||
Args:
|
||||
expiry_hours: Number of hours after which forms should be expired
|
||||
|
||||
Returns:
|
||||
Number of forms cleaned up
|
||||
"""
|
||||
count = self._repository.mark_expired_forms(expiry_hours)
|
||||
logger.info("Cleaned up %d expired forms", count)
|
||||
return count
|
||||
|
||||
def get_pending_forms_for_workflow_run(self, workflow_run_id: str) -> list[HumanInputForm]:
|
||||
"""
|
||||
Get all pending human input forms for a workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID to filter by
|
||||
|
||||
Returns:
|
||||
List of pending HumanInputForm domain models
|
||||
"""
|
||||
return self._repository.get_pending_forms_for_workflow_run(workflow_run_id)
|
||||
|
||||
def form_exists(self, identifier: str, is_token: bool = False) -> bool:
|
||||
"""
|
||||
Check if a form exists.
|
||||
|
||||
Args:
|
||||
identifier: Form ID or web app token
|
||||
is_token: True if identifier is a web app token, False if it's a form ID
|
||||
|
||||
Returns:
|
||||
True if the form exists, False otherwise
|
||||
"""
|
||||
if is_token:
|
||||
return self._repository.exists_by_web_app_token(identifier)
|
||||
else:
|
||||
return self._repository.exists_by_id(identifier)
|
||||
@ -1,352 +0,0 @@
|
||||
"""
|
||||
Service for managing human input forms.
|
||||
|
||||
This service maintains backward compatibility while internally using domain models
|
||||
and repositories for better architecture.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.human_input_form import HumanInputForm as DomainHumanInputForm, HumanInputSubmissionType
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository
|
||||
from models.human_input import (
|
||||
HumanInputForm,
|
||||
HumanInputFormStatus,
|
||||
HumanInputSubmissionType as DBHumanInputSubmissionType,
|
||||
)
|
||||
from services.errors.base import BaseServiceError
|
||||
from services.human_input_form_domain_service import (
|
||||
HumanInputFormDomainService,
|
||||
HumanInputFormNotFoundError as DomainNotFoundError,
|
||||
HumanInputFormExpiredError as DomainExpiredError,
|
||||
HumanInputFormAlreadySubmittedError as DomainAlreadySubmittedError,
|
||||
InvalidFormDataError as DomainInvalidFormDataError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormNotFoundError(BaseServiceError):
|
||||
"""Raised when a human input form is not found."""
|
||||
|
||||
def __init__(self, identifier: str):
|
||||
super().__init__(f"Human input form not found: {identifier}")
|
||||
self.identifier = identifier
|
||||
|
||||
|
||||
class HumanInputFormExpiredError(BaseServiceError):
|
||||
"""Raised when a human input form has expired."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Human input form has expired")
|
||||
|
||||
|
||||
class HumanInputFormAlreadySubmittedError(BaseServiceError):
|
||||
"""Raised when trying to operate on an already submitted form."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Human input form has already been submitted")
|
||||
|
||||
|
||||
class InvalidFormDataError(BaseServiceError):
|
||||
"""Raised when form submission data is invalid."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(f"Invalid form data: {message}")
|
||||
self.message = message
|
||||
|
||||
|
||||
class HumanInputFormService:
|
||||
"""Service for managing human input forms using domain models internally."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
"""
|
||||
Initialize the service with a database session.
|
||||
|
||||
Args:
|
||||
session: SQLAlchemy session
|
||||
"""
|
||||
self._session = session
|
||||
# For backward compatibility, we need user and app_id context
|
||||
# These would typically be available from the request context
|
||||
# For now, we'll extract them from the session or use defaults
|
||||
self._user = None # This should be set from request context
|
||||
self._app_id = None # This should be set from request context
|
||||
self._domain_service = None
|
||||
|
||||
def _get_domain_service(self) -> HumanInputFormDomainService:
|
||||
"""
|
||||
Get the domain service instance.
|
||||
|
||||
Note: This requires user and app_id context to be properly set.
|
||||
In a real implementation, these would be extracted from the request context.
|
||||
"""
|
||||
if self._domain_service is None:
|
||||
if not self._user:
|
||||
# For backward compatibility, we need to handle this case
|
||||
# In practice, the user should be available from the request context
|
||||
raise ValueError("User context is required for domain operations")
|
||||
|
||||
repository = DifyCoreRepositoryFactory.create_human_input_form_repository(
|
||||
session_factory=self._session,
|
||||
user=self._user,
|
||||
app_id=self._app_id or "",
|
||||
)
|
||||
self._domain_service = HumanInputFormDomainService(repository)
|
||||
|
||||
return self._domain_service
|
||||
|
||||
def _domain_to_db_model(self, domain_form: DomainHumanInputForm) -> HumanInputForm:
|
||||
"""Convert domain model to database model for backward compatibility."""
|
||||
# Find existing DB model or create new one
|
||||
db_model = self._session.get(HumanInputForm, domain_form.id_)
|
||||
if db_model is None:
|
||||
db_model = HumanInputForm()
|
||||
db_model.id = domain_form.id_
|
||||
# Set tenant_id and app_id from context
|
||||
if hasattr(self._user, "current_tenant_id"):
|
||||
db_model.tenant_id = self._user.current_tenant_id
|
||||
elif hasattr(self._user, "tenant_id"):
|
||||
db_model.tenant_id = self._user.tenant_id
|
||||
if self._app_id:
|
||||
db_model.app_id = self._app_id
|
||||
|
||||
# Update fields
|
||||
db_model.workflow_run_id = domain_form.workflow_run_id
|
||||
db_model.form_definition = json.dumps(domain_form.form_definition)
|
||||
db_model.rendered_content = domain_form.rendered_content
|
||||
db_model.status = HumanInputFormStatus(domain_form.status.value)
|
||||
db_model.web_app_token = domain_form.web_app_token
|
||||
db_model.created_at = domain_form.created_at
|
||||
|
||||
# Handle submission data
|
||||
if domain_form.submission:
|
||||
db_model.submitted_data = json.dumps(domain_form.submission.data)
|
||||
db_model.submitted_at = domain_form.submission.submitted_at
|
||||
db_model.submission_type = DBHumanInputSubmissionType(domain_form.submission.submission_type.value)
|
||||
db_model.submission_user_id = domain_form.submission.submission_user_id
|
||||
db_model.submission_end_user_id = domain_form.submission.submission_end_user_id
|
||||
# Note: submitter_email is not in the current DB model schema
|
||||
|
||||
return db_model
|
||||
|
||||
def set_context(self, user, app_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Set user and app context for the service.
|
||||
|
||||
Args:
|
||||
user: User object (Account or EndUser)
|
||||
app_id: Application ID
|
||||
"""
|
||||
self._user = user
|
||||
self._app_id = app_id
|
||||
self._domain_service = None # Reset to force recreation with new context
|
||||
|
||||
def create_form(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
workflow_run_id: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
form_definition: str,
|
||||
rendered_content: str,
|
||||
web_app_token: Optional[str] = None,
|
||||
) -> HumanInputForm:
|
||||
"""Create a new human input form."""
|
||||
# Set context for this operation
|
||||
self._app_id = app_id
|
||||
|
||||
try:
|
||||
domain_service = self._get_domain_service()
|
||||
domain_form = domain_service.create_form(
|
||||
form_id=form_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
form_definition=json.loads(form_definition),
|
||||
rendered_content=rendered_content,
|
||||
web_app_token=web_app_token,
|
||||
)
|
||||
|
||||
# Convert back to DB model for backward compatibility
|
||||
db_model = self._domain_to_db_model(domain_form)
|
||||
self._session.add(db_model)
|
||||
self._session.commit()
|
||||
|
||||
logger.info("Created human input form %s", form_id)
|
||||
return db_model
|
||||
|
||||
except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e:
|
||||
# Convert domain exceptions to service exceptions
|
||||
if isinstance(e, DomainNotFoundError):
|
||||
raise HumanInputFormNotFoundError(e.identifier) from e
|
||||
elif isinstance(e, DomainExpiredError):
|
||||
raise HumanInputFormExpiredError() from e
|
||||
elif isinstance(e, DomainAlreadySubmittedError):
|
||||
raise HumanInputFormAlreadySubmittedError() from e
|
||||
elif isinstance(e, DomainInvalidFormDataError):
|
||||
raise InvalidFormDataError(e.message) from e
|
||||
|
||||
def get_form_by_id(self, form_id: str) -> HumanInputForm:
|
||||
"""Get a form by its ID."""
|
||||
try:
|
||||
domain_service = self._get_domain_service()
|
||||
domain_form = domain_service.get_form_by_id(form_id)
|
||||
return self._domain_to_db_model(domain_form)
|
||||
except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e:
|
||||
if isinstance(e, DomainNotFoundError):
|
||||
raise HumanInputFormNotFoundError(e.identifier) from e
|
||||
elif isinstance(e, DomainExpiredError):
|
||||
raise HumanInputFormExpiredError() from e
|
||||
elif isinstance(e, DomainAlreadySubmittedError):
|
||||
raise HumanInputFormAlreadySubmittedError() from e
|
||||
elif isinstance(e, DomainInvalidFormDataError):
|
||||
raise InvalidFormDataError(e.message) from e
|
||||
|
||||
def get_form_by_token(self, web_app_token: str) -> HumanInputForm:
|
||||
"""Get a form by its web app token."""
|
||||
try:
|
||||
domain_service = self._get_domain_service()
|
||||
domain_form = domain_service.get_form_by_token(web_app_token)
|
||||
return self._domain_to_db_model(domain_form)
|
||||
except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e:
|
||||
if isinstance(e, DomainNotFoundError):
|
||||
raise HumanInputFormNotFoundError(e.identifier) from e
|
||||
elif isinstance(e, DomainExpiredError):
|
||||
raise HumanInputFormExpiredError() from e
|
||||
elif isinstance(e, DomainAlreadySubmittedError):
|
||||
raise HumanInputFormAlreadySubmittedError() from e
|
||||
elif isinstance(e, DomainInvalidFormDataError):
|
||||
raise InvalidFormDataError(e.message) from e
|
||||
|
||||
def get_form_definition(
|
||||
self,
|
||||
identifier: str,
|
||||
is_token: bool = False,
|
||||
include_site_info: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get form definition for display.
|
||||
|
||||
Args:
|
||||
identifier: Form ID or web app token
|
||||
is_token: True if identifier is a web app token, False if it's a form ID
|
||||
include_site_info: Whether to include site information in the response
|
||||
"""
|
||||
try:
|
||||
domain_service = self._get_domain_service()
|
||||
return domain_service.get_form_definition(
|
||||
identifier=identifier,
|
||||
is_token=is_token,
|
||||
include_site_info=include_site_info,
|
||||
app_id=self._app_id,
|
||||
)
|
||||
except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e:
|
||||
if isinstance(e, DomainNotFoundError):
|
||||
raise HumanInputFormNotFoundError(e.identifier) from e
|
||||
elif isinstance(e, DomainExpiredError):
|
||||
raise HumanInputFormExpiredError() from e
|
||||
elif isinstance(e, DomainAlreadySubmittedError):
|
||||
raise HumanInputFormAlreadySubmittedError() from e
|
||||
elif isinstance(e, DomainInvalidFormDataError):
|
||||
raise InvalidFormDataError(e.message) from e
|
||||
|
||||
def submit_form(
|
||||
self,
|
||||
identifier: str,
|
||||
form_data: dict[str, Any],
|
||||
action: str,
|
||||
is_token: bool = False,
|
||||
submission_type: HumanInputSubmissionType = HumanInputSubmissionType.web_form,
|
||||
submission_user_id: Optional[str] = None,
|
||||
submission_end_user_id: Optional[str] = None,
|
||||
) -> HumanInputForm:
|
||||
"""
|
||||
Submit a form.
|
||||
|
||||
Args:
|
||||
identifier: Form ID or web app token
|
||||
form_data: The submitted form data
|
||||
action: The action taken by the user
|
||||
is_token: True if identifier is a web app token, False if it's a form ID
|
||||
submission_type: Type of submission (web_form, web_app, email)
|
||||
submission_user_id: ID of the user who submitted (for console submissions)
|
||||
submission_end_user_id: ID of the end user who submitted (for webapp submissions)
|
||||
"""
|
||||
try:
|
||||
domain_service = self._get_domain_service()
|
||||
domain_form = domain_service.submit_form(
|
||||
identifier=identifier,
|
||||
form_data=form_data,
|
||||
action=action,
|
||||
is_token=is_token,
|
||||
submission_type=submission_type,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
|
||||
# Convert back to DB model for backward compatibility
|
||||
db_model = self._domain_to_db_model(domain_form)
|
||||
self._session.merge(db_model)
|
||||
self._session.commit()
|
||||
|
||||
return db_model
|
||||
|
||||
except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e:
|
||||
if isinstance(e, DomainNotFoundError):
|
||||
raise HumanInputFormNotFoundError(e.identifier) from e
|
||||
elif isinstance(e, DomainExpiredError):
|
||||
raise HumanInputFormExpiredError() from e
|
||||
elif isinstance(e, DomainAlreadySubmittedError):
|
||||
raise HumanInputFormAlreadySubmittedError() from e
|
||||
elif isinstance(e, DomainInvalidFormDataError):
|
||||
raise InvalidFormDataError(e.message) from e
|
||||
|
||||
def _validate_submission(self, form: HumanInputForm, form_data: dict[str, Any], action: str) -> None:
|
||||
"""Validate form submission data."""
|
||||
form_definition = json.loads(form.form_definition)
|
||||
|
||||
# Check that the action is valid
|
||||
valid_actions = {act.get("id") for act in form_definition.get("user_actions", [])}
|
||||
if action not in valid_actions:
|
||||
raise InvalidFormDataError(f"Invalid action: {action}")
|
||||
|
||||
# Note: We don't validate required inputs here as the original implementation
|
||||
# allows extra inputs and doesn't strictly enforce all inputs to be present
|
||||
|
||||
def cleanup_expired_forms(self) -> int:
|
||||
"""Clean up expired forms. Returns the number of forms cleaned up."""
|
||||
try:
|
||||
domain_service = self._get_domain_service()
|
||||
return domain_service.cleanup_expired_forms()
|
||||
except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e:
|
||||
if isinstance(e, DomainNotFoundError):
|
||||
raise HumanInputFormNotFoundError(e.identifier) from e
|
||||
elif isinstance(e, DomainExpiredError):
|
||||
raise HumanInputFormExpiredError() from e
|
||||
elif isinstance(e, DomainAlreadySubmittedError):
|
||||
raise HumanInputFormAlreadySubmittedError() from e
|
||||
elif isinstance(e, DomainInvalidFormDataError):
|
||||
raise InvalidFormDataError(e.message) from e
|
||||
|
||||
def get_pending_forms_for_workflow_run(self, workflow_run_id: str) -> list[HumanInputForm]:
|
||||
"""Get all pending human input forms for a workflow run."""
|
||||
try:
|
||||
domain_service = self._get_domain_service()
|
||||
domain_forms = domain_service.get_pending_forms_for_workflow_run(workflow_run_id)
|
||||
return [self._domain_to_db_model(domain_form) for domain_form in domain_forms]
|
||||
except (DomainNotFoundError, DomainExpiredError, DomainAlreadySubmittedError, DomainInvalidFormDataError) as e:
|
||||
if isinstance(e, DomainNotFoundError):
|
||||
raise HumanInputFormNotFoundError(e.identifier) from e
|
||||
elif isinstance(e, DomainExpiredError):
|
||||
raise HumanInputFormExpiredError() from e
|
||||
elif isinstance(e, DomainAlreadySubmittedError):
|
||||
raise HumanInputFormAlreadySubmittedError() from e
|
||||
elif isinstance(e, DomainInvalidFormDataError):
|
||||
raise InvalidFormDataError(e.message) from e
|
||||
63
api/services/human_input_service.py
Normal file
63
api/services/human_input_service.py
Normal file
@ -0,0 +1,63 @@
|
||||
import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from libs.exception import BaseHTTPException
|
||||
from models.human_input import RecipientType
|
||||
|
||||
|
||||
class Form(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def get_definition(self) -> FormDefinition:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def submitted(self) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class HumanInputError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FormSubmittedError(HumanInputError, BaseHTTPException):
|
||||
error_code = "human_input_form_submitted"
|
||||
description = "This form has already been submitted by another user, form_id={form_id}"
|
||||
code = 412
|
||||
|
||||
def __init__(self, form_id: str):
|
||||
description = self.description.format(form_id=form_id)
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class FormNotFoundError(HumanInputError, BaseException):
|
||||
error_code = "human_input_form_not_found"
|
||||
code = 404
|
||||
|
||||
|
||||
class HumanInputService:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
|
||||
def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form:
|
||||
pass
|
||||
|
||||
def get_form_definition_by_id(self, form_id: str) -> Form | None:
|
||||
pass
|
||||
|
||||
def submit_form_by_id(self, form_id: str, selected_action_id: str, form_data: Mapping[str, Any]):
|
||||
pass
|
||||
|
||||
def submit_form_by_token(
|
||||
self, recipient_type: RecipientType, form_token: str, selected_action_id: str, form_data: Mapping[str, Any]
|
||||
):
|
||||
pass
|
||||
@ -136,7 +136,7 @@ class _ChatflowRunner:
|
||||
|
||||
chat_generator = AdvancedChatAppGenerator()
|
||||
|
||||
workflow_run_id = uuid.uuid4()
|
||||
workflow_run_id = exec_params.workflow_run_id
|
||||
|
||||
with self._setup_flask_context(user):
|
||||
response = chat_generator.generate(
|
||||
|
||||
@ -0,0 +1,75 @@
|
||||
"""
|
||||
Tests for PauseReason discriminated union serialization/deserialization.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.workflow.entities.pause_reason import (
|
||||
HumanInputRequired,
|
||||
PauseReason,
|
||||
SchedulingPause,
|
||||
)
|
||||
|
||||
|
||||
class _Holder(BaseModel):
|
||||
"""Helper model that embeds PauseReason for union tests."""
|
||||
|
||||
reason: PauseReason
|
||||
|
||||
|
||||
class TestPauseReasonDiscriminator:
|
||||
"""Test suite for PauseReason union discriminator."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("dict_value", "expected"),
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
"reason": {
|
||||
"TYPE": "human_input_required",
|
||||
"human_input_form_id": "form_id",
|
||||
},
|
||||
},
|
||||
HumanInputRequired(human_input_form_id="form_id"),
|
||||
id="HumanInputRequired",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"reason": {
|
||||
"TYPE": "scheduled_pause",
|
||||
"message": "Hold on",
|
||||
}
|
||||
},
|
||||
SchedulingPause(message="Hold on"),
|
||||
id="SchedulingPause",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_model_validate(self, dict_value, expected):
|
||||
"""Ensure scheduled pause payloads with lowercase TYPE deserialize."""
|
||||
holder = _Holder.model_validate(dict_value)
|
||||
|
||||
assert type(holder.reason) == type(expected)
|
||||
assert holder.reason == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reason",
|
||||
[
|
||||
HumanInputRequired(human_input_form_id="form_id"),
|
||||
SchedulingPause(message="Hold on"),
|
||||
],
|
||||
ids=lambda x: type(x).__name__,
|
||||
)
|
||||
def test_model_construct(self, reason):
|
||||
holder = _Holder(reason=reason)
|
||||
assert holder.reason == reason
|
||||
|
||||
def test_model_construct_with_invalid_type(self):
|
||||
with pytest.raises(ValidationError):
|
||||
holder = _Holder(reason=object()) # type: ignore
|
||||
|
||||
def test_unknown_type_fails_validation(self):
|
||||
"""Unknown TYPE values should raise a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
_Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}})
|
||||
@ -7,10 +7,11 @@ from pydantic import ValidationError
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
ButtonStyle,
|
||||
DeliveryMethod,
|
||||
DeliveryMethodType,
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
EmailRecipientType,
|
||||
ExternalRecipient,
|
||||
FormInput,
|
||||
FormInputPlaceholder,
|
||||
@ -18,10 +19,10 @@ from core.workflow.nodes.human_input.entities import (
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
PlaceholderType,
|
||||
RecipientType,
|
||||
TimeoutUnit,
|
||||
UserAction,
|
||||
WebAppDeliveryConfig,
|
||||
WebAppDeliveryMethod,
|
||||
_WebAppDeliveryConfig,
|
||||
)
|
||||
|
||||
|
||||
@ -30,19 +31,19 @@ class TestDeliveryMethod:
|
||||
|
||||
def test_webapp_delivery_method(self):
|
||||
"""Test webapp delivery method creation."""
|
||||
delivery_method = DeliveryMethod(type=DeliveryMethodType.WEBAPP, enabled=True, config=WebAppDeliveryConfig())
|
||||
delivery_method = WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())
|
||||
|
||||
assert delivery_method.type == DeliveryMethodType.WEBAPP
|
||||
assert delivery_method.enabled is True
|
||||
assert isinstance(delivery_method.config, WebAppDeliveryConfig)
|
||||
assert isinstance(delivery_method.config, _WebAppDeliveryConfig)
|
||||
|
||||
def test_email_delivery_method(self):
|
||||
"""Test email delivery method creation."""
|
||||
recipients = EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(type=RecipientType.MEMBER, user_id="test-user-123"),
|
||||
ExternalRecipient(type=RecipientType.EXTERNAL, email="test@example.com"),
|
||||
MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"),
|
||||
ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -50,7 +51,7 @@ class TestDeliveryMethod:
|
||||
recipients=recipients, subject="Test Subject", body="Test body with {{#url#}} placeholder"
|
||||
)
|
||||
|
||||
delivery_method = DeliveryMethod(type=DeliveryMethodType.EMAIL, enabled=True, config=config)
|
||||
delivery_method = EmailDeliveryMethod(enabled=True, config=config)
|
||||
|
||||
assert delivery_method.type == DeliveryMethodType.EMAIL
|
||||
assert delivery_method.enabled is True
|
||||
@ -118,7 +119,7 @@ class TestHumanInputNodeData:
|
||||
|
||||
def test_valid_node_data_creation(self):
|
||||
"""Test creating valid human input node data."""
|
||||
delivery_methods = [DeliveryMethod(type=DeliveryMethodType.WEBAPP, enabled=True, config=WebAppDeliveryConfig())]
|
||||
delivery_methods = [WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())]
|
||||
|
||||
inputs = [
|
||||
FormInput(
|
||||
@ -153,11 +154,12 @@ class TestHumanInputNodeData:
|
||||
def test_node_data_with_multiple_delivery_methods(self):
|
||||
"""Test node data with multiple delivery methods."""
|
||||
delivery_methods = [
|
||||
DeliveryMethod(type=DeliveryMethodType.WEBAPP, enabled=True, config=WebAppDeliveryConfig()),
|
||||
DeliveryMethod(
|
||||
type=DeliveryMethodType.EMAIL,
|
||||
WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()),
|
||||
EmailDeliveryMethod(
|
||||
enabled=False, # Disabled method should be fine
|
||||
config=None,
|
||||
config=EmailDeliveryConfig(
|
||||
subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -182,28 +184,64 @@ class TestHumanInputNodeData:
|
||||
assert node_data.timeout == 36
|
||||
assert node_data.timeout_unit == TimeoutUnit.HOUR
|
||||
|
||||
def test_duplicate_input_output_variable_name_raises_validation_error(self):
|
||||
"""Duplicate form input output_variable_name should raise validation error."""
|
||||
duplicate_inputs = [
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"),
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValidationError, match="duplicated output_variable_name 'content'"):
|
||||
HumanInputNodeData(title="Test Node", inputs=duplicate_inputs)
|
||||
|
||||
def test_duplicate_user_action_ids_raise_validation_error(self):
|
||||
"""Duplicate user action ids should raise validation error."""
|
||||
duplicate_actions = [
|
||||
UserAction(id="submit", title="Submit"),
|
||||
UserAction(id="submit", title="Submit Again"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValidationError, match="duplicated user action id 'submit'"):
|
||||
HumanInputNodeData(title="Test Node", user_actions=duplicate_actions)
|
||||
|
||||
def test_extract_outputs_field_names(self):
|
||||
content = r"""This is titile {{#start.title#}}
|
||||
|
||||
A content is required:
|
||||
|
||||
{{#$outputs.content#}}
|
||||
|
||||
A ending is required:
|
||||
|
||||
{{#$outputs.ending#}}
|
||||
"""
|
||||
|
||||
node_data = HumanInputNodeData(title="Human Input", form_content=content)
|
||||
field_names = node_data.outputs_field_names()
|
||||
assert field_names == ["content", "ending"]
|
||||
|
||||
|
||||
class TestRecipients:
|
||||
"""Test email recipient entities."""
|
||||
|
||||
def test_member_recipient(self):
|
||||
"""Test member recipient creation."""
|
||||
recipient = MemberRecipient(type=RecipientType.MEMBER, user_id="user-123")
|
||||
recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")
|
||||
|
||||
assert recipient.type == RecipientType.MEMBER
|
||||
assert recipient.type == EmailRecipientType.MEMBER
|
||||
assert recipient.user_id == "user-123"
|
||||
|
||||
def test_external_recipient(self):
|
||||
"""Test external recipient creation."""
|
||||
recipient = ExternalRecipient(type=RecipientType.EXTERNAL, email="test@example.com")
|
||||
recipient = ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com")
|
||||
|
||||
assert recipient.type == RecipientType.EXTERNAL
|
||||
assert recipient.type == EmailRecipientType.EXTERNAL
|
||||
assert recipient.email == "test@example.com"
|
||||
|
||||
def test_email_recipients_whole_workspace(self):
|
||||
"""Test email recipients with whole workspace enabled."""
|
||||
recipients = EmailRecipients(
|
||||
whole_workspace=True, items=[MemberRecipient(type=RecipientType.MEMBER, user_id="user-123")]
|
||||
whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")]
|
||||
)
|
||||
|
||||
assert recipients.whole_workspace is True
|
||||
@ -214,8 +252,8 @@ class TestRecipients:
|
||||
recipients = EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(type=RecipientType.MEMBER, user_id="user-123"),
|
||||
ExternalRecipient(type=RecipientType.EXTERNAL, email="external@example.com"),
|
||||
MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"),
|
||||
ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user