Files
dify/api/services/human_input_service.py

464 lines
16 KiB
Python

import logging
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from typing import Any, Protocol, cast
from pydantic import JsonValue
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormSubmissionRepository,
)
from factories.file_factory import build_from_mapping, build_from_mappings
from graphon.file import FileUploadConfig
from graphon.nodes.human_input.entities import (
FileInputConfig,
FileListInputConfig,
FormDefinition,
FormInputConfig,
HumanInputSubmissionValidationError,
SelectInputConfig,
UserActionConfig,
)
from graphon.nodes.human_input.entities import (
validate_human_input_submission as graphon_validate_human_input_submission,
)
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus, ValueSourceType
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
from libs.exception import BaseHTTPException
from models.human_input import RecipientType
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
from tasks.app_generate.workflow_execute_task import resume_app_execution
_file_access_controller = DatabaseFileAccessController()
class Form:
def __init__(self, record: HumanInputFormRecord):
self._record = record
def get_definition(self) -> FormDefinition:
return self._record.definition
@property
def submitted(self) -> bool:
return self._record.submitted
@property
def id(self) -> str:
return self._record.form_id
@property
def workflow_run_id(self) -> str | None:
"""Workflow run id for runtime forms; None for delivery tests."""
return self._record.workflow_run_id
@property
def tenant_id(self) -> str:
return self._record.tenant_id
@property
def app_id(self) -> str:
return self._record.app_id
@property
def recipient_id(self) -> str | None:
return self._record.recipient_id
@property
def recipient_type(self) -> RecipientType | None:
return self._record.recipient_type
@property
def status(self) -> HumanInputFormStatus:
return self._record.status
@property
def form_kind(self) -> HumanInputFormKind:
return self._record.form_kind
@property
def created_at(self) -> "datetime":
return self._record.created_at
@property
def expiration_time(self) -> "datetime":
return self._record.expiration_time
class HumanInputError(Exception):
pass
class FormSubmittedError(BaseHTTPException, HumanInputError):
error_code = "human_input_form_submitted"
description = "This form has already been submitted by another user, form_id={form_id}"
code = 412
def __init__(self, form_id: str):
template = self.description or "This form has already been submitted by another user, form_id={form_id}"
description = template.format(form_id=form_id)
BaseHTTPException.__init__(self, description=description)
class FormNotFoundError(BaseHTTPException, HumanInputError):
error_code = "human_input_form_not_found"
code = 404
class InvalidFormDataError(BaseHTTPException, HumanInputError):
error_code = "invalid_form_data"
code = 400
def __init__(self, description: str):
BaseHTTPException.__init__(self, description=description)
class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
pass
class FormExpiredError(BaseHTTPException, HumanInputError):
error_code = "human_input_form_expired"
code = 412
def __init__(self, form_id: str):
BaseHTTPException.__init__(
self,
description=f"This form has expired, form_id={form_id}",
)
logger = logging.getLogger(__name__)
class FormDefinitionProtocol(Protocol):
@property
def inputs(self) -> Sequence[FormInputConfig]: ...
@property
def user_actions(self) -> Sequence[UserActionConfig]: ...
class HumanInputService:
def __init__(
self,
session_factory: sessionmaker[Session] | Engine,
form_repository: HumanInputFormSubmissionRepository | None = None,
):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._form_repository = form_repository or HumanInputFormSubmissionRepository()
def get_form_by_token(self, form_token: str) -> Form | None:
record = self._form_repository.get_by_token(form_token)
if record is None:
return None
return Form(record)
def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form | None:
form = self.get_form_by_token(form_token)
if form is None or form.recipient_type != recipient_type:
return None
self._ensure_not_submitted(form)
return form
def get_form_definition_by_token_for_console(self, form_token: str) -> Form | None:
form = self.get_form_by_token(form_token)
if form is None or form.recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
return None
self._ensure_not_submitted(form)
return form
def submit_form_by_token(
self,
recipient_type: RecipientType,
form_token: str,
selected_action_id: str,
form_data: Mapping[str, JsonValue],
submission_end_user_id: str | None = None,
submission_user_id: str | None = None,
):
form = self.get_form_by_token(form_token)
if form is None or form.recipient_type != recipient_type:
raise WebAppDeliveryNotEnabledError()
self.ensure_form_active(form)
normalized_form_data = self._validate_submission(
form=form,
selected_action_id=selected_action_id,
form_data=form_data,
)
result = self._form_repository.mark_submitted(
form_id=form.id,
recipient_id=form.recipient_id,
selected_action_id=selected_action_id,
form_data=normalized_form_data,
submission_user_id=submission_user_id,
submission_end_user_id=submission_end_user_id,
)
if result.form_kind != HumanInputFormKind.RUNTIME:
return
if result.workflow_run_id is None:
return
self.enqueue_resume(result.workflow_run_id)
def ensure_form_active(self, form: Form) -> None:
if form.submitted:
raise FormSubmittedError(form.id)
if form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
raise FormExpiredError(form.id)
now = naive_utc_now()
if ensure_naive_utc(form.expiration_time) <= now:
raise FormExpiredError(form.id)
if self._is_globally_expired(form, now=now):
raise FormExpiredError(form.id)
def _ensure_not_submitted(self, form: Form) -> None:
if form.submitted:
raise FormSubmittedError(form.id)
def _validate_submission(
self,
form: Form,
selected_action_id: str,
form_data: Mapping[str, Any],
) -> dict[str, JsonValue]:
definition = form.get_definition()
try:
return self.validate_and_normalize_submission(
tenant_id=form.tenant_id,
form_definition=definition,
selected_action_id=selected_action_id,
form_data=form_data,
)
except HumanInputSubmissionValidationError as exc:
raise InvalidFormDataError(str(exc)) from exc
def enqueue_resume(self, workflow_run_id: str) -> None:
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(workflow_run_id)
if workflow_run is None:
raise AssertionError(f"WorkflowRun not found, id={workflow_run_id}")
with self._session_factory(expire_on_commit=False) as session:
app_query = select(App).where(App.id == workflow_run.app_id)
app = session.execute(app_query).scalar_one_or_none()
if app is None:
logger.error(
"App not found for WorkflowRun, workflow_run_id=%s, app_id=%s", workflow_run_id, workflow_run.app_id
)
return
if app.mode in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
payload = {"workflow_run_id": workflow_run_id}
try:
resume_app_execution.apply_async(
kwargs={"payload": payload},
)
except Exception: # pragma: no cover
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)
return
logger.warning("App mode %s does not support resume for workflow run %s", app.mode, workflow_run_id)
def _is_globally_expired(self, form: Form, *, now: datetime | None = None) -> bool:
global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS
if global_timeout_seconds <= 0:
return False
if form.workflow_run_id is None:
return False
current = now or naive_utc_now()
created_at = ensure_naive_utc(form.created_at)
global_deadline = created_at + timedelta(seconds=global_timeout_seconds)
return global_deadline <= current
@staticmethod
def validate_human_input_submission(
*,
form_definition: FormDefinitionProtocol,
selected_action_id: str,
form_data: Mapping[str, Any],
) -> None:
graphon_validate_human_input_submission(
inputs=form_definition.inputs,
user_actions=form_definition.user_actions,
selected_action_id=selected_action_id,
form_data=form_data,
)
@classmethod
def validate_and_normalize_submission(
cls,
*,
tenant_id: str,
form_definition: FormDefinitionProtocol,
selected_action_id: str,
form_data: Mapping[str, Any],
) -> dict[str, JsonValue]:
"""
Normalize Dify-owned runtime payloads before delegating shape validation to graphon.
graphon owns the form schema and validation rules, while Dify owns tenant-aware file
reconstruction and persistence compatibility for submitted payloads.
"""
normalized_form_data = cls.normalize_submission_data(
tenant_id=tenant_id,
form_definition=form_definition,
form_data=form_data,
)
graphon_validate_human_input_submission(
inputs=form_definition.inputs,
user_actions=form_definition.user_actions,
selected_action_id=selected_action_id,
form_data=normalized_form_data,
)
return normalized_form_data
@classmethod
def normalize_submission_data(
cls,
*,
tenant_id: str,
form_definition: FormDefinitionProtocol,
form_data: Mapping[str, Any],
) -> dict[str, JsonValue]:
normalized_form_data: dict[str, JsonValue] = {key: cast(JsonValue, value) for key, value in form_data.items()}
inputs_by_name = {form_input.output_variable_name: form_input for form_input in form_definition.inputs}
for name, form_input in inputs_by_name.items():
if name not in form_data:
continue
normalized_form_data[name] = cls._normalize_input_value(
tenant_id=tenant_id,
form_input=form_input,
value=form_data[name],
)
return normalized_form_data
@classmethod
def _normalize_input_value(
cls,
*,
tenant_id: str,
form_input: FormInputConfig,
value: Any,
) -> JsonValue:
if isinstance(form_input, SelectInputConfig):
return cls._normalize_select_value(form_input=form_input, value=value)
if isinstance(form_input, FileInputConfig):
return cls._normalize_file_value(
tenant_id=tenant_id,
form_input=form_input,
value=value,
)
if isinstance(form_input, FileListInputConfig):
return cls._normalize_file_list_value(
tenant_id=tenant_id,
form_input=form_input,
value=value,
)
return cast(JsonValue, value)
@classmethod
def _normalize_select_value(
cls,
*,
form_input: SelectInputConfig,
value: Any,
) -> JsonValue:
if not isinstance(value, str):
raise HumanInputSubmissionValidationError(
f"Invalid value for select input '{form_input.output_variable_name}': expected string"
)
option_source = form_input.option_source
if option_source.type == ValueSourceType.CONSTANT and value not in option_source.value:
raise HumanInputSubmissionValidationError(
f"Invalid value for select input '{form_input.output_variable_name}': {value}"
)
return value
@classmethod
def _normalize_file_value(
cls,
*,
tenant_id: str,
form_input: FileInputConfig,
value: Any,
) -> JsonValue:
if not isinstance(value, Mapping):
raise HumanInputSubmissionValidationError(
f"Invalid value for file input '{form_input.output_variable_name}': expected mapping"
)
upload_config = cls._build_file_upload_config(form_input=form_input, number_limits=1)
try:
# `build_from_mapping` enforces tenant ownership for persisted upload references.
file = build_from_mapping(
mapping=value,
tenant_id=tenant_id,
config=upload_config,
strict_type_validation=True,
access_controller=_file_access_controller,
)
except ValueError as exc:
raise HumanInputSubmissionValidationError(
f"Invalid value for file input '{form_input.output_variable_name}': {exc}"
) from exc
return cast(JsonValue, file.to_dict())
@classmethod
def _normalize_file_list_value(
cls,
*,
tenant_id: str,
form_input: FileListInputConfig,
value: Any,
) -> JsonValue:
if not isinstance(value, list):
raise HumanInputSubmissionValidationError(
f"Invalid value for file list input '{form_input.output_variable_name}': expected list"
)
if any(not isinstance(item, Mapping) for item in value):
raise HumanInputSubmissionValidationError(
f"Invalid value for file list input '{form_input.output_variable_name}': expected list of mappings"
)
upload_config = cls._build_file_upload_config(
form_input=form_input,
number_limits=form_input.number_limits,
)
try:
# `build_from_mappings` performs the same tenant-aware ownership validation in batch.
files = build_from_mappings(
mappings=cast(Sequence[Mapping[str, Any]], value),
tenant_id=tenant_id,
config=upload_config,
strict_type_validation=True,
access_controller=_file_access_controller,
)
except ValueError as exc:
raise HumanInputSubmissionValidationError(
f"Invalid value for file list input '{form_input.output_variable_name}': {exc}"
) from exc
return cast(JsonValue, [file.to_dict() for file in files])
@staticmethod
def _build_file_upload_config(
*,
form_input: FileInputConfig | FileListInputConfig,
number_limits: int,
) -> FileUploadConfig:
return FileUploadConfig(
allowed_file_types=list(form_input.allowed_file_types),
allowed_file_extensions=list(form_input.allowed_file_extensions),
allowed_file_upload_methods=list(form_input.allowed_file_upload_methods),
number_limits=number_limits,
)