feat(api): add human input data to extra contents

This commit is contained in:
QuantumGhost
2026-01-08 10:20:24 +08:00
parent dac94b573e
commit de428bc9bb
14 changed files with 652 additions and 58 deletions

View File

@ -32,7 +32,7 @@ from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
from services.message_service import MessageService, _attach_message_extra_contents
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -198,6 +198,7 @@ message_detail_model = console_ns.model(
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"extra_contents": fields.List(fields.Raw),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
@ -290,6 +291,7 @@ class ChatMessageListApi(Resource):
has_more = False
history_messages = list(reversed(history_messages))
_attach_message_extra_contents(history_messages)
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@ -474,4 +476,5 @@ class MessageApi(Resource):
if not message:
raise NotFound("Message Not Exists.")
_attach_message_extra_contents([message])
return message

View File

@ -65,6 +65,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@ -73,7 +75,8 @@ from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -137,6 +140,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow_tenant_id = workflow.tenant_id
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
@ -146,6 +150,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
self._message_saved_on_pause = False
self._seed_graph_runtime_state_from_queue_manager()
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@ -539,7 +544,22 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event=event,
task_id=self._application_generate_entity.task_id,
)
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id)
yield from responses
resolved_state: GraphRuntimeState | None = None
try:
resolved_state = self._ensure_graph_runtime_initialized()
except ValueError:
resolved_state = None
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
message = self._get_message(session=session)
if message is not None:
message.status = MessageStatus.PAUSED
self._message_saved_on_pause = True
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_failed_event(
@ -629,9 +649,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
# Save message unless it has already been persisted on pause.
if not self._message_saved_on_pause:
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
@ -661,10 +682,53 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self, event: QueueHumanInputFormFilledEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form filled events."""
self._persist_human_input_extra_content(node_id=event.node_id)
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None:
if not self._workflow_run_id or not self._message_id:
return
if form_id is None:
if node_id is None:
return
form_id = self._load_human_input_form_id(node_id=node_id)
if form_id is None:
logger.warning(
"HumanInput form not found for workflow run %s node %s",
self._workflow_run_id,
node_id,
)
return
with self._database_session() as session:
exists_stmt = select(HumanInputContent).where(
HumanInputContent.workflow_run_id == self._workflow_run_id,
HumanInputContent.message_id == self._message_id,
HumanInputContent.form_id == form_id,
)
if session.scalar(exists_stmt) is not None:
return
content = HumanInputContent(
workflow_run_id=self._workflow_run_id,
message_id=self._message_id,
form_id=form_id,
)
session.add(content)
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self._workflow_tenant_id,
)
form = form_repository.get_form(self._workflow_run_id, node_id)
if form is None:
return None
return form.id
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
yield self._workflow_response_converter.handle_agent_log(
@ -800,6 +864,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
if message is None:
return
if message.status == MessageStatus.PAUSED:
message.status = MessageStatus.NORMAL
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer

View File

@ -1,27 +1,81 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any, TypeAlias
from core.workflow.nodes.human_input.entities import FormInput, UserAction
from models.execution_extra_content import ExecutionContentType
@dataclass(frozen=True, kw_only=True)
class HumanInputContent:
action_id: str
action_text: str
rendered_content: str
type: ExecutionContentType = field(default=ExecutionContentType.HUMAN_INPUT_RESULT, init=False)
class HumanInputFormDefinition:
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = field(default_factory=list)
actions: Sequence[UserAction] = field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_placeholder_values: Mapping[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"type": self.type.value,
"form_id": self.form_id,
"node_id": self.node_id,
"node_title": self.node_title,
"form_content": self.form_content,
"inputs": [item.model_dump(mode="json") for item in self.inputs],
"actions": [item.model_dump(mode="json") for item in self.actions],
"display_in_ui": self.display_in_ui,
"form_token": self.form_token,
"resolved_placeholder_values": self.resolved_placeholder_values,
}
@dataclass(frozen=True, kw_only=True)
class HumanInputFormSubmissionData:
node_id: str
node_title: str
rendered_content: str
action_id: str
action_text: str
def to_dict(self) -> dict[str, Any]:
return {
"node_id": self.node_id,
"node_title": self.node_title,
"rendered_content": self.rendered_content,
"action_id": self.action_id,
"action_text": self.action_text,
"rendered_content": self.rendered_content,
}
@dataclass(frozen=True, kw_only=True)
class HumanInputContent:
submitted: bool
form_definition: HumanInputFormDefinition | None = None
form_submission_data: HumanInputFormSubmissionData | None = None
type: ExecutionContentType = field(default=ExecutionContentType.HUMAN_INPUT, init=False)
def to_dict(self) -> dict[str, Any]:
payload: dict[str, Any] = {
"type": self.type.value,
"submitted": self.submitted,
}
if self.form_definition is not None:
payload["form_definition"] = self.form_definition.to_dict()
if self.form_submission_data is not None:
payload["form_submission_data"] = self.form_submission_data.to_dict()
return payload
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
__all__ = ["ExecutionExtraContentDomainModel", "HumanInputContent"]
__all__ = [
"ExecutionExtraContentDomainModel",
"HumanInputContent",
"HumanInputFormDefinition",
"HumanInputFormSubmissionData",
]

View File

@ -63,6 +63,7 @@ message_fields = {
"answer": fields.String(attribute="re_sign_file_url_answer"),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"extra_contents": fields.List(cls_or_instance=fields.Raw),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"message_files": fields.List(fields.Nested(message_file_fields)),

View File

@ -11,7 +11,7 @@ if TYPE_CHECKING:
class ExecutionContentType(StrEnum):
HUMAN_INPUT_RESULT = auto()
HUMAN_INPUT = auto()
class ExecutionExtraContent(DefaultFieldsMixin, Base):
@ -56,7 +56,7 @@ class HumanInputContent(ExecutionExtraContent):
It should only be initialized with the `new` class method."""
__mapper_args__ = {
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT_RESULT,
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT,
}
# A relation to HumanInputForm table.

View File

@ -1,28 +1,45 @@
from __future__ import annotations
import json
import logging
import re
from collections import defaultdict
from collections.abc import Sequence
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.entities.execution_extra_content import (
ExecutionExtraContentDomainModel,
HumanInputFormDefinition,
HumanInputFormSubmissionData,
)
from core.entities.execution_extra_content import (
HumanInputContent as HumanInputContentDomainModel,
)
from core.workflow.nodes.human_input.entities import FormDefinition
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from models.execution_extra_content import (
ExecutionExtraContent as ExecutionExtraContentModel,
)
from models.execution_extra_content import (
HumanInputContent as HumanInputContentModel,
)
from models.human_input import HumanInputFormRecipient, RecipientType
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
logger = logging.getLogger(__name__)
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
def _extract_output_field_names(form_content: str) -> list[str]:
if not form_content:
return []
return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)]
class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository):
def __init__(self, session_maker: sessionmaker[Session]):
@ -46,12 +63,26 @@ class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository)
with self._session_maker() as session:
results = session.scalars(stmt).all()
form_ids = {
content.form_id
for content in results
if isinstance(content, HumanInputContentModel) and content.form_id is not None
}
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list)
if form_ids:
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
recipients = session.scalars(recipient_stmt).all()
for recipient in recipients:
recipients_by_form_id[recipient.form_id].append(recipient)
else:
recipients_by_form_id = {}
for content in results:
message_id = content.message_id
if not message_id or message_id not in grouped_contents:
continue
domain_model = self._map_model_to_domain(content)
domain_model = self._map_model_to_domain(content, recipients_by_form_id)
if domain_model is None:
continue
@ -59,40 +90,105 @@ class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository)
return [grouped_contents[message_id] for message_id in message_ids]
def _map_model_to_domain(self, model: ExecutionExtraContentModel) -> ExecutionExtraContentDomainModel | None:
def _map_model_to_domain(
self,
model: ExecutionExtraContentModel,
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
) -> ExecutionExtraContentDomainModel | None:
if isinstance(model, HumanInputContentModel):
return self._map_human_input_content(model)
return self._map_human_input_content(model, recipients_by_form_id)
logger.debug("Unsupported execution extra content type encountered: %s", model.type)
return None
def _map_human_input_content(self, model: HumanInputContentModel) -> HumanInputContentDomainModel | None:
def _map_human_input_content(
self,
model: HumanInputContentModel,
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
) -> HumanInputContentDomainModel | None:
form = model.form
if form is None:
logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id)
return None
selected_action_id = form.selected_action_id
if not selected_action_id:
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
return None
try:
form_definition = FormDefinition.model_validate_json(form.form_definition)
except ValueError:
logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id)
return None
node_title = form_definition.node_title or form.node_id
display_in_ui = bool(form_definition.display_in_ui)
submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED
if not submitted:
form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, []))
return HumanInputContentDomainModel(
submitted=False,
form_definition=HumanInputFormDefinition(
form_id=form.id,
node_id=form.node_id,
node_title=node_title,
form_content=form.rendered_content,
inputs=form_definition.inputs,
actions=form_definition.user_actions,
display_in_ui=display_in_ui,
form_token=form_token,
resolved_placeholder_values=form_definition.placeholder_values,
),
)
selected_action_id = form.selected_action_id
if not selected_action_id:
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
return None
action_text = next(
(action.title for action in form_definition.user_actions if action.id == selected_action_id),
selected_action_id,
)
return HumanInputContentDomainModel(
action_id=selected_action_id,
action_text=action_text,
rendered_content=form.rendered_content,
submitted_data: dict[str, Any] = {}
if form.submitted_data:
try:
submitted_data = json.loads(form.submitted_data)
except ValueError:
logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id)
return None
rendered_content = HumanInputNode._render_form_content_with_outputs(
form.rendered_content,
submitted_data,
_extract_output_field_names(form_definition.form_content),
)
return HumanInputContentDomainModel(
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id=form.node_id,
node_title=node_title,
rendered_content=rendered_content,
action_id=selected_action_id,
action_text=action_text,
),
)
@staticmethod
def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None:
console_recipient = next(
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE),
None,
)
if console_recipient and console_recipient.access_token:
return console_recipient.access_token
web_app_recipient = next(
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP),
None,
)
if web_app_recipient and web_app_recipient.access_token:
return web_app_recipient.access_token
return None
__all__ = ["SQLAlchemyExecutionExtraContentRepository"]

View File

@ -136,7 +136,7 @@ class AudioService:
message = db.session.query(Message).where(Message.id == message_id).first()
if message is None:
return None
if message.answer == "" and message.status == MessageStatus.NORMAL:
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:
return None
else:

View File

@ -20,6 +20,8 @@ def test_get_by_message_ids_returns_human_input_content(db_session_with_containe
assert len(results) == 1
assert len(results[0]) == 1
content = results[0][0]
assert content.action_id == fixture.action_id
assert content.action_text == fixture.action_text
assert content.rendered_content == fixture.form.rendered_content
assert content.submitted is True
assert content.form_submission_data is not None
assert content.form_submission_data.action_id == fixture.action_id
assert content.form_submission_data.action_text == fixture.action_text
assert content.form_submission_data.rendered_content == fixture.form.rendered_content

View File

@ -24,9 +24,14 @@ def test_pagination_returns_extra_contents(db_session_with_containers):
message = pagination.data[0]
assert message.extra_contents == [
{
"type": "human_input_result",
"action_id": fixture.action_id,
"action_text": fixture.action_text,
"rendered_content": fixture.form.rendered_content,
"type": "human_input",
"submitted": True,
"form_submission_data": {
"node_id": fixture.form.node_id,
"node_title": fixture.node_title,
"rendered_content": fixture.form.rendered_content,
"action_id": fixture.action_id,
"action_text": fixture.action_text,
},
}
]

View File

@ -0,0 +1,156 @@
import json
import time
import uuid
from typing import Any
from unittest.mock import patch
from flask.testing import FlaskClient
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.task_entities import StreamEvent
from models import Account, Tenant, TenantAccountJoin
from models.account import TenantAccountRole
from models.model import ApiToken, App, AppMode
from models.workflow import Workflow
def _create_tenant_and_owner() -> tuple[Tenant, Account, TenantAccountJoin]:
tenant = Tenant(
name="Test Tenant",
status="normal",
)
tenant.id = str(uuid.uuid4())
account = Account(
email=f"owner-{uuid.uuid4()}@example.com",
name="Owner",
interface_language="en-US",
status="active",
)
account.id = str(uuid.uuid4())
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
return tenant, account, tenant_join
def _create_workflow_app(tenant: Tenant, account: Account) -> tuple[App, Workflow]:
app = App(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name="Test Workflow App",
description="",
mode=AppMode.WORKFLOW,
icon_type="emoji",
icon="robot",
icon_background="#FFFFFF",
enable_site=True,
enable_api=True,
created_by=account.id,
updated_by=account.id,
)
workflow = Workflow(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
app_id=app.id,
type="workflow",
version="v1",
graph=json.dumps({"nodes": [], "edges": []}),
features=json.dumps({"features": []}),
created_by=account.id,
updated_by=account.id,
environment_variables=[],
conversation_variables=[],
)
app.workflow_id = workflow.id
return app, workflow
def _create_api_token(app: App) -> ApiToken:
return ApiToken(
app_id=app.id,
tenant_id=app.tenant_id,
type="app",
token=f"app-token-{uuid.uuid4()}",
)
def _collect_sse_events(response, max_events: int = 2, timeout: float = 3.0) -> list[dict[str, Any]]:
events: list[dict[str, Any]] = []
buffer = ""
start_time = time.time()
for chunk in response.response:
if not chunk:
if time.time() - start_time > timeout:
break
continue
buffer += chunk.decode("utf-8")
while "\n\n" in buffer:
block, buffer = buffer.split("\n\n", 1)
for line in block.splitlines():
if not line.startswith("data: "):
continue
payload = line[len("data: ") :]
events.append(json.loads(payload))
if len(events) >= max_events:
return events
if time.time() - start_time > timeout:
break
return events
class TestSSEStartGateIntegration:
def test_workflow_streaming_sse_starts_after_subscribe(
self,
db_session_with_containers,
test_client_with_containers: FlaskClient,
):
tenant, account, tenant_join = _create_tenant_and_owner()
app, workflow = _create_workflow_app(tenant, account)
api_token = _create_api_token(app)
db_session_with_containers.add_all([tenant, account, tenant_join, app, workflow, api_token])
db_session_with_containers.commit()
def _fake_delay(payload_json: str):
payload = json.loads(payload_json)
workflow_run_id = payload["workflow_run_id"]
app_mode = AppMode.value_of(payload["app_mode"])
topic = MessageBasedAppGenerator.get_response_topic(app_mode, uuid.UUID(workflow_run_id))
events = [
{
"event": StreamEvent.WORKFLOW_STARTED.value,
"workflow_run_id": workflow_run_id,
"created_at": int(time.time()),
},
{
"event": StreamEvent.WORKFLOW_FINISHED.value,
"workflow_run_id": workflow_run_id,
"created_at": int(time.time()),
},
]
for event in events:
topic.publish(json.dumps(event).encode())
payload = {
"inputs": {},
"response_mode": "streaming",
"user": "test-end-user",
}
with patch("services.app_generate_service.chatflow_execute_task.delay", side_effect=_fake_delay):
response = test_client_with_containers.post(
"/v1/workflows/run",
json=payload,
headers={"Authorization": f"Bearer {api_token.token}"},
buffered=False,
)
assert response.status_code == 200
events = _collect_sse_events(response)
assert len(events) == 2
assert events[0]["event"] == StreamEvent.WORKFLOW_STARTED.value
assert events[1]["event"] == StreamEvent.WORKFLOW_FINISHED.value

View File

@ -0,0 +1,132 @@
from __future__ import annotations
from contextlib import contextmanager
from types import SimpleNamespace
from unittest import mock
import pytest
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
from core.workflow.entities.pause_reason import HumanInputRequired
from models.enums import MessageStatus
from models.execution_extra_content import HumanInputContent
def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline:
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__(
pipeline_module.AdvancedChatAppGenerateTaskPipeline
)
pipeline._workflow_run_id = "run-1"
pipeline._message_id = "message-1"
pipeline._workflow_tenant_id = "tenant-1"
return pipeline
def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None:
pipeline = _build_pipeline()
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
captured_session: dict[str, mock.Mock] = {}
@contextmanager
def fake_session():
session = mock.Mock()
session.scalar.return_value = None
captured_session["session"] = session
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
pipeline._persist_human_input_extra_content(node_id="node-1")
session = captured_session["session"]
session.add.assert_called_once()
content = session.add.call_args.args[0]
assert isinstance(content, HumanInputContent)
assert content.workflow_run_id == "run-1"
assert content.message_id == "message-1"
assert content.form_id == "form-1"
def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
pipeline = _build_pipeline()
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None)
called = {"value": False}
@contextmanager
def fake_session():
called["value"] = True
session = mock.Mock()
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
pipeline._persist_human_input_extra_content(node_id="node-1")
assert called["value"] is False
def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None:
pipeline = _build_pipeline()
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
captured_session: dict[str, mock.Mock] = {}
@contextmanager
def fake_session():
session = mock.Mock()
session.scalar.return_value = HumanInputContent(
workflow_run_id="run-1",
message_id="message-1",
form_id="form-1",
)
captured_session["session"] = session
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
pipeline._persist_human_input_extra_content(node_id="node-1")
session = captured_session["session"]
session.add.assert_not_called()
def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None:
pipeline = _build_pipeline()
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
pipeline._workflow_response_converter = mock.Mock()
pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = []
pipeline._ensure_graph_runtime_initialized = mock.Mock(side_effect=ValueError())
pipeline._save_message = mock.Mock()
message = SimpleNamespace(status=MessageStatus.NORMAL)
pipeline._get_message = mock.Mock(return_value=message)
pipeline._persist_human_input_extra_content = mock.Mock()
pipeline._base_task_pipeline = mock.Mock()
pipeline._base_task_pipeline.queue_manager = mock.Mock()
pipeline._message_saved_on_pause = False
@contextmanager
def fake_session():
session = mock.Mock()
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
reason = HumanInputRequired(
form_id="form-1",
form_content="content",
inputs=[],
actions=[],
node_id="node-1",
node_title="Approval",
form_token="token-1",
resolved_placeholder_values={},
)
event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"])
list(pipeline._handle_workflow_paused_event(event))
pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1")
assert message.status == MessageStatus.PAUSED

View File

@ -38,7 +38,6 @@ class _InMemoryFormEntity(HumanInputFormEntity):
form_id: str
rendered: str
token: str | None = None
console_token_value: str | None = None
action_id: str | None = None
data: Mapping[str, Any] | None = None
is_submitted: bool = False
@ -51,8 +50,6 @@ class _InMemoryFormEntity(HumanInputFormEntity):
@property
def web_app_token(self) -> str | None:
if self.console_token_value is not None:
return self.console_token_value
return self.token
@property
@ -97,12 +94,11 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
self.created_params.append(params)
self._form_counter += 1
form_id = f"form-{self._form_counter}"
console_token = f"console-{form_id}" if params.console_recipient_required else None
token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}"
entity = _InMemoryFormEntity(
form_id=form_id,
rendered=params.rendered_content,
token=f"token-{form_id}",
console_token_value=console_token,
token=token,
)
self.created_forms.append(entity)
self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain
from core.entities.execution_extra_content import HumanInputFormSubmissionData
from core.workflow.nodes.human_input.entities import (
FormDefinition,
UserAction,
@ -14,7 +15,7 @@ from core.workflow.nodes.human_input.enums import (
TimeoutUnit,
)
from models.execution_extra_content import HumanInputContent as HumanInputContentModel
from models.human_input import HumanInputForm
from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
@ -27,11 +28,13 @@ class _FakeScalarResult:
class _FakeSession:
def __init__(self, values: Sequence[HumanInputContentModel]):
self._values = values
def __init__(self, values: Sequence[Sequence[object]]):
self._values = list(values)
def scalars(self, _stmt):
return _FakeScalarResult(self._values)
if not self._values:
return _FakeScalarResult([])
return _FakeScalarResult(self._values.pop(0))
def __enter__(self):
return self
@ -56,6 +59,8 @@ def _build_form(action_id: str, action_title: str, rendered_content: str) -> Hum
rendered_content="rendered",
timeout=1,
timeout_unit=TimeoutUnit.HOUR,
node_title="Approval",
display_in_ui=True,
)
form = HumanInputForm(
id=f"form-{action_id}",
@ -89,8 +94,9 @@ def _build_content(message_id: str, action_id: str, action_title: str) -> HumanI
def test_get_by_message_ids_groups_contents_by_message() -> None:
message_ids = ["msg-1", "msg-2"]
contents = [_build_content("msg-1", "approve", "Approve")]
repository = SQLAlchemyExecutionExtraContentRepository(
session_maker=_FakeSessionMaker(session=_FakeSession(values=[_build_content("msg-1", "approve", "Approve")]))
session_maker=_FakeSessionMaker(session=_FakeSession(values=[contents, []]))
)
result = repository.get_by_message_ids(message_ids)
@ -98,9 +104,73 @@ def test_get_by_message_ids_groups_contents_by_message() -> None:
assert len(result) == 2
assert [content.to_dict() for content in result[0]] == [
HumanInputContentDomain(
action_id="approve",
action_text="Approve",
rendered_content="Rendered Approve",
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id="node-id",
node_title="Approval",
rendered_content="Rendered Approve",
action_id="approve",
action_text="Approve",
),
).to_dict()
]
assert result[1] == []
def test_get_by_message_ids_returns_unsubmitted_form_definition() -> None:
definition = FormDefinition(
form_content="content",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
timeout=1,
timeout_unit=TimeoutUnit.HOUR,
placeholder_values={"name": "John"},
node_title="Approval",
display_in_ui=True,
)
form = HumanInputForm(
id="form-1",
tenant_id="tenant-id",
workflow_run_id="workflow-run",
node_id="node-id",
form_definition=definition.model_dump_json(),
rendered_content="Rendered block",
status=HumanInputFormStatus.WAITING,
expiration_time=datetime.now(UTC) + timedelta(days=1),
)
content = HumanInputContentModel(
id="content-msg-1",
form_id=form.id,
message_id="msg-1",
workflow_run_id=form.workflow_run_id,
)
content.form = form
recipient = HumanInputFormRecipient(
form_id=form.id,
delivery_id="delivery-1",
recipient_type=RecipientType.CONSOLE,
recipient_payload=ConsoleRecipientPayload(account_id=None).model_dump_json(),
access_token="token-1",
)
repository = SQLAlchemyExecutionExtraContentRepository(
session_maker=_FakeSessionMaker(session=_FakeSession(values=[[content], [recipient]]))
)
result = repository.get_by_message_ids(["msg-1"])
assert len(result) == 1
assert len(result[0]) == 1
domain_content = result[0][0]
assert domain_content.submitted is False
assert domain_content.form_definition is not None
form_definition = domain_content.form_definition
assert form_definition.form_id == "form-1"
assert form_definition.node_id == "node-id"
assert form_definition.node_title == "Approval"
assert form_definition.form_content == "Rendered block"
assert form_definition.display_in_ui is True
assert form_definition.form_token == "token-1"
assert form_definition.resolved_placeholder_values == {"name": "John"}

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import pytest
from core.entities.execution_extra_content import HumanInputContent
from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData
from services import message_service
@ -24,9 +24,14 @@ def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: p
"get_by_message_ids": lambda _self, message_ids: [
[
HumanInputContent(
action_id="approve",
action_text="Approve",
rendered_content="Rendered",
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id="node-1",
node_title="Approval",
rendered_content="Rendered",
action_id="approve",
action_text="Approve",
),
)
],
[],
@ -40,10 +45,15 @@ def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: p
assert messages[0].extra_contents == [
{
"type": "human_input_result",
"action_id": "approve",
"action_text": "Approve",
"rendered_content": "Rendered",
"type": "human_input",
"submitted": True,
"form_submission_data": {
"node_id": "node-1",
"node_title": "Approval",
"rendered_content": "Rendered",
"action_id": "approve",
"action_text": "Approve",
},
}
]
assert messages[1].extra_contents == []