Merge main HEAD (segment 5) into sandboxed-agent-rebase

Resolve 83 conflicts: 10 backend, 62 frontend, 11 config/lock files.
Preserve sandbox/agent/collaboration features while adopting main's
UI refactorings (Dialog/AlertDialog/Popover), model provider updates,
and enterprise features.

Made-with: Cursor
This commit is contained in:
Novice
2026-03-23 14:20:06 +08:00
1671 changed files with 124822 additions and 22302 deletions

View File

@ -177,13 +177,11 @@ class Account(UserMixin, TypeBase):
@classmethod
def get_by_openid(cls, provider: str, open_id: str):
account_integrate = (
db.session.query(AccountIntegrate)
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
.one_or_none()
)
account_integrate = db.session.execute(
select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
).scalar_one_or_none()
if account_integrate:
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id))
return None
# check current_user.current_tenant.current_role in ['admin', 'owner']

View File

@ -8,9 +8,10 @@ import os
import pickle
import re
import time
from collections.abc import Sequence
from datetime import datetime
from json import JSONDecodeError
from typing import Any, cast
from typing import Any, TypedDict, cast
from uuid import uuid4
import sqlalchemy as sa
@ -30,13 +31,81 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
from .account import Account
from .base import Base, TypeBase
from .engine import db
from .enums import CreatorUserRole
from .enums import (
CollectionBindingType,
CreatorUserRole,
DatasetMetadataType,
DatasetQuerySource,
DatasetRuntimeMode,
DataSourceType,
DocumentCreatedFrom,
DocumentDocType,
IndexingStatus,
ProcessRuleMode,
SegmentStatus,
SummaryStatus,
)
from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
logger = logging.getLogger(__name__)
class PreProcessingRuleItem(TypedDict):
id: str
enabled: bool
class SegmentationConfig(TypedDict):
delimiter: str
max_tokens: int
chunk_overlap: int
class AutomaticRulesConfig(TypedDict):
pre_processing_rules: list[PreProcessingRuleItem]
segmentation: SegmentationConfig
class ProcessRuleDict(TypedDict):
id: str
dataset_id: str
mode: str
rules: dict[str, Any] | None
class DocMetadataDetailItem(TypedDict):
id: str
name: str
type: str
value: Any
class AttachmentItem(TypedDict):
id: str
name: str
size: int
extension: str
mime_type: str
source_url: str
class DatasetBindingItem(TypedDict):
id: str
name: str
class ExternalKnowledgeApiDict(TypedDict):
id: str
tenant_id: str
name: str
description: str
settings: dict[str, Any] | None
dataset_bindings: list[DatasetBindingItem]
created_by: str
created_at: str
class DatasetPermissionEnum(enum.StrEnum):
ONLY_ME = "only_me"
ALL_TEAM = "all_team_members"
@ -65,7 +134,7 @@ class Dataset(Base):
server_default=sa.text("'only_me'"),
default=DatasetPermissionEnum.ONLY_ME,
)
data_source_type = mapped_column(String(255))
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
indexing_technique: Mapped[str | None] = mapped_column(String(255))
index_struct = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
@ -82,7 +151,9 @@ class Dataset(Base):
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
icon_info = mapped_column(AdjustedJSON, nullable=True)
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
runtime_mode = mapped_column(
EnumText(DatasetRuntimeMode, length=255), nullable=True, server_default=sa.text("'general'")
)
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(sa.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
@ -90,30 +161,25 @@ class Dataset(Base):
@property
def total_documents(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
@property
def total_available_documents(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
)
.scalar()
or 0
)
@property
def dataset_keyword_table(self):
dataset_keyword_table = (
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
)
if dataset_keyword_table:
return dataset_keyword_table
return None
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
@property
def index_struct_dict(self):
@ -140,64 +206,66 @@ class Dataset(Base):
@property
def latest_process_rule(self):
return (
db.session.query(DatasetProcessRule)
return db.session.scalar(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.first()
.limit(1)
)
@property
def app_count(self):
return (
db.session.query(func.count(AppDatasetJoin.id))
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.scalar()
db.session.scalar(
select(func.count(AppDatasetJoin.id)).where(
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
)
)
or 0
)
@property
def document_count(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
@property
def available_document_count(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
)
.scalar()
or 0
)
@property
def available_segment_count(self):
return (
db.session.query(func.count(DocumentSegment.id))
.where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
)
.scalar()
or 0
)
@property
def word_count(self):
return (
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
.where(Document.dataset_id == self.id)
.scalar()
return db.session.scalar(
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
)
@property
def doc_form(self) -> str | None:
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
if document:
return document.doc_form
return None
@ -215,8 +283,8 @@ class Dataset(Base):
@property
def tags(self):
tags = (
db.session.query(Tag)
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@ -224,8 +292,7 @@ class Dataset(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "knowledge",
)
.all()
)
).all()
return tags or []
@ -233,8 +300,8 @@ class Dataset(Base):
def external_knowledge_info(self):
if self.provider != "external":
return None
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
)
if not external_knowledge_binding:
return None
@ -255,7 +322,7 @@ class Dataset(Base):
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
if pipeline:
return pipeline.is_published
return False
@ -327,14 +394,14 @@ class DatasetProcessRule(Base): # bug
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
dataset_id = mapped_column(StringUUID, nullable=False)
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'"))
rules = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
MODES = ["automatic", "custom", "hierarchical"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
AUTOMATIC_RULES: dict[str, Any] = {
AUTOMATIC_RULES: AutomaticRulesConfig = {
"pre_processing_rules": [
{"id": "remove_extra_spaces", "enabled": True},
{"id": "remove_urls_emails", "enabled": False},
@ -342,7 +409,7 @@ class DatasetProcessRule(Base): # bug
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
}
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> ProcessRuleDict:
return {
"id": self.id,
"dataset_id": self.dataset_id,
@ -373,12 +440,12 @@ class Document(Base):
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
data_source_type: Mapped[str] = mapped_column(EnumText(DataSourceType, length=255), nullable=False)
data_source_info = mapped_column(LongText, nullable=True)
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
batch: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[str] = mapped_column(EnumText(DocumentCreatedFrom, length=255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_api_request_id = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -412,7 +479,9 @@ class Document(Base):
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# basic fields
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'"))
indexing_status = mapped_column(
EnumText(IndexingStatus, length=255), nullable=False, server_default=sa.text("'waiting'")
)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
@ -423,7 +492,7 @@ class Document(Base):
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
doc_type = mapped_column(String(40), nullable=True)
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_language = mapped_column(String(255), nullable=True)
@ -466,10 +535,8 @@ class Document(Base):
if self.data_source_info:
if self.data_source_type == "upload_file":
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
.one_or_none()
file_detail = db.session.scalar(
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
)
if file_detail:
return {
@ -502,24 +569,23 @@ class Document(Base):
@property
def dataset(self):
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def segment_count(self):
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
return (
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
)
@property
def hit_count(self):
return (
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
.where(DocumentSegment.document_id == self.id)
.scalar()
return db.session.scalar(
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
)
@property
def uploader(self):
user = db.session.query(Account).where(Account.id == self.created_by).first()
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
return user.name if user else None
@property
@ -531,19 +597,18 @@ class Document(Base):
return self.updated_at
@property
def doc_metadata_details(self) -> list[dict[str, Any]] | None:
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
document_metadatas = db.session.scalars(
select(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.where(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
)
.all()
)
metadata_list: list[dict[str, Any]] = []
).all()
metadata_list: list[DocMetadataDetailItem] = []
for metadata in document_metadatas:
metadata_dict: dict[str, Any] = {
metadata_dict: DocMetadataDetailItem = {
"id": metadata.id,
"name": metadata.name,
"type": metadata.type,
@ -557,13 +622,13 @@ class Document(Base):
return None
@property
def process_rule_dict(self) -> dict[str, Any] | None:
def process_rule_dict(self) -> ProcessRuleDict | None:
if self.dataset_process_rule_id and self.dataset_process_rule:
return self.dataset_process_rule.to_dict()
return None
def get_built_in_fields(self) -> list[dict[str, Any]]:
built_in_fields: list[dict[str, Any]] = []
def get_built_in_fields(self) -> list[DocMetadataDetailItem]:
built_in_fields: list[DocMetadataDetailItem] = []
built_in_fields.append(
{
"id": "built-in",
@ -736,7 +801,7 @@ class DocumentSegment(Base):
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'"))
status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"))
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -771,7 +836,7 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self) -> list[Any]:
def child_chunks(self) -> Sequence[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -780,16 +845,13 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
return child_chunks or []
return []
def get_child_chunks(self) -> list[Any]:
def get_child_chunks(self) -> Sequence[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -798,12 +860,9 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
return child_chunks or []
return []
@ -877,7 +936,7 @@ class DocumentSegment(Base):
return text
@property
def attachments(self) -> list[dict[str, Any]]:
def attachments(self) -> list[AttachmentItem]:
# Use JOIN to fetch attachments in a single query instead of two separate queries
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
@ -891,7 +950,7 @@ class DocumentSegment(Base):
).all()
if not attachments_with_bindings:
return []
attachment_list = []
attachment_list: list[AttachmentItem] = []
for _, attachment in attachments_with_bindings:
upload_file_id = attachment.id
nonce = os.urandom(16).hex()
@ -952,15 +1011,15 @@ class ChildChunk(Base):
@property
def dataset(self):
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def document(self):
return db.session.query(Document).where(Document.id == self.document_id).first()
return db.session.scalar(select(Document).where(Document.id == self.document_id))
@property
def segment(self):
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
class AppDatasetJoin(TypeBase):
@ -1006,7 +1065,7 @@ class DatasetQuery(TypeBase):
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
source: Mapped[str] = mapped_column(String(255), nullable=False)
source: Mapped[str] = mapped_column(EnumText(DatasetQuerySource, length=255), nullable=False)
source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -1021,7 +1080,7 @@ class DatasetQuery(TypeBase):
if isinstance(queries, list):
for query in queries:
if query["content_type"] == QueryType.IMAGE_QUERY:
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
if file_info:
query["file_info"] = {
"id": file_info.id,
@ -1086,7 +1145,7 @@ class DatasetKeywordTable(TypeBase):
super().__init__(object_hook=object_hook, *args, **kwargs)
# get dataset
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
if not dataset:
return None
if self.data_source_type == "database":
@ -1151,7 +1210,9 @@ class DatasetCollectionBinding(TypeBase):
)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
type: Mapped[str] = mapped_column(
EnumText(CollectionBindingType, length=40), server_default=sa.text("'dataset'"), nullable=False
)
collection_name: Mapped[str] = mapped_column(String(64), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
@ -1261,7 +1322,7 @@ class ExternalKnowledgeApis(TypeBase):
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> ExternalKnowledgeApiDict:
return {
"id": self.id,
"tenant_id": self.tenant_id,
@ -1281,13 +1342,13 @@ class ExternalKnowledgeApis(TypeBase):
return None
@property
def dataset_bindings(self) -> list[dict[str, Any]]:
def dataset_bindings(self) -> list[DatasetBindingItem]:
external_knowledge_bindings = db.session.scalars(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
).all()
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
dataset_bindings: list[dict[str, Any]] = []
dataset_bindings: list[DatasetBindingItem] = []
for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name})
@ -1378,7 +1439,7 @@ class DatasetMetadata(TypeBase):
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[str] = mapped_column(EnumText(DatasetMetadataType, length=255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
@ -1480,7 +1541,7 @@ class PipelineCustomizedTemplate(TypeBase):
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
if account:
return account.name
return ""
@ -1515,7 +1576,7 @@ class Pipeline(TypeBase):
)
def retrieve_dataset(self, session: Session):
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
class DocumentPipelineExecutionLog(TypeBase):
@ -1605,7 +1666,9 @@ class DocumentSegmentSummary(Base):
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
status: Mapped[str] = mapped_column(
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
)
error: Mapped[str] = mapped_column(LongText, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)

View File

@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
ACCOUNT = "account"
END_USER = "end_user"
@classmethod
def _missing_(cls, value):
if value == "end-user":
return cls.END_USER
else:
return super()._missing_(value)
class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging"
@ -96,3 +103,223 @@ class ConversationStatus(StrEnum):
"""Conversation Status Enum"""
NORMAL = "normal"
class DataSourceType(StrEnum):
"""Data Source Type for Dataset and Document"""
UPLOAD_FILE = "upload_file"
NOTION_IMPORT = "notion_import"
WEBSITE_CRAWL = "website_crawl"
LOCAL_FILE = "local_file"
ONLINE_DOCUMENT = "online_document"
class ProcessRuleMode(StrEnum):
"""Dataset Process Rule Mode"""
AUTOMATIC = "automatic"
CUSTOM = "custom"
HIERARCHICAL = "hierarchical"
class IndexingStatus(StrEnum):
"""Document Indexing Status"""
WAITING = "waiting"
PARSING = "parsing"
CLEANING = "cleaning"
SPLITTING = "splitting"
INDEXING = "indexing"
PAUSED = "paused"
COMPLETED = "completed"
ERROR = "error"
class DocumentCreatedFrom(StrEnum):
"""Document Created From"""
WEB = "web"
API = "api"
RAG_PIPELINE = "rag-pipeline"
class ConversationFromSource(StrEnum):
"""Conversation / Message from_source"""
API = "api"
CONSOLE = "console"
class FeedbackFromSource(StrEnum):
"""MessageFeedback from_source"""
USER = "user"
ADMIN = "admin"
class FeedbackRating(StrEnum):
"""MessageFeedback rating"""
LIKE = "like"
DISLIKE = "dislike"
class InvokeFrom(StrEnum):
"""How a conversation/message was invoked"""
SERVICE_API = "service-api"
WEB_APP = "web-app"
TRIGGER = "trigger"
EXPLORE = "explore"
DEBUGGER = "debugger"
PUBLISHED_PIPELINE = "published"
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str) -> "InvokeFrom":
return cls(value)
def to_source(self) -> str:
source_mapping = {
InvokeFrom.WEB_APP: "web_app",
InvokeFrom.DEBUGGER: "dev",
InvokeFrom.EXPLORE: "explore_app",
InvokeFrom.TRIGGER: "trigger",
InvokeFrom.SERVICE_API: "api",
}
return source_mapping.get(self, "dev")
class DocumentDocType(StrEnum):
"""Document doc_type classification"""
BOOK = "book"
WEB_PAGE = "web_page"
PAPER = "paper"
SOCIAL_MEDIA_POST = "social_media_post"
WIKIPEDIA_ENTRY = "wikipedia_entry"
PERSONAL_DOCUMENT = "personal_document"
BUSINESS_DOCUMENT = "business_document"
IM_CHAT_LOG = "im_chat_log"
SYNCED_FROM_NOTION = "synced_from_notion"
SYNCED_FROM_GITHUB = "synced_from_github"
OTHERS = "others"
class TagType(StrEnum):
"""Tag type"""
KNOWLEDGE = "knowledge"
APP = "app"
class DatasetMetadataType(StrEnum):
"""Dataset metadata value type"""
STRING = "string"
NUMBER = "number"
TIME = "time"
class SegmentStatus(StrEnum):
"""Document segment status"""
WAITING = "waiting"
INDEXING = "indexing"
COMPLETED = "completed"
ERROR = "error"
PAUSED = "paused"
RE_SEGMENT = "re_segment"
class DatasetRuntimeMode(StrEnum):
"""Dataset runtime mode"""
GENERAL = "general"
RAG_PIPELINE = "rag_pipeline"
class CollectionBindingType(StrEnum):
"""Dataset collection binding type"""
DATASET = "dataset"
ANNOTATION = "annotation"
class DatasetQuerySource(StrEnum):
"""Dataset query source"""
HIT_TESTING = "hit_testing"
APP = "app"
class TidbAuthBindingStatus(StrEnum):
"""TiDB auth binding status"""
CREATING = "CREATING"
ACTIVE = "ACTIVE"
class MessageFileBelongsTo(StrEnum):
"""MessageFile belongs_to"""
USER = "user"
ASSISTANT = "assistant"
class CredentialSourceType(StrEnum):
"""Load balancing credential source type"""
PROVIDER = "provider"
CUSTOM_MODEL = "custom_model"
class PaymentStatus(StrEnum):
"""Provider order payment status"""
WAIT_PAY = "wait_pay"
PAID = "paid"
FAILED = "failed"
REFUNDED = "refunded"
class BannerStatus(StrEnum):
"""ExporleBanner status"""
ENABLED = "enabled"
DISABLED = "disabled"
class SummaryStatus(StrEnum):
"""Document segment summary status"""
NOT_STARTED = "not_started"
GENERATING = "generating"
COMPLETED = "completed"
ERROR = "error"
TIMEOUT = "timeout"
class MessageChainType(StrEnum):
"""Message chain type"""
SYSTEM = "system"
class ProviderQuotaType(StrEnum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value: str) -> "ProviderQuotaType":
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")

View File

@ -30,6 +30,15 @@ def _generate_token() -> str:
class HumanInputForm(DefaultFieldsMixin, Base):
__tablename__ = "human_input_forms"
__table_args__ = (
sa.Index(
"human_input_forms_workflow_run_id_node_id_idx",
"workflow_run_id",
"node_id",
),
sa.Index("human_input_forms_status_expiration_time_idx", "status", "expiration_time"),
sa.Index("human_input_forms_status_created_at_idx", "status", "created_at"),
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -84,6 +93,12 @@ class HumanInputForm(DefaultFieldsMixin, Base):
class HumanInputDelivery(DefaultFieldsMixin, Base):
__tablename__ = "human_input_form_deliveries"
__table_args__ = (
sa.Index(
None,
"form_id",
),
)
form_id: Mapped[str] = mapped_column(
StringUUID,
@ -181,6 +196,10 @@ RecipientPayload = Annotated[
class HumanInputFormRecipient(DefaultFieldsMixin, Base):
__tablename__ = "human_input_form_recipients"
__table_args__ = (
sa.Index(None, "form_id"),
sa.Index(None, "delivery_id"),
)
form_id: Mapped[str] = mapped_column(
StringUUID,

View File

@ -23,13 +23,27 @@ from core.tools.signature import sign_tool_file
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from dify_graph.file import helpers as file_helpers
from extensions.storage.storage_type import StorageType
from libs.helper import generate_string # type: ignore[import-not-found]
from libs.uuid_utils import uuidv7
from .account import Account, Tenant
from .base import Base, TypeBase, gen_uuidv4_string
from .engine import db
from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus
from .enums import (
AppMCPServerStatus,
AppStatus,
BannerStatus,
ConversationFromSource,
ConversationStatus,
CreatorUserRole,
FeedbackFromSource,
FeedbackRating,
InvokeFrom,
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
)
from .provider_ids import GenericProviderID
from .types import EnumText, LongText, StringUUID
@ -382,13 +396,12 @@ class App(Base):
@property
def site(self) -> Site | None:
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
return db.session.scalar(select(Site).where(Site.app_id == self.id))
@property
def app_model_config(self) -> AppModelConfig | None:
if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
return None
@ -397,7 +410,7 @@ class App(Base):
if self.workflow_id:
from .workflow import Workflow
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
return None
@ -407,8 +420,7 @@ class App(Base):
@property
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
@property
def is_agent(self) -> bool:
@ -548,9 +560,9 @@ class App(Base):
return deleted_tools
@property
def tags(self) -> list[Tag]:
tags = (
db.session.query(Tag)
def tags(self) -> Sequence[Tag]:
tags = db.session.scalars(
select(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@ -558,15 +570,14 @@ class App(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "app",
)
.all()
)
).all()
return tags or []
@property
def author_name(self) -> str | None:
if self.created_by:
account = db.session.query(Account).where(Account.id == self.created_by).first()
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
if account:
return account.name
@ -618,8 +629,7 @@ class AppModelConfig(TypeBase):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def model_dict(self) -> ModelConfig:
@ -654,8 +664,8 @@ class AppModelConfig(TypeBase):
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
@ -847,8 +857,7 @@ class RecommendedApp(Base): # bug
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
class InstalledApp(TypeBase):
@ -875,13 +884,11 @@ class InstalledApp(TypeBase):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class TrialApp(Base):
@ -901,8 +908,7 @@ class TrialApp(Base):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
class AccountTrialAppRecord(Base):
@ -921,13 +927,11 @@ class AccountTrialAppRecord(Base):
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
return db.session.scalar(select(App).where(App.id == self.app_id))
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class ExporleBanner(TypeBase):
@ -937,8 +941,11 @@ class ExporleBanner(TypeBase):
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
link: Mapped[str] = mapped_column(String(255), nullable=False)
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
status: Mapped[str] = mapped_column(
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
status: Mapped[BannerStatus] = mapped_column(
EnumText(BannerStatus, length=255),
nullable=False,
server_default=sa.text("'enabled'::character varying"),
default=BannerStatus.ENABLED,
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
@ -1019,10 +1026,12 @@ class Conversation(Base):
#
# Its value corresponds to the members of `InvokeFrom`.
# (api/core/app/entities/app_invoke_entities.py)
invoke_from = mapped_column(String(255), nullable=True)
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
# ref: ConversationSource.
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_source: Mapped[ConversationFromSource] = mapped_column(
EnumText(ConversationFromSource, length=255), nullable=False
)
from_end_user_id = mapped_column(StringUUID)
from_account_id = mapped_column(StringUUID)
read_at = mapped_column(sa.DateTime)
@ -1119,8 +1128,8 @@ class Conversation(Base):
else:
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
else:
app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
app_model_config = db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
)
if app_model_config:
model_config = app_model_config.to_dict()
@ -1143,36 +1152,43 @@ class Conversation(Base):
@property
def annotated(self):
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
return (
db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
)
or 0
) > 0
@property
def annotation(self):
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
@property
def message_count(self):
return db.session.query(Message).where(Message.conversation_id == self.id).count()
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
@property
def user_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == FeedbackRating.LIKE,
)
)
.count()
or 0
)
dislike = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == FeedbackRating.DISLIKE,
)
)
.count()
or 0
)
return {"like": like, "dislike": dislike}
@ -1180,23 +1196,25 @@ class Conversation(Base):
@property
def admin_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == FeedbackRating.LIKE,
)
)
.count()
or 0
)
dislike = (
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == FeedbackRating.DISLIKE,
)
)
.count()
or 0
)
return {"like": like, "dislike": dislike}
@ -1258,22 +1276,19 @@ class Conversation(Base):
@property
def first_message(self):
return (
db.session.query(Message)
.where(Message.conversation_id == self.id)
.order_by(Message.created_at.asc())
.first()
return db.session.scalar(
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
)
@property
def app(self) -> App | None:
with Session(db.engine, expire_on_commit=False) as session:
return session.query(App).where(App.id == self.app_id).first()
return session.scalar(select(App).where(App.id == self.app_id))
@property
def from_end_user_session_id(self):
if self.from_end_user_id:
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
if end_user:
return end_user.session_id
@ -1282,7 +1297,7 @@ class Conversation(Base):
@property
def from_account_name(self) -> str | None:
if self.from_account_id:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
if account:
return account.name
@ -1365,8 +1380,10 @@ class Message(Base):
)
error: Mapped[str | None] = mapped_column(LongText)
message_metadata: Mapped[str | None] = mapped_column(LongText)
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
from_source: Mapped[ConversationFromSource] = mapped_column(
EnumText(ConversationFromSource, length=255), nullable=False
)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
@ -1507,21 +1524,15 @@ class Message(Base):
@property
def user_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.first()
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
)
return feedback
@property
def admin_feedback(self):
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.first()
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
)
return feedback
@property
def feedbacks(self):
@ -1530,28 +1541,27 @@ class Message(Base):
@property
def annotation(self):
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
return annotation
@property
def annotation_hit_history(self):
annotation_history = (
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
annotation_history = db.session.scalar(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
)
if annotation_history:
annotation = (
db.session.query(MessageAnnotation)
.where(MessageAnnotation.id == annotation_history.annotation_id)
.first()
return db.session.scalar(
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
)
return annotation
return None
@property
def app_model_config(self):
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
if conversation:
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
return db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
)
return None
@ -1564,13 +1574,12 @@ class Message(Base):
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
def agent_thoughts(self) -> list[MessageAgentThought]:
return (
db.session.query(MessageAgentThought)
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
return db.session.scalars(
select(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
.order_by(MessageAgentThought.position.asc())
.all()
)
).all()
# FIXME (Novice) -- It's easy to cause N+1 query problem here.
@property
@ -1593,7 +1602,7 @@ class Message(Base):
from factories import file_factory
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
current_app = db.session.query(App).where(App.id == self.app_id).first()
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
if not current_app:
raise ValueError(f"App {self.app_id} not found")
@ -1739,8 +1748,8 @@ class MessageFeedback(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
rating: Mapped[str] = mapped_column(String(255), nullable=False)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False)
from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False)
content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
@ -1757,8 +1766,7 @@ class MessageFeedback(TypeBase):
@property
def from_account(self) -> Account | None:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
def to_dict(self) -> MessageFeedbackDict:
return {
@ -1794,7 +1802,9 @@ class MessageFile(TypeBase):
)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column(
EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None
)
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
@ -1831,13 +1841,11 @@ class MessageAnnotation(Base):
@property
def account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
@property
def annotation_create_account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class AppAnnotationHitHistory(TypeBase):
@ -1866,18 +1874,15 @@ class AppAnnotationHitHistory(TypeBase):
@property
def account(self):
account = (
db.session.query(Account)
return db.session.scalar(
select(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.where(MessageAnnotation.id == self.annotation_id)
.first()
)
return account
@property
def annotation_create_account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
return db.session.scalar(select(Account).where(Account.id == self.account_id))
class AppAnnotationSetting(TypeBase):
@ -1910,12 +1915,9 @@ class AppAnnotationSetting(TypeBase):
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding
collection_binding_detail = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == self.collection_binding_id)
.first()
return db.session.scalar(
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
)
return collection_binding_detail
class OperationLog(TypeBase):
@ -2021,7 +2023,9 @@ class AppMCPServer(TypeBase):
def generate_server_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
while (
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
) > 0:
result = generate_string(n)
return result
@ -2082,7 +2086,7 @@ class Site(Base):
def generate_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(Site).where(Site.code == result).count() > 0:
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
result = generate_string(n)
return result
@ -2130,7 +2134,7 @@ class UploadFile(Base):
# The `server_default` serves as a fallback mechanism.
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
key: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
@ -2174,7 +2178,7 @@ class UploadFile(Base):
self,
*,
tenant_id: str,
storage_type: str,
storage_type: StorageType,
key: str,
name: str,
size: int,
@ -2239,7 +2243,7 @@ class MessageChain(TypeBase):
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False)
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
output: Mapped[str | None] = mapped_column(LongText, nullable=True)
created_at: Mapped[datetime] = mapped_column(

View File

@ -6,13 +6,14 @@ from functools import cached_property
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, text
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .enums import CredentialSourceType, PaymentStatus
from .types import EnumText, LongText, StringUUID
@ -96,7 +97,7 @@ class Provider(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
@property
def credential_name(self):
@ -159,10 +160,8 @@ class ProviderModel(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return (
db.session.query(ProviderModelCredential)
.where(ProviderModelCredential.id == self.credential_id)
.first()
return db.session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
@property
@ -211,7 +210,7 @@ class TenantPreferredModelProvider(TypeBase):
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
preferred_provider_type: Mapped[ProviderType] = mapped_column(EnumText(ProviderType, length=40), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
@ -239,7 +238,9 @@ class ProviderOrder(TypeBase):
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
currency: Mapped[str | None] = mapped_column(String(40))
total_amount: Mapped[int | None] = mapped_column(sa.Integer)
payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'"))
payment_status: Mapped[PaymentStatus] = mapped_column(
EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'")
)
paid_at: Mapped[datetime | None] = mapped_column(DateTime)
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
@ -302,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase):
name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
credential_source_type: Mapped[CredentialSourceType | None] = mapped_column(
EnumText(CredentialSourceType, length=40), nullable=True, default=None
)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False

View File

@ -8,7 +8,7 @@ from uuid import uuid4
import sqlalchemy as sa
from deprecated import deprecated
from sqlalchemy import ForeignKey, String, func
from sqlalchemy import ForeignKey, String, func, select
from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase):
def user(self) -> Account | None:
if not self.user_id:
return None
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class ToolLabelBinding(TypeBase):
@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase):
@property
def user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
@property
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase):
@property
def app(self) -> App | None:
return db.session.query(App).where(App.id == self.app_id).first()
return db.session.scalar(select(App).where(App.id == self.app_id))
class MCPToolProvider(TypeBase):
@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase):
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
return db.session.scalar(select(Account).where(Account.id == self.user_id))
@property
def credentials(self) -> dict[str, Any]:

View File

@ -3,7 +3,7 @@ import time
from collections.abc import Mapping
from datetime import datetime
from functools import cached_property
from typing import Any, cast
from typing import Any, TypedDict, cast
from uuid import uuid4
import sqlalchemy as sa
@ -23,6 +23,47 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
from .model import Account
from .types import EnumText, LongText, StringUUID
TriggerJsonObject = dict[str, object]
TriggerCredentials = dict[str, str]
class WorkflowTriggerLogDict(TypedDict):
id: str
tenant_id: str
app_id: str
workflow_id: str
workflow_run_id: str | None
root_node_id: str | None
trigger_metadata: Any
trigger_type: str
trigger_data: Any
inputs: Any
outputs: Any
status: str
error: str | None
queue_name: str
celery_task_id: str | None
retry_count: int
elapsed_time: float | None
total_tokens: int | None
created_by_role: str
created_by: str
created_at: str | None
triggered_at: str | None
finished_at: str | None
class WorkflowSchedulePlanDict(TypedDict):
id: str
app_id: str
node_id: str
tenant_id: str
cron_expression: str
timezone: str
next_run_at: str | None
created_at: str
updated_at: str
class TriggerSubscription(TypeBase):
"""
@ -51,10 +92,14 @@ class TriggerSubscription(TypeBase):
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
)
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
parameters: Mapped[TriggerJsonObject] = mapped_column(
sa.JSON, nullable=False, comment="Subscription parameters JSON"
)
properties: Mapped[TriggerJsonObject] = mapped_column(
sa.JSON, nullable=False, comment="Subscription properties JSON"
)
credentials: Mapped[dict[str, Any]] = mapped_column(
credentials: Mapped[TriggerCredentials] = mapped_column(
sa.JSON, nullable=False, comment="Subscription credentials JSON"
)
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
@ -162,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase):
)
@property
def oauth_params(self) -> Mapping[str, Any]:
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
def oauth_params(self) -> Mapping[str, object]:
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
class WorkflowTriggerLog(TypeBase):
@ -250,7 +295,7 @@ class WorkflowTriggerLog(TypeBase):
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> WorkflowTriggerLogDict:
"""Convert to dictionary for API responses"""
return {
"id": self.id,
@ -481,7 +526,7 @@ class WorkflowSchedulePlan(TypeBase):
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
)
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> WorkflowSchedulePlanDict:
"""Convert to dictionary representation"""
return {
"id": self.id,

View File

@ -2,7 +2,7 @@ from datetime import datetime
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, func
from sqlalchemy import DateTime, func, select
from sqlalchemy.orm import Mapped, mapped_column
from .base import TypeBase
@ -38,7 +38,7 @@ class SavedMessage(TypeBase):
@property
def message(self):
return db.session.query(Message).where(Message.id == self.message_id).first()
return db.session.scalar(select(Message).where(Message.id == self.message_id))
class PinnedConversation(TypeBase):

View File

@ -1,9 +1,10 @@
import copy
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@ -19,21 +20,21 @@ from sqlalchemy import (
orm,
select,
)
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from dify_graph.constants import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey
from dify_graph.file.constants import maybe_file_object
from dify_graph.file.models import File
from dify_graph.variables import utils as variable_utils
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now
@ -59,6 +60,25 @@ from .types import EnumText, LongText, StringUUID
logger = logging.getLogger(__name__)
SerializedWorkflowValue = dict[str, Any]
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
class WorkflowContentDict(TypedDict):
graph: Mapping[str, Any]
features: dict[str, Any]
environment_variables: list[dict[str, Any]]
conversation_variables: list[dict[str, Any]]
rag_pipeline_variables: list[dict[str, Any]]
class WorkflowRunSummaryDict(TypedDict):
id: str
status: str
triggered_from: str
elapsed_time: float
total_tokens: int
def is_generation_outputs(outputs: Mapping[str, Any]) -> bool:
if not outputs:
@ -314,26 +334,40 @@ class Workflow(Base): # bug
def features(self) -> str:
"""
Convert old features structure to new features structure.
This property avoids rewriting the underlying JSON when normalization
produces no effective change, to prevent marking the row dirty on read.
"""
if not self._features:
return self._features
features = json.loads(self._features)
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
image_enabled = True
image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS))
image_transfer_methods = features["file_upload"]["image"].get(
"transfer_methods", ["remote_url", "local_file"]
)
features["file_upload"]["enabled"] = image_enabled
features["file_upload"]["number_limits"] = image_number_limits
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
"allowed_file_extensions", []
)
del features["file_upload"]["image"]
self._features = json.dumps(features)
# Parse once and deep-copy before normalization to detect in-place changes.
original_dict = self._decode_features_payload(self._features)
if original_dict is None:
return self._features
# Fast-path: if the legacy file_upload.image.enabled shape is absent, skip
# deep-copy and normalization entirely and return the stored JSON.
file_upload_payload = original_dict.get("file_upload")
if not isinstance(file_upload_payload, dict):
return self._features
file_upload = cast(dict[str, Any], file_upload_payload)
image_payload = file_upload.get("image")
if not isinstance(image_payload, dict):
return self._features
image = cast(dict[str, Any], image_payload)
if "enabled" not in image:
return self._features
normalized_dict = self._normalize_features_payload(copy.deepcopy(original_dict))
if normalized_dict == original_dict:
# No effective change; return stored JSON unchanged.
return self._features
# Normalization changed the payload: persist the normalized JSON.
self._features = json.dumps(normalized_dict)
return self._features
@features.setter
@ -347,6 +381,44 @@ class Workflow(Base): # bug
def get_feature(self, key: WorkflowFeatures) -> WorkflowFeature:
return WorkflowFeature.from_dict(self.features_dict.get(key.value))
@property
def serialized_features(self) -> str:
"""Return the stored features JSON without triggering compatibility rewrites."""
return self._features
@property
def normalized_features_dict(self) -> dict[str, Any]:
"""Decode features with legacy normalization without mutating the model state."""
if not self._features:
return {}
features = self._decode_features_payload(self._features)
return self._normalize_features_payload(features) if features is not None else {}
@staticmethod
def _decode_features_payload(features: str) -> dict[str, Any] | None:
"""Decode workflow features JSON when it contains an object payload."""
payload = json.loads(features)
return cast(dict[str, Any], payload) if isinstance(payload, dict) else None
@staticmethod
def _normalize_features_payload(features: dict[str, Any]) -> dict[str, Any]:
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS))
image_transfer_methods = features["file_upload"]["image"].get(
"transfer_methods", ["remote_url", "local_file"]
)
features["file_upload"]["enabled"] = True
features["file_upload"]["number_limits"] = image_number_limits
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
"allowed_file_extensions", []
)
del features["file_upload"]["image"]
return features
def walk_nodes(
self, specific_node_type: NodeType | None = None
) -> Generator[tuple[str, Mapping[str, Any]], None, None]:
@ -423,7 +495,7 @@ class Workflow(Base): # bug
def rag_pipeline_user_input_form(self) -> list:
# get user_input_form from start node
variables: list[Any] = self.rag_pipeline_variables
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
return variables
@ -466,17 +538,13 @@ class Workflow(Base): # bug
def environment_variables(
self,
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id
if not tenant_id:
return []
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
results = [
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
]
@ -536,14 +604,39 @@ class Workflow(Base): # bug
)
self._environment_variables = environment_variables_json
def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]:
@staticmethod
def normalize_environment_variable_mappings(
mappings: Sequence[Mapping[str, Any]],
) -> list[dict[str, Any]]:
"""Convert masked secret placeholders into the draft hidden sentinel.
Regular draft sync requests should preserve existing secrets without shipping
plaintext values back from the client. The dedicated restore endpoint now
copies published secrets server-side, so draft sync only needs to normalize
the UI mask into `HIDDEN_VALUE`.
"""
masked_secret_value = encrypter.full_mask_token()
normalized_mappings: list[dict[str, Any]] = []
for mapping in mappings:
normalized_mapping = dict(mapping)
if (
normalized_mapping.get("value_type") == SegmentType.SECRET.value
and normalized_mapping.get("value") == masked_secret_value
):
normalized_mapping["value"] = HIDDEN_VALUE
normalized_mappings.append(normalized_mapping)
return normalized_mappings
def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict:
environment_variables = list(self.environment_variables)
environment_variables = [
v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""})
for v in environment_variables
]
result = {
result: WorkflowContentDict = {
"graph": self.graph_dict,
"features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
@ -554,11 +647,7 @@ class Workflow(Base): # bug
@property
def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
return results
@ -570,22 +659,29 @@ class Workflow(Base): # bug
)
@property
def rag_pipeline_variables(self) -> list[dict]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._rag_pipeline_variables is None:
self._rag_pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
results = list(variables_dict.values())
return results
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
@rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: list[dict]) -> None:
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
self._rag_pipeline_variables = json.dumps(
{item["variable"]: item for item in values},
{
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
for rag_pipeline_variable in (
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
for item in values
)
},
ensure_ascii=False,
)
def copy_serialized_variable_storage_from(self, source_workflow: "Workflow") -> None:
"""Copy stored variable JSON directly for same-tenant restore flows."""
self._environment_variables = source_workflow._environment_variables
self._conversation_variables = source_workflow._conversation_variables
self._rag_pipeline_variables = source_workflow._rag_pipeline_variables
@staticmethod
def version_from_datetime(d: datetime) -> str:
return str(d)
@ -701,14 +797,14 @@ class WorkflowRun(Base):
def message(self):
from .model import Message
return (
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
return db.session.scalar(
select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id)
)
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def workflow(self):
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
@property
def outputs_as_generation(self):
@ -825,44 +921,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
__tablename__ = "workflow_node_executions"
@declared_attr.directive
@classmethod
def __table_args__(cls) -> Any:
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
# The first argument is the index name,
# which we leave as `None`` to allow auto-generation by the ORM.
None,
cls.tenant_id,
cls.workflow_id,
cls.node_id,
# MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes.
cls.created_at.desc(),
),
)
__table_args__ = (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
None,
"tenant_id",
"workflow_id",
"node_id",
sa.desc("created_at"),
),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID)
@ -971,8 +1059,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata:
datasource_info = execution_metadata["datasource_info"]
extras["icon"] = datasource_info.get("icon")
elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata:
trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {}
elif (
self.node_type == TRIGGER_PLUGIN_NODE_TYPE
and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata
):
trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {}
provider_id = trigger_info.get("provider_id")
if provider_id:
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
@ -1270,7 +1361,7 @@ class WorkflowArchiveLog(TypeBase):
)
@property
def workflow_run_summary(self) -> dict[str, Any]:
def workflow_run_summary(self) -> WorkflowRunSummaryDict:
return {
"id": self.workflow_run_id,
"status": self.run_status,
@ -1325,16 +1416,17 @@ class WorkflowDraftVariable(Base):
"""
@staticmethod
def unique_app_id_node_id_name() -> list[str]:
def unique_app_id_user_id_node_id_name() -> list[str]:
return [
"app_id",
"user_id",
"node_id",
"name",
]
__tablename__ = "workflow_draft_variables"
__table_args__ = (
UniqueConstraint(*unique_app_id_node_id_name()),
UniqueConstraint(*unique_app_id_user_id_node_id_name()),
Index("workflow_draft_variable_file_id_idx", "file_id"),
)
# Required for instance variable annotation.
@ -1360,6 +1452,11 @@ class WorkflowDraftVariable(Base):
# "`app_id` maps to the `id` field in the `model.App` model."
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# Owner of this draft variable.
#
# This field is nullable during migration and will be migrated to NOT NULL
# in a follow-up release.
user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
# `last_edited_at` records when the value of a given draft variable
# is edited.
@ -1612,6 +1709,7 @@ class WorkflowDraftVariable(Base):
cls,
*,
app_id: str,
user_id: str | None,
node_id: str,
name: str,
value: Segment,
@ -1625,6 +1723,7 @@ class WorkflowDraftVariable(Base):
variable.updated_at = naive_utc_now()
variable.description = description
variable.app_id = app_id
variable.user_id = user_id
variable.node_id = node_id
variable.name = name
variable.set_value(value)
@ -1638,12 +1737,14 @@ class WorkflowDraftVariable(Base):
cls,
*,
app_id: str,
user_id: str | None = None,
name: str,
value: Segment,
description: str = "",
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
user_id=user_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
name=name,
value=value,
@ -1658,6 +1759,7 @@ class WorkflowDraftVariable(Base):
cls,
*,
app_id: str,
user_id: str | None = None,
name: str,
value: Segment,
node_execution_id: str,
@ -1665,6 +1767,7 @@ class WorkflowDraftVariable(Base):
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
user_id=user_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
name=name,
node_execution_id=node_execution_id,
@ -1678,6 +1781,7 @@ class WorkflowDraftVariable(Base):
cls,
*,
app_id: str,
user_id: str | None = None,
node_id: str,
name: str,
value: Segment,
@ -1688,6 +1792,7 @@ class WorkflowDraftVariable(Base):
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
user_id=user_id,
node_id=node_id,
name=name,
node_execution_id=node_execution_id,