WIP: feat: ExecutionExtraContent model

This commit is contained in:
QuantumGhost
2025-12-05 02:44:34 +08:00
parent 08175ab32a
commit 095eaabc0d
5 changed files with 200 additions and 20 deletions

View File

@ -0,0 +1,42 @@
"""Add ExecutionExtraContent model
Revision ID: e63797cc11c2
Revises: d411af417245
Create Date: 2025-12-03 17:23:05.140844
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e63797cc11c2"
down_revision = "d411af417245"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"execution_extra_contents",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("type", sa.String(30), nullable=False),
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False),
sa.Column("message_id", models.types.StringUUID(), nullable=True),
sa.Column("form_id", models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("execution_extra_contents_pkey")),
)
with op.batch_alter_table("execution_extra_contents", schema=None) as batch_op:
batch_op.create_index(batch_op.f("execution_extra_contents_message_id_idx"), ["message_id"], unique=False)
batch_op.create_index(
batch_op.f("execution_extra_contents_workflow_run_id_idx"), ["workflow_run_id"], unique=False
)
def downgrade():
op.drop_table("execution_extra_contents")

View File

@ -34,6 +34,7 @@ from .enums import (
WorkflowRunTriggeredFrom,
WorkflowTriggerStatus,
)
from .execution_extra_content import ExecutionExtraContent, HumanInputContent
from .human_input import HumanInputForm
from .model import (
ApiRequest,
@ -54,7 +55,6 @@ from .model import (
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageExtraContent,
MessageFeedback,
MessageFile,
OperationLog,
@ -151,8 +151,10 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"ExecutionExtraContent",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"HumanInputContent",
"HumanInputForm",
"IconType",
"InstalledApp",
@ -162,7 +164,6 @@ __all__ = [
"MessageAgentThought",
"MessageAnnotation",
"MessageChain",
"MessageExtraContent",
"MessageFeedback",
"MessageFile",
"OperationLog",

View File

@ -0,0 +1,78 @@
from enum import StrEnum, auto
from typing import TYPE_CHECKING
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base, DefaultFieldsMixin
from .types import EnumText, StringUUID
if TYPE_CHECKING:
from human_input import HumanInputForm
class ExecutionContentType(StrEnum):
HUMAN_INPUT = auto()
class ExecutionExtraContent(DefaultFieldsMixin, Base):
"""ExecutionExtraContent stores extra contents produced during workflow / chatflow execution."""
# The `ExecutionExtraContent` uses single table inheritance to model different
# kinds of contents produced during message generation.
#
# See: https://docs.sqlalchemy.org/en/20/orm/inheritance.html#single-table-inheritance
__tablename__ = "execution_extra_contents"
__mapper_args__ = {
"polymorphic_abstract": True,
"polymorphic_on": "type",
"with_polymorphic": "*",
}
# type records the type of the content. It serves as the `discriminator` for the
# single table inheritance.
type: Mapped[ExecutionContentType] = mapped_column(
EnumText(ExecutionContentType, length=30),
nullable=False,
)
# `workflow_run_id` records the workflow execution which generates this content, correspond to
# `WorkflowRun.id`.
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
# `message_id` records the messages generated by the execution associated with this `ExecutionExtraContent`.
# It references to `Message.id`.
#
# For workflow execution, this field is `None`.
#
# For chatflow execution, `message_id`` is not None, and the following condition holds:
#
# The message referenced by `message_id` has `message.workflow_run_id == execution_extra_content.workflow_run_id`
#
message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, index=True)
class HumanInputContent(ExecutionExtraContent):
"""HumanInputContent is a concrete class that represents human input content.
It should only be initialized with the `new` class method."""
__mapper_args__ = {
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT,
}
# A relation to HumanInputForm table.
#
# While the form_id column is nullable in database (due to the nature of single table inheritance),
# the form_id field should not be null for a given `HumanInputContent` instance.
form_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
@classmethod
def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent":
return cls(form_id=form_id, message_id=message_id)
form: Mapped["HumanInputForm"] = relationship(
"HumanInputForm",
foreign_keys=[form_id],
uselist=False,
lazy="raise",
primaryjoin="foreign(HumanInputContent.form_id) == HumanInputForm.id",
)

View File

@ -24,11 +24,11 @@ from libs.helper import generate_string # type: ignore[import-not-found]
from libs.uuid_utils import uuidv7
from .account import Account, Tenant
from .base import Base, DefaultFieldsMixin, TypeBase
from .base import Base, TypeBase
from .engine import db
from .enums import CreatorUserRole
from .provider_ids import GenericProviderID
from .types import EnumText, LongText, StringUUID
from .types import LongText, StringUUID
if TYPE_CHECKING:
from .workflow import Workflow
@ -2065,19 +2065,3 @@ class TraceAppConfig(TypeBase):
"created_at": str(self.created_at) if self.created_at else None,
"updated_at": str(self.updated_at) if self.updated_at else None,
}
class MessageExtraContentType(StrEnum):
human_input_result = auto()
class MessageExtraContent(DefaultFieldsMixin):
__tablename__ = "message_extra_contents"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_extra_content_pkey"),
sa.Index("message_extra_content_message_id_idx", "message_id"),
)
message_id = mapped_column(StringUUID, nullable=False, index=True)
type: Mapped[MessageExtraContentType] = mapped_column(EnumText(MessageExtraContentType, length=30), nullable=False)
content = mapped_column(sa.Text, nullable=False)

View File

@ -0,0 +1,75 @@
import uuid
from decimal import Decimal
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from libs.uuid_utils import uuidv7
from models.enums import CreatorUserRole
from models.model import AppMode, Conversation, Message
def _create_conversation(session) -> Conversation:
conversation = Conversation(
app_id=str(uuid.uuid4()),
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source=CreatorUserRole.ACCOUNT,
from_account_id=str(uuid.uuid4()),
)
conversation.inputs = {}
session.add(conversation)
session.commit()
return conversation
def _create_message(session, conversation: Conversation) -> Message:
message = Message(
app_id=conversation.app_id,
conversation_id=conversation.id,
query="Need manual approval",
message={"type": "text", "content": "Need manual approval"},
answer="Acknowledged",
message_tokens=10,
answer_tokens=20,
message_unit_price=Decimal("0.001"),
answer_unit_price=Decimal("0.001"),
message_price_unit=Decimal("0.001"),
answer_price_unit=Decimal("0.001"),
currency="USD",
status="normal",
from_source=CreatorUserRole.ACCOUNT,
)
message.inputs = {}
session.add(message)
session.commit()
return message
def test_message_auto_loads_multiple_extra_variants(db_session_with_containers):
conversation = _create_conversation(db_session_with_containers)
message = _create_message(db_session_with_containers, conversation)
human_input_result_content_id = str(uuidv7())
human_input_result_content = HumanInputResultRelation(
id=human_input_result_content_id,
message_id=message.id,
form_id=None,
)
db_session_with_containers.add(human_input_result_content)
db_session_with_containers.commit()
# polymorphic_extra = with_polymorphic(
# MessageExtraContent,
# [HumanInputResultRelation],
# )
stmt = select(Message).options(selectinload(Message.extra_content)).where(Message.id == message.id)
loaded_message = db_session_with_containers.execute(stmt).scalar_one()
assert len(loaded_message.extra_content) == 1
assert human_input_result_content_id in {extra.id for extra in loaded_message.extra_content}
loaded_types = {type(extra) for extra in loaded_message.extra_content}
assert HumanInputResultRelation in loaded_types