mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
WIP: message extra contet
This commit is contained in:
@ -80,6 +80,7 @@ class MessageListApi(WebApiResource):
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"extra_contents": fields.List(fields.Raw),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
|
||||
27
api/core/entities/execution_extra_content.py
Normal file
27
api/core/entities/execution_extra_content.py
Normal file
@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from models.execution_extra_content import ExecutionContentType
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class HumanInputContent:
|
||||
action_id: str
|
||||
action_text: str
|
||||
rendered_content: str
|
||||
type: ExecutionContentType = field(default=ExecutionContentType.HUMAN_INPUT_RESULT, init=False)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"action_id": self.action_id,
|
||||
"action_text": self.action_text,
|
||||
"rendered_content": self.rendered_content,
|
||||
}
|
||||
|
||||
|
||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
|
||||
|
||||
__all__ = ["ExecutionExtraContentDomainModel", "HumanInputContent"]
|
||||
@ -11,7 +11,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ExecutionContentType(StrEnum):
|
||||
HUMAN_INPUT = auto()
|
||||
HUMAN_INPUT_RESULT = auto()
|
||||
|
||||
|
||||
class ExecutionExtraContent(DefaultFieldsMixin, Base):
|
||||
@ -56,7 +56,7 @@ class HumanInputContent(ExecutionExtraContent):
|
||||
It should only be initialized with the `new` class method."""
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT,
|
||||
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT_RESULT,
|
||||
}
|
||||
|
||||
# A relation to HumanInputForm table.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
@ -1269,6 +1269,14 @@ class Message(Base):
|
||||
db.session.commit()
|
||||
return result
|
||||
|
||||
# TODO(QuantumGhost): dirty hacks, fix this later.
|
||||
def set_extra_contents(self, contents: Sequence[dict[str, Any]]) -> None:
|
||||
self._extra_contents = list(contents)
|
||||
|
||||
@property
|
||||
def extra_contents(self) -> list[dict[str, Any]]:
|
||||
return getattr(self, "_extra_contents", [])
|
||||
|
||||
@property
|
||||
def workflow_run(self):
|
||||
if self.workflow_run_id:
|
||||
|
||||
14
api/repositories/execution_extra_content_repository.py
Normal file
14
api/repositories/execution_extra_content_repository.py
Normal file
@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
|
||||
|
||||
class ExecutionExtraContentRepository(Protocol):
|
||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["ExecutionExtraContentRepository"]
|
||||
@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.entities.execution_extra_content import (
|
||||
ExecutionExtraContentDomainModel,
|
||||
)
|
||||
from core.entities.execution_extra_content import (
|
||||
HumanInputContent as HumanInputContentDomainModel,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from models.execution_extra_content import (
|
||||
ExecutionExtraContent as ExecutionExtraContentModel,
|
||||
)
|
||||
from models.execution_extra_content import (
|
||||
HumanInputContent as HumanInputContentModel,
|
||||
)
|
||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository):
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
self._session_maker = session_maker
|
||||
|
||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
|
||||
if not message_ids:
|
||||
return []
|
||||
|
||||
grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = {
|
||||
message_id: [] for message_id in message_ids
|
||||
}
|
||||
|
||||
stmt = (
|
||||
select(ExecutionExtraContentModel)
|
||||
.where(ExecutionExtraContentModel.message_id.in_(message_ids))
|
||||
.options(selectinload(HumanInputContentModel.form))
|
||||
.order_by(ExecutionExtraContentModel.created_at.asc())
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
results = session.scalars(stmt).all()
|
||||
|
||||
for content in results:
|
||||
message_id = content.message_id
|
||||
if not message_id or message_id not in grouped_contents:
|
||||
continue
|
||||
|
||||
domain_model = self._map_model_to_domain(content)
|
||||
if domain_model is None:
|
||||
continue
|
||||
|
||||
grouped_contents[message_id].append(domain_model)
|
||||
|
||||
return [grouped_contents[message_id] for message_id in message_ids]
|
||||
|
||||
def _map_model_to_domain(self, model: ExecutionExtraContentModel) -> ExecutionExtraContentDomainModel | None:
|
||||
if isinstance(model, HumanInputContentModel):
|
||||
return self._map_human_input_content(model)
|
||||
|
||||
logger.debug("Unsupported execution extra content type encountered: %s", model.type)
|
||||
return None
|
||||
|
||||
def _map_human_input_content(self, model: HumanInputContentModel) -> HumanInputContentDomainModel | None:
|
||||
form = model.form
|
||||
if form is None:
|
||||
logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id)
|
||||
return None
|
||||
|
||||
selected_action_id = form.selected_action_id
|
||||
if not selected_action_id:
|
||||
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
|
||||
return None
|
||||
|
||||
try:
|
||||
form_definition = FormDefinition.model_validate_json(form.form_definition)
|
||||
except ValueError:
|
||||
logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id)
|
||||
return None
|
||||
|
||||
action_text = next(
|
||||
(action.title for action in form_definition.user_actions if action.id == selected_action_id),
|
||||
selected_action_id,
|
||||
)
|
||||
|
||||
return HumanInputContentDomainModel(
|
||||
action_id=selected_action_id,
|
||||
action_text=action_text,
|
||||
rendered_content=form.rendered_content,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["SQLAlchemyExecutionExtraContentRepository"]
|
||||
@ -1,7 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -14,6 +16,10 @@ from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
|
||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import (
|
||||
SQLAlchemyExecutionExtraContentRepository,
|
||||
)
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
@ -24,6 +30,23 @@ from services.errors.message import (
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
return SQLAlchemyExecutionExtraContentRepository(session_maker=session_maker)
|
||||
|
||||
|
||||
def _attach_message_extra_contents(messages: Sequence[Message]) -> None:
|
||||
if not messages:
|
||||
return
|
||||
|
||||
repository = _create_execution_extra_content_repository()
|
||||
extra_contents_lists = repository.get_by_message_ids([message.id for message in messages])
|
||||
|
||||
for index, message in enumerate(messages):
|
||||
contents = extra_contents_lists[index] if index < len(extra_contents_lists) else []
|
||||
message.set_extra_contents([content.to_dict() for content in contents])
|
||||
|
||||
|
||||
class MessageService:
|
||||
@classmethod
|
||||
def pagination_by_first_id(
|
||||
@ -85,6 +108,8 @@ class MessageService:
|
||||
if order == "asc":
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
_attach_message_extra_contents(history_messages)
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -0,0 +1 @@
|
||||
"""Helper utilities for integration tests."""
|
||||
@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition, TimeoutUnit, UserAction
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.human_input import HumanInputForm, HumanInputFormStatus
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
|
||||
@dataclass
|
||||
class HumanInputMessageFixture:
|
||||
app: App
|
||||
account: Account
|
||||
conversation: Conversation
|
||||
message: Message
|
||||
form: HumanInputForm
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
|
||||
def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}")
|
||||
db_session.add(tenant)
|
||||
db_session.flush()
|
||||
|
||||
account = Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"human_input_{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
db_session.add(account)
|
||||
db_session.flush()
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role="owner",
|
||||
current=True,
|
||||
)
|
||||
db_session.add(tenant_join)
|
||||
db_session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"App {uuid4()}",
|
||||
description="",
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=100,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db_session.add(app)
|
||||
db_session.flush()
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
mode="chat",
|
||||
name="Test Conversation",
|
||||
summary="",
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
status="normal",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
from_account_id=account.id,
|
||||
from_end_user_id=None,
|
||||
)
|
||||
conversation.inputs = {}
|
||||
db_session.add(conversation)
|
||||
db_session.flush()
|
||||
|
||||
workflow_run_id = str(uuid4())
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
inputs={},
|
||||
query="Human input query",
|
||||
message={"messages": []},
|
||||
answer="Human input answer",
|
||||
message_tokens=50,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=80,
|
||||
answer_unit_price=Decimal("0.001"),
|
||||
provider_response_latency=0.5,
|
||||
currency="USD",
|
||||
from_source="console",
|
||||
from_account_id=account.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.flush()
|
||||
|
||||
action_id = "approve"
|
||||
action_text = "Approve request"
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id=action_id, title=action_text)],
|
||||
rendered_content="Rendered block",
|
||||
timeout=1,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
tenant_id=tenant.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id="node-id",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="Rendered block",
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
expiration_time=datetime.utcnow() + timedelta(days=1),
|
||||
selected_action_id=action_id,
|
||||
)
|
||||
db_session.add(form)
|
||||
db_session.flush()
|
||||
|
||||
content = HumanInputContent(
|
||||
workflow_run_id=workflow_run_id,
|
||||
message_id=message.id,
|
||||
form_id=form.id,
|
||||
)
|
||||
db_session.add(content)
|
||||
db_session.commit()
|
||||
|
||||
return HumanInputMessageFixture(
|
||||
app=app,
|
||||
account=account,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
form=form,
|
||||
action_id=action_id,
|
||||
action_text=action_text,
|
||||
)
|
||||
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
|
||||
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
|
||||
create_human_input_message_fixture,
|
||||
)
|
||||
|
||||
|
||||
def test_get_by_message_ids_returns_human_input_content(db_session_with_containers):
|
||||
fixture = create_human_input_message_fixture(db_session_with_containers)
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
)
|
||||
|
||||
results = repository.get_by_message_ids([fixture.message.id])
|
||||
|
||||
assert len(results) == 1
|
||||
assert len(results[0]) == 1
|
||||
content = results[0][0]
|
||||
assert content.action_id == fixture.action_id
|
||||
assert content.action_text == fixture.action_text
|
||||
assert content.rendered_content == fixture.form.rendered_content
|
||||
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from services.message_service import MessageService
|
||||
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
|
||||
create_human_input_message_fixture,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
|
||||
def test_pagination_returns_extra_contents(db_session_with_containers):
|
||||
fixture = create_human_input_message_fixture(db_session_with_containers)
|
||||
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model=fixture.app,
|
||||
user=fixture.account,
|
||||
conversation_id=fixture.conversation.id,
|
||||
first_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert pagination.data
|
||||
message = pagination.data[0]
|
||||
assert message.extra_contents == [
|
||||
{
|
||||
"type": "human_input_result",
|
||||
"action_id": fixture.action_id,
|
||||
"action_text": fixture.action_text,
|
||||
"rendered_content": fixture.form.rendered_content,
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormDefinition,
|
||||
HumanInputFormStatus,
|
||||
TimeoutUnit,
|
||||
UserAction,
|
||||
)
|
||||
from models.execution_extra_content import HumanInputContent as HumanInputContentModel
|
||||
from models.human_input import HumanInputForm
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, values: Sequence[HumanInputContentModel]):
|
||||
self._values = list(values)
|
||||
|
||||
def all(self) -> list[HumanInputContentModel]:
|
||||
return list(self._values)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, values: Sequence[HumanInputContentModel]):
|
||||
self._values = values
|
||||
|
||||
def scalars(self, _stmt):
|
||||
return _FakeScalarResult(self._values)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeSessionMaker:
|
||||
session: _FakeSession
|
||||
|
||||
def __call__(self) -> _FakeSession:
|
||||
return self.session
|
||||
|
||||
|
||||
def _build_form(action_id: str, action_title: str, rendered_content: str) -> HumanInputForm:
|
||||
definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id=action_id, title=action_title)],
|
||||
rendered_content="rendered",
|
||||
timeout=1,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
id=f"form-{action_id}",
|
||||
tenant_id="tenant-id",
|
||||
workflow_run_id="workflow-run",
|
||||
node_id="node-id",
|
||||
form_definition=definition.model_dump_json(),
|
||||
rendered_content=rendered_content,
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
expiration_time=datetime.now(UTC) + timedelta(days=1),
|
||||
)
|
||||
form.selected_action_id = action_id
|
||||
return form
|
||||
|
||||
|
||||
def _build_content(message_id: str, action_id: str, action_title: str) -> HumanInputContentModel:
|
||||
form = _build_form(
|
||||
action_id=action_id,
|
||||
action_title=action_title,
|
||||
rendered_content=f"Rendered {action_title}",
|
||||
)
|
||||
content = HumanInputContentModel(
|
||||
id=f"content-{message_id}",
|
||||
form_id=form.id,
|
||||
message_id=message_id,
|
||||
workflow_run_id=form.workflow_run_id,
|
||||
)
|
||||
content.form = form
|
||||
return content
|
||||
|
||||
|
||||
def test_get_by_message_ids_groups_contents_by_message() -> None:
|
||||
message_ids = ["msg-1", "msg-2"]
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=_FakeSessionMaker(session=_FakeSession(values=[_build_content("msg-1", "approve", "Approve")]))
|
||||
)
|
||||
|
||||
result = repository.get_by_message_ids(message_ids)
|
||||
|
||||
assert len(result) == 2
|
||||
assert [content.to_dict() for content in result[0]] == [
|
||||
HumanInputContentDomain(
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
rendered_content="Rendered Approve",
|
||||
).to_dict()
|
||||
]
|
||||
assert result[1] == []
|
||||
@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent
|
||||
from services import message_service
|
||||
|
||||
|
||||
class _FakeMessage:
|
||||
def __init__(self, message_id: str):
|
||||
self.id = message_id
|
||||
self.extra_contents = None
|
||||
|
||||
def set_extra_contents(self, contents):
|
||||
self.extra_contents = contents
|
||||
|
||||
|
||||
def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")]
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_by_message_ids": lambda _self, message_ids: [
|
||||
[
|
||||
HumanInputContent(
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
rendered_content="Rendered",
|
||||
)
|
||||
],
|
||||
[],
|
||||
]
|
||||
},
|
||||
)()
|
||||
|
||||
monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo)
|
||||
|
||||
message_service._attach_message_extra_contents(messages)
|
||||
|
||||
assert messages[0].extra_contents == [
|
||||
{
|
||||
"type": "human_input_result",
|
||||
"action_id": "approve",
|
||||
"action_text": "Approve",
|
||||
"rendered_content": "Rendered",
|
||||
}
|
||||
]
|
||||
assert messages[1].extra_contents == []
|
||||
Reference in New Issue
Block a user