mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 06:28:05 +08:00
Merge main HEAD (segment 5) into sandboxed-agent-rebase
Resolve 83 conflicts: 10 backend, 62 frontend, 11 config/lock files. Preserve sandbox/agent/collaboration features while adopting main's UI refactorings (Dialog/AlertDialog/Popover), model provider updates, and enterprise features. Made-with: Cursor
This commit is contained in:
@ -177,13 +177,11 @@ class Account(UserMixin, TypeBase):
|
||||
|
||||
@classmethod
|
||||
def get_by_openid(cls, provider: str, open_id: str):
|
||||
account_integrate = (
|
||||
db.session.query(AccountIntegrate)
|
||||
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
.one_or_none()
|
||||
)
|
||||
account_integrate = db.session.execute(
|
||||
select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
).scalar_one_or_none()
|
||||
if account_integrate:
|
||||
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
|
||||
return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id))
|
||||
return None
|
||||
|
||||
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
||||
|
||||
@ -8,9 +8,10 @@ import os
|
||||
import pickle
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -30,13 +31,81 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
|
||||
from .account import Account
|
||||
from .base import Base, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .enums import (
|
||||
CollectionBindingType,
|
||||
CreatorUserRole,
|
||||
DatasetMetadataType,
|
||||
DatasetQuerySource,
|
||||
DatasetRuntimeMode,
|
||||
DataSourceType,
|
||||
DocumentCreatedFrom,
|
||||
DocumentDocType,
|
||||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SummaryStatus,
|
||||
)
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreProcessingRuleItem(TypedDict):
|
||||
id: str
|
||||
enabled: bool
|
||||
|
||||
|
||||
class SegmentationConfig(TypedDict):
|
||||
delimiter: str
|
||||
max_tokens: int
|
||||
chunk_overlap: int
|
||||
|
||||
|
||||
class AutomaticRulesConfig(TypedDict):
|
||||
pre_processing_rules: list[PreProcessingRuleItem]
|
||||
segmentation: SegmentationConfig
|
||||
|
||||
|
||||
class ProcessRuleDict(TypedDict):
|
||||
id: str
|
||||
dataset_id: str
|
||||
mode: str
|
||||
rules: dict[str, Any] | None
|
||||
|
||||
|
||||
class DocMetadataDetailItem(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
value: Any
|
||||
|
||||
|
||||
class AttachmentItem(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
|
||||
|
||||
class DatasetBindingItem(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ExternalKnowledgeApiDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
name: str
|
||||
description: str
|
||||
settings: dict[str, Any] | None
|
||||
dataset_bindings: list[DatasetBindingItem]
|
||||
created_by: str
|
||||
created_at: str
|
||||
|
||||
|
||||
class DatasetPermissionEnum(enum.StrEnum):
|
||||
ONLY_ME = "only_me"
|
||||
ALL_TEAM = "all_team_members"
|
||||
@ -65,7 +134,7 @@ class Dataset(Base):
|
||||
server_default=sa.text("'only_me'"),
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
data_source_type = mapped_column(String(255))
|
||||
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
|
||||
indexing_technique: Mapped[str | None] = mapped_column(String(255))
|
||||
index_struct = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
@ -82,7 +151,9 @@ class Dataset(Base):
|
||||
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
icon_info = mapped_column(AdjustedJSON, nullable=True)
|
||||
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
|
||||
runtime_mode = mapped_column(
|
||||
EnumText(DatasetRuntimeMode, length=255), nullable=True, server_default=sa.text("'general'")
|
||||
)
|
||||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
@ -90,30 +161,25 @@ class Dataset(Base):
|
||||
|
||||
@property
|
||||
def total_documents(self):
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def total_available_documents(self):
|
||||
return (
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
dataset_keyword_table = (
|
||||
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
|
||||
)
|
||||
if dataset_keyword_table:
|
||||
return dataset_keyword_table
|
||||
|
||||
return None
|
||||
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
|
||||
|
||||
@property
|
||||
def index_struct_dict(self):
|
||||
@ -140,64 +206,66 @@ class Dataset(Base):
|
||||
|
||||
@property
|
||||
def latest_process_rule(self):
|
||||
return (
|
||||
db.session.query(DatasetProcessRule)
|
||||
return db.session.scalar(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == self.id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@property
|
||||
def app_count(self):
|
||||
return (
|
||||
db.session.query(func.count(AppDatasetJoin.id))
|
||||
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
|
||||
.scalar()
|
||||
db.session.scalar(
|
||||
select(func.count(AppDatasetJoin.id)).where(
|
||||
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def document_count(self):
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def available_document_count(self):
|
||||
return (
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def available_segment_count(self):
|
||||
return (
|
||||
db.session.query(func.count(DocumentSegment.id))
|
||||
.where(
|
||||
DocumentSegment.dataset_id == self.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.dataset_id == self.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def word_count(self):
|
||||
return (
|
||||
db.session.query(Document)
|
||||
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
|
||||
.where(Document.dataset_id == self.id)
|
||||
.scalar()
|
||||
return db.session.scalar(
|
||||
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
def doc_form(self) -> str | None:
|
||||
if self.chunk_structure:
|
||||
return self.chunk_structure
|
||||
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
|
||||
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
|
||||
if document:
|
||||
return document.doc_form
|
||||
return None
|
||||
@ -215,8 +283,8 @@ class Dataset(Base):
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == self.id,
|
||||
@ -224,8 +292,7 @@ class Dataset(Base):
|
||||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == "knowledge",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return tags or []
|
||||
|
||||
@ -233,8 +300,8 @@ class Dataset(Base):
|
||||
def external_knowledge_info(self):
|
||||
if self.provider != "external":
|
||||
return None
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
return None
|
||||
@ -255,7 +322,7 @@ class Dataset(Base):
|
||||
@property
|
||||
def is_published(self):
|
||||
if self.pipeline_id:
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
|
||||
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
|
||||
if pipeline:
|
||||
return pipeline.is_published
|
||||
return False
|
||||
@ -327,14 +394,14 @@ class DatasetProcessRule(Base): # bug
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
rules = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
MODES = ["automatic", "custom", "hierarchical"]
|
||||
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
|
||||
AUTOMATIC_RULES: dict[str, Any] = {
|
||||
AUTOMATIC_RULES: AutomaticRulesConfig = {
|
||||
"pre_processing_rules": [
|
||||
{"id": "remove_extra_spaces", "enabled": True},
|
||||
{"id": "remove_urls_emails", "enabled": False},
|
||||
@ -342,7 +409,7 @@ class DatasetProcessRule(Base): # bug
|
||||
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> ProcessRuleDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"dataset_id": self.dataset_id,
|
||||
@ -373,12 +440,12 @@ class Document(Base):
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
data_source_type: Mapped[str] = mapped_column(EnumText(DataSourceType, length=255), nullable=False)
|
||||
data_source_info = mapped_column(LongText, nullable=True)
|
||||
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
|
||||
batch: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[str] = mapped_column(EnumText(DocumentCreatedFrom, length=255), nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_api_request_id = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@ -412,7 +479,9 @@ class Document(Base):
|
||||
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# basic fields
|
||||
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'"))
|
||||
indexing_status = mapped_column(
|
||||
EnumText(IndexingStatus, length=255), nullable=False, server_default=sa.text("'waiting'")
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
@ -423,7 +492,7 @@ class Document(Base):
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
doc_type = mapped_column(String(40), nullable=True)
|
||||
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
@ -466,10 +535,8 @@ class Document(Base):
|
||||
if self.data_source_info:
|
||||
if self.data_source_type == "upload_file":
|
||||
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
.one_or_none()
|
||||
file_detail = db.session.scalar(
|
||||
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
)
|
||||
if file_detail:
|
||||
return {
|
||||
@ -502,24 +569,23 @@ class Document(Base):
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
|
||||
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
|
||||
@property
|
||||
def segment_count(self):
|
||||
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
|
||||
return (
|
||||
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
|
||||
)
|
||||
|
||||
@property
|
||||
def hit_count(self):
|
||||
return (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
|
||||
.where(DocumentSegment.document_id == self.id)
|
||||
.scalar()
|
||||
return db.session.scalar(
|
||||
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
def uploader(self):
|
||||
user = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
return user.name if user else None
|
||||
|
||||
@property
|
||||
@ -531,19 +597,18 @@ class Document(Base):
|
||||
return self.updated_at
|
||||
|
||||
@property
|
||||
def doc_metadata_details(self) -> list[dict[str, Any]] | None:
|
||||
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
|
||||
if self.doc_metadata:
|
||||
document_metadatas = (
|
||||
db.session.query(DatasetMetadata)
|
||||
document_metadatas = db.session.scalars(
|
||||
select(DatasetMetadata)
|
||||
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
|
||||
)
|
||||
.all()
|
||||
)
|
||||
metadata_list: list[dict[str, Any]] = []
|
||||
).all()
|
||||
metadata_list: list[DocMetadataDetailItem] = []
|
||||
for metadata in document_metadatas:
|
||||
metadata_dict: dict[str, Any] = {
|
||||
metadata_dict: DocMetadataDetailItem = {
|
||||
"id": metadata.id,
|
||||
"name": metadata.name,
|
||||
"type": metadata.type,
|
||||
@ -557,13 +622,13 @@ class Document(Base):
|
||||
return None
|
||||
|
||||
@property
|
||||
def process_rule_dict(self) -> dict[str, Any] | None:
|
||||
def process_rule_dict(self) -> ProcessRuleDict | None:
|
||||
if self.dataset_process_rule_id and self.dataset_process_rule:
|
||||
return self.dataset_process_rule.to_dict()
|
||||
return None
|
||||
|
||||
def get_built_in_fields(self) -> list[dict[str, Any]]:
|
||||
built_in_fields: list[dict[str, Any]] = []
|
||||
def get_built_in_fields(self) -> list[DocMetadataDetailItem]:
|
||||
built_in_fields: list[DocMetadataDetailItem] = []
|
||||
built_in_fields.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
@ -736,7 +801,7 @@ class DocumentSegment(Base):
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'"))
|
||||
status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"))
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
@ -771,7 +836,7 @@ class DocumentSegment(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def child_chunks(self) -> list[Any]:
|
||||
def child_chunks(self) -> Sequence[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
@ -780,16 +845,13 @@ class DocumentSegment(Base):
|
||||
if rules_dict:
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
|
||||
).all()
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
def get_child_chunks(self) -> list[Any]:
|
||||
def get_child_chunks(self) -> Sequence[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
@ -798,12 +860,9 @@ class DocumentSegment(Base):
|
||||
if rules_dict:
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
|
||||
).all()
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
@ -877,7 +936,7 @@ class DocumentSegment(Base):
|
||||
return text
|
||||
|
||||
@property
|
||||
def attachments(self) -> list[dict[str, Any]]:
|
||||
def attachments(self) -> list[AttachmentItem]:
|
||||
# Use JOIN to fetch attachments in a single query instead of two separate queries
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
@ -891,7 +950,7 @@ class DocumentSegment(Base):
|
||||
).all()
|
||||
if not attachments_with_bindings:
|
||||
return []
|
||||
attachment_list = []
|
||||
attachment_list: list[AttachmentItem] = []
|
||||
for _, attachment in attachments_with_bindings:
|
||||
upload_file_id = attachment.id
|
||||
nonce = os.urandom(16).hex()
|
||||
@ -952,15 +1011,15 @@ class ChildChunk(Base):
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
|
||||
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
return db.session.query(Document).where(Document.id == self.document_id).first()
|
||||
return db.session.scalar(select(Document).where(Document.id == self.document_id))
|
||||
|
||||
@property
|
||||
def segment(self):
|
||||
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
|
||||
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
|
||||
|
||||
|
||||
class AppDatasetJoin(TypeBase):
|
||||
@ -1006,7 +1065,7 @@ class DatasetQuery(TypeBase):
|
||||
)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source: Mapped[str] = mapped_column(EnumText(DatasetQuerySource, length=255), nullable=False)
|
||||
source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
@ -1021,7 +1080,7 @@ class DatasetQuery(TypeBase):
|
||||
if isinstance(queries, list):
|
||||
for query in queries:
|
||||
if query["content_type"] == QueryType.IMAGE_QUERY:
|
||||
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
|
||||
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
|
||||
if file_info:
|
||||
query["file_info"] = {
|
||||
"id": file_info.id,
|
||||
@ -1086,7 +1145,7 @@ class DatasetKeywordTable(TypeBase):
|
||||
super().__init__(object_hook=object_hook, *args, **kwargs)
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
|
||||
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
if not dataset:
|
||||
return None
|
||||
if self.data_source_type == "database":
|
||||
@ -1151,7 +1210,9 @@ class DatasetCollectionBinding(TypeBase):
|
||||
)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
|
||||
type: Mapped[str] = mapped_column(
|
||||
EnumText(CollectionBindingType, length=40), server_default=sa.text("'dataset'"), nullable=False
|
||||
)
|
||||
collection_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
@ -1261,7 +1322,7 @@ class ExternalKnowledgeApis(TypeBase):
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> ExternalKnowledgeApiDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
@ -1281,13 +1342,13 @@ class ExternalKnowledgeApis(TypeBase):
|
||||
return None
|
||||
|
||||
@property
|
||||
def dataset_bindings(self) -> list[dict[str, Any]]:
|
||||
def dataset_bindings(self) -> list[DatasetBindingItem]:
|
||||
external_knowledge_bindings = db.session.scalars(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
|
||||
).all()
|
||||
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
|
||||
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
dataset_bindings: list[dict[str, Any]] = []
|
||||
dataset_bindings: list[DatasetBindingItem] = []
|
||||
for dataset in datasets:
|
||||
dataset_bindings.append({"id": dataset.id, "name": dataset.name})
|
||||
|
||||
@ -1378,7 +1439,7 @@ class DatasetMetadata(TypeBase):
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[str] = mapped_column(EnumText(DatasetMetadataType, length=255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||
@ -1480,7 +1541,7 @@ class PipelineCustomizedTemplate(TypeBase):
|
||||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
@ -1515,7 +1576,7 @@ class Pipeline(TypeBase):
|
||||
)
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLog(TypeBase):
|
||||
@ -1605,7 +1666,9 @@ class DocumentSegmentSummary(Base):
|
||||
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
|
||||
status: Mapped[str] = mapped_column(
|
||||
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
|
||||
)
|
||||
error: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
else:
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
@ -96,3 +103,223 @@ class ConversationStatus(StrEnum):
|
||||
"""Conversation Status Enum"""
|
||||
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
class DataSourceType(StrEnum):
|
||||
"""Data Source Type for Dataset and Document"""
|
||||
|
||||
UPLOAD_FILE = "upload_file"
|
||||
NOTION_IMPORT = "notion_import"
|
||||
WEBSITE_CRAWL = "website_crawl"
|
||||
LOCAL_FILE = "local_file"
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
|
||||
|
||||
class ProcessRuleMode(StrEnum):
|
||||
"""Dataset Process Rule Mode"""
|
||||
|
||||
AUTOMATIC = "automatic"
|
||||
CUSTOM = "custom"
|
||||
HIERARCHICAL = "hierarchical"
|
||||
|
||||
|
||||
class IndexingStatus(StrEnum):
|
||||
"""Document Indexing Status"""
|
||||
|
||||
WAITING = "waiting"
|
||||
PARSING = "parsing"
|
||||
CLEANING = "cleaning"
|
||||
SPLITTING = "splitting"
|
||||
INDEXING = "indexing"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class DocumentCreatedFrom(StrEnum):
|
||||
"""Document Created From"""
|
||||
|
||||
WEB = "web"
|
||||
API = "api"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class ConversationFromSource(StrEnum):
|
||||
"""Conversation / Message from_source"""
|
||||
|
||||
API = "api"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class FeedbackFromSource(StrEnum):
|
||||
"""MessageFeedback from_source"""
|
||||
|
||||
USER = "user"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class FeedbackRating(StrEnum):
|
||||
"""MessageFeedback rating"""
|
||||
|
||||
LIKE = "like"
|
||||
DISLIKE = "dislike"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
"""How a conversation/message was invoked"""
|
||||
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
return cls(value)
|
||||
|
||||
def to_source(self) -> str:
|
||||
source_mapping = {
|
||||
InvokeFrom.WEB_APP: "web_app",
|
||||
InvokeFrom.DEBUGGER: "dev",
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
class DocumentDocType(StrEnum):
|
||||
"""Document doc_type classification"""
|
||||
|
||||
BOOK = "book"
|
||||
WEB_PAGE = "web_page"
|
||||
PAPER = "paper"
|
||||
SOCIAL_MEDIA_POST = "social_media_post"
|
||||
WIKIPEDIA_ENTRY = "wikipedia_entry"
|
||||
PERSONAL_DOCUMENT = "personal_document"
|
||||
BUSINESS_DOCUMENT = "business_document"
|
||||
IM_CHAT_LOG = "im_chat_log"
|
||||
SYNCED_FROM_NOTION = "synced_from_notion"
|
||||
SYNCED_FROM_GITHUB = "synced_from_github"
|
||||
OTHERS = "others"
|
||||
|
||||
|
||||
class TagType(StrEnum):
|
||||
"""Tag type"""
|
||||
|
||||
KNOWLEDGE = "knowledge"
|
||||
APP = "app"
|
||||
|
||||
|
||||
class DatasetMetadataType(StrEnum):
|
||||
"""Dataset metadata value type"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
TIME = "time"
|
||||
|
||||
|
||||
class SegmentStatus(StrEnum):
|
||||
"""Document segment status"""
|
||||
|
||||
WAITING = "waiting"
|
||||
INDEXING = "indexing"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
PAUSED = "paused"
|
||||
RE_SEGMENT = "re_segment"
|
||||
|
||||
|
||||
class DatasetRuntimeMode(StrEnum):
|
||||
"""Dataset runtime mode"""
|
||||
|
||||
GENERAL = "general"
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class CollectionBindingType(StrEnum):
|
||||
"""Dataset collection binding type"""
|
||||
|
||||
DATASET = "dataset"
|
||||
ANNOTATION = "annotation"
|
||||
|
||||
|
||||
class DatasetQuerySource(StrEnum):
|
||||
"""Dataset query source"""
|
||||
|
||||
HIT_TESTING = "hit_testing"
|
||||
APP = "app"
|
||||
|
||||
|
||||
class TidbAuthBindingStatus(StrEnum):
|
||||
"""TiDB auth binding status"""
|
||||
|
||||
CREATING = "CREATING"
|
||||
ACTIVE = "ACTIVE"
|
||||
|
||||
|
||||
class MessageFileBelongsTo(StrEnum):
|
||||
"""MessageFile belongs_to"""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class CredentialSourceType(StrEnum):
|
||||
"""Load balancing credential source type"""
|
||||
|
||||
PROVIDER = "provider"
|
||||
CUSTOM_MODEL = "custom_model"
|
||||
|
||||
|
||||
class PaymentStatus(StrEnum):
|
||||
"""Provider order payment status"""
|
||||
|
||||
WAIT_PAY = "wait_pay"
|
||||
PAID = "paid"
|
||||
FAILED = "failed"
|
||||
REFUNDED = "refunded"
|
||||
|
||||
|
||||
class BannerStatus(StrEnum):
|
||||
"""ExporleBanner status"""
|
||||
|
||||
ENABLED = "enabled"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class SummaryStatus(StrEnum):
|
||||
"""Document segment summary status"""
|
||||
|
||||
NOT_STARTED = "not_started"
|
||||
GENERATING = "generating"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
class MessageChainType(StrEnum):
|
||||
"""Message chain type"""
|
||||
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = "paid"
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value: str) -> "ProviderQuotaType":
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
@ -30,6 +30,15 @@ def _generate_token() -> str:
|
||||
|
||||
class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_forms"
|
||||
__table_args__ = (
|
||||
sa.Index(
|
||||
"human_input_forms_workflow_run_id_node_id_idx",
|
||||
"workflow_run_id",
|
||||
"node_id",
|
||||
),
|
||||
sa.Index("human_input_forms_status_expiration_time_idx", "status", "expiration_time"),
|
||||
sa.Index("human_input_forms_status_created_at_idx", "status", "created_at"),
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
@ -84,6 +93,12 @@ class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
|
||||
class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_form_deliveries"
|
||||
__table_args__ = (
|
||||
sa.Index(
|
||||
None,
|
||||
"form_id",
|
||||
),
|
||||
)
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
@ -181,6 +196,10 @@ RecipientPayload = Annotated[
|
||||
|
||||
class HumanInputFormRecipient(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_form_recipients"
|
||||
__table_args__ = (
|
||||
sa.Index(None, "form_id"),
|
||||
sa.Index(None, "delivery_id"),
|
||||
)
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
|
||||
@ -23,13 +23,27 @@ from core.tools.signature import sign_tool_file
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .base import Base, TypeBase, gen_uuidv4_string
|
||||
from .engine import db
|
||||
from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus
|
||||
from .enums import (
|
||||
AppMCPServerStatus,
|
||||
AppStatus,
|
||||
BannerStatus,
|
||||
ConversationFromSource,
|
||||
ConversationStatus,
|
||||
CreatorUserRole,
|
||||
FeedbackFromSource,
|
||||
FeedbackRating,
|
||||
InvokeFrom,
|
||||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
MessageStatus,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
@ -382,13 +396,12 @@ class App(Base):
|
||||
|
||||
@property
|
||||
def site(self) -> Site | None:
|
||||
site = db.session.query(Site).where(Site.app_id == self.id).first()
|
||||
return site
|
||||
return db.session.scalar(select(Site).where(Site.app_id == self.id))
|
||||
|
||||
@property
|
||||
def app_model_config(self) -> AppModelConfig | None:
|
||||
if self.app_model_config_id:
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
|
||||
|
||||
return None
|
||||
|
||||
@ -397,7 +410,7 @@ class App(Base):
|
||||
if self.workflow_id:
|
||||
from .workflow import Workflow
|
||||
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
|
||||
return None
|
||||
|
||||
@ -407,8 +420,7 @@ class App(Base):
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
@property
|
||||
def is_agent(self) -> bool:
|
||||
@ -548,9 +560,9 @@ class App(Base):
|
||||
return deleted_tools
|
||||
|
||||
@property
|
||||
def tags(self) -> list[Tag]:
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
def tags(self) -> Sequence[Tag]:
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == self.id,
|
||||
@ -558,15 +570,14 @@ class App(Base):
|
||||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == "app",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return tags or []
|
||||
|
||||
@property
|
||||
def author_name(self) -> str | None:
|
||||
if self.created_by:
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
if account:
|
||||
return account.name
|
||||
|
||||
@ -618,8 +629,7 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def model_dict(self) -> ModelConfig:
|
||||
@ -654,8 +664,8 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
|
||||
)
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
@ -847,8 +857,7 @@ class RecommendedApp(Base): # bug
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class InstalledApp(TypeBase):
|
||||
@ -875,13 +884,11 @@ class InstalledApp(TypeBase):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
@ -901,8 +908,7 @@ class TrialApp(Base):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
@ -921,13 +927,11 @@ class AccountTrialAppRecord(Base):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
user = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return user
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class ExporleBanner(TypeBase):
|
||||
@ -937,8 +941,11 @@ class ExporleBanner(TypeBase):
|
||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
|
||||
status: Mapped[BannerStatus] = mapped_column(
|
||||
EnumText(BannerStatus, length=255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'enabled'::character varying"),
|
||||
default=BannerStatus.ENABLED,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
@ -1019,10 +1026,12 @@ class Conversation(Base):
|
||||
#
|
||||
# Its value corresponds to the members of `InvokeFrom`.
|
||||
# (api/core/app/entities/app_invoke_entities.py)
|
||||
invoke_from = mapped_column(String(255), nullable=True)
|
||||
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
|
||||
|
||||
# ref: ConversationSource.
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_source: Mapped[ConversationFromSource] = mapped_column(
|
||||
EnumText(ConversationFromSource, length=255), nullable=False
|
||||
)
|
||||
from_end_user_id = mapped_column(StringUUID)
|
||||
from_account_id = mapped_column(StringUUID)
|
||||
read_at = mapped_column(sa.DateTime)
|
||||
@ -1119,8 +1128,8 @@ class Conversation(Base):
|
||||
else:
|
||||
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
|
||||
else:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
app_model_config = db.session.scalar(
|
||||
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
|
||||
)
|
||||
if app_model_config:
|
||||
model_config = app_model_config.to_dict()
|
||||
@ -1143,36 +1152,43 @@ class Conversation(Base):
|
||||
|
||||
@property
|
||||
def annotated(self):
|
||||
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
|
||||
)
|
||||
or 0
|
||||
) > 0
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
|
||||
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
|
||||
|
||||
@property
|
||||
def message_count(self):
|
||||
return db.session.query(Message).where(Message.conversation_id == self.id).count()
|
||||
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
|
||||
|
||||
@property
|
||||
def user_feedback_stats(self):
|
||||
like = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == FeedbackRating.LIKE,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
dislike = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == FeedbackRating.DISLIKE,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {"like": like, "dislike": dislike}
|
||||
@ -1180,23 +1196,25 @@ class Conversation(Base):
|
||||
@property
|
||||
def admin_feedback_stats(self):
|
||||
like = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == FeedbackRating.LIKE,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
dislike = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == FeedbackRating.DISLIKE,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {"like": like, "dislike": dislike}
|
||||
@ -1258,22 +1276,19 @@ class Conversation(Base):
|
||||
|
||||
@property
|
||||
def first_message(self):
|
||||
return (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == self.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return session.query(App).where(App.id == self.app_id).first()
|
||||
return session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
@property
|
||||
def from_end_user_session_id(self):
|
||||
if self.from_end_user_id:
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
|
||||
if end_user:
|
||||
return end_user.session_id
|
||||
|
||||
@ -1282,7 +1297,7 @@ class Conversation(Base):
|
||||
@property
|
||||
def from_account_name(self) -> str | None:
|
||||
if self.from_account_id:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
|
||||
if account:
|
||||
return account.name
|
||||
|
||||
@ -1365,8 +1380,10 @@ class Message(Base):
|
||||
)
|
||||
error: Mapped[str | None] = mapped_column(LongText)
|
||||
message_metadata: Mapped[str | None] = mapped_column(LongText)
|
||||
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
|
||||
from_source: Mapped[ConversationFromSource] = mapped_column(
|
||||
EnumText(ConversationFromSource, length=255), nullable=False
|
||||
)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||
@ -1507,21 +1524,15 @@ class Message(Base):
|
||||
|
||||
@property
|
||||
def user_feedback(self):
|
||||
feedback = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
|
||||
)
|
||||
return feedback
|
||||
|
||||
@property
|
||||
def admin_feedback(self):
|
||||
feedback = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
|
||||
)
|
||||
return feedback
|
||||
|
||||
@property
|
||||
def feedbacks(self):
|
||||
@ -1530,28 +1541,27 @@ class Message(Base):
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
|
||||
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
|
||||
return annotation
|
||||
|
||||
@property
|
||||
def annotation_hit_history(self):
|
||||
annotation_history = (
|
||||
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
|
||||
annotation_history = db.session.scalar(
|
||||
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
|
||||
)
|
||||
if annotation_history:
|
||||
annotation = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.where(MessageAnnotation.id == annotation_history.annotation_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
|
||||
)
|
||||
return annotation
|
||||
return None
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
|
||||
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
|
||||
if conversation:
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
|
||||
return db.session.scalar(
|
||||
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -1564,13 +1574,12 @@ class Message(Base):
|
||||
return json.loads(self.message_metadata) if self.message_metadata else {}
|
||||
|
||||
@property
|
||||
def agent_thoughts(self) -> list[MessageAgentThought]:
|
||||
return (
|
||||
db.session.query(MessageAgentThought)
|
||||
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
|
||||
return db.session.scalars(
|
||||
select(MessageAgentThought)
|
||||
.where(MessageAgentThought.message_id == self.id)
|
||||
.order_by(MessageAgentThought.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# FIXME (Novice) -- It's easy to cause N+1 query problem here.
|
||||
@property
|
||||
@ -1593,7 +1602,7 @@ class Message(Base):
|
||||
from factories import file_factory
|
||||
|
||||
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
|
||||
current_app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
if not current_app:
|
||||
raise ValueError(f"App {self.app_id} not found")
|
||||
|
||||
@ -1739,8 +1748,8 @@ class MessageFeedback(TypeBase):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
rating: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False)
|
||||
from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False)
|
||||
content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
@ -1757,8 +1766,7 @@ class MessageFeedback(TypeBase):
|
||||
|
||||
@property
|
||||
def from_account(self) -> Account | None:
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
|
||||
|
||||
def to_dict(self) -> MessageFeedbackDict:
|
||||
return {
|
||||
@ -1794,7 +1802,9 @@ class MessageFile(TypeBase):
|
||||
)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column(
|
||||
EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None
|
||||
)
|
||||
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
@ -1831,13 +1841,11 @@ class MessageAnnotation(Base):
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class AppAnnotationHitHistory(TypeBase):
|
||||
@ -1866,18 +1874,15 @@ class AppAnnotationHitHistory(TypeBase):
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
account = (
|
||||
db.session.query(Account)
|
||||
return db.session.scalar(
|
||||
select(Account)
|
||||
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
|
||||
.where(MessageAnnotation.id == self.annotation_id)
|
||||
.first()
|
||||
)
|
||||
return account
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
|
||||
|
||||
class AppAnnotationSetting(TypeBase):
|
||||
@ -1910,12 +1915,9 @@ class AppAnnotationSetting(TypeBase):
|
||||
def collection_binding_detail(self):
|
||||
from .dataset import DatasetCollectionBinding
|
||||
|
||||
collection_binding_detail = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == self.collection_binding_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
|
||||
)
|
||||
return collection_binding_detail
|
||||
|
||||
|
||||
class OperationLog(TypeBase):
|
||||
@ -2021,7 +2023,9 @@ class AppMCPServer(TypeBase):
|
||||
def generate_server_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
|
||||
while (
|
||||
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
|
||||
) > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
@ -2082,7 +2086,7 @@ class Site(Base):
|
||||
def generate_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(Site).where(Site.code == result).count() > 0:
|
||||
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
@ -2130,7 +2134,7 @@ class UploadFile(Base):
|
||||
# The `server_default` serves as a fallback mechanism.
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
@ -2174,7 +2178,7 @@ class UploadFile(Base):
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
storage_type: str,
|
||||
storage_type: StorageType,
|
||||
key: str,
|
||||
name: str,
|
||||
size: int,
|
||||
@ -2239,7 +2243,7 @@ class MessageChain(TypeBase):
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False)
|
||||
input: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
output: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@ -6,13 +6,14 @@ from functools import cached_property
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, text
|
||||
from sqlalchemy import DateTime, String, func, select, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import CredentialSourceType, PaymentStatus
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
@ -96,7 +97,7 @@ class Provider(TypeBase):
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
|
||||
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
@ -159,10 +160,8 @@ class ProviderModel(TypeBase):
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.where(ProviderModelCredential.id == self.credential_id)
|
||||
.first()
|
||||
return db.session.scalar(
|
||||
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
|
||||
)
|
||||
|
||||
@property
|
||||
@ -211,7 +210,7 @@ class TenantPreferredModelProvider(TypeBase):
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
preferred_provider_type: Mapped[ProviderType] = mapped_column(EnumText(ProviderType, length=40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
@ -239,7 +238,9 @@ class ProviderOrder(TypeBase):
|
||||
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
|
||||
currency: Mapped[str | None] = mapped_column(String(40))
|
||||
total_amount: Mapped[int | None] = mapped_column(sa.Integer)
|
||||
payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'"))
|
||||
payment_status: Mapped[PaymentStatus] = mapped_column(
|
||||
EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'")
|
||||
)
|
||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
@ -302,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase):
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
|
||||
credential_source_type: Mapped[CredentialSourceType | None] = mapped_column(
|
||||
EnumText(CredentialSourceType, length=40), nullable=True, default=None
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
|
||||
@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy import ForeignKey, String, func
|
||||
from sqlalchemy import ForeignKey, String, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase):
|
||||
def user(self) -> Account | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class ToolLabelBinding(TypeBase):
|
||||
@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase):
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
@property
|
||||
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
|
||||
@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.query(App).where(App.id == self.app_id).first()
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class MCPToolProvider(TypeBase):
|
||||
@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase):
|
||||
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -23,6 +23,47 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
|
||||
from .model import Account
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
TriggerJsonObject = dict[str, object]
|
||||
TriggerCredentials = dict[str, str]
|
||||
|
||||
|
||||
class WorkflowTriggerLogDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_run_id: str | None
|
||||
root_node_id: str | None
|
||||
trigger_metadata: Any
|
||||
trigger_type: str
|
||||
trigger_data: Any
|
||||
inputs: Any
|
||||
outputs: Any
|
||||
status: str
|
||||
error: str | None
|
||||
queue_name: str
|
||||
celery_task_id: str | None
|
||||
retry_count: int
|
||||
elapsed_time: float | None
|
||||
total_tokens: int | None
|
||||
created_by_role: str
|
||||
created_by: str
|
||||
created_at: str | None
|
||||
triggered_at: str | None
|
||||
finished_at: str | None
|
||||
|
||||
|
||||
class WorkflowSchedulePlanDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
cron_expression: str
|
||||
timezone: str
|
||||
next_run_at: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class TriggerSubscription(TypeBase):
|
||||
"""
|
||||
@ -51,10 +92,14 @@ class TriggerSubscription(TypeBase):
|
||||
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||
)
|
||||
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
|
||||
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
|
||||
parameters: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription parameters JSON"
|
||||
)
|
||||
properties: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription properties JSON"
|
||||
)
|
||||
|
||||
credentials: Mapped[dict[str, Any]] = mapped_column(
|
||||
credentials: Mapped[TriggerCredentials] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription credentials JSON"
|
||||
)
|
||||
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
||||
@ -162,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase):
|
||||
)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> Mapping[str, Any]:
|
||||
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
||||
def oauth_params(self) -> Mapping[str, object]:
|
||||
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
||||
|
||||
class WorkflowTriggerLog(TypeBase):
|
||||
@ -250,7 +295,7 @@ class WorkflowTriggerLog(TypeBase):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> WorkflowTriggerLogDict:
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
"id": self.id,
|
||||
@ -481,7 +526,7 @@ class WorkflowSchedulePlan(TypeBase):
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> WorkflowSchedulePlanDict:
|
||||
"""Convert to dictionary representation"""
|
||||
return {
|
||||
"id": self.id,
|
||||
|
||||
@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy import DateTime, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import TypeBase
|
||||
@ -38,7 +38,7 @@ class SavedMessage(TypeBase):
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return db.session.query(Message).where(Message.id == self.message_id).first()
|
||||
return db.session.scalar(select(Message).where(Message.id == self.message_id))
|
||||
|
||||
|
||||
class PinnedConversation(TypeBase):
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -19,21 +20,21 @@ from sqlalchemy import (
|
||||
orm,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -59,6 +60,25 @@ from .types import EnumText, LongText, StringUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SerializedWorkflowValue = dict[str, Any]
|
||||
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
|
||||
|
||||
|
||||
class WorkflowContentDict(TypedDict):
|
||||
graph: Mapping[str, Any]
|
||||
features: dict[str, Any]
|
||||
environment_variables: list[dict[str, Any]]
|
||||
conversation_variables: list[dict[str, Any]]
|
||||
rag_pipeline_variables: list[dict[str, Any]]
|
||||
|
||||
|
||||
class WorkflowRunSummaryDict(TypedDict):
|
||||
id: str
|
||||
status: str
|
||||
triggered_from: str
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
|
||||
|
||||
def is_generation_outputs(outputs: Mapping[str, Any]) -> bool:
|
||||
if not outputs:
|
||||
@ -314,26 +334,40 @@ class Workflow(Base): # bug
|
||||
def features(self) -> str:
|
||||
"""
|
||||
Convert old features structure to new features structure.
|
||||
|
||||
This property avoids rewriting the underlying JSON when normalization
|
||||
produces no effective change, to prevent marking the row dirty on read.
|
||||
"""
|
||||
if not self._features:
|
||||
return self._features
|
||||
|
||||
features = json.loads(self._features)
|
||||
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
|
||||
image_enabled = True
|
||||
image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS))
|
||||
image_transfer_methods = features["file_upload"]["image"].get(
|
||||
"transfer_methods", ["remote_url", "local_file"]
|
||||
)
|
||||
features["file_upload"]["enabled"] = image_enabled
|
||||
features["file_upload"]["number_limits"] = image_number_limits
|
||||
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
|
||||
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
|
||||
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
|
||||
"allowed_file_extensions", []
|
||||
)
|
||||
del features["file_upload"]["image"]
|
||||
self._features = json.dumps(features)
|
||||
# Parse once and deep-copy before normalization to detect in-place changes.
|
||||
original_dict = self._decode_features_payload(self._features)
|
||||
if original_dict is None:
|
||||
return self._features
|
||||
|
||||
# Fast-path: if the legacy file_upload.image.enabled shape is absent, skip
|
||||
# deep-copy and normalization entirely and return the stored JSON.
|
||||
file_upload_payload = original_dict.get("file_upload")
|
||||
if not isinstance(file_upload_payload, dict):
|
||||
return self._features
|
||||
file_upload = cast(dict[str, Any], file_upload_payload)
|
||||
|
||||
image_payload = file_upload.get("image")
|
||||
if not isinstance(image_payload, dict):
|
||||
return self._features
|
||||
image = cast(dict[str, Any], image_payload)
|
||||
if "enabled" not in image:
|
||||
return self._features
|
||||
|
||||
normalized_dict = self._normalize_features_payload(copy.deepcopy(original_dict))
|
||||
|
||||
if normalized_dict == original_dict:
|
||||
# No effective change; return stored JSON unchanged.
|
||||
return self._features
|
||||
|
||||
# Normalization changed the payload: persist the normalized JSON.
|
||||
self._features = json.dumps(normalized_dict)
|
||||
return self._features
|
||||
|
||||
@features.setter
|
||||
@ -347,6 +381,44 @@ class Workflow(Base): # bug
|
||||
def get_feature(self, key: WorkflowFeatures) -> WorkflowFeature:
|
||||
return WorkflowFeature.from_dict(self.features_dict.get(key.value))
|
||||
|
||||
@property
|
||||
def serialized_features(self) -> str:
|
||||
"""Return the stored features JSON without triggering compatibility rewrites."""
|
||||
return self._features
|
||||
|
||||
@property
|
||||
def normalized_features_dict(self) -> dict[str, Any]:
|
||||
"""Decode features with legacy normalization without mutating the model state."""
|
||||
if not self._features:
|
||||
return {}
|
||||
|
||||
features = self._decode_features_payload(self._features)
|
||||
return self._normalize_features_payload(features) if features is not None else {}
|
||||
|
||||
@staticmethod
|
||||
def _decode_features_payload(features: str) -> dict[str, Any] | None:
|
||||
"""Decode workflow features JSON when it contains an object payload."""
|
||||
payload = json.loads(features)
|
||||
return cast(dict[str, Any], payload) if isinstance(payload, dict) else None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_features_payload(features: dict[str, Any]) -> dict[str, Any]:
|
||||
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
|
||||
image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS))
|
||||
image_transfer_methods = features["file_upload"]["image"].get(
|
||||
"transfer_methods", ["remote_url", "local_file"]
|
||||
)
|
||||
features["file_upload"]["enabled"] = True
|
||||
features["file_upload"]["number_limits"] = image_number_limits
|
||||
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
|
||||
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
|
||||
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
|
||||
"allowed_file_extensions", []
|
||||
)
|
||||
del features["file_upload"]["image"]
|
||||
|
||||
return features
|
||||
|
||||
def walk_nodes(
|
||||
self, specific_node_type: NodeType | None = None
|
||||
) -> Generator[tuple[str, Mapping[str, Any]], None, None]:
|
||||
@ -423,7 +495,7 @@ class Workflow(Base): # bug
|
||||
|
||||
def rag_pipeline_user_input_form(self) -> list:
|
||||
# get user_input_form from start node
|
||||
variables: list[Any] = self.rag_pipeline_variables
|
||||
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
|
||||
|
||||
return variables
|
||||
|
||||
@ -466,17 +538,13 @@ class Workflow(Base): # bug
|
||||
def environment_variables(
|
||||
self,
|
||||
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
|
||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
if not tenant_id:
|
||||
return []
|
||||
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
|
||||
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
|
||||
results = [
|
||||
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
|
||||
]
|
||||
@ -536,14 +604,39 @@ class Workflow(Base): # bug
|
||||
)
|
||||
self._environment_variables = environment_variables_json
|
||||
|
||||
def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]:
|
||||
@staticmethod
|
||||
def normalize_environment_variable_mappings(
|
||||
mappings: Sequence[Mapping[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert masked secret placeholders into the draft hidden sentinel.
|
||||
|
||||
Regular draft sync requests should preserve existing secrets without shipping
|
||||
plaintext values back from the client. The dedicated restore endpoint now
|
||||
copies published secrets server-side, so draft sync only needs to normalize
|
||||
the UI mask into `HIDDEN_VALUE`.
|
||||
"""
|
||||
masked_secret_value = encrypter.full_mask_token()
|
||||
normalized_mappings: list[dict[str, Any]] = []
|
||||
|
||||
for mapping in mappings:
|
||||
normalized_mapping = dict(mapping)
|
||||
if (
|
||||
normalized_mapping.get("value_type") == SegmentType.SECRET.value
|
||||
and normalized_mapping.get("value") == masked_secret_value
|
||||
):
|
||||
normalized_mapping["value"] = HIDDEN_VALUE
|
||||
normalized_mappings.append(normalized_mapping)
|
||||
|
||||
return normalized_mappings
|
||||
|
||||
def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict:
|
||||
environment_variables = list(self.environment_variables)
|
||||
environment_variables = [
|
||||
v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""})
|
||||
for v in environment_variables
|
||||
]
|
||||
|
||||
result = {
|
||||
result: WorkflowContentDict = {
|
||||
"graph": self.graph_dict,
|
||||
"features": self.features_dict,
|
||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||
@ -554,11 +647,7 @@ class Workflow(Base): # bug
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[VariableBase]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
|
||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@ -570,22 +659,29 @@ class Workflow(Base): # bug
|
||||
)
|
||||
|
||||
@property
|
||||
def rag_pipeline_variables(self) -> list[dict]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||
results = list(variables_dict.values())
|
||||
return results
|
||||
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
|
||||
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
|
||||
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: list[dict]) -> None:
|
||||
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{item["variable"]: item for item in values},
|
||||
{
|
||||
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
|
||||
for rag_pipeline_variable in (
|
||||
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
|
||||
for item in values
|
||||
)
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
def copy_serialized_variable_storage_from(self, source_workflow: "Workflow") -> None:
|
||||
"""Copy stored variable JSON directly for same-tenant restore flows."""
|
||||
self._environment_variables = source_workflow._environment_variables
|
||||
self._conversation_variables = source_workflow._conversation_variables
|
||||
self._rag_pipeline_variables = source_workflow._rag_pipeline_variables
|
||||
|
||||
@staticmethod
|
||||
def version_from_datetime(d: datetime) -> str:
|
||||
return str(d)
|
||||
@ -701,14 +797,14 @@ class WorkflowRun(Base):
|
||||
def message(self):
|
||||
from .model import Message
|
||||
|
||||
return (
|
||||
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
|
||||
return db.session.scalar(
|
||||
select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id)
|
||||
)
|
||||
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def workflow(self):
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
|
||||
@property
|
||||
def outputs_as_generation(self):
|
||||
@ -825,44 +921,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
@declared_attr.directive
|
||||
@classmethod
|
||||
def __table_args__(cls) -> Any:
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
# The first argument is the index name,
|
||||
# which we leave as `None`` to allow auto-generation by the ORM.
|
||||
None,
|
||||
cls.tenant_id,
|
||||
cls.workflow_id,
|
||||
cls.node_id,
|
||||
# MyPy may flag the following line because it doesn't recognize that
|
||||
# the `declared_attr` decorator passes the receiving class as the first
|
||||
# argument to this method, allowing us to reference class attributes.
|
||||
cls.created_at.desc(),
|
||||
),
|
||||
)
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
None,
|
||||
"tenant_id",
|
||||
"workflow_id",
|
||||
"node_id",
|
||||
sa.desc("created_at"),
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
@ -971,8 +1059,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata:
|
||||
datasource_info = execution_metadata["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata:
|
||||
trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {}
|
||||
elif (
|
||||
self.node_type == TRIGGER_PLUGIN_NODE_TYPE
|
||||
and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata
|
||||
):
|
||||
trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {}
|
||||
provider_id = trigger_info.get("provider_id")
|
||||
if provider_id:
|
||||
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
@ -1270,7 +1361,7 @@ class WorkflowArchiveLog(TypeBase):
|
||||
)
|
||||
|
||||
@property
|
||||
def workflow_run_summary(self) -> dict[str, Any]:
|
||||
def workflow_run_summary(self) -> WorkflowRunSummaryDict:
|
||||
return {
|
||||
"id": self.workflow_run_id,
|
||||
"status": self.run_status,
|
||||
@ -1325,16 +1416,17 @@ class WorkflowDraftVariable(Base):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def unique_app_id_node_id_name() -> list[str]:
|
||||
def unique_app_id_user_id_node_id_name() -> list[str]:
|
||||
return [
|
||||
"app_id",
|
||||
"user_id",
|
||||
"node_id",
|
||||
"name",
|
||||
]
|
||||
|
||||
__tablename__ = "workflow_draft_variables"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(*unique_app_id_node_id_name()),
|
||||
UniqueConstraint(*unique_app_id_user_id_node_id_name()),
|
||||
Index("workflow_draft_variable_file_id_idx", "file_id"),
|
||||
)
|
||||
# Required for instance variable annotation.
|
||||
@ -1360,6 +1452,11 @@ class WorkflowDraftVariable(Base):
|
||||
|
||||
# "`app_id` maps to the `id` field in the `model.App` model."
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# Owner of this draft variable.
|
||||
#
|
||||
# This field is nullable during migration and will be migrated to NOT NULL
|
||||
# in a follow-up release.
|
||||
user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
|
||||
# `last_edited_at` records when the value of a given draft variable
|
||||
# is edited.
|
||||
@ -1612,6 +1709,7 @@ class WorkflowDraftVariable(Base):
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
user_id: str | None,
|
||||
node_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
@ -1625,6 +1723,7 @@ class WorkflowDraftVariable(Base):
|
||||
variable.updated_at = naive_utc_now()
|
||||
variable.description = description
|
||||
variable.app_id = app_id
|
||||
variable.user_id = user_id
|
||||
variable.node_id = node_id
|
||||
variable.name = name
|
||||
variable.set_value(value)
|
||||
@ -1638,12 +1737,14 @@ class WorkflowDraftVariable(Base):
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
user_id: str | None = None,
|
||||
name: str,
|
||||
value: Segment,
|
||||
description: str = "",
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
node_id=CONVERSATION_VARIABLE_NODE_ID,
|
||||
name=name,
|
||||
value=value,
|
||||
@ -1658,6 +1759,7 @@ class WorkflowDraftVariable(Base):
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
user_id: str | None = None,
|
||||
name: str,
|
||||
value: Segment,
|
||||
node_execution_id: str,
|
||||
@ -1665,6 +1767,7 @@ class WorkflowDraftVariable(Base):
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||
name=name,
|
||||
node_execution_id=node_execution_id,
|
||||
@ -1678,6 +1781,7 @@ class WorkflowDraftVariable(Base):
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
user_id: str | None = None,
|
||||
node_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
@ -1688,6 +1792,7 @@ class WorkflowDraftVariable(Base):
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
node_execution_id=node_execution_id,
|
||||
|
||||
Reference in New Issue
Block a user