Merge branch 'main' into feat/mcp

This commit is contained in:
Novice
2025-05-28 09:37:55 +08:00
799 changed files with 22592 additions and 6640 deletions

View File

@ -27,7 +27,7 @@ from .dataset import (
Whitelist,
)
from .engine import db
from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
from .model import (
ApiRequest,
ApiToken,
@ -114,7 +114,7 @@ __all__ = [
"CeleryTaskSet",
"Conversation",
"ConversationVariable",
"CreatedByRole",
"CreatorUserRole",
"DataSourceApiKeyAuthBinding",
"DataSourceOauthBinding",
"Dataset",

View File

@ -1,9 +1,10 @@
import enum
import json
from typing import Optional, cast
from flask_login import UserMixin # type: ignore
from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from models.base import Base
@ -11,6 +12,66 @@ from .engine import db
from .types import StringUUID
class TenantAccountRole(enum.StrEnum):
OWNER = "owner"
ADMIN = "admin"
EDITOR = "editor"
NORMAL = "normal"
DATASET_OPERATOR = "dataset_operator"
@staticmethod
def is_valid_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_privileged_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
@staticmethod
def is_admin_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role == TenantAccountRole.ADMIN
@staticmethod
def is_non_owner_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_editing_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
@staticmethod
def is_dataset_edit_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.DATASET_OPERATOR,
}
class AccountStatus(enum.StrEnum):
PENDING = "pending"
UNINITIALIZED = "uninitialized"
@ -40,54 +101,54 @@ class Account(UserMixin, Base):
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@reconstructor
def init_on_load(self):
self.role: Optional[TenantAccountRole] = None
self._current_tenant: Optional[Tenant] = None
@property
def is_password_set(self):
return self.password is not None
@property
def current_tenant(self):
# FIXME: fix the type error later, because the type is important maybe cause some bugs
return self._current_tenant # type: ignore
return self._current_tenant
@current_tenant.setter
def current_tenant(self, value: "Tenant"):
tenant = value
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=self.id).first()
def current_tenant(self, tenant: "Tenant"):
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
if ta:
tenant.current_role = ta.role
else:
tenant = None # type: ignore
self._current_tenant = tenant
self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant
return
self._current_tenant = None
@property
def current_tenant_id(self) -> str | None:
return self._current_tenant.id if self._current_tenant else None
@current_tenant_id.setter
def current_tenant_id(self, value: str):
try:
tenant_account_join = (
def set_tenant_id(self, tenant_id: str):
tenant_account_join = cast(
tuple[Tenant, TenantAccountJoin],
(
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == value)
.filter(Tenant.id == tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == self.id)
.one_or_none()
)
),
)
if tenant_account_join:
tenant, ta = tenant_account_join
tenant.current_role = ta.role
else:
tenant = None
except Exception:
tenant = None
if not tenant_account_join:
return
tenant, join = tenant_account_join
self.role = join.role
self._current_tenant = tenant
@property
def current_role(self):
return self._current_tenant.current_role
return self.role
def get_status(self) -> AccountStatus:
status_str = self.status
@ -107,23 +168,23 @@ class Account(UserMixin, Base):
# check current_user.current_tenant.current_role in ['admin', 'owner']
@property
def is_admin_or_owner(self):
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
return TenantAccountRole.is_privileged_role(self.role)
@property
def is_admin(self):
return TenantAccountRole.is_admin_role(self._current_tenant.current_role)
return TenantAccountRole.is_admin_role(self.role)
@property
def is_editor(self):
return TenantAccountRole.is_editing_role(self._current_tenant.current_role)
return TenantAccountRole.is_editing_role(self.role)
@property
def is_dataset_editor(self):
return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role)
return TenantAccountRole.is_dataset_edit_role(self.role)
@property
def is_dataset_operator(self):
return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR
return self.role == TenantAccountRole.DATASET_OPERATOR
class TenantStatus(enum.StrEnum):
@ -131,67 +192,7 @@ class TenantStatus(enum.StrEnum):
ARCHIVE = "archive"
class TenantAccountRole(enum.StrEnum):
OWNER = "owner"
ADMIN = "admin"
EDITOR = "editor"
NORMAL = "normal"
DATASET_OPERATOR = "dataset_operator"
@staticmethod
def is_valid_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_privileged_role(role: str) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
@staticmethod
def is_admin_role(role: str) -> bool:
if not role:
return False
return role == TenantAccountRole.ADMIN
@staticmethod
def is_non_owner_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}
@staticmethod
def is_editing_role(role: str) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
@staticmethod
def is_dataset_edit_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.DATASET_OPERATOR,
}
class Tenant(db.Model): # type: ignore[name-defined]
class Tenant(Base):
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
@ -220,7 +221,7 @@ class Tenant(db.Model): # type: ignore[name-defined]
self.custom_config = json.dumps(value)
class TenantAccountJoin(db.Model): # type: ignore[name-defined]
class TenantAccountJoin(Base):
__tablename__ = "tenant_account_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
@ -239,7 +240,7 @@ class TenantAccountJoin(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AccountIntegrate(db.Model): # type: ignore[name-defined]
class AccountIntegrate(Base):
__tablename__ = "account_integrates"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
@ -256,7 +257,7 @@ class AccountIntegrate(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class InvitationCode(db.Model): # type: ignore[name-defined]
class InvitationCode(Base):
__tablename__ = "invitation_codes"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),

View File

@ -2,6 +2,7 @@ import enum
from sqlalchemy import func
from .base import Base
from .engine import db
from .types import StringUUID
@ -13,7 +14,7 @@ class APIBasedExtensionPoint(enum.Enum):
APP_MODERATION_OUTPUT = "app.moderation.output"
class APIBasedExtension(db.Model): # type: ignore[name-defined]
class APIBasedExtension(Base):
__tablename__ = "api_based_extensions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),

View File

@ -1,5 +1,7 @@
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase
from models.engine import metadata
Base = declarative_base(metadata=metadata)
class Base(DeclarativeBase):
metadata = metadata

View File

@ -22,6 +22,7 @@ from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from .account import Account
from .base import Base
from .engine import db
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
@ -33,7 +34,7 @@ class DatasetPermissionEnum(enum.StrEnum):
PARTIAL_TEAM = "partial_members"
class Dataset(db.Model): # type: ignore[name-defined]
class Dataset(Base):
__tablename__ = "datasets"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
@ -92,7 +93,8 @@ class Dataset(db.Model): # type: ignore[name-defined]
@property
def latest_process_rule(self):
return (
DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.first()
)
@ -137,7 +139,8 @@ class Dataset(db.Model): # type: ignore[name-defined]
@property
def word_count(self):
return (
Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count)))
.filter(Document.dataset_id == self.id)
.scalar()
)
@ -255,7 +258,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
return f"Vector_index_{normalized_dataset_id}_Node"
class DatasetProcessRule(db.Model): # type: ignore[name-defined]
class DatasetProcessRule(Base):
__tablename__ = "dataset_process_rules"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
@ -295,7 +298,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
return None
class Document(db.Model): # type: ignore[name-defined]
class Document(Base):
__tablename__ = "documents"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pkey"),
@ -439,12 +442,13 @@ class Document(db.Model): # type: ignore[name-defined]
@property
def segment_count(self):
return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count()
@property
def hit_count(self):
return (
DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
.filter(DocumentSegment.document_id == self.id)
.scalar()
)
@ -635,7 +639,7 @@ class Document(db.Model): # type: ignore[name-defined]
)
class DocumentSegment(db.Model): # type: ignore[name-defined]
class DocumentSegment(Base):
__tablename__ = "document_segments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
@ -786,7 +790,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
return text
class ChildChunk(db.Model): # type: ignore[name-defined]
class ChildChunk(Base):
__tablename__ = "child_chunks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@ -829,7 +833,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined]
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(db.Model): # type: ignore[name-defined]
class AppDatasetJoin(Base):
__tablename__ = "app_dataset_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
@ -846,7 +850,7 @@ class AppDatasetJoin(db.Model): # type: ignore[name-defined]
return db.session.get(App, self.app_id)
class DatasetQuery(db.Model): # type: ignore[name-defined]
class DatasetQuery(Base):
__tablename__ = "dataset_queries"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
@ -863,7 +867,7 @@ class DatasetQuery(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
class DatasetKeywordTable(Base):
__tablename__ = "dataset_keyword_tables"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
@ -891,7 +895,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
return dct
# get dataset
dataset = Dataset.query.filter_by(id=self.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
if not dataset:
return None
if self.data_source_type == "database":
@ -908,7 +912,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
return None
class Embedding(db.Model): # type: ignore[name-defined]
class Embedding(Base):
__tablename__ = "embeddings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
@ -932,7 +936,7 @@ class Embedding(db.Model): # type: ignore[name-defined]
return cast(list[float], pickle.loads(self.embedding)) # noqa: S301
class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
class DatasetCollectionBinding(Base):
__tablename__ = "dataset_collection_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
@ -947,7 +951,7 @@ class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TidbAuthBinding(db.Model): # type: ignore[name-defined]
class TidbAuthBinding(Base):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@ -967,7 +971,7 @@ class TidbAuthBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Whitelist(db.Model): # type: ignore[name-defined]
class Whitelist(Base):
__tablename__ = "whitelists"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
@ -979,7 +983,7 @@ class Whitelist(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetPermission(db.Model): # type: ignore[name-defined]
class DatasetPermission(Base):
__tablename__ = "dataset_permissions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
@ -996,7 +1000,7 @@ class DatasetPermission(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
class ExternalKnowledgeApis(Base):
__tablename__ = "external_knowledge_apis"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
@ -1049,7 +1053,7 @@ class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
return dataset_bindings
class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
class ExternalKnowledgeBindings(Base):
__tablename__ = "external_knowledge_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
@ -1070,7 +1074,7 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
class DatasetAutoDisableLog(Base):
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
@ -1087,7 +1091,7 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class RateLimitLog(db.Model): # type: ignore[name-defined]
class RateLimitLog(Base):
__tablename__ = "rate_limit_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
@ -1102,7 +1106,7 @@ class RateLimitLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class DatasetMetadata(db.Model): # type: ignore[name-defined]
class DatasetMetadata(Base):
__tablename__ = "dataset_metadatas"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
@ -1121,7 +1125,7 @@ class DatasetMetadata(db.Model): # type: ignore[name-defined]
updated_by = db.Column(StringUUID, nullable=True)
class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
class DatasetMetadataBinding(Base):
__tablename__ = "dataset_metadata_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),

View File

@ -1,7 +1,7 @@
from enum import StrEnum
class CreatedByRole(StrEnum):
class CreatorUserRole(StrEnum):
ACCOUNT = "account"
END_USER = "end_user"
@ -14,3 +14,10 @@ class UserFrom(StrEnum):
class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging"
APP_RUN = "app-run"
class DraftVariableType(StrEnum):
# node means that the correspond variable
NODE = "node"
SYS = "sys"
CONVERSATION = "conversation"

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING:
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, Session, mapped_column
@ -25,13 +25,13 @@ from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers
from libs.helper import generate_string
from models.base import Base
from models.enums import CreatedByRole
from models.workflow import WorkflowRunStatus
from .account import Account, Tenant
from .base import Base
from .engine import db
from .enums import CreatorUserRole
from .types import StringUUID
from .workflow import WorkflowRunStatus
if TYPE_CHECKING:
from .workflow import Workflow
@ -298,6 +298,15 @@ class App(Base):
def mcp_server(self):
return db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.id).first()
@property
def author_name(self):
if self.created_by:
account = db.session.query(Account).filter(Account.id == self.created_by).first()
if account:
return account.name
return None
class AppModelConfig(Base):
__tablename__ = "app_model_configs"
@ -606,7 +615,7 @@ class InstalledApp(Base):
return tenant
class Conversation(db.Model): # type: ignore[name-defined]
class Conversation(Base):
__tablename__ = "conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
@ -798,7 +807,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
for message in messages:
if message.workflow_run:
status_counts[message.workflow_run.status] += 1
status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1
return (
{
@ -868,7 +877,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
}
class Message(db.Model): # type: ignore[name-defined]
class Message(Base):
__tablename__ = "messages"
__table_args__ = (
PrimaryKeyConstraint("id", name="message_pkey"),
@ -1215,7 +1224,7 @@ class Message(db.Model): # type: ignore[name-defined]
)
class MessageFeedback(db.Model): # type: ignore[name-defined]
class MessageFeedback(Base):
__tablename__ = "message_feedbacks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
@ -1241,8 +1250,23 @@ class MessageFeedback(db.Model): # type: ignore[name-defined]
account = db.session.query(Account).filter(Account.id == self.from_account_id).first()
return account
def to_dict(self):
return {
"id": str(self.id),
"app_id": str(self.app_id),
"conversation_id": str(self.conversation_id),
"message_id": str(self.message_id),
"rating": self.rating,
"content": self.content,
"from_source": self.from_source,
"from_end_user_id": str(self.from_end_user_id) if self.from_end_user_id else None,
"from_account_id": str(self.from_account_id) if self.from_account_id else None,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
class MessageFile(db.Model): # type: ignore[name-defined]
class MessageFile(Base):
__tablename__ = "message_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
@ -1259,7 +1283,7 @@ class MessageFile(db.Model): # type: ignore[name-defined]
url: str | None = None,
belongs_to: Literal["user", "assistant"] | None = None,
upload_file_id: str | None = None,
created_by_role: CreatedByRole,
created_by_role: CreatorUserRole,
created_by: str,
):
self.message_id = message_id
@ -1283,7 +1307,7 @@ class MessageFile(db.Model): # type: ignore[name-defined]
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageAnnotation(db.Model): # type: ignore[name-defined]
class MessageAnnotation(Base):
__tablename__ = "message_annotations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
@ -1314,7 +1338,7 @@ class MessageAnnotation(db.Model): # type: ignore[name-defined]
return account
class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
class AppAnnotationHitHistory(Base):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@ -1326,7 +1350,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False)
annotation_id = db.Column(StringUUID, nullable=False)
annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False)
source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
@ -1352,7 +1376,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
return account
class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
class AppAnnotationSetting(Base):
__tablename__ = "app_annotation_settings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
@ -1368,26 +1392,6 @@ class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
updated_user_id = db.Column(StringUUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def created_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property
def updated_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding
@ -1426,7 +1430,7 @@ class EndUser(Base, UserMixin):
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=True)
type = db.Column(db.String(255), nullable=False)
external_user_id = db.Column(db.String(255), nullable=True)
@ -1585,7 +1589,7 @@ class UploadFile(Base):
size: int,
extension: str,
mime_type: str,
created_by_role: CreatedByRole,
created_by_role: CreatorUserRole,
created_by: str,
created_at: datetime,
used: bool,

View File

@ -2,8 +2,7 @@ from enum import Enum
from sqlalchemy import func
from models.base import Base
from .base import Base
from .engine import db
from .types import StringUUID

View File

@ -9,7 +9,7 @@ from .engine import db
from .types import StringUUID
class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
class DataSourceOauthBinding(Base):
__tablename__ = "data_source_oauth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),

View File

@ -1,6 +1,6 @@
import json
from datetime import datetime
from typing import Any, Optional, cast
from typing import Any, cast
import sqlalchemy as sa
from deprecated import deprecated
@ -173,10 +173,6 @@ class WorkflowToolProvider(Base):
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
@property
def schema_type(self) -> ApiProviderSchemaType:
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()
@ -366,8 +362,11 @@ class DeprecatedPublishedAppTool(Base):
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# id of the app
app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
# who published this tool
description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM
@ -387,34 +386,3 @@ class DeprecatedPublishedAppTool(Base):
@property
def description_i18n(self) -> I18nObject:
return I18nObject(**json.loads(self.description))
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
file_key: Mapped[str] = db.Column(db.String(255), nullable=False)
mimetype: Mapped[str] = db.Column(db.String(255), nullable=False)
original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True)
name: Mapped[str] = mapped_column(default="")
size: Mapped[int] = mapped_column(default=-1)
def __init__(
self,
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
file_key: str,
mimetype: str,
original_url: Optional[str] = None,
name: str,
size: int,
):
self.user_id = user_id
self.tenant_id = tenant_id
self.conversation_id = conversation_id
self.file_key = file_key
self.mimetype = mimetype
self.original_url = original_url
self.name = name
self.size = size

View File

@ -1,4 +1,7 @@
from sqlalchemy import CHAR, TypeDecorator
import enum
from typing import Generic, TypeVar
from sqlalchemy import CHAR, VARCHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
@ -24,3 +27,51 @@ class StringUUID(TypeDecorator):
if value is None:
return value
return str(value)
_E = TypeVar("_E", bound=enum.StrEnum)
class EnumText(TypeDecorator, Generic[_E]):
impl = VARCHAR
cache_ok = True
_length: int
_enum_class: type[_E]
def __init__(self, enum_class: type[_E], length: int | None = None):
self._enum_class = enum_class
max_enum_value_len = max(len(e.value) for e in enum_class)
if length is not None:
if length < max_enum_value_len:
raise ValueError("length should be greater than enum value length.")
self._length = length
else:
# leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20)
def process_bind_param(self, value: _E | str | None, dialect):
if value is None:
return value
if isinstance(value, self._enum_class):
return value.value
elif isinstance(value, str):
self._enum_class(value)
return value
else:
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value, dialect) -> _E | None:
if value is None:
return value
if not isinstance(value, str):
raise TypeError(f"expected str, got {type(value)}")
return self._enum_class(value)
def compare_values(self, x, y):
if x is None or y is None:
return x is y
return x == y

View File

@ -1,29 +1,37 @@
import json
import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Self, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4
from flask_login import current_user
from core.variables import utils as variable_utils
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
if TYPE_CHECKING:
from models.model import AppMode
import sqlalchemy as sa
from sqlalchemy import Index, PrimaryKeyConstraint, func
from sqlalchemy import UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
import contexts
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
from core.variables import SecretVariable, Variable
from core.variables import SecretVariable, Segment, SegmentType, Variable
from factories import variable_factory
from libs import helper
from models.base import Base
from models.enums import CreatedByRole
from .account import Account
from .base import Base
from .engine import db
from .types import StringUUID
from .enums import CreatorUserRole, DraftVariableType
from .types import EnumText, StringUUID
_logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from models.model import AppMode
@ -143,7 +151,7 @@ class Workflow(Base):
conversation_variables: Sequence[Variable],
marked_name: str = "",
marked_comment: str = "",
) -> Self:
) -> "Workflow":
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
@ -192,7 +200,9 @@ class Workflow(Base):
features["file_upload"]["number_limits"] = image_number_limits
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
features["file_upload"]["allowed_file_extensions"] = []
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
"allowed_file_extensions", []
)
del features["file_upload"]["image"]
self._features = json.dumps(features)
return self._features
@ -265,7 +275,16 @@ class Workflow(Base):
if self._environment_variables is None:
self._environment_variables = "{}"
tenant_id = contexts.tenant_id.get()
# Get tenant_id from current_user (Account or EndUser)
if isinstance(current_user, Account):
# Account user
tenant_id = current_user.current_tenant_id
else:
# EndUser
tenant_id = current_user.tenant_id
if not tenant_id:
return []
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
results = [
@ -288,7 +307,17 @@ class Workflow(Base):
self._environment_variables = "{}"
return
tenant_id = contexts.tenant_id.get()
# Get tenant_id from current_user (Account or EndUser)
if isinstance(current_user, Account):
# Account user
tenant_id = current_user.current_tenant_id
else:
# EndUser
tenant_id = current_user.tenant_id
if not tenant_id:
self._environment_variables = "{}"
return
value = list(value)
if any(var for var in value if not var.id):
@ -418,29 +447,29 @@ class WorkflowRun(Base):
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0"))
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at = db.Column(db.DateTime)
exceptions_count = db.Column(db.Integer, server_default=db.text("0"))
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
@property
def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
@property
def graph_dict(self):
def graph_dict(self) -> Mapping[str, Any]:
return json.loads(self.graph) if self.graph else {}
@property
@ -634,24 +663,24 @@ class WorkflowNodeExecution(Base):
@property
def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role)
created_by_role = CreatorUserRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role)
created_by_role = CreatorUserRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
@property
def inputs_dict(self):
return json.loads(self.inputs) if self.inputs else None
@property
def outputs_dict(self):
def outputs_dict(self) -> dict[str, Any] | None:
return json.loads(self.outputs) if self.outputs else None
@property
@ -659,8 +688,11 @@ class WorkflowNodeExecution(Base):
return json.loads(self.process_data) if self.process_data else None
@property
def execution_metadata_dict(self):
return json.loads(self.execution_metadata) if self.execution_metadata else None
def execution_metadata_dict(self) -> dict[str, Any]:
# When the metadata is unset, we return an empty dictionary instead of `None`.
# This approach streamlines the logic for the caller, making it easier to handle
# cases where metadata is absent.
return json.loads(self.execution_metadata) if self.execution_metadata else {}
@property
def extras(self):
@ -736,19 +768,18 @@ class WorkflowAppLog(Base):
__tablename__ = "workflow_app_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id", "created_at"),
db.Index("workflow_app_log_workflow_run_idx", "workflow_run_id"),
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id = db.Column(StringUUID, nullable=False)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from = db.Column(db.String(255), nullable=False)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_from: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def workflow_run(self):
@ -756,31 +787,28 @@ class WorkflowAppLog(Base):
@property
def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
class ConversationVariable(Base):
__tablename__ = "workflow_conversation_variables"
__table_args__ = (
PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"),
Index("workflow__conversation_variables_app_id_idx", "app_id"),
Index("workflow__conversation_variables_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
data = mapped_column(db.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data: Mapped[str] = mapped_column(db.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@ -803,3 +831,201 @@ class ConversationVariable(Base):
def to_variable(self) -> Variable:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
# Only `sys.query` and `sys.files` could be modified.
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
def _naive_utc_datetime():
return datetime.now(UTC).replace(tzinfo=None)
class WorkflowDraftVariable(Base):
@staticmethod
def unique_columns() -> list[str]:
return [
"app_id",
"node_id",
"name",
]
__tablename__ = "workflow_draft_variables"
__table_args__ = (UniqueConstraint(*unique_columns()),)
# id is the unique identifier of a draft variable.
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
created_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
default=_naive_utc_datetime,
server_default=func.current_timestamp(),
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
default=_naive_utc_datetime,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
# "`app_id` maps to the `id` field in the `model.App` model."
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# `last_edited_at` records when the value of a given draft variable
# is edited.
#
# If it's not edited after creation, its value is `None`.
last_edited_at: Mapped[datetime | None] = mapped_column(
db.DateTime,
nullable=True,
default=None,
)
# The `node_id` field is special.
#
# If the variable is a conversation variable or a system variable, then the value of `node_id`
# is `conversation` or `sys`, respective.
#
# Otherwise, if the variable is a variable belonging to a specific node, the value of `_node_id` is
# the identity of correspond node in graph definition. An example of node id is `"1745769620734"`.
#
# However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other
# "Answer" node conform the rules above.)
node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id")
# From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than
# 80 chars.
#
# ref: api/core/workflow/entities/variable_pool.py:18
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(
sa.String(255),
default="",
nullable=False,
)
selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector")
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
# JSON string
value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value")
# visible
visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
def get_selector(self) -> list[str]:
selector = json.loads(self.selector)
if not isinstance(selector, list):
_logger.error(
"invalid selector loaded from database, type=%s, value=%s",
type(selector),
self.selector,
)
raise ValueError("invalid selector.")
return selector
def _set_selector(self, value: list[str]):
self.selector = json.dumps(value)
def get_value(self) -> Segment | None:
return build_segment(json.loads(self.value))
def set_name(self, name: str):
self.name = name
self._set_selector([self.node_id, name])
def set_value(self, value: Segment):
self.value = json.dumps(value.value)
self.value_type = value.value_type
def get_node_id(self) -> str | None:
if self.get_variable_type() == DraftVariableType.NODE:
return self.node_id
else:
return None
def get_variable_type(self) -> DraftVariableType:
match self.node_id:
case DraftVariableType.CONVERSATION:
return DraftVariableType.CONVERSATION
case DraftVariableType.SYS:
return DraftVariableType.SYS
case _:
return DraftVariableType.NODE
@classmethod
def _new(
cls,
*,
app_id: str,
node_id: str,
name: str,
value: Segment,
description: str = "",
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.created_at = _naive_utc_datetime()
variable.updated_at = _naive_utc_datetime()
variable.description = description
variable.app_id = app_id
variable.node_id = node_id
variable.name = name
variable.set_value(value)
variable._set_selector(list(variable_utils.to_selector(node_id, name)))
return variable
@classmethod
def new_conversation_variable(
cls,
*,
app_id: str,
name: str,
value: Segment,
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
name=name,
value=value,
)
return variable
@classmethod
def new_sys_variable(
cls,
*,
app_id: str,
name: str,
value: Segment,
editable: bool = False,
) -> "WorkflowDraftVariable":
variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value)
variable.editable = editable
return variable
@classmethod
def new_node_variable(
cls,
*,
app_id: str,
node_id: str,
name: str,
value: Segment,
visible: bool = True,
) -> "WorkflowDraftVariable":
variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value)
variable.visible = visible
variable.editable = True
return variable
@property
def edited(self):
return self.last_edited_at is not None
def is_system_variable_editable(name: str) -> bool:
return name in _EDITABLE_SYSTEM_VARIABLE