mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 13:45:57 +08:00
feat(api): add human input data to extra contents
This commit is contained in:
@ -32,7 +32,7 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from services.message_service import MessageService, _attach_message_extra_contents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
@ -198,6 +198,7 @@ message_detail_model = console_ns.model(
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"extra_contents": fields.List(fields.Raw),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
@ -290,6 +291,7 @@ class ChatMessageListApi(Resource):
|
||||
has_more = False
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
_attach_message_extra_contents(history_messages)
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
|
||||
@ -474,4 +476,5 @@ class MessageApi(Resource):
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
_attach_message_extra_contents([message])
|
||||
return message
|
||||
|
||||
@ -65,6 +65,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
@ -73,7 +75,8 @@ from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -137,6 +140,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._workflow_tenant_id = workflow.tenant_id
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
self._message_id = message.id
|
||||
@ -146,6 +150,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
self._message_saved_on_pause = False
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
@ -539,7 +544,22 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
)
|
||||
for reason in event.reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id)
|
||||
yield from responses
|
||||
resolved_state: GraphRuntimeState | None = None
|
||||
try:
|
||||
resolved_state = self._ensure_graph_runtime_initialized()
|
||||
except ValueError:
|
||||
resolved_state = None
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
message = self._get_message(session=session)
|
||||
if message is not None:
|
||||
message.status = MessageStatus.PAUSED
|
||||
self._message_saved_on_pause = True
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_failed_event(
|
||||
@ -629,9 +649,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
# Save message unless it has already been persisted on pause.
|
||||
if not self._message_saved_on_pause:
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@ -661,10 +682,53 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self, event: QueueHumanInputFormFilledEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form filled events."""
|
||||
self._persist_human_input_extra_content(node_id=event.node_id)
|
||||
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None:
|
||||
if not self._workflow_run_id or not self._message_id:
|
||||
return
|
||||
|
||||
if form_id is None:
|
||||
if node_id is None:
|
||||
return
|
||||
form_id = self._load_human_input_form_id(node_id=node_id)
|
||||
if form_id is None:
|
||||
logger.warning(
|
||||
"HumanInput form not found for workflow run %s node %s",
|
||||
self._workflow_run_id,
|
||||
node_id,
|
||||
)
|
||||
return
|
||||
|
||||
with self._database_session() as session:
|
||||
exists_stmt = select(HumanInputContent).where(
|
||||
HumanInputContent.workflow_run_id == self._workflow_run_id,
|
||||
HumanInputContent.message_id == self._message_id,
|
||||
HumanInputContent.form_id == form_id,
|
||||
)
|
||||
if session.scalar(exists_stmt) is not None:
|
||||
return
|
||||
|
||||
content = HumanInputContent(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
message_id=self._message_id,
|
||||
form_id=form_id,
|
||||
)
|
||||
session.add(content)
|
||||
|
||||
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self._workflow_tenant_id,
|
||||
)
|
||||
form = form_repository.get_form(self._workflow_run_id, node_id)
|
||||
if form is None:
|
||||
return None
|
||||
return form.id
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
yield self._workflow_response_converter.handle_agent_log(
|
||||
@ -800,6 +864,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
message = self._get_message(session=session)
|
||||
if message is None:
|
||||
return
|
||||
|
||||
if message.status == MessageStatus.PAUSED:
|
||||
message.status = MessageStatus.NORMAL
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
|
||||
@ -1,27 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from models.execution_extra_content import ExecutionContentType
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class HumanInputContent:
|
||||
action_id: str
|
||||
action_text: str
|
||||
rendered_content: str
|
||||
type: ExecutionContentType = field(default=ExecutionContentType.HUMAN_INPUT_RESULT, init=False)
|
||||
class HumanInputFormDefinition:
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = field(default_factory=list)
|
||||
actions: Sequence[UserAction] = field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_placeholder_values: Mapping[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"form_id": self.form_id,
|
||||
"node_id": self.node_id,
|
||||
"node_title": self.node_title,
|
||||
"form_content": self.form_content,
|
||||
"inputs": [item.model_dump(mode="json") for item in self.inputs],
|
||||
"actions": [item.model_dump(mode="json") for item in self.actions],
|
||||
"display_in_ui": self.display_in_ui,
|
||||
"form_token": self.form_token,
|
||||
"resolved_placeholder_values": self.resolved_placeholder_values,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class HumanInputFormSubmissionData:
|
||||
node_id: str
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"node_id": self.node_id,
|
||||
"node_title": self.node_title,
|
||||
"rendered_content": self.rendered_content,
|
||||
"action_id": self.action_id,
|
||||
"action_text": self.action_text,
|
||||
"rendered_content": self.rendered_content,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class HumanInputContent:
|
||||
submitted: bool
|
||||
form_definition: HumanInputFormDefinition | None = None
|
||||
form_submission_data: HumanInputFormSubmissionData | None = None
|
||||
type: ExecutionContentType = field(default=ExecutionContentType.HUMAN_INPUT, init=False)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"type": self.type.value,
|
||||
"submitted": self.submitted,
|
||||
}
|
||||
if self.form_definition is not None:
|
||||
payload["form_definition"] = self.form_definition.to_dict()
|
||||
if self.form_submission_data is not None:
|
||||
payload["form_submission_data"] = self.form_submission_data.to_dict()
|
||||
return payload
|
||||
|
||||
|
||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
|
||||
|
||||
__all__ = ["ExecutionExtraContentDomainModel", "HumanInputContent"]
|
||||
__all__ = [
|
||||
"ExecutionExtraContentDomainModel",
|
||||
"HumanInputContent",
|
||||
"HumanInputFormDefinition",
|
||||
"HumanInputFormSubmissionData",
|
||||
]
|
||||
|
||||
@ -63,6 +63,7 @@ message_fields = {
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"extra_contents": fields.List(cls_or_instance=fields.Raw),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
|
||||
@ -11,7 +11,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ExecutionContentType(StrEnum):
|
||||
HUMAN_INPUT_RESULT = auto()
|
||||
HUMAN_INPUT = auto()
|
||||
|
||||
|
||||
class ExecutionExtraContent(DefaultFieldsMixin, Base):
|
||||
@ -56,7 +56,7 @@ class HumanInputContent(ExecutionExtraContent):
|
||||
It should only be initialized with the `new` class method."""
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT_RESULT,
|
||||
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT,
|
||||
}
|
||||
|
||||
# A relation to HumanInputForm table.
|
||||
|
||||
@ -1,28 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.entities.execution_extra_content import (
|
||||
ExecutionExtraContentDomainModel,
|
||||
HumanInputFormDefinition,
|
||||
HumanInputFormSubmissionData,
|
||||
)
|
||||
from core.entities.execution_extra_content import (
|
||||
HumanInputContent as HumanInputContentDomainModel,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from models.execution_extra_content import (
|
||||
ExecutionExtraContent as ExecutionExtraContentModel,
|
||||
)
|
||||
from models.execution_extra_content import (
|
||||
HumanInputContent as HumanInputContentModel,
|
||||
)
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
|
||||
|
||||
|
||||
def _extract_output_field_names(form_content: str) -> list[str]:
|
||||
if not form_content:
|
||||
return []
|
||||
return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)]
|
||||
|
||||
|
||||
class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository):
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
@ -46,12 +63,26 @@ class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository)
|
||||
with self._session_maker() as session:
|
||||
results = session.scalars(stmt).all()
|
||||
|
||||
form_ids = {
|
||||
content.form_id
|
||||
for content in results
|
||||
if isinstance(content, HumanInputContentModel) and content.form_id is not None
|
||||
}
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list)
|
||||
if form_ids:
|
||||
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
recipients = session.scalars(recipient_stmt).all()
|
||||
for recipient in recipients:
|
||||
recipients_by_form_id[recipient.form_id].append(recipient)
|
||||
else:
|
||||
recipients_by_form_id = {}
|
||||
|
||||
for content in results:
|
||||
message_id = content.message_id
|
||||
if not message_id or message_id not in grouped_contents:
|
||||
continue
|
||||
|
||||
domain_model = self._map_model_to_domain(content)
|
||||
domain_model = self._map_model_to_domain(content, recipients_by_form_id)
|
||||
if domain_model is None:
|
||||
continue
|
||||
|
||||
@ -59,40 +90,105 @@ class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository)
|
||||
|
||||
return [grouped_contents[message_id] for message_id in message_ids]
|
||||
|
||||
def _map_model_to_domain(self, model: ExecutionExtraContentModel) -> ExecutionExtraContentDomainModel | None:
|
||||
def _map_model_to_domain(
|
||||
self,
|
||||
model: ExecutionExtraContentModel,
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
|
||||
) -> ExecutionExtraContentDomainModel | None:
|
||||
if isinstance(model, HumanInputContentModel):
|
||||
return self._map_human_input_content(model)
|
||||
return self._map_human_input_content(model, recipients_by_form_id)
|
||||
|
||||
logger.debug("Unsupported execution extra content type encountered: %s", model.type)
|
||||
return None
|
||||
|
||||
def _map_human_input_content(self, model: HumanInputContentModel) -> HumanInputContentDomainModel | None:
|
||||
def _map_human_input_content(
|
||||
self,
|
||||
model: HumanInputContentModel,
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
|
||||
) -> HumanInputContentDomainModel | None:
|
||||
form = model.form
|
||||
if form is None:
|
||||
logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id)
|
||||
return None
|
||||
|
||||
selected_action_id = form.selected_action_id
|
||||
if not selected_action_id:
|
||||
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
|
||||
return None
|
||||
|
||||
try:
|
||||
form_definition = FormDefinition.model_validate_json(form.form_definition)
|
||||
except ValueError:
|
||||
logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id)
|
||||
return None
|
||||
node_title = form_definition.node_title or form.node_id
|
||||
display_in_ui = bool(form_definition.display_in_ui)
|
||||
|
||||
submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED
|
||||
if not submitted:
|
||||
form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, []))
|
||||
return HumanInputContentDomainModel(
|
||||
submitted=False,
|
||||
form_definition=HumanInputFormDefinition(
|
||||
form_id=form.id,
|
||||
node_id=form.node_id,
|
||||
node_title=node_title,
|
||||
form_content=form.rendered_content,
|
||||
inputs=form_definition.inputs,
|
||||
actions=form_definition.user_actions,
|
||||
display_in_ui=display_in_ui,
|
||||
form_token=form_token,
|
||||
resolved_placeholder_values=form_definition.placeholder_values,
|
||||
),
|
||||
)
|
||||
|
||||
selected_action_id = form.selected_action_id
|
||||
if not selected_action_id:
|
||||
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
|
||||
return None
|
||||
|
||||
action_text = next(
|
||||
(action.title for action in form_definition.user_actions if action.id == selected_action_id),
|
||||
selected_action_id,
|
||||
)
|
||||
|
||||
return HumanInputContentDomainModel(
|
||||
action_id=selected_action_id,
|
||||
action_text=action_text,
|
||||
rendered_content=form.rendered_content,
|
||||
submitted_data: dict[str, Any] = {}
|
||||
if form.submitted_data:
|
||||
try:
|
||||
submitted_data = json.loads(form.submitted_data)
|
||||
except ValueError:
|
||||
logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id)
|
||||
return None
|
||||
|
||||
rendered_content = HumanInputNode._render_form_content_with_outputs(
|
||||
form.rendered_content,
|
||||
submitted_data,
|
||||
_extract_output_field_names(form_definition.form_content),
|
||||
)
|
||||
|
||||
return HumanInputContentDomainModel(
|
||||
submitted=True,
|
||||
form_submission_data=HumanInputFormSubmissionData(
|
||||
node_id=form.node_id,
|
||||
node_title=node_title,
|
||||
rendered_content=rendered_content,
|
||||
action_id=selected_action_id,
|
||||
action_text=action_text,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None:
|
||||
console_recipient = next(
|
||||
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE),
|
||||
None,
|
||||
)
|
||||
if console_recipient and console_recipient.access_token:
|
||||
return console_recipient.access_token
|
||||
|
||||
web_app_recipient = next(
|
||||
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP),
|
||||
None,
|
||||
)
|
||||
if web_app_recipient and web_app_recipient.access_token:
|
||||
return web_app_recipient.access_token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["SQLAlchemyExecutionExtraContentRepository"]
|
||||
|
||||
@ -136,7 +136,7 @@ class AudioService:
|
||||
message = db.session.query(Message).where(Message.id == message_id).first()
|
||||
if message is None:
|
||||
return None
|
||||
if message.answer == "" and message.status == MessageStatus.NORMAL:
|
||||
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:
|
||||
return None
|
||||
|
||||
else:
|
||||
|
||||
@ -20,6 +20,8 @@ def test_get_by_message_ids_returns_human_input_content(db_session_with_containe
|
||||
assert len(results) == 1
|
||||
assert len(results[0]) == 1
|
||||
content = results[0][0]
|
||||
assert content.action_id == fixture.action_id
|
||||
assert content.action_text == fixture.action_text
|
||||
assert content.rendered_content == fixture.form.rendered_content
|
||||
assert content.submitted is True
|
||||
assert content.form_submission_data is not None
|
||||
assert content.form_submission_data.action_id == fixture.action_id
|
||||
assert content.form_submission_data.action_text == fixture.action_text
|
||||
assert content.form_submission_data.rendered_content == fixture.form.rendered_content
|
||||
|
||||
@ -24,9 +24,14 @@ def test_pagination_returns_extra_contents(db_session_with_containers):
|
||||
message = pagination.data[0]
|
||||
assert message.extra_contents == [
|
||||
{
|
||||
"type": "human_input_result",
|
||||
"action_id": fixture.action_id,
|
||||
"action_text": fixture.action_text,
|
||||
"rendered_content": fixture.form.rendered_content,
|
||||
"type": "human_input",
|
||||
"submitted": True,
|
||||
"form_submission_data": {
|
||||
"node_id": fixture.form.node_id,
|
||||
"node_title": fixture.node_title,
|
||||
"rendered_content": fixture.form.rendered_content,
|
||||
"action_id": fixture.action_id,
|
||||
"action_text": fixture.action_text,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
@ -0,0 +1,156 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
from models.account import TenantAccountRole
|
||||
from models.model import ApiToken, App, AppMode
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
def _create_tenant_and_owner() -> tuple[Tenant, Account, TenantAccountJoin]:
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
tenant.id = str(uuid.uuid4())
|
||||
account = Account(
|
||||
email=f"owner-{uuid.uuid4()}@example.com",
|
||||
name="Owner",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
account.id = str(uuid.uuid4())
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
return tenant, account, tenant_join
|
||||
|
||||
|
||||
def _create_workflow_app(tenant: Tenant, account: Account) -> tuple[App, Workflow]:
|
||||
app = App(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name="Test Workflow App",
|
||||
description="",
|
||||
mode=AppMode.WORKFLOW,
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
version="v1",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
features=json.dumps({"features": []}),
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
app.workflow_id = workflow.id
|
||||
return app, workflow
|
||||
|
||||
|
||||
def _create_api_token(app: App) -> ApiToken:
|
||||
return ApiToken(
|
||||
app_id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
type="app",
|
||||
token=f"app-token-{uuid.uuid4()}",
|
||||
)
|
||||
|
||||
|
||||
def _collect_sse_events(response, max_events: int = 2, timeout: float = 3.0) -> list[dict[str, Any]]:
|
||||
events: list[dict[str, Any]] = []
|
||||
buffer = ""
|
||||
start_time = time.time()
|
||||
for chunk in response.response:
|
||||
if not chunk:
|
||||
if time.time() - start_time > timeout:
|
||||
break
|
||||
continue
|
||||
buffer += chunk.decode("utf-8")
|
||||
while "\n\n" in buffer:
|
||||
block, buffer = buffer.split("\n\n", 1)
|
||||
for line in block.splitlines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = line[len("data: ") :]
|
||||
events.append(json.loads(payload))
|
||||
if len(events) >= max_events:
|
||||
return events
|
||||
if time.time() - start_time > timeout:
|
||||
break
|
||||
return events
|
||||
|
||||
|
||||
class TestSSEStartGateIntegration:
|
||||
def test_workflow_streaming_sse_starts_after_subscribe(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
test_client_with_containers: FlaskClient,
|
||||
):
|
||||
tenant, account, tenant_join = _create_tenant_and_owner()
|
||||
app, workflow = _create_workflow_app(tenant, account)
|
||||
api_token = _create_api_token(app)
|
||||
|
||||
db_session_with_containers.add_all([tenant, account, tenant_join, app, workflow, api_token])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
def _fake_delay(payload_json: str):
|
||||
payload = json.loads(payload_json)
|
||||
workflow_run_id = payload["workflow_run_id"]
|
||||
app_mode = AppMode.value_of(payload["app_mode"])
|
||||
topic = MessageBasedAppGenerator.get_response_topic(app_mode, uuid.UUID(workflow_run_id))
|
||||
events = [
|
||||
{
|
||||
"event": StreamEvent.WORKFLOW_STARTED.value,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"created_at": int(time.time()),
|
||||
},
|
||||
{
|
||||
"event": StreamEvent.WORKFLOW_FINISHED.value,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"created_at": int(time.time()),
|
||||
},
|
||||
]
|
||||
for event in events:
|
||||
topic.publish(json.dumps(event).encode())
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"response_mode": "streaming",
|
||||
"user": "test-end-user",
|
||||
}
|
||||
|
||||
with patch("services.app_generate_service.chatflow_execute_task.delay", side_effect=_fake_delay):
|
||||
response = test_client_with_containers.post(
|
||||
"/v1/workflows/run",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {api_token.token}"},
|
||||
buffered=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
events = _collect_sse_events(response)
|
||||
assert len(events) == 2
|
||||
assert events[0]["event"] == StreamEvent.WORKFLOW_STARTED.value
|
||||
assert events[1]["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from models.enums import MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
|
||||
|
||||
def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline:
|
||||
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__(
|
||||
pipeline_module.AdvancedChatAppGenerateTaskPipeline
|
||||
)
|
||||
pipeline._workflow_run_id = "run-1"
|
||||
pipeline._message_id = "message-1"
|
||||
pipeline._workflow_tenant_id = "tenant-1"
|
||||
return pipeline
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
|
||||
|
||||
captured_session: dict[str, mock.Mock] = {}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
session.scalar.return_value = None
|
||||
captured_session["session"] = session
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
session = captured_session["session"]
|
||||
session.add.assert_called_once()
|
||||
content = session.add.call_args.args[0]
|
||||
assert isinstance(content, HumanInputContent)
|
||||
assert content.workflow_run_id == "run-1"
|
||||
assert content.message_id == "message-1"
|
||||
assert content.form_id == "form-1"
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None)
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
called["value"] = True
|
||||
session = mock.Mock()
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
assert called["value"] is False
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
|
||||
|
||||
captured_session: dict[str, mock.Mock] = {}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
session.scalar.return_value = HumanInputContent(
|
||||
workflow_run_id="run-1",
|
||||
message_id="message-1",
|
||||
form_id="form-1",
|
||||
)
|
||||
captured_session["session"] = session
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
session = captured_session["session"]
|
||||
session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = []
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(side_effect=ValueError())
|
||||
pipeline._save_message = mock.Mock()
|
||||
message = SimpleNamespace(status=MessageStatus.NORMAL)
|
||||
pipeline._get_message = mock.Mock(return_value=message)
|
||||
pipeline._persist_human_input_extra_content = mock.Mock()
|
||||
pipeline._base_task_pipeline = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager = mock.Mock()
|
||||
pipeline._message_saved_on_pause = False
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
form_token="token-1",
|
||||
resolved_placeholder_values={},
|
||||
)
|
||||
event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"])
|
||||
|
||||
list(pipeline._handle_workflow_paused_event(event))
|
||||
|
||||
pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1")
|
||||
assert message.status == MessageStatus.PAUSED
|
||||
@ -38,7 +38,6 @@ class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
token: str | None = None
|
||||
console_token_value: str | None = None
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
is_submitted: bool = False
|
||||
@ -51,8 +50,6 @@ class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
if self.console_token_value is not None:
|
||||
return self.console_token_value
|
||||
return self.token
|
||||
|
||||
@property
|
||||
@ -97,12 +94,11 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
self.created_params.append(params)
|
||||
self._form_counter += 1
|
||||
form_id = f"form-{self._form_counter}"
|
||||
console_token = f"console-{form_id}" if params.console_recipient_required else None
|
||||
token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}"
|
||||
entity = _InMemoryFormEntity(
|
||||
form_id=form_id,
|
||||
rendered=params.rendered_content,
|
||||
token=f"token-{form_id}",
|
||||
console_token_value=console_token,
|
||||
token=token,
|
||||
)
|
||||
self.created_forms.append(entity)
|
||||
self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity
|
||||
|
||||
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain
|
||||
from core.entities.execution_extra_content import HumanInputFormSubmissionData
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormDefinition,
|
||||
UserAction,
|
||||
@ -14,7 +15,7 @@ from core.workflow.nodes.human_input.enums import (
|
||||
TimeoutUnit,
|
||||
)
|
||||
from models.execution_extra_content import HumanInputContent as HumanInputContentModel
|
||||
from models.human_input import HumanInputForm
|
||||
from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
|
||||
|
||||
|
||||
@ -27,11 +28,13 @@ class _FakeScalarResult:
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, values: Sequence[HumanInputContentModel]):
|
||||
self._values = values
|
||||
def __init__(self, values: Sequence[Sequence[object]]):
|
||||
self._values = list(values)
|
||||
|
||||
def scalars(self, _stmt):
|
||||
return _FakeScalarResult(self._values)
|
||||
if not self._values:
|
||||
return _FakeScalarResult([])
|
||||
return _FakeScalarResult(self._values.pop(0))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@ -56,6 +59,8 @@ def _build_form(action_id: str, action_title: str, rendered_content: str) -> Hum
|
||||
rendered_content="rendered",
|
||||
timeout=1,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
node_title="Approval",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
id=f"form-{action_id}",
|
||||
@ -89,8 +94,9 @@ def _build_content(message_id: str, action_id: str, action_title: str) -> HumanI
|
||||
|
||||
def test_get_by_message_ids_groups_contents_by_message() -> None:
|
||||
message_ids = ["msg-1", "msg-2"]
|
||||
contents = [_build_content("msg-1", "approve", "Approve")]
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=_FakeSessionMaker(session=_FakeSession(values=[_build_content("msg-1", "approve", "Approve")]))
|
||||
session_maker=_FakeSessionMaker(session=_FakeSession(values=[contents, []]))
|
||||
)
|
||||
|
||||
result = repository.get_by_message_ids(message_ids)
|
||||
@ -98,9 +104,73 @@ def test_get_by_message_ids_groups_contents_by_message() -> None:
|
||||
assert len(result) == 2
|
||||
assert [content.to_dict() for content in result[0]] == [
|
||||
HumanInputContentDomain(
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
rendered_content="Rendered Approve",
|
||||
submitted=True,
|
||||
form_submission_data=HumanInputFormSubmissionData(
|
||||
node_id="node-id",
|
||||
node_title="Approval",
|
||||
rendered_content="Rendered Approve",
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
),
|
||||
).to_dict()
|
||||
]
|
||||
assert result[1] == []
|
||||
|
||||
|
||||
def test_get_by_message_ids_returns_unsubmitted_form_definition() -> None:
|
||||
definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
timeout=1,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
placeholder_values={"name": "John"},
|
||||
node_title="Approval",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
id="form-1",
|
||||
tenant_id="tenant-id",
|
||||
workflow_run_id="workflow-run",
|
||||
node_id="node-id",
|
||||
form_definition=definition.model_dump_json(),
|
||||
rendered_content="Rendered block",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=datetime.now(UTC) + timedelta(days=1),
|
||||
)
|
||||
content = HumanInputContentModel(
|
||||
id="content-msg-1",
|
||||
form_id=form.id,
|
||||
message_id="msg-1",
|
||||
workflow_run_id=form.workflow_run_id,
|
||||
)
|
||||
content.form = form
|
||||
|
||||
recipient = HumanInputFormRecipient(
|
||||
form_id=form.id,
|
||||
delivery_id="delivery-1",
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload=ConsoleRecipientPayload(account_id=None).model_dump_json(),
|
||||
access_token="token-1",
|
||||
)
|
||||
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=_FakeSessionMaker(session=_FakeSession(values=[[content], [recipient]]))
|
||||
)
|
||||
|
||||
result = repository.get_by_message_ids(["msg-1"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 1
|
||||
domain_content = result[0][0]
|
||||
assert domain_content.submitted is False
|
||||
assert domain_content.form_definition is not None
|
||||
form_definition = domain_content.form_definition
|
||||
assert form_definition.form_id == "form-1"
|
||||
assert form_definition.node_id == "node-id"
|
||||
assert form_definition.node_title == "Approval"
|
||||
assert form_definition.form_content == "Rendered block"
|
||||
assert form_definition.display_in_ui is True
|
||||
assert form_definition.form_token == "token-1"
|
||||
assert form_definition.resolved_placeholder_values == {"name": "John"}
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent
|
||||
from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData
|
||||
from services import message_service
|
||||
|
||||
|
||||
@ -24,9 +24,14 @@ def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: p
|
||||
"get_by_message_ids": lambda _self, message_ids: [
|
||||
[
|
||||
HumanInputContent(
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
rendered_content="Rendered",
|
||||
submitted=True,
|
||||
form_submission_data=HumanInputFormSubmissionData(
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
rendered_content="Rendered",
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
),
|
||||
)
|
||||
],
|
||||
[],
|
||||
@ -40,10 +45,15 @@ def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: p
|
||||
|
||||
assert messages[0].extra_contents == [
|
||||
{
|
||||
"type": "human_input_result",
|
||||
"action_id": "approve",
|
||||
"action_text": "Approve",
|
||||
"rendered_content": "Rendered",
|
||||
"type": "human_input",
|
||||
"submitted": True,
|
||||
"form_submission_data": {
|
||||
"node_id": "node-1",
|
||||
"node_title": "Approval",
|
||||
"rendered_content": "Rendered",
|
||||
"action_id": "approve",
|
||||
"action_text": "Approve",
|
||||
},
|
||||
}
|
||||
]
|
||||
assert messages[1].extra_contents == []
|
||||
|
||||
Reference in New Issue
Block a user