mirror of
https://github.com/langgenius/dify.git
synced 2026-06-08 09:27:39 +08:00
feat: add service api of HITL (#32826)
Co-authored-by: Blackoutta <hyytez@gmail.com> Co-authored-by: QuantumGhost <QuantumGhost@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
This commit is contained in:
@ -12,20 +12,16 @@ from collections.abc import Sequence
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
|
||||
from extensions.ext_database import db
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
_FORM_TOKEN_PRIORITY = {
|
||||
RecipientType.BACKSTAGE: 0,
|
||||
RecipientType.CONSOLE: 1,
|
||||
RecipientType.STANDALONE_WEB_APP: 2,
|
||||
}
|
||||
|
||||
|
||||
def load_form_tokens_by_form_id(
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
session: Session | None = None,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Load the preferred access token for each human input form."""
|
||||
unique_form_ids = list(dict.fromkeys(form_ids))
|
||||
@ -33,23 +29,43 @@ def load_form_tokens_by_form_id(
|
||||
return {}
|
||||
|
||||
if session is not None:
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids)
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
|
||||
|
||||
|
||||
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
|
||||
tokens_by_form_id: dict[str, tuple[int, str]] = {}
|
||||
def _load_form_tokens_by_form_id(
|
||||
session: Session,
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
|
||||
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(stmt):
|
||||
priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type)
|
||||
if priority is None or not recipient.access_token:
|
||||
if not recipient.access_token:
|
||||
continue
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(
|
||||
(recipient.recipient_type, recipient.access_token)
|
||||
)
|
||||
|
||||
candidate = (priority, recipient.access_token)
|
||||
current = tokens_by_form_id.get(recipient.form_id)
|
||||
if current is None or candidate[0] < current[0]:
|
||||
tokens_by_form_id[recipient.form_id] = candidate
|
||||
tokens_by_form_id: dict[str, str] = {}
|
||||
for form_id, recipients in recipients_by_form_id.items():
|
||||
token = _get_surface_form_token(recipients, surface=surface)
|
||||
if token is not None:
|
||||
tokens_by_form_id[form_id] = token
|
||||
return tokens_by_form_id
|
||||
|
||||
return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()}
|
||||
|
||||
def _get_surface_form_token(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> str | None:
|
||||
if surface == HumanInputSurface.SERVICE_API:
|
||||
for recipient_type, token in recipients:
|
||||
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||
return token
|
||||
|
||||
return get_preferred_form_token(recipients)
|
||||
|
||||
73
api/core/workflow/human_input_policy.py
Normal file
73
api/core/workflow/human_input_policy.py
Normal file
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from models.human_input import RecipientType
|
||||
|
||||
|
||||
class HumanInputSurface(StrEnum):
|
||||
SERVICE_API = "service_api"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
# Service API is intentionally narrower than other surfaces: app-token callers
|
||||
# should only be able to act on end-user web forms, not internal console flows.
|
||||
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
|
||||
}
|
||||
|
||||
# A single HITL form can have multiple recipient records; this shared priority
|
||||
# keeps every API surface consistent about which resume token to expose.
|
||||
_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = {
|
||||
RecipientType.BACKSTAGE: 0,
|
||||
RecipientType.CONSOLE: 1,
|
||||
RecipientType.STANDALONE_WEB_APP: 2,
|
||||
}
|
||||
|
||||
|
||||
def is_recipient_type_allowed_for_surface(
|
||||
recipient_type: RecipientType | None,
|
||||
surface: HumanInputSurface,
|
||||
) -> bool:
|
||||
if recipient_type is None:
|
||||
return False
|
||||
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
|
||||
|
||||
|
||||
def get_preferred_form_token(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
) -> str | None:
|
||||
chosen_token: str | None = None
|
||||
chosen_priority: int | None = None
|
||||
for recipient_type, token in recipients:
|
||||
priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type)
|
||||
if priority is None or not token:
|
||||
continue
|
||||
if chosen_priority is None or priority < chosen_priority:
|
||||
chosen_priority = priority
|
||||
chosen_token = token
|
||||
return chosen_token
|
||||
|
||||
|
||||
def enrich_human_input_pause_reasons(
|
||||
reasons: Sequence[Mapping[str, Any]],
|
||||
*,
|
||||
form_tokens_by_form_id: Mapping[str, str],
|
||||
expiration_times_by_form_id: Mapping[str, int],
|
||||
) -> list[dict[str, Any]]:
|
||||
enriched: list[dict[str, Any]] = []
|
||||
for reason in reasons:
|
||||
updated = dict(reason)
|
||||
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
form_id = updated.get("form_id")
|
||||
if isinstance(form_id, str):
|
||||
updated["form_token"] = form_tokens_by_form_id.get(form_id)
|
||||
expiration_time = expiration_times_by_form_id.get(form_id)
|
||||
if expiration_time is not None:
|
||||
updated["expiration_time"] = expiration_time
|
||||
enriched.append(updated)
|
||||
return enriched
|
||||
Reference in New Issue
Block a user