WIP: message extra contet

This commit is contained in:
QuantumGhost
2025-12-08 01:18:44 +08:00
parent 095eaabc0d
commit 1f64281ce5
13 changed files with 537 additions and 4 deletions

View File

@ -80,6 +80,7 @@ class MessageListApi(WebApiResource):
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"extra_contents": fields.List(fields.Raw),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, TypeAlias
from models.execution_extra_content import ExecutionContentType
@dataclass(frozen=True, kw_only=True)
class HumanInputContent:
action_id: str
action_text: str
rendered_content: str
type: ExecutionContentType = field(default=ExecutionContentType.HUMAN_INPUT_RESULT, init=False)
def to_dict(self) -> dict[str, Any]:
return {
"type": self.type.value,
"action_id": self.action_id,
"action_text": self.action_text,
"rendered_content": self.rendered_content,
}
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
__all__ = ["ExecutionExtraContentDomainModel", "HumanInputContent"]

View File

@ -11,7 +11,7 @@ if TYPE_CHECKING:
class ExecutionContentType(StrEnum):
HUMAN_INPUT = auto()
HUMAN_INPUT_RESULT = auto()
class ExecutionExtraContent(DefaultFieldsMixin, Base):
@ -56,7 +56,7 @@ class HumanInputContent(ExecutionExtraContent):
It should only be initialized with the `new` class method."""
__mapper_args__ = {
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT,
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT_RESULT,
}
# A relation to HumanInputForm table.

View File

@ -1,7 +1,7 @@
import json
import re
import uuid
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
@ -1269,6 +1269,14 @@ class Message(Base):
db.session.commit()
return result
# TODO(QuantumGhost): dirty hacks, fix this later.
def set_extra_contents(self, contents: Sequence[dict[str, Any]]) -> None:
self._extra_contents = list(contents)
@property
def extra_contents(self) -> list[dict[str, Any]]:
return getattr(self, "_extra_contents", [])
@property
def workflow_run(self):
if self.workflow_run_id:

View File

@ -0,0 +1,14 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
class ExecutionExtraContentRepository(Protocol):
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
...
__all__ = ["ExecutionExtraContentRepository"]

View File

@ -0,0 +1,98 @@
from __future__ import annotations
import logging
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.entities.execution_extra_content import (
ExecutionExtraContentDomainModel,
)
from core.entities.execution_extra_content import (
HumanInputContent as HumanInputContentDomainModel,
)
from core.workflow.nodes.human_input.entities import FormDefinition
from models.execution_extra_content import (
ExecutionExtraContent as ExecutionExtraContentModel,
)
from models.execution_extra_content import (
HumanInputContent as HumanInputContentModel,
)
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
logger = logging.getLogger(__name__)
class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository):
def __init__(self, session_maker: sessionmaker[Session]):
self._session_maker = session_maker
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
if not message_ids:
return []
grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = {
message_id: [] for message_id in message_ids
}
stmt = (
select(ExecutionExtraContentModel)
.where(ExecutionExtraContentModel.message_id.in_(message_ids))
.options(selectinload(HumanInputContentModel.form))
.order_by(ExecutionExtraContentModel.created_at.asc())
)
with self._session_maker() as session:
results = session.scalars(stmt).all()
for content in results:
message_id = content.message_id
if not message_id or message_id not in grouped_contents:
continue
domain_model = self._map_model_to_domain(content)
if domain_model is None:
continue
grouped_contents[message_id].append(domain_model)
return [grouped_contents[message_id] for message_id in message_ids]
def _map_model_to_domain(self, model: ExecutionExtraContentModel) -> ExecutionExtraContentDomainModel | None:
if isinstance(model, HumanInputContentModel):
return self._map_human_input_content(model)
logger.debug("Unsupported execution extra content type encountered: %s", model.type)
return None
def _map_human_input_content(self, model: HumanInputContentModel) -> HumanInputContentDomainModel | None:
form = model.form
if form is None:
logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id)
return None
selected_action_id = form.selected_action_id
if not selected_action_id:
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
return None
try:
form_definition = FormDefinition.model_validate_json(form.form_definition)
except ValueError:
logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id)
return None
action_text = next(
(action.title for action in form_definition.user_actions if action.id == selected_action_id),
selected_action_id,
)
return HumanInputContentDomainModel(
action_id=selected_action_id,
action_text=action_text,
rendered_content=form.rendered_content,
)
__all__ = ["SQLAlchemyExecutionExtraContentRepository"]

View File

@ -1,7 +1,9 @@
import json
from collections.abc import Sequence
from typing import Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from sqlalchemy.orm import sessionmaker
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from core.memory.token_buffer_memory import TokenBufferMemory
@ -14,6 +16,10 @@ from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
from repositories.sqlalchemy_execution_extra_content_repository import (
SQLAlchemyExecutionExtraContentRepository,
)
from services.conversation_service import ConversationService
from services.errors.message import (
FirstMessageNotExistsError,
@ -24,6 +30,23 @@ from services.errors.message import (
from services.workflow_service import WorkflowService
def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
return SQLAlchemyExecutionExtraContentRepository(session_maker=session_maker)
def _attach_message_extra_contents(messages: Sequence[Message]) -> None:
if not messages:
return
repository = _create_execution_extra_content_repository()
extra_contents_lists = repository.get_by_message_ids([message.id for message in messages])
for index, message in enumerate(messages):
contents = extra_contents_lists[index] if index < len(extra_contents_lists) else []
message.set_extra_contents([content.to_dict() for content in contents])
class MessageService:
@classmethod
def pagination_by_first_id(
@ -85,6 +108,8 @@ class MessageService:
if order == "asc":
history_messages = list(reversed(history_messages))
_attach_message_extra_contents(history_messages)
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
@classmethod

View File

@ -0,0 +1 @@
"""Helper utilities for integration tests."""

View File

@ -0,0 +1,149 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta
from decimal import Decimal
from uuid import uuid4
from core.workflow.nodes.human_input.entities import FormDefinition, TimeoutUnit, UserAction
from models.account import Account, Tenant, TenantAccountJoin
from models.execution_extra_content import HumanInputContent
from models.human_input import HumanInputForm, HumanInputFormStatus
from models.model import App, Conversation, Message
@dataclass
class HumanInputMessageFixture:
app: App
account: Account
conversation: Conversation
message: Message
form: HumanInputForm
action_id: str
action_text: str
def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
tenant = Tenant(name=f"Tenant {uuid4()}")
db_session.add(tenant)
db_session.flush()
account = Account(
name=f"Account {uuid4()}",
email=f"human_input_{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session.add(account)
db_session.flush()
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
current=True,
)
db_session.add(tenant_join)
db_session.flush()
app = App(
tenant_id=tenant.id,
name=f"App {uuid4()}",
description="",
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
updated_by=account.id,
)
db_session.add(app)
db_session.flush()
conversation = Conversation(
app_id=app.id,
mode="chat",
name="Test Conversation",
summary="",
introduction="",
system_instruction="",
status="normal",
invoke_from="console",
from_source="console",
from_account_id=account.id,
from_end_user_id=None,
)
conversation.inputs = {}
db_session.add(conversation)
db_session.flush()
workflow_run_id = str(uuid4())
message = Message(
app_id=app.id,
conversation_id=conversation.id,
inputs={},
query="Human input query",
message={"messages": []},
answer="Human input answer",
message_tokens=50,
message_unit_price=Decimal("0.001"),
answer_tokens=80,
answer_unit_price=Decimal("0.001"),
provider_response_latency=0.5,
currency="USD",
from_source="console",
from_account_id=account.id,
workflow_run_id=workflow_run_id,
)
db_session.add(message)
db_session.flush()
action_id = "approve"
action_text = "Approve request"
form_definition = FormDefinition(
form_content="content",
inputs=[],
user_actions=[UserAction(id=action_id, title=action_text)],
rendered_content="Rendered block",
timeout=1,
timeout_unit=TimeoutUnit.HOUR,
)
form = HumanInputForm(
tenant_id=tenant.id,
workflow_run_id=workflow_run_id,
node_id="node-id",
form_definition=form_definition.model_dump_json(),
rendered_content="Rendered block",
status=HumanInputFormStatus.SUBMITTED,
expiration_time=datetime.utcnow() + timedelta(days=1),
selected_action_id=action_id,
)
db_session.add(form)
db_session.flush()
content = HumanInputContent(
workflow_run_id=workflow_run_id,
message_id=message.id,
form_id=form.id,
)
db_session.add(content)
db_session.commit()
return HumanInputMessageFixture(
app=app,
account=account,
conversation=conversation,
message=message,
form=form,
action_id=action_id,
action_text=action_text,
)

View File

@ -0,0 +1,25 @@
from __future__ import annotations
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
create_human_input_message_fixture,
)
def test_get_by_message_ids_returns_human_input_content(db_session_with_containers):
fixture = create_human_input_message_fixture(db_session_with_containers)
repository = SQLAlchemyExecutionExtraContentRepository(
session_maker=sessionmaker(bind=db.engine, expire_on_commit=False)
)
results = repository.get_by_message_ids([fixture.message.id])
assert len(results) == 1
assert len(results[0]) == 1
content = results[0][0]
assert content.action_id == fixture.action_id
assert content.action_text == fixture.action_text
assert content.rendered_content == fixture.form.rendered_content

View File

@ -0,0 +1,32 @@
from __future__ import annotations
import pytest
from services.message_service import MessageService
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
create_human_input_message_fixture,
)
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
def test_pagination_returns_extra_contents(db_session_with_containers):
fixture = create_human_input_message_fixture(db_session_with_containers)
pagination = MessageService.pagination_by_first_id(
app_model=fixture.app,
user=fixture.account,
conversation_id=fixture.conversation.id,
first_id=None,
limit=10,
)
assert pagination.data
message = pagination.data[0]
assert message.extra_contents == [
{
"type": "human_input_result",
"action_id": fixture.action_id,
"action_text": fixture.action_text,
"rendered_content": fixture.form.rendered_content,
}
]

View File

@ -0,0 +1,104 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain
from core.workflow.nodes.human_input.entities import (
FormDefinition,
HumanInputFormStatus,
TimeoutUnit,
UserAction,
)
from models.execution_extra_content import HumanInputContent as HumanInputContentModel
from models.human_input import HumanInputForm
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
class _FakeScalarResult:
def __init__(self, values: Sequence[HumanInputContentModel]):
self._values = list(values)
def all(self) -> list[HumanInputContentModel]:
return list(self._values)
class _FakeSession:
def __init__(self, values: Sequence[HumanInputContentModel]):
self._values = values
def scalars(self, _stmt):
return _FakeScalarResult(self._values)
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
@dataclass
class _FakeSessionMaker:
session: _FakeSession
def __call__(self) -> _FakeSession:
return self.session
def _build_form(action_id: str, action_title: str, rendered_content: str) -> HumanInputForm:
definition = FormDefinition(
form_content="content",
inputs=[],
user_actions=[UserAction(id=action_id, title=action_title)],
rendered_content="rendered",
timeout=1,
timeout_unit=TimeoutUnit.HOUR,
)
form = HumanInputForm(
id=f"form-{action_id}",
tenant_id="tenant-id",
workflow_run_id="workflow-run",
node_id="node-id",
form_definition=definition.model_dump_json(),
rendered_content=rendered_content,
status=HumanInputFormStatus.SUBMITTED,
expiration_time=datetime.now(UTC) + timedelta(days=1),
)
form.selected_action_id = action_id
return form
def _build_content(message_id: str, action_id: str, action_title: str) -> HumanInputContentModel:
form = _build_form(
action_id=action_id,
action_title=action_title,
rendered_content=f"Rendered {action_title}",
)
content = HumanInputContentModel(
id=f"content-{message_id}",
form_id=form.id,
message_id=message_id,
workflow_run_id=form.workflow_run_id,
)
content.form = form
return content
def test_get_by_message_ids_groups_contents_by_message() -> None:
message_ids = ["msg-1", "msg-2"]
repository = SQLAlchemyExecutionExtraContentRepository(
session_maker=_FakeSessionMaker(session=_FakeSession(values=[_build_content("msg-1", "approve", "Approve")]))
)
result = repository.get_by_message_ids(message_ids)
assert len(result) == 2
assert [content.to_dict() for content in result[0]] == [
HumanInputContentDomain(
action_id="approve",
action_text="Approve",
rendered_content="Rendered Approve",
).to_dict()
]
assert result[1] == []

View File

@ -0,0 +1,49 @@
from __future__ import annotations
import pytest
from core.entities.execution_extra_content import HumanInputContent
from services import message_service
class _FakeMessage:
def __init__(self, message_id: str):
self.id = message_id
self.extra_contents = None
def set_extra_contents(self, contents):
self.extra_contents = contents
def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")]
repo = type(
"Repo",
(),
{
"get_by_message_ids": lambda _self, message_ids: [
[
HumanInputContent(
action_id="approve",
action_text="Approve",
rendered_content="Rendered",
)
],
[],
]
},
)()
monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo)
message_service._attach_message_extra_contents(messages)
assert messages[0].extra_contents == [
{
"type": "human_input_result",
"action_id": "approve",
"action_text": "Approve",
"rendered_content": "Rendered",
}
]
assert messages[1].extra_contents == []