refactor: replace sa.String with EnumText in mapped_column for type s… (#33332)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
tmimmanuel
2026-03-14 04:38:27 +00:00
committed by GitHub
parent 6043ec4423
commit e64f4d6039
40 changed files with 218 additions and 138 deletions

View File

@ -8,12 +8,12 @@ from uuid import uuid4
import sqlalchemy as sa
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, validates
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated
from .base import TypeBase
from .engine import db
from .types import LongText, StringUUID
from .types import EnumText, LongText, StringUUID
class TenantAccountRole(enum.StrEnum):
@ -104,7 +104,9 @@ class Account(UserMixin, TypeBase):
last_active_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'"), default="active")
status: Mapped[AccountStatus] = mapped_column(
EnumText(AccountStatus, length=16), server_default=sa.text("'active'"), default=AccountStatus.ACTIVE
)
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
@ -116,12 +118,6 @@ class Account(UserMixin, TypeBase):
role: TenantAccountRole | None = field(default=None, init=False)
_current_tenant: "Tenant | None" = field(default=None, init=False)
@validates("status")
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
if isinstance(value, AccountStatus):
return value.value
return value
@property
def is_password_set(self):
return self.password is not None
@ -177,8 +173,7 @@ class Account(UserMixin, TypeBase):
return self.role
def get_status(self) -> AccountStatus:
status_str = self.status
return AccountStatus(status_str)
return self.status
@classmethod
def get_by_openid(cls, provider: str, open_id: str):
@ -249,7 +244,9 @@ class Tenant(TypeBase):
name: Mapped[str] = mapped_column(String(255))
encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic")
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'"), default="normal")
status: Mapped[TenantStatus] = mapped_column(
EnumText(TenantStatus, length=255), server_default=sa.text("'normal'"), default=TenantStatus.NORMAL
)
custom_config: Mapped[str | None] = mapped_column(LongText, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
@ -291,7 +288,9 @@ class TenantAccountJoin(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID)
account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal")
role: Mapped[TenantAccountRole] = mapped_column(
EnumText(TenantAccountRole, length=16), server_default="normal", default=TenantAccountRole.NORMAL
)
invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False

View File

@ -30,8 +30,9 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
from .account import Account
from .base import Base, TypeBase
from .engine import db
from .enums import CreatorUserRole
from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
logger = logging.getLogger(__name__)
@ -59,7 +60,11 @@ class Dataset(Base):
name: Mapped[str] = mapped_column(String(255))
description = mapped_column(LongText, nullable=True)
provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'"))
permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'"))
permission: Mapped[DatasetPermissionEnum] = mapped_column(
EnumText(DatasetPermissionEnum, length=255),
server_default=sa.text("'only_me'"),
default=DatasetPermissionEnum.ONLY_ME,
)
data_source_type = mapped_column(String(255))
indexing_technique: Mapped[str | None] = mapped_column(String(255))
index_struct = mapped_column(LongText, nullable=True)
@ -1003,7 +1008,7 @@ class DatasetQuery(TypeBase):
content: Mapped[str] = mapped_column(LongText, nullable=False)
source: Mapped[str] = mapped_column(String(255), nullable=False)
source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False

View File

@ -29,9 +29,9 @@ from libs.uuid_utils import uuidv7
from .account import Account, Tenant
from .base import Base, TypeBase, gen_uuidv4_string
from .engine import db
from .enums import CreatorUserRole
from .enums import CreatorUserRole, MessageStatus
from .provider_ids import GenericProviderID
from .types import LongText, StringUUID
from .types import EnumText, LongText, StringUUID
if TYPE_CHECKING:
from .workflow import Workflow
@ -337,8 +337,8 @@ class App(Base):
tenant_id: Mapped[str] = mapped_column(StringUUID)
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str] = mapped_column(LongText, default=sa.text("''"))
mode: Mapped[str] = mapped_column(String(255))
icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link
mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255))
icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255))
icon = mapped_column(String(255))
icon_background: Mapped[str | None] = mapped_column(String(255))
app_model_config_id = mapped_column(StringUUID, nullable=True)
@ -1000,7 +1000,7 @@ class Conversation(Base):
model_provider = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(LongText)
model_id = mapped_column(String(255), nullable=True)
mode: Mapped[str] = mapped_column(String(255))
mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255))
name: Mapped[str] = mapped_column(String(255), nullable=False)
summary = mapped_column(LongText)
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
@ -1351,7 +1351,12 @@ class Message(Base):
provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7))
currency: Mapped[str] = mapped_column(String(255), nullable=False)
status: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'"))
status: Mapped[MessageStatus] = mapped_column(
EnumText(MessageStatus, length=255),
nullable=False,
server_default=sa.text("'normal'"),
default=MessageStatus.NORMAL,
)
error: Mapped[str | None] = mapped_column(LongText)
message_metadata: Mapped[str | None] = mapped_column(LongText)
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
@ -1364,7 +1369,7 @@ class Message(Base):
)
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True)
app_mode: Mapped[AppMode | None] = mapped_column(EnumText(AppMode, length=255), nullable=True)
@property
def inputs(self) -> dict[str, Any]:
@ -1767,7 +1772,7 @@ class MessageFile(TypeBase):
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
@ -2015,7 +2020,7 @@ class Site(Base):
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=False)
title: Mapped[str] = mapped_column(String(255), nullable=False)
icon_type = mapped_column(String(255), nullable=True)
icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255), nullable=True)
icon = mapped_column(String(255))
icon_background = mapped_column(String(255))
description = mapped_column(LongText)
@ -2110,7 +2115,12 @@ class UploadFile(Base):
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
# Its value is derived from the `CreatorUserRole` enumeration.
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'account'"))
created_by_role: Mapped[CreatorUserRole] = mapped_column(
EnumText(CreatorUserRole, length=255),
nullable=False,
server_default=sa.text("'account'"),
default=CreatorUserRole.ACCOUNT,
)
# The `created_by` field stores the ID of the entity that created this upload file.
#
@ -2163,7 +2173,7 @@ class UploadFile(Base):
self.size = size
self.extension = extension
self.mime_type = mime_type
self.created_by_role = created_by_role.value
self.created_by_role = created_by_role
self.created_by = created_by
self.created_at = created_at
self.used = used
@ -2226,7 +2236,7 @@ class MessageAgentThought(TypeBase):
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)

View File

@ -13,7 +13,7 @@ from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .types import LongText, StringUUID
from .types import EnumText, LongText, StringUUID
class ProviderType(StrEnum):
@ -69,8 +69,8 @@ class Provider(TypeBase):
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
provider_type: Mapped[str] = mapped_column(
String(40), nullable=False, server_default=text("'custom'"), default="custom"
provider_type: Mapped[ProviderType] = mapped_column(
EnumText(ProviderType, length=40), nullable=False, server_default=text("'custom'"), default=ProviderType.CUSTOM
)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False)
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False)

View File

@ -227,7 +227,7 @@ class WorkflowTriggerLog(TypeBase):
queue_name: Mapped[str] = mapped_column(String(100), nullable=False)
celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(String(255), nullable=False)
retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)

View File

@ -2,13 +2,14 @@ from datetime import datetime
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy import DateTime, func
from sqlalchemy.orm import Mapped, mapped_column
from .base import TypeBase
from .engine import db
from .enums import CreatorUserRole
from .model import Message
from .types import StringUUID
from .types import EnumText, StringUUID
class SavedMessage(TypeBase):
@ -24,7 +25,9 @@ class SavedMessage(TypeBase):
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'"))
created_by_role: Mapped[CreatorUserRole] = mapped_column(
EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'")
)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime,
@ -50,8 +53,8 @@ class PinnedConversation(TypeBase):
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role: Mapped[str] = mapped_column(
String(255),
created_by_role: Mapped[CreatorUserRole] = mapped_column(
EnumText(CreatorUserRole, length=255),
nullable=False,
server_default=sa.text("'end_user'"),
)

View File

@ -53,7 +53,7 @@ from libs import helper
from .account import Account
from .base import Base, DefaultFieldsMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
logger = logging.getLogger(__name__)
@ -141,7 +141,7 @@ class Workflow(Base): # bug
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255), nullable=False)
version: Mapped[str] = mapped_column(String(255), nullable=False)
marked_name: Mapped[str] = mapped_column(String(255), default="", server_default="")
marked_comment: Mapped[str] = mapped_column(String(255), default="", server_default="")
@ -188,7 +188,7 @@ class Workflow(Base): # bug
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
workflow.app_id = app_id
workflow.type = type
workflow.type = WorkflowType(type)
workflow.version = version
workflow.graph = graph
workflow.features = features
@ -608,8 +608,8 @@ class WorkflowRun(Base):
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID)
type: Mapped[str] = mapped_column(String(255))
triggered_from: Mapped[str] = mapped_column(String(255))
type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255))
triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(EnumText(WorkflowRunTriggeredFrom, length=255))
version: Mapped[str] = mapped_column(String(255))
graph: Mapped[str | None] = mapped_column(LongText)
inputs: Mapped[str | None] = mapped_column(LongText)
@ -830,7 +830,9 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID)
triggered_from: Mapped[str] = mapped_column(String(255))
triggered_from: Mapped[WorkflowNodeExecutionTriggeredFrom] = mapped_column(
EnumText(WorkflowNodeExecutionTriggeredFrom, length=255)
)
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
index: Mapped[int] = mapped_column(sa.Integer)
predecessor_node_id: Mapped[str | None] = mapped_column(String(255))
@ -846,7 +848,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0"))
execution_metadata: Mapped[str | None] = mapped_column(LongText)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
created_by_role: Mapped[str] = mapped_column(String(255))
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255))
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
@ -1130,7 +1132,7 @@ class WorkflowAppLog(TypeBase):
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
@ -1204,7 +1206,7 @@ class WorkflowArchiveLog(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
@ -1213,7 +1215,9 @@ class WorkflowArchiveLog(TypeBase):
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False)
run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(
EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False
)
run_error: Mapped[str | None] = mapped_column(LongText, nullable=True)
run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))