Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -26,7 +26,6 @@ from .dataset import (
TidbAuthBinding,
Whitelist,
)
from .engine import db
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
from .model import (
ApiRequest,
@ -57,6 +56,7 @@ from .model import (
TraceAppConfig,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
from .provider import (
LoadBalancingModelConfig,
Provider,
@ -90,6 +90,7 @@ from .workflow import (
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowSchedulePlan,
@ -131,6 +132,8 @@ __all__ = [
"DatasetProcessRule",
"DatasetQuery",
"DatasetRetrieverResource",
"DatasourceOauthParamConfig",
"DatasourceProvider",
"DifySetup",
"Document",
"DocumentSegment",
@ -183,11 +186,11 @@ __all__ = [
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionOffload",
"WorkflowNodeExecutionTriggeredFrom",
"WorkflowRun",
"WorkflowRunTriggeredFrom",
"WorkflowSchedulePlan",
"WorkflowToolProvider",
"WorkflowType",
"db",
]

View File

@ -1,12 +1,13 @@
import enum
import json
from datetime import datetime
from typing import Optional
from typing import Any, Optional
import sqlalchemy as sa
from flask_login import UserMixin
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
from typing_extensions import deprecated
from models.base import Base
@ -89,24 +90,24 @@ class Account(UserMixin, Base):
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(String(255))
email: Mapped[str] = mapped_column(String(255))
password: Mapped[Optional[str]] = mapped_column(String(255))
password_salt: Mapped[Optional[str]] = mapped_column(String(255))
avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
interface_language: Mapped[Optional[str]] = mapped_column(String(255))
interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
timezone: Mapped[Optional[str]] = mapped_column(String(255))
last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
password: Mapped[str | None] = mapped_column(String(255))
password_salt: Mapped[str | None] = mapped_column(String(255))
avatar: Mapped[str | None] = mapped_column(String(255), nullable=True)
interface_language: Mapped[str | None] = mapped_column(String(255))
interface_theme: Mapped[str | None] = mapped_column(String(255), nullable=True)
timezone: Mapped[str | None] = mapped_column(String(255))
last_login_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
last_login_ip: Mapped[str | None] = mapped_column(String(255), nullable=True)
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
@reconstructor
def init_on_load(self):
self.role: Optional[TenantAccountRole] = None
self._current_tenant: Optional[Tenant] = None
self.role: TenantAccountRole | None = None
self._current_tenant: Tenant | None = None
@property
def is_password_set(self):
@ -187,7 +188,28 @@ class Account(UserMixin, Base):
return TenantAccountRole.is_admin_role(self.role)
@property
@deprecated("Use has_edit_permission instead.")
def is_editor(self):
"""Determines if the account has edit permissions in their current tenant (workspace).
This property checks if the current role has editing privileges, which includes:
- `OWNER`
- `ADMIN`
- `EDITOR`
Note: This checks for any role with editing permission, not just the 'EDITOR' role specifically.
"""
return self.has_edit_permission
@property
def has_edit_permission(self):
"""Determines if the account has editing permissions in their current tenant (workspace).
This property checks if the current role has editing privileges, which includes:
- `OWNER`
- `ADMIN`
- `EDITOR`
"""
return TenantAccountRole.is_editing_role(self.role)
@property
@ -210,26 +232,28 @@ class Tenant(Base):
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(String(255))
encrypt_public_key: Mapped[Optional[str]] = mapped_column(sa.Text)
encrypt_public_key: Mapped[str | None] = mapped_column(sa.Text)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
custom_config: Mapped[Optional[str]] = mapped_column(sa.Text)
custom_config: Mapped[str | None] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
def get_accounts(self) -> list[Account]:
return (
db.session.query(Account)
.where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
.all()
return list(
db.session.scalars(
select(Account).where(
Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id
)
).all()
)
@property
def custom_config_dict(self) -> dict:
def custom_config_dict(self) -> dict[str, Any]:
return json.loads(self.custom_config) if self.custom_config else {}
@custom_config_dict.setter
def custom_config_dict(self, value: dict):
def custom_config_dict(self, value: dict[str, Any]) -> None:
self.custom_config = json.dumps(value)
@ -247,7 +271,7 @@ class TenantAccountJoin(Base):
account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
role: Mapped[str] = mapped_column(String(16), server_default="normal")
invited_by: Mapped[Optional[str]] = mapped_column(StringUUID)
invited_by: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
@ -281,10 +305,10 @@ class InvitationCode(Base):
batch: Mapped[str] = mapped_column(String(255))
code: Mapped[str] = mapped_column(String(32))
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
used_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID)
used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
used_at: Mapped[datetime | None] = mapped_column(DateTime)
used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID)
used_by_account_id: Mapped[str | None] = mapped_column(StringUUID)
deprecated_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))

View File

@ -1,7 +1,15 @@
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass
from models.engine import metadata
class Base(DeclarativeBase):
metadata = metadata
class TypeBase(MappedAsDataclass, DeclarativeBase):
"""
This is for adding type, after all finished, rename to Base.
"""
metadata = metadata

View File

@ -10,12 +10,12 @@ import re
import time
from datetime import datetime
from json import JSONDecodeError
from typing import Any, Optional, cast
from typing import Any, cast
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
@ -49,24 +49,47 @@ class Dataset(Base):
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
PROVIDER_LIST = ["vendor", "external", None]
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
name: Mapped[str] = mapped_column(String(255))
description = mapped_column(sa.Text, nullable=True)
provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying"))
permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying"))
data_source_type = mapped_column(String(255))
indexing_technique: Mapped[Optional[str]] = mapped_column(String(255))
indexing_technique: Mapped[str | None] = mapped_column(String(255))
index_struct = mapped_column(sa.Text, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = mapped_column(String(255), nullable=True)
embedding_model_provider = mapped_column(String(255), nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = mapped_column(db.String(255), nullable=True)
embedding_model_provider = mapped_column(db.String(255), nullable=True)
keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10"))
collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(JSONB, nullable=True)
built_in_field_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
icon_info = db.Column(JSONB, nullable=True)
runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
pipeline_id = db.Column(StringUUID, nullable=True)
chunk_structure = db.Column(db.String(255), nullable=True)
enable_api = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
@property
def total_documents(self):
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
@property
def total_available_documents(self):
return (
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
.scalar()
)
@property
def dataset_keyword_table(self):
@ -150,7 +173,9 @@ class Dataset(Base):
)
@property
def doc_form(self):
def doc_form(self) -> str | None:
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
if document:
return document.doc_form
@ -206,9 +231,19 @@ class Dataset(Base):
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
}
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
if pipeline:
return pipeline.is_published
return False
@property
def doc_metadata(self):
dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all()
dataset_metadatas = db.session.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id)
).all()
doc_metadata = [
{
@ -222,35 +257,35 @@ class Dataset(Base):
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.document_name.value,
"name": BuiltInField.document_name,
"type": "string",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.uploader.value,
"name": BuiltInField.uploader,
"type": "string",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.upload_date.value,
"name": BuiltInField.upload_date,
"type": "time",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.last_update_date.value,
"name": BuiltInField.last_update_date,
"type": "time",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.source.value,
"name": BuiltInField.source,
"type": "string",
}
)
@ -286,7 +321,7 @@ class DatasetProcessRule(Base):
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
}
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"dataset_id": self.dataset_id,
@ -295,7 +330,7 @@ class DatasetProcessRule(Base):
}
@property
def rules_dict(self):
def rules_dict(self) -> dict[str, Any] | None:
try:
return json.loads(self.rules) if self.rules else None
except JSONDecodeError:
@ -328,42 +363,42 @@ class Document(Base):
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
# start processing
processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
processing_started_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# parsing
file_id = mapped_column(sa.Text, nullable=True)
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable
parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable
parsing_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# cleaning
cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
cleaning_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# split
splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
splitting_completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# indexing
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
indexing_latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# pause
is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
is_paused: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
paused_by = mapped_column(StringUUID, nullable=True)
paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
paused_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# error
error = mapped_column(sa.Text, nullable=True)
stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
# basic fields
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying"))
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
archived_reason = mapped_column(String(255), nullable=True)
archived_by = mapped_column(StringUUID, nullable=True)
archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
doc_type = mapped_column(String(40), nullable=True)
doc_metadata = mapped_column(JSONB, nullable=True)
@ -392,21 +427,21 @@ class Document(Base):
return status
@property
def data_source_info_dict(self):
def data_source_info_dict(self) -> dict[str, Any]:
if self.data_source_info:
try:
data_source_info_dict = json.loads(self.data_source_info)
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
except JSONDecodeError:
data_source_info_dict = {}
return data_source_info_dict
return None
return {}
@property
def data_source_detail_dict(self):
def data_source_detail_dict(self) -> dict[str, Any]:
if self.data_source_info:
if self.data_source_type == "upload_file":
data_source_info_dict = json.loads(self.data_source_info)
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
@ -425,7 +460,8 @@ class Document(Base):
}
}
elif self.data_source_type in {"notion_import", "website_crawl"}:
return json.loads(self.data_source_info)
result: dict[str, Any] = json.loads(self.data_source_info)
return result
return {}
@property
@ -471,7 +507,7 @@ class Document(Base):
return self.updated_at
@property
def doc_metadata_details(self):
def doc_metadata_details(self) -> list[dict[str, Any]] | None:
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
@ -481,9 +517,9 @@ class Document(Base):
)
.all()
)
metadata_list = []
metadata_list: list[dict[str, Any]] = []
for metadata in document_metadatas:
metadata_dict = {
metadata_dict: dict[str, Any] = {
"id": metadata.id,
"name": metadata.name,
"type": metadata.type,
@ -497,13 +533,13 @@ class Document(Base):
return None
@property
def process_rule_dict(self):
if self.dataset_process_rule_id:
def process_rule_dict(self) -> dict[str, Any] | None:
if self.dataset_process_rule_id and self.dataset_process_rule:
return self.dataset_process_rule.to_dict()
return None
def get_built_in_fields(self):
built_in_fields = []
def get_built_in_fields(self) -> list[dict[str, Any]]:
built_in_fields: list[dict[str, Any]] = []
built_in_fields.append(
{
"id": "built-in",
@ -541,12 +577,12 @@ class Document(Base):
"id": "built-in",
"name": BuiltInField.source,
"type": "string",
"value": MetadataDataSource[self.data_source_type].value,
"value": MetadataDataSource[self.data_source_type],
}
)
return built_in_fields
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"tenant_id": self.tenant_id,
@ -592,13 +628,13 @@ class Document(Base):
"data_source_info_dict": self.data_source_info_dict,
"average_segment_length": self.average_segment_length,
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
"dataset": self.dataset.to_dict() if self.dataset else None,
"dataset": None, # Dataset class doesn't have a to_dict method
"segment_count": self.segment_count,
"hit_count": self.hit_count,
}
@classmethod
def from_dict(cls, data: dict):
def from_dict(cls, data: dict[str, Any]):
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@ -674,17 +710,17 @@ class DocumentSegment(Base):
# basic fields
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying"))
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(sa.Text, nullable=True)
stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
@property
def dataset(self):
@ -711,50 +747,52 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
def child_chunks(self) -> list[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []
def get_child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
def get_child_chunks(self) -> list[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []
@property
def sign_content(self):
def sign_content(self) -> str:
return self.get_sign_content()
def get_sign_content(self):
signed_urls = []
def get_sign_content(self) -> str:
signed_urls: list[tuple[int, int, str]] = []
text = self.content
# For data before v0.10.0
pattern = r"/files/([a-f0-9\-]+)/image-preview"
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
matches = re.finditer(pattern, text)
for match in matches:
upload_file_id = match.group(1)
@ -766,11 +804,12 @@ class DocumentSegment(Base):
encoded_sign = base64.urlsafe_b64encode(sign).decode()
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
signed_url = f"{match.group(0)}?{params}"
base_url = f"/files/{upload_file_id}/image-preview"
signed_url = f"{base_url}?{params}"
signed_urls.append((match.start(), match.end(), signed_url))
# For data after v0.10.0
pattern = r"/files/([a-f0-9\-]+)/file-preview"
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
matches = re.finditer(pattern, text)
for match in matches:
upload_file_id = match.group(1)
@ -782,7 +821,27 @@ class DocumentSegment(Base):
encoded_sign = base64.urlsafe_b64encode(sign).decode()
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
signed_url = f"{match.group(0)}?{params}"
base_url = f"/files/{upload_file_id}/file-preview"
signed_url = f"{base_url}?{params}"
signed_urls.append((match.start(), match.end(), signed_url))
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
# Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes)
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
matches = re.finditer(pattern, text)
for match in matches:
upload_file_id = match.group(1)
file_extension = match.group(2)
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
base_url = f"/files/tools/{upload_file_id}.{file_extension}"
signed_url = f"{base_url}?{params}"
signed_urls.append((match.start(), match.end(), signed_url))
# Reconstruct the text with signed URLs
@ -824,8 +883,8 @@ class ChildChunk(Base):
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(sa.Text, nullable=True)
@property
@ -890,17 +949,22 @@ class DatasetKeywordTable(Base):
)
@property
def keyword_table_dict(self):
def keyword_table_dict(self) -> dict[str, set[Any]] | None:
class SetDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)
def __init__(self, *args: Any, **kwargs: Any) -> None:
def object_hook(dct: Any) -> Any:
if isinstance(dct, dict):
result: dict[str, Any] = {}
items = cast(dict[str, Any], dct).items()
for keyword, node_idxs in items:
if isinstance(node_idxs, list):
result[keyword] = set(cast(list[Any], node_idxs))
else:
result[keyword] = node_idxs
return result
return dct
def object_hook(self, dct):
if isinstance(dct, dict):
for keyword, node_idxs in dct.items():
if isinstance(node_idxs, list):
dct[keyword] = set(node_idxs)
return dct
super().__init__(object_hook=object_hook, *args, **kwargs)
# get dataset
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
@ -915,7 +979,7 @@ class DatasetKeywordTable(Base):
if keyword_table_text:
return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder)
return None
except Exception as e:
except Exception:
logger.exception("Failed to load keyword table from file: %s", file_key)
return None
@ -1026,7 +1090,7 @@ class ExternalKnowledgeApis(Base):
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"tenant_id": self.tenant_id,
@ -1039,22 +1103,20 @@ class ExternalKnowledgeApis(Base):
}
@property
def settings_dict(self):
def settings_dict(self) -> dict[str, Any] | None:
try:
return json.loads(self.settings) if self.settings else None
except JSONDecodeError:
return None
@property
def dataset_bindings(self):
external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
.all()
)
def dataset_bindings(self) -> list[dict[str, Any]]:
external_knowledge_bindings = db.session.scalars(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
).all()
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
dataset_bindings = []
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
dataset_bindings: list[dict[str, Any]] = []
for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name})
@ -1158,3 +1220,112 @@ class DatasetMetadataBinding(Base):
document_id = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_by = mapped_column(StringUUID, nullable=False)
class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_built_in_templates"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False)
chunk_structure = db.Column(db.String(255), nullable=False)
icon = db.Column(db.JSON, nullable=False)
yaml_content = db.Column(db.Text, nullable=False)
copyright = db.Column(db.String(255), nullable=False)
privacy_policy = db.Column(db.String(255), nullable=False)
position = db.Column(db.Integer, nullable=False)
install_count = db.Column(db.Integer, nullable=False, default=0)
language = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = db.Column(StringUUID, nullable=False)
updated_by = db.Column(StringUUID, nullable=True)
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_customized_templates"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
db.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False)
chunk_structure = db.Column(db.String(255), nullable=False)
icon = db.Column(db.JSON, nullable=False)
position = db.Column(db.Integer, nullable=False)
yaml_content = db.Column(db.Text, nullable=False)
install_count = db.Column(db.Integer, nullable=False, default=0)
language = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
updated_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def created_user_name(self):
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
class Pipeline(Base): # type: ignore[name-defined]
__tablename__ = "pipelines"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
workflow_id = db.Column(StringUUID, nullable=True)
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def retrieve_dataset(self, session: Session):
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
class DocumentPipelineExecutionLog(Base):
__tablename__ = "document_pipeline_execution_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
pipeline_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
datasource_type = db.Column(db.String(255), nullable=False)
datasource_info = db.Column(db.Text, nullable=False)
datasource_node_id = db.Column(db.String(255), nullable=False)
input_data = db.Column(db.JSON, nullable=False)
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class PipelineRecommendedPlugin(Base):
__tablename__ = "pipeline_recommended_plugins"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
plugin_id = db.Column(db.Text, nullable=False)
provider_name = db.Column(db.Text, nullable=False)
position = db.Column(db.Integer, nullable=False, default=0)
active = db.Column(db.Boolean, nullable=False, default=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -14,6 +14,8 @@ class UserFrom(StrEnum):
class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging"
APP_RUN = "app-run" # webapp / service api
RAG_PIPELINE_RUN = "rag-pipeline-run"
RAG_PIPELINE_DEBUGGING = "rag-pipeline-debugging"
WEBHOOK = "webhook"
SCHEDULE = "schedule"
PLUGIN = "plugin"
@ -33,3 +35,9 @@ class MessageStatus(StrEnum):
NORMAL = "normal"
ERROR = "error"
class ExecutionOffLoadType(StrEnum):
INPUTS = "inputs"
PROCESS_DATA = "process_data"
OUTPUTS = "outputs"

View File

@ -3,20 +3,12 @@ import re
import uuid
from collections.abc import Mapping
from datetime import datetime
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from core.plugin.entities.plugin import GenericProviderID
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.signature import sign_tool_file
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
if TYPE_CHECKING:
from models.workflow import Workflow
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
@ -24,14 +16,20 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers
from libs.helper import generate_string
from core.tools.signature import sign_tool_file
from core.workflow.enums import WorkflowExecutionStatus
from libs.helper import generate_string # type: ignore[import-not-found]
from .account import Account, Tenant
from .base import Base
from .engine import db
from .enums import CreatorUserRole
from .provider_ids import GenericProviderID
from .types import StringUUID
if TYPE_CHECKING:
from models.workflow import Workflow
class DifySetup(Base):
__tablename__ = "dify_setups"
@ -47,6 +45,8 @@ class AppMode(StrEnum):
CHAT = "chat"
ADVANCED_CHAT = "advanced-chat"
AGENT_CHAT = "agent-chat"
CHANNEL = "channel"
RAG_PIPELINE = "rag-pipeline"
@classmethod
def value_of(cls, value: str) -> "AppMode":
@ -62,9 +62,9 @@ class AppMode(StrEnum):
raise ValueError(f"invalid mode value {value}")
class IconType(Enum):
IMAGE = "image"
EMOJI = "emoji"
class IconType(StrEnum):
IMAGE = auto()
EMOJI = auto()
class App(Base):
@ -76,9 +76,9 @@ class App(Base):
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying"))
mode: Mapped[str] = mapped_column(String(255))
icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji
icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji
icon = mapped_column(String(255))
icon_background: Mapped[Optional[str]] = mapped_column(String(255))
icon_background: Mapped[str | None] = mapped_column(String(255))
app_model_config_id = mapped_column(StringUUID, nullable=True)
workflow_id = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
@ -90,7 +90,7 @@ class App(Base):
is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
tracing = mapped_column(sa.Text, nullable=True)
max_active_requests: Mapped[Optional[int]]
max_active_requests: Mapped[int | None]
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -98,7 +98,7 @@ class App(Base):
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
def desc_or_prompt(self):
def desc_or_prompt(self) -> str:
if self.description:
return self.description
else:
@ -109,12 +109,12 @@ class App(Base):
return ""
@property
def site(self):
def site(self) -> Optional["Site"]:
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
@property
def app_model_config(self):
def app_model_config(self) -> Optional["AppModelConfig"]:
if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
@ -130,11 +130,11 @@ class App(Base):
return None
@property
def api_base_url(self):
def api_base_url(self) -> str:
return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"
@property
def tenant(self):
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
@ -149,21 +149,21 @@ class App(Base):
if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get(
"strategy", ""
) in {"function_call", "react"}:
self.mode = AppMode.AGENT_CHAT.value
self.mode = AppMode.AGENT_CHAT
db.session.commit()
return True
return False
@property
def mode_compatible_with_agent(self) -> str:
if self.mode == AppMode.CHAT.value and self.is_agent:
return AppMode.AGENT_CHAT.value
if self.mode == AppMode.CHAT and self.is_agent:
return AppMode.AGENT_CHAT
return str(self.mode)
@property
def deleted_tools(self) -> list:
from core.tools.tool_manager import ToolManager
def deleted_tools(self) -> list[dict[str, str]]:
from core.tools.tool_manager import ToolManager, ToolProviderType
from services.plugin.plugin_service import PluginService
# get agent mode tools
@ -178,6 +178,7 @@ class App(Base):
tools = agent_mode.get("tools", [])
api_provider_ids: list[str] = []
builtin_provider_ids: list[GenericProviderID] = []
for tool in tools:
@ -242,7 +243,7 @@ class App(Base):
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
}
deleted_tools = []
deleted_tools: list[dict[str, str]] = []
for tool in tools:
keys = list(tool.keys())
@ -275,7 +276,7 @@ class App(Base):
return deleted_tools
@property
def tags(self):
def tags(self) -> list["Tag"]:
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
@ -291,7 +292,7 @@ class App(Base):
return tags or []
@property
def author_name(self):
def author_name(self) -> str | None:
if self.created_by:
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
@ -334,20 +335,20 @@ class AppModelConfig(Base):
file_upload = mapped_column(sa.Text)
@property
def app(self):
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def model_dict(self) -> dict:
def model_dict(self) -> dict[str, Any]:
return json.loads(self.model) if self.model else {}
@property
def suggested_questions_list(self) -> list:
def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else []
@property
def suggested_questions_after_answer_dict(self) -> dict:
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
return (
json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer
@ -355,19 +356,19 @@ class AppModelConfig(Base):
)
@property
def speech_to_text_dict(self) -> dict:
def speech_to_text_dict(self) -> dict[str, Any]:
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
@property
def text_to_speech_dict(self) -> dict:
def text_to_speech_dict(self) -> dict[str, Any]:
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
@property
def retriever_resource_dict(self) -> dict:
def retriever_resource_dict(self) -> dict[str, Any]:
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
@property
def annotation_reply_dict(self) -> dict:
def annotation_reply_dict(self) -> dict[str, Any]:
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
)
@ -390,11 +391,11 @@ class AppModelConfig(Base):
return {"enabled": False}
@property
def more_like_this_dict(self) -> dict:
def more_like_this_dict(self) -> dict[str, Any]:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@property
def sensitive_word_avoidance_dict(self) -> dict:
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
return (
json.loads(self.sensitive_word_avoidance)
if self.sensitive_word_avoidance
@ -402,15 +403,15 @@ class AppModelConfig(Base):
)
@property
def external_data_tools_list(self) -> list[dict]:
def external_data_tools_list(self) -> list[dict[str, Any]]:
return json.loads(self.external_data_tools) if self.external_data_tools else []
@property
def user_input_form_list(self):
def user_input_form_list(self) -> list[dict[str, Any]]:
return json.loads(self.user_input_form) if self.user_input_form else []
@property
def agent_mode_dict(self) -> dict:
def agent_mode_dict(self) -> dict[str, Any]:
return (
json.loads(self.agent_mode)
if self.agent_mode
@ -418,17 +419,17 @@ class AppModelConfig(Base):
)
@property
def chat_prompt_config_dict(self) -> dict:
def chat_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
@property
def completion_prompt_config_dict(self) -> dict:
def completion_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
@property
def dataset_configs_dict(self) -> dict:
def dataset_configs_dict(self) -> dict[str, Any]:
if self.dataset_configs:
dataset_configs: dict = json.loads(self.dataset_configs)
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"}
else:
@ -438,7 +439,7 @@ class AppModelConfig(Base):
}
@property
def file_upload_dict(self) -> dict:
def file_upload_dict(self) -> dict[str, Any]:
return (
json.loads(self.file_upload)
if self.file_upload
@ -452,7 +453,7 @@ class AppModelConfig(Base):
}
)
def to_dict(self) -> dict:
def to_dict(self) -> dict[str, Any]:
return {
"opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list,
@ -546,7 +547,7 @@ class RecommendedApp(Base):
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@ -570,12 +571,12 @@ class InstalledApp(Base):
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def tenant(self):
def tenant(self) -> Tenant | None:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
@ -622,7 +623,7 @@ class Conversation(Base):
mode: Mapped[str] = mapped_column(String(255))
name: Mapped[str] = mapped_column(String(255), nullable=False)
summary = mapped_column(sa.Text)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
introduction = mapped_column(sa.Text)
system_instruction = mapped_column(sa.Text)
system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
@ -652,7 +653,7 @@ class Conversation(Base):
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
def inputs(self):
def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy()
# Convert file mapping to File object
@ -660,22 +661,39 @@ class Conversation(Base):
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
value["tool_file_id"] = value["related_id"]
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
):
inputs[key] = []
for item in value:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
item["tool_file_id"] = item["related_id"]
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
item["upload_file_id"] = item["related_id"]
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
value_dict = cast(dict[str, Any], value)
if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
value_dict["tool_file_id"] = value_dict["related_id"]
elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value_dict["upload_file_id"] = value_dict["related_id"]
tenant_id = cast(str, value_dict.get("tenant_id", ""))
inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
elif isinstance(value, list):
value_list = cast(list[Any], value)
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
item_dict["tool_file_id"] = item_dict["related_id"]
elif item_dict["transfer_method"] in [
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
]:
item_dict["upload_file_id"] = item_dict["related_id"]
tenant_id = cast(str, item_dict.get("tenant_id", ""))
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
inputs[key] = file_list
return inputs
@ -685,16 +703,18 @@ class Conversation(Base):
for k, v in inputs.items():
if isinstance(v, File):
inputs[k] = v.model_dump()
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
inputs[k] = [item.model_dump() for item in v]
elif isinstance(v, list):
v_list = cast(list[Any], v)
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
self._inputs = inputs
@property
def model_config(self):
model_config = {}
app_model_config: Optional[AppModelConfig] = None
app_model_config: AppModelConfig | None = None
if self.mode == AppMode.ADVANCED_CHAT.value:
if self.mode == AppMode.ADVANCED_CHAT:
if self.override_model_configs:
override_model_configs = json.loads(self.override_model_configs)
model_config = override_model_configs
@ -793,7 +813,7 @@ class Conversation(Base):
@property
def status_count(self):
messages = db.session.query(Message).where(Message.conversation_id == self.id).all()
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
status_counts = {
WorkflowExecutionStatus.RUNNING: 0,
WorkflowExecutionStatus.SUCCEEDED: 0,
@ -826,8 +846,9 @@ class Conversation(Base):
)
@property
def app(self):
return db.session.query(App).where(App.id == self.app_id).first()
def app(self) -> App | None:
with Session(db.engine, expire_on_commit=False) as session:
return session.query(App).where(App.id == self.app_id).first()
@property
def from_end_user_session_id(self):
@ -839,7 +860,7 @@ class Conversation(Base):
return None
@property
def from_account_name(self):
def from_account_name(self) -> str | None:
if self.from_account_id:
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
if account:
@ -848,10 +869,10 @@ class Conversation(Base):
return None
@property
def in_debug_mode(self):
def in_debug_mode(self) -> bool:
return self.override_model_configs is not None
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"app_id": self.app_id,
@ -897,7 +918,7 @@ class Message(Base):
model_id = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(sa.Text)
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
query: Mapped[str] = mapped_column(sa.Text, nullable=False)
message = mapped_column(sa.JSON, nullable=False)
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
@ -914,38 +935,55 @@ class Message(Base):
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
error = mapped_column(sa.Text)
message_metadata = mapped_column(sa.Text)
invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID)
from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
@property
def inputs(self):
def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy()
for key, value in inputs.items():
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
value["tool_file_id"] = value["related_id"]
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
):
inputs[key] = []
for item in value:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
item["tool_file_id"] = item["related_id"]
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
item["upload_file_id"] = item["related_id"]
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
value_dict = cast(dict[str, Any], value)
if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
value_dict["tool_file_id"] = value_dict["related_id"]
elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value_dict["upload_file_id"] = value_dict["related_id"]
tenant_id = cast(str, value_dict.get("tenant_id", ""))
inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
elif isinstance(value, list):
value_list = cast(list[Any], value)
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
item_dict["tool_file_id"] = item_dict["related_id"]
elif item_dict["transfer_method"] in [
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
]:
item_dict["upload_file_id"] = item_dict["related_id"]
tenant_id = cast(str, item_dict.get("tenant_id", ""))
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
inputs[key] = file_list
return inputs
@inputs.setter
@ -954,8 +992,10 @@ class Message(Base):
for k, v in inputs.items():
if isinstance(v, File):
inputs[k] = v.model_dump()
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
inputs[k] = [item.model_dump() for item in v]
elif isinstance(v, list):
v_list = cast(list[Any], v)
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
self._inputs = inputs
@property
@ -1004,7 +1044,7 @@ class Message(Base):
sign_url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
elif "file-preview" in url:
# get upload file id
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview\?timestamp="
result = re.search(upload_file_id_pattern, url)
if not result:
continue
@ -1015,7 +1055,7 @@ class Message(Base):
sign_url = file_helpers.get_signed_file_url(upload_file_id)
elif "image-preview" in url:
# image-preview is deprecated, use file-preview instead
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview\?timestamp="
result = re.search(upload_file_id_pattern, url)
if not result:
continue
@ -1052,7 +1092,7 @@ class Message(Base):
@property
def feedbacks(self):
feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all()
feedbacks = db.session.scalars(select(MessageFeedback).where(MessageFeedback.message_id == self.id)).all()
return feedbacks
@property
@ -1083,15 +1123,15 @@ class Message(Base):
return None
@property
def in_debug_mode(self):
def in_debug_mode(self) -> bool:
return self.override_model_configs is not None
@property
def message_metadata_dict(self) -> dict:
def message_metadata_dict(self) -> dict[str, Any]:
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
def agent_thoughts(self):
def agent_thoughts(self) -> list["MessageAgentThought"]:
return (
db.session.query(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
@ -1100,19 +1140,19 @@ class Message(Base):
)
@property
def retriever_resources(self):
def retriever_resources(self) -> Any:
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
@property
def message_files(self):
def message_files(self) -> list[dict[str, Any]]:
from factories import file_factory
message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
current_app = db.session.query(App).where(App.id == self.app_id).first()
if not current_app:
raise ValueError(f"App {self.app_id} not found")
files = []
files: list[File] = []
for message_file in message_files:
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value:
if message_file.upload_file_id is None:
@ -1159,7 +1199,7 @@ class Message(Base):
)
files.append(file)
result = [
result: list[dict[str, Any]] = [
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files)
]
@ -1176,7 +1216,7 @@ class Message(Base):
return None
def to_dict(self) -> dict:
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"app_id": self.app_id,
@ -1200,7 +1240,7 @@ class Message(Base):
}
@classmethod
def from_dict(cls, data: dict):
def from_dict(cls, data: dict[str, Any]) -> "Message":
return cls(
id=data["id"],
app_id=data["app_id"],
@ -1250,7 +1290,7 @@ class MessageFeedback(Base):
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": str(self.id),
"app_id": str(self.app_id),
@ -1299,9 +1339,9 @@ class MessageFile(Base):
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
transfer_method: Mapped[str] = mapped_column(String(255), nullable=False)
url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
url: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
belongs_to: Mapped[str | None] = mapped_column(String(255), nullable=True)
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@ -1318,8 +1358,8 @@ class MessageAnnotation(Base):
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id: Mapped[str] = mapped_column(StringUUID)
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
message_id: Mapped[Optional[str]] = mapped_column(StringUUID)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
message_id: Mapped[str | None] = mapped_column(StringUUID)
question = mapped_column(sa.Text, nullable=True)
content = mapped_column(sa.Text, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
@ -1421,6 +1461,14 @@ class OperationLog(Base):
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class DefaultEndUserSessionID(StrEnum):
"""
End User Session ID enum.
"""
DEFAULT_SESSION_ID = "DEFAULT-USER"
class EndUser(Base, UserMixin):
__tablename__ = "end_users"
__table_args__ = (
@ -1435,7 +1483,18 @@ class EndUser(Base, UserMixin):
type: Mapped[str] = mapped_column(String(255), nullable=False)
external_user_id = mapped_column(String(255), nullable=True)
name = mapped_column(String(255))
is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
_is_anonymous: Mapped[bool] = mapped_column(
"is_anonymous", sa.Boolean, nullable=False, server_default=sa.text("true")
)
@property
def is_anonymous(self) -> Literal[False]:
return False
@is_anonymous.setter
def is_anonymous(self, value: bool) -> None:
self._is_anonymous = value
session_id: Mapped[str] = mapped_column()
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@ -1461,7 +1520,7 @@ class AppMCPServer(Base):
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_server_code(n):
def generate_server_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
@ -1518,7 +1577,7 @@ class Site(Base):
self._custom_disclaimer = value
@staticmethod
def generate_code(n):
def generate_code(n: int) -> str:
while True:
result = generate_string(n)
while db.session.query(Site).where(Site.code == result).count() > 0:
@ -1549,7 +1608,7 @@ class ApiToken(Base):
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_api_key(prefix, n):
def generate_api_key(prefix: str, n: int) -> str:
while True:
result = prefix + generate_string(n)
if db.session.scalar(select(exists().where(ApiToken.token == result))):
@ -1564,6 +1623,9 @@ class UploadFile(Base):
sa.Index("upload_file_tenant_idx", "tenant_id"),
)
# NOTE: The `id` field is generated within the application to minimize extra roundtrips
# (especially when generating `source_url`).
# The `server_default` serves as a fallback mechanism.
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
@ -1572,12 +1634,32 @@ class UploadFile(Base):
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
extension: Mapped[str] = mapped_column(String(255), nullable=False)
mime_type: Mapped[str] = mapped_column(String(255), nullable=True)
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
# Its value is derived from the `CreatorUserRole` enumeration.
created_by_role: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=sa.text("'account'::character varying")
)
# The `created_by` field stores the ID of the entity that created this upload file.
#
# If `created_by_role` is `ACCOUNT`, it corresponds to `Account.id`.
# Otherwise, it corresponds to `EndUser.id`.
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
# The fields `used` and `used_by` are not consistently maintained.
#
# When using this model in new code, ensure the following:
#
# 1. Set `used` to `true` when the file is utilized.
# 2. Assign `used_by` to the corresponding `Account.id` or `EndUser.id` based on the `created_by_role`.
# 3. Avoid relying on these fields for logic, as their values may not always be accurate.
#
# `used` may indicate whether the file has been utilized by another service.
used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
# `used_by` may indicate the ID of the user who utilized this file.
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
@ -1602,6 +1684,7 @@ class UploadFile(Base):
hash: str | None = None,
source_url: str = "",
):
self.id = str(uuid.uuid4())
self.tenant_id = tenant_id
self.storage_type = storage_type
self.key = key
@ -1672,24 +1755,24 @@ class MessageAgentThought(Base):
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
tool_process_data = mapped_column(sa.Text, nullable=True)
message = mapped_column(sa.Text, nullable=True)
message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
message_unit_price = mapped_column(sa.Numeric, nullable=True)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
message_files = mapped_column(sa.Text, nullable=True)
answer = mapped_column(sa.Text, nullable=True)
answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
total_price = mapped_column(sa.Numeric, nullable=True)
currency = mapped_column(String, nullable=True)
latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def files(self) -> list:
def files(self) -> list[Any]:
if self.message_files:
return cast(list[Any], json.loads(self.message_files))
else:
@ -1700,32 +1783,32 @@ class MessageAgentThought(Base):
return self.tool.split(";") if self.tool else []
@property
def tool_labels(self) -> dict:
def tool_labels(self) -> dict[str, Any]:
try:
if self.tool_labels_str:
return cast(dict, json.loads(self.tool_labels_str))
return cast(dict[str, Any], json.loads(self.tool_labels_str))
else:
return {}
except Exception:
return {}
@property
def tool_meta(self) -> dict:
def tool_meta(self) -> dict[str, Any]:
try:
if self.tool_meta_str:
return cast(dict, json.loads(self.tool_meta_str))
return cast(dict[str, Any], json.loads(self.tool_meta_str))
else:
return {}
except Exception:
return {}
@property
def tool_inputs_dict(self) -> dict:
def tool_inputs_dict(self) -> dict[str, Any]:
tools = self.tools
try:
if self.tool_input:
data = json.loads(self.tool_input)
result = {}
result: dict[str, Any] = {}
for tool in tools:
if tool in data:
result[tool] = data[tool]
@ -1741,12 +1824,12 @@ class MessageAgentThought(Base):
return {}
@property
def tool_outputs_dict(self):
def tool_outputs_dict(self) -> dict[str, Any]:
tools = self.tools
try:
if self.observation:
data = json.loads(self.observation)
result = {}
result: dict[str, Any] = {}
for tool in tools:
if tool in data:
result[tool] = data[tool]
@ -1781,11 +1864,11 @@ class DatasetRetrieverResource(Base):
document_name = mapped_column(sa.Text, nullable=False)
data_source_type = mapped_column(sa.Text, nullable=True)
segment_id = mapped_column(StringUUID, nullable=True)
score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
score: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
content = mapped_column(sa.Text, nullable=False)
hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
hit_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
segment_position: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
index_node_hash = mapped_column(sa.Text, nullable=True)
retriever_from = mapped_column(sa.Text, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
@ -1844,14 +1927,14 @@ class TraceAppConfig(Base):
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
@property
def tracing_config_dict(self):
def tracing_config_dict(self) -> dict[str, Any]:
return self.tracing_config or {}
@property
def tracing_config_str(self):
def tracing_config_str(self) -> str:
return json.dumps(self.tracing_config_dict)
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"app_id": self.app_id,

61
api/models/oauth.py Normal file
View File

@ -0,0 +1,61 @@
from datetime import datetime
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped
from .base import Base
from .engine import db
from .types import StringUUID
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
__tablename__ = "datasource_oauth_params"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
class DatasourceProvider(Base):
__tablename__ = "datasource_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = db.Column(StringUUID, nullable=False)
name: Mapped[str] = db.Column(db.String(255), nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
avatar_url: Mapped[str] = db.Column(db.Text, nullable=True, default="default")
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1")
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
class DatasourceOauthTenantParamConfig(Base):
__tablename__ = "datasource_oauth_tenant_params"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
client_params: Mapped[dict] = db.Column(JSONB, nullable=False, default={})
enabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)

View File

@ -1,7 +1,6 @@
from datetime import datetime
from enum import Enum
from enum import StrEnum, auto
from functools import cached_property
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, text
@ -12,30 +11,30 @@ from .engine import db
from .types import StringUUID
class ProviderType(Enum):
CUSTOM = "custom"
SYSTEM = "system"
class ProviderType(StrEnum):
CUSTOM = auto()
SYSTEM = auto()
@staticmethod
def value_of(value):
def value_of(value: str) -> "ProviderType":
for member in ProviderType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ProviderQuotaType(Enum):
PAID = "paid"
class ProviderQuotaType(StrEnum):
PAID = auto()
"""hosted paid quota"""
FREE = "free"
FREE = auto()
"""third-party free quota"""
TRIAL = "trial"
TRIAL = auto()
"""hosted trial quota"""
@staticmethod
def value_of(value):
def value_of(value: str) -> "ProviderQuotaType":
for member in ProviderQuotaType:
if member.value == value:
return member
@ -63,14 +62,14 @@ class Provider(Base):
String(40), nullable=False, server_default=text("'custom'::character varying")
)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
quota_type: Mapped[Optional[str]] = mapped_column(
quota_type: Mapped[str | None] = mapped_column(
String(40), nullable=True, server_default=text("''::character varying")
)
quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True)
quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0)
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True)
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -133,7 +132,7 @@ class ProviderModel(Base):
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -201,17 +200,17 @@ class ProviderOrder(Base):
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False)
payment_id: Mapped[Optional[str]] = mapped_column(String(191))
transaction_id: Mapped[Optional[str]] = mapped_column(String(191))
payment_id: Mapped[str | None] = mapped_column(String(191))
transaction_id: Mapped[str | None] = mapped_column(String(191))
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
currency: Mapped[Optional[str]] = mapped_column(String(40))
total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer)
currency: Mapped[str | None] = mapped_column(String(40))
total_amount: Mapped[int | None] = mapped_column(sa.Integer)
payment_status: Mapped[str] = mapped_column(
String(40), nullable=False, server_default=text("'wait_pay'::character varying")
)
paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
paid_at: Mapped[datetime | None] = mapped_column(DateTime)
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -255,9 +254,9 @@ class LoadBalancingModelConfig(Base):
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True)
encrypted_config: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -0,0 +1,59 @@
"""Provider ID entities for plugin system."""
import re
from werkzeug.exceptions import NotFound
class GenericProviderID:
organization: str
plugin_name: str
provider_name: str
is_hardcoded: bool
def to_string(self) -> str:
return str(self)
def __str__(self) -> str:
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
if not value:
raise NotFound("plugin not found, please add plugin")
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
if re.match(r"^[a-z0-9_-]+$", value):
value = f"langgenius/{value}/{value}"
else:
raise ValueError(f"Invalid plugin id {value}")
self.organization, self.plugin_name, self.provider_name = value.split("/")
self.is_hardcoded = is_hardcoded
def is_langgenius(self) -> bool:
return self.organization == "langgenius"
@property
def plugin_id(self) -> str:
return f"{self.organization}/{self.plugin_name}"
class ModelProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded)
if self.organization == "langgenius" and self.provider_name == "google":
self.plugin_name = "gemini"
class ToolProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded)
if self.organization == "langgenius":
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
self.plugin_name = f"{self.provider_name}_tool"
class DatasourceProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded)

View File

@ -1,6 +1,5 @@
import json
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
@ -27,7 +26,7 @@ class DataSourceOauthBinding(Base):
source_info = mapped_column(JSONB, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
class DataSourceApiKeyAuthBinding(Base):
@ -45,7 +44,7 @@ class DataSourceApiKeyAuthBinding(Base):
credentials = mapped_column(sa.Text, nullable=True) # JSON
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
def to_dict(self):
return {

View File

@ -1,5 +1,4 @@
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from celery import states
@ -32,7 +31,7 @@ class CeleryTask(Base):
args = mapped_column(sa.LargeBinary, nullable=True)
kwargs = mapped_column(sa.LargeBinary, nullable=True)
worker = mapped_column(String(155), nullable=True)
retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
retries: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
queue = mapped_column(String(155), nullable=True)
@ -46,4 +45,4 @@ class CeleryTaskSet(Base):
)
taskset_id = mapped_column(String(155), unique=True)
result = mapped_column(db.PickleType, nullable=True)
date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True)
date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True)

View File

@ -1,6 +1,7 @@
import json
from collections.abc import Mapping
from datetime import datetime
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import urlparse
import sqlalchemy as sa
@ -8,29 +9,33 @@ from deprecated import deprecated
from sqlalchemy import ForeignKey, String, func
from sqlalchemy.orm import Mapped, mapped_column
from core.file import helpers as file_helpers
from core.helper import encrypter
from core.mcp.types import Tool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from models.base import Base
from models.base import Base, TypeBase
from .engine import db
from .model import Account, App, Tenant
from .types import StringUUID
if TYPE_CHECKING:
from core.mcp.types import Tool as MCPTool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
# system level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthSystemClient(Base):
class ToolOAuthSystemClient(TypeBase):
__tablename__ = "tool_oauth_system_clients"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
plugin_id = mapped_column(String(512), nullable=False)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
@ -54,8 +59,8 @@ class ToolOAuthTenantClient(Base):
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
def oauth_params(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
class BuiltinToolProvider(Base):
@ -96,8 +101,8 @@ class BuiltinToolProvider(Base):
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
@property
def credentials(self) -> dict:
return cast(dict, json.loads(self.encrypted_credentials))
def credentials(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.encrypted_credentials))
class ApiToolProvider(Base):
@ -138,16 +143,20 @@ class ApiToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def schema_type(self) -> ApiProviderSchemaType:
def schema_type(self) -> "ApiProviderSchemaType":
from core.tools.entities.tool_entities import ApiProviderSchemaType
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
def tools(self) -> list[ApiToolBundle]:
def tools(self) -> list["ApiToolBundle"]:
from core.tools.entities.tool_bundle import ApiToolBundle
return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
@property
def credentials(self) -> dict:
return dict(json.loads(self.credentials_str))
def credentials(self) -> dict[str, Any]:
return dict[str, Any](json.loads(self.credentials_str))
@property
def user(self) -> Account | None:
@ -160,7 +169,7 @@ class ApiToolProvider(Base):
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
class ToolLabelBinding(Base):
class ToolLabelBinding(TypeBase):
"""
The table stores the labels for tools.
"""
@ -171,7 +180,7 @@ class ToolLabelBinding(Base):
sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
# tool id
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type
@ -230,7 +239,9 @@ class WorkflowToolProvider(Base):
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
@property
@ -280,6 +291,8 @@ class MCPToolProvider(Base):
)
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
# encrypted headers for MCP server requests
encrypted_headers: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
@ -289,20 +302,24 @@ class MCPToolProvider(Base):
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def credentials(self) -> dict:
def credentials(self) -> dict[str, Any]:
try:
return cast(dict, json.loads(self.encrypted_credentials)) or {}
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
except Exception:
return {}
@property
def mcp_tools(self) -> list[Tool]:
return [Tool(**tool) for tool in json.loads(self.tools)]
def mcp_tools(self) -> list["MCPTool"]:
from core.mcp.types import Tool as MCPTool
return [MCPTool(**tool) for tool in json.loads(self.tools)]
@property
def provider_icon(self) -> dict[str, str] | str:
def provider_icon(self) -> Mapping[str, str] | str:
from core.file import helpers as file_helpers
try:
return cast(dict[str, str], json.loads(self.icon))
return json.loads(self.icon)
except json.JSONDecodeError:
return file_helpers.get_signed_file_url(self.icon)
@ -310,6 +327,62 @@ class MCPToolProvider(Base):
def decrypted_server_url(self) -> str:
return encrypter.decrypt_token(self.tenant_id, self.server_url)
@property
def decrypted_headers(self) -> dict[str, Any]:
"""Get decrypted headers for MCP server requests."""
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
try:
if not self.encrypted_headers:
return {}
headers_data = json.loads(self.encrypted_headers)
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
result = encrypter_instance.decrypt(headers_data)
return result
except Exception:
return {}
@property
def masked_headers(self) -> dict[str, Any]:
"""Get masked headers for frontend display."""
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
try:
if not self.encrypted_headers:
return {}
headers_data = json.loads(self.encrypted_headers)
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# First decrypt, then mask
decrypted_headers = encrypter_instance.decrypt(headers_data)
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
return result
except Exception:
return {}
@property
def masked_server_url(self) -> str:
def mask_url(url: str, mask_char: str = "*") -> str:
@ -327,12 +400,12 @@ class MCPToolProvider(Base):
return mask_url(self.decrypted_server_url)
@property
def decrypted_credentials(self) -> dict:
def decrypted_credentials(self) -> dict[str, Any]:
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController._from_db(self)
provider_controller = MCPToolProviderController.from_db(self)
encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
@ -340,7 +413,7 @@ class MCPToolProvider(Base):
cache=NoOpProviderCredentialCache(),
)
return encrypter.decrypt(self.credentials) # type: ignore
return encrypter.decrypt(self.credentials)
class ToolModelInvoke(Base):
@ -408,11 +481,11 @@ class ToolConversationVariables(Base):
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def variables(self) -> Any:
def variables(self):
return json.loads(self.variables_str)
class ToolFile(Base):
class ToolFile(TypeBase):
"""This table stores file metadata generated in workflows,
not only files created by agent.
"""
@ -423,19 +496,19 @@ class ToolFile(Base):
sa.Index("tool_file_conversation_id_idx", "conversation_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
# conversation user id
user_id: Mapped[str] = mapped_column(StringUUID)
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID)
# conversation id
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
# file key
file_key: Mapped[str] = mapped_column(String(255), nullable=False)
# mime type
mimetype: Mapped[str] = mapped_column(String(255), nullable=False)
# original url
original_url: Mapped[str] = mapped_column(String(2048), nullable=True)
original_url: Mapped[str | None] = mapped_column(String(2048), nullable=True, default=None)
# name
name: Mapped[str] = mapped_column(default="")
# size
@ -476,5 +549,7 @@ class DeprecatedPublishedAppTool(Base):
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
@property
def description_i18n(self) -> I18nObject:
def description_i18n(self) -> "I18nObject":
from core.tools.entities.common_entities import I18nObject
return I18nObject(**json.loads(self.description))

View File

@ -1,29 +1,34 @@
import enum
from typing import Generic, TypeVar
import uuid
from typing import Any, Generic, TypeVar
from sqlalchemy import CHAR, VARCHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.type_api import TypeEngine
class StringUUID(TypeDecorator):
class StringUUID(TypeDecorator[uuid.UUID | str | None]):
impl = CHAR
cache_ok = True
def process_bind_param(self, value, dialect):
def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
elif dialect.name == "postgresql":
return str(value)
else:
return value.hex
if isinstance(value, uuid.UUID):
return value.hex
return value
def load_dialect_impl(self, dialect):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect):
def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
return str(value)
@ -32,7 +37,7 @@ class StringUUID(TypeDecorator):
_E = TypeVar("_E", bound=enum.StrEnum)
class EnumText(TypeDecorator, Generic[_E]):
class EnumText(TypeDecorator[_E | None], Generic[_E]):
impl = VARCHAR
cache_ok = True
@ -50,28 +55,25 @@ class EnumText(TypeDecorator, Generic[_E]):
# leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20)
def process_bind_param(self, value: _E | str | None, dialect):
def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
if isinstance(value, self._enum_class):
return value.value
elif isinstance(value, str):
self._enum_class(value)
return value
else:
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
# Since _E is bound to StrEnum which inherits from str, at this point value must be str
self._enum_class(value)
return value
def load_dialect_impl(self, dialect):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value, dialect) -> _E | None:
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
if value is None:
return value
if not isinstance(value, str):
raise TypeError(f"expected str, got {type(value)}")
# Type annotation guarantees value is str at this point
return self._enum_class(value)
def compare_values(self, x, y):
def compare_values(self, x: _E | None, y: _E | None) -> bool:
if x is None or y is None:
return x is y
return x == y

View File

@ -2,26 +2,28 @@ import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, exists, orm, select
from sqlalchemy import DateTime, Select, exists, orm, select
from core.file.constants import maybe_file_object
from core.file.models import File
from core.variables import utils as variable_utils
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
if TYPE_CHECKING:
from models.model import AppMode
from models.model import AppMode, UploadFile
from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
@ -35,19 +37,20 @@ from libs import helper
from .account import Account
from .base import Base
from .engine import db
from .enums import CreatorUserRole, DraftVariableType
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType
from .types import EnumText, StringUUID
logger = logging.getLogger(__name__)
class WorkflowType(Enum):
class WorkflowType(StrEnum):
"""
Workflow Type Enum
"""
WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"
@classmethod
def value_of(cls, value: str) -> "WorkflowType":
@ -130,7 +133,7 @@ class Workflow(Base):
_features: Mapped[str] = mapped_column("features", sa.TEXT)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by: Mapped[Optional[str]] = mapped_column(StringUUID)
updated_by: Mapped[str | None] = mapped_column(StringUUID)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
@ -143,6 +146,9 @@ class Workflow(Base):
_conversation_variables: Mapped[str] = mapped_column(
"conversation_variables", sa.Text, nullable=False, server_default="{}"
)
_rag_pipeline_variables: Mapped[str] = mapped_column(
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
)
VERSION_DRAFT = "draft"
@ -159,6 +165,7 @@ class Workflow(Base):
created_by: str,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
) -> "Workflow":
@ -173,6 +180,7 @@ class Workflow(Base):
workflow.created_by = created_by
workflow.environment_variables = environment_variables or []
workflow.conversation_variables = conversation_variables or []
workflow.rag_pipeline_variables = rag_pipeline_variables or []
workflow.marked_name = marked_name
workflow.marked_comment = marked_comment
workflow.created_at = naive_utc_now()
@ -224,7 +232,7 @@ class Workflow(Base):
raise WorkflowDataError("nodes not found in workflow graph")
try:
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
assert isinstance(node_config, dict)
@ -282,7 +290,7 @@ class Workflow(Base):
return self._features
@features.setter
def features(self, value: str) -> None:
def features(self, value: str):
self._features = value
@property
@ -337,7 +345,7 @@ class Workflow(Base):
else:
yield from ((node["id"], node["data"]) for node in graph_dict["nodes"])
def user_input_form(self, to_old_structure: bool = False) -> list:
def user_input_form(self, to_old_structure: bool = False) -> list[Any]:
# get start node from graph
if not self.graph:
return []
@ -354,7 +362,7 @@ class Workflow(Base):
variables: list[Any] = start_node.get("data", {}).get("variables", [])
if to_old_structure:
old_structure_variables = []
old_structure_variables: list[dict[str, Any]] = []
for variable in variables:
old_structure_variables.append({variable["type"]: variable})
@ -362,6 +370,12 @@ class Workflow(Base):
return variables
def rag_pipeline_user_input_form(self) -> list:
# get user_input_form from start node
variables: list[Any] = self.rag_pipeline_variables
return variables
@property
def unique_hash(self) -> str:
"""
@ -394,9 +408,7 @@ class Workflow(Base):
@property
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# _environment_variables is guaranteed to be non-None due to server_default="{}"
# Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id
@ -404,23 +416,24 @@ class Workflow(Base):
if not tenant_id:
return []
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
results = [
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
]
# decrypt secret variables value
def decrypt_func(var):
def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var
else:
raise AssertionError("this statement should be unreachable.")
# Other variable types are not supported for environment variables
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
map(decrypt_func, results)
)
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [
decrypt_func(var) for var in results
]
return decrypted_results
@environment_variables.setter
@ -448,7 +461,7 @@ class Workflow(Base):
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
def encrypt_func(var):
def encrypt_func(var: Variable) -> Variable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@ -473,26 +486,42 @@ class Workflow(Base):
"features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
"rag_pipeline_variables": self.rag_pipeline_variables,
}
return result
@property
def conversation_variables(self) -> Sequence[Variable]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
# _conversation_variables is guaranteed to be non-None due to server_default="{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
return results
@conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]) -> None:
def conversation_variables(self, value: Sequence[Variable]):
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
)
@property
def rag_pipeline_variables(self) -> list[dict]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._rag_pipeline_variables is None:
self._rag_pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
results = list(variables_dict.values())
return results
@rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: list[dict]) -> None:
self._rag_pipeline_variables = json.dumps(
{item["variable"]: item for item in values},
ensure_ascii=False,
)
@staticmethod
def version_from_datetime(d: datetime) -> str:
return str(d)
@ -550,18 +579,18 @@ class WorkflowRun(Base):
type: Mapped[str] = mapped_column(String(255))
triggered_from: Mapped[str] = mapped_column(String(255))
version: Mapped[str] = mapped_column(String(255))
graph: Mapped[Optional[str]] = mapped_column(sa.Text)
inputs: Mapped[Optional[str]] = mapped_column(sa.Text)
graph: Mapped[str | None] = mapped_column(sa.Text)
inputs: Mapped[str | None] = mapped_column(sa.Text)
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(sa.Text)
outputs: Mapped[str | None] = mapped_column(sa.Text, default="{}")
error: Mapped[str | None] = mapped_column(sa.Text)
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
@property
@ -625,7 +654,7 @@ class WorkflowRun(Base):
}
@classmethod
def from_dict(cls, data: dict) -> "WorkflowRun":
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@ -657,9 +686,10 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum):
SINGLE_STEP = "single-step"
WORKFLOW_RUN = "workflow-run"
RAG_PIPELINE_RUN = "rag-pipeline-run"
class WorkflowNodeExecutionModel(Base):
class WorkflowNodeExecutionModel(Base): # This model is expected to have `offload_data` preloaded in most cases.
"""
Workflow Node Execution
@ -710,7 +740,8 @@ class WorkflowNodeExecutionModel(Base):
__tablename__ = "workflow_node_executions"
@declared_attr
def __table_args__(cls): # noqa
@classmethod
def __table_args__(cls) -> Any:
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
@ -747,7 +778,7 @@ class WorkflowNodeExecutionModel(Base):
# MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes.
cls.created_at.desc(), # type: ignore
cls.created_at.desc(),
),
)
@ -756,24 +787,50 @@ class WorkflowNodeExecutionModel(Base):
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID)
triggered_from: Mapped[str] = mapped_column(String(255))
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
index: Mapped[int] = mapped_column(sa.Integer)
predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255))
node_execution_id: Mapped[Optional[str]] = mapped_column(String(255))
predecessor_node_id: Mapped[str | None] = mapped_column(String(255))
node_execution_id: Mapped[str | None] = mapped_column(String(255))
node_id: Mapped[str] = mapped_column(String(255))
node_type: Mapped[str] = mapped_column(String(255))
title: Mapped[str] = mapped_column(String(255))
inputs: Mapped[Optional[str]] = mapped_column(sa.Text)
process_data: Mapped[Optional[str]] = mapped_column(sa.Text)
outputs: Mapped[Optional[str]] = mapped_column(sa.Text)
inputs: Mapped[str | None] = mapped_column(sa.Text)
process_data: Mapped[str | None] = mapped_column(sa.Text)
outputs: Mapped[str | None] = mapped_column(sa.Text)
status: Mapped[str] = mapped_column(String(255))
error: Mapped[Optional[str]] = mapped_column(sa.Text)
error: Mapped[str | None] = mapped_column(sa.Text)
elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0"))
execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text)
execution_metadata: Mapped[str | None] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
created_by_role: Mapped[str] = mapped_column(String(255))
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
"WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True,
lazy="raise",
back_populates="execution",
)
@staticmethod
def preload_offload_data(
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod
def preload_offload_data_and_files(
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
# Using `joinedload` instead of `selectinload` to minimize database roundtrips,
# as `selectinload` would require separate queries for `inputs_file` and `outputs_file`.
orm.selectinload(WorkflowNodeExecutionOffload.file),
)
)
@property
def created_by_account(self):
@ -809,25 +866,148 @@ class WorkflowNodeExecutionModel(Base):
return json.loads(self.execution_metadata) if self.execution_metadata else {}
@property
def extras(self):
def extras(self) -> dict[str, Any]:
from core.tools.tool_manager import ToolManager
extras = {}
extras: dict[str, Any] = {}
if self.execution_metadata_dict:
from core.workflow.nodes import NodeType
if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict:
tool_info = self.execution_metadata_dict["tool_info"]
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self.tenant_id,
provider_type=tool_info["provider_type"],
provider_id=tool_info["provider_id"],
)
elif self.node_type == NodeType.DATASOURCE.value and "datasource_info" in self.execution_metadata_dict:
datasource_info = self.execution_metadata_dict["datasource_info"]
extras["icon"] = datasource_info.get("icon")
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
class WorkflowAppLogCreatedFrom(Enum):
@property
def inputs_truncated(self) -> bool:
"""Check if inputs were truncated (offloaded to external storage)."""
return self._get_offload_by_type(ExecutionOffLoadType.INPUTS) is not None
@property
def outputs_truncated(self) -> bool:
"""Check if outputs were truncated (offloaded to external storage)."""
return self._get_offload_by_type(ExecutionOffLoadType.OUTPUTS) is not None
@property
def process_data_truncated(self) -> bool:
"""Check if process_data were truncated (offloaded to external storage)."""
return self._get_offload_by_type(ExecutionOffLoadType.PROCESS_DATA) is not None
@staticmethod
def _load_full_content(session: orm.Session, file_id: str, storage: Storage):
from .model import UploadFile
stmt = sa.select(UploadFile).where(UploadFile.id == file_id)
file = session.scalars(stmt).first()
assert file is not None, f"UploadFile with id {file_id} should exist but not"
content = storage.load(file.key)
return json.loads(content)
def load_full_inputs(self, session: orm.Session, storage: Storage) -> Mapping[str, Any] | None:
offload = self._get_offload_by_type(ExecutionOffLoadType.INPUTS)
if offload is None:
return self.inputs_dict
return self._load_full_content(session, offload.file_id, storage)
def load_full_outputs(self, session: orm.Session, storage: Storage) -> Mapping[str, Any] | None:
offload: WorkflowNodeExecutionOffload | None = self._get_offload_by_type(ExecutionOffLoadType.OUTPUTS)
if offload is None:
return self.outputs_dict
return self._load_full_content(session, offload.file_id, storage)
def load_full_process_data(self, session: orm.Session, storage: Storage) -> Mapping[str, Any] | None:
offload: WorkflowNodeExecutionOffload | None = self._get_offload_by_type(ExecutionOffLoadType.PROCESS_DATA)
if offload is None:
return self.process_data_dict
return self._load_full_content(session, offload.file_id, storage)
class WorkflowNodeExecutionOffload(Base):
__tablename__ = "workflow_node_execution_offload"
__table_args__ = (
# PostgreSQL 14 treats NULL values as distinct in unique constraints by default,
# allowing multiple records with NULL values for the same column combination.
#
# This behavior allows us to have multiple records with NULL node_execution_id,
# simplifying garbage collection process.
UniqueConstraint(
"node_execution_id",
"type",
# Note: PostgreSQL 15+ supports explicit `nulls distinct` behavior through
# `postgresql_nulls_not_distinct=False`, which would make our intention clearer.
# We rely on PostgreSQL's default behavior of treating NULLs as distinct values.
# postgresql_nulls_not_distinct=False,
),
)
_HASH_COL_SIZE = 64
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
server_default=sa.text("uuidv7()"),
)
created_at: Mapped[datetime] = mapped_column(
DateTime, default=naive_utc_now, server_default=func.current_timestamp()
)
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
# `node_execution_id` indicates the `WorkflowNodeExecutionModel` associated with this offload record.
# A value of `None` signifies that this offload record is not linked to any execution record
# and should be considered for garbage collection.
node_execution_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
type_: Mapped[ExecutionOffLoadType] = mapped_column(EnumText(ExecutionOffLoadType), name="type", nullable=False)
# Design Decision: Combining inputs and outputs into a single object was considered to reduce I/O
# operations. However, due to the current design of `WorkflowNodeExecutionRepository`,
# the `save` method is called at two distinct times:
#
# - When the node starts execution: the `inputs` field exists, but the `outputs` field is absent
# - When the node completes execution (either succeeded or failed): the `outputs` field becomes available
#
# It's difficult to correlate these two successive calls to `save` for combined storage.
# Converting the `WorkflowNodeExecutionRepository` to buffer the first `save` call and flush
# when execution completes was also considered, but this would make the execution state unobservable
# until completion, significantly damaging the observability of workflow execution.
#
# Given these constraints, `inputs` and `outputs` are stored separately to maintain real-time
# observability and system reliability.
# `file_id` references to the offloaded storage object containing the data.
file_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
execution: Mapped[WorkflowNodeExecutionModel] = orm.relationship(
foreign_keys=[node_execution_id],
lazy="raise",
uselist=False,
primaryjoin="WorkflowNodeExecutionOffload.node_execution_id == WorkflowNodeExecutionModel.id",
back_populates="offload_data",
)
file: Mapped[Optional["UploadFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
primaryjoin="WorkflowNodeExecutionOffload.file_id == UploadFile.id",
)
class WorkflowAppLogCreatedFrom(StrEnum):
"""
Workflow App Log Created From Enum
"""
@ -883,6 +1063,7 @@ class WorkflowAppLog(Base):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
sa.Index("workflow_app_log_workflow_run_id_idx", "workflow_run_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
@ -939,7 +1120,7 @@ class ConversationVariable(Base):
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None:
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str):
self.id = id
self.app_id = app_id
self.conversation_id = conversation_id
@ -988,7 +1169,10 @@ class WorkflowDraftVariable(Base):
]
__tablename__ = "workflow_draft_variables"
__table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),)
__table_args__ = (
UniqueConstraint(*unique_app_id_node_id_name()),
Index("workflow_draft_variable_file_id_idx", "file_id"),
)
# Required for instance variable annotation.
__allow_unmapped__ = True
@ -1049,9 +1233,16 @@ class WorkflowDraftVariable(Base):
selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector")
# The data type of this variable's value
#
# If the variable is offloaded, `value_type` represents the type of the truncated value,
# which may differ from the original value's type. Typically, they are the same,
# but in cases where the structurally truncated value still exceeds the size limit,
# text slicing is applied, and the `value_type` is converted to `STRING`.
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
# The variable's value serialized as a JSON string
#
# If the variable is offloaded, `value` contains a truncated version, not the full original value.
value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value")
# Controls whether the variable should be displayed in the variable inspection panel
@ -1071,6 +1262,35 @@ class WorkflowDraftVariable(Base):
default=None,
)
# Reference to WorkflowDraftVariableFile for offloaded large variables
#
# Indicates whether the current draft variable is offloaded.
# If not offloaded, this field will be None.
file_id: Mapped[str | None] = mapped_column(
StringUUID,
nullable=True,
default=None,
comment="Reference to WorkflowDraftVariableFile if variable is offloaded to external storage",
)
is_default_value: Mapped[bool] = mapped_column(
sa.Boolean,
nullable=False,
default=False,
comment=(
"Indicates whether the current value is the default for a conversation variable. "
"Always `FALSE` for other types of variables."
),
)
# Relationship to WorkflowDraftVariableFile
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
primaryjoin="WorkflowDraftVariableFile.id == WorkflowDraftVariable.file_id",
)
# Cache for deserialized value
#
# NOTE(QuantumGhost): This field serves two purposes:
@ -1084,7 +1304,7 @@ class WorkflowDraftVariable(Base):
# making this attribute harder to access from outside the class.
__value: Segment | None
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
The constructor of `WorkflowDraftVariable` is not intended for
direct use outside this file. Its solo purpose is setup private state
@ -1102,15 +1322,15 @@ class WorkflowDraftVariable(Base):
self.__value = None
def get_selector(self) -> list[str]:
selector = json.loads(self.selector)
selector: Any = json.loads(self.selector)
if not isinstance(selector, list):
logger.error(
"invalid selector loaded from database, type=%s, value=%s",
type(selector),
type(selector).__name__,
self.selector,
)
raise ValueError("invalid selector.")
return selector
return cast(list[str], selector)
def _set_selector(self, value: list[str]):
self.selector = json.dumps(value)
@ -1120,7 +1340,7 @@ class WorkflowDraftVariable(Base):
return self.build_segment_with_type(self.value_type, value)
@staticmethod
def rebuild_file_types(value: Any) -> Any:
def rebuild_file_types(value: Any):
# NOTE(QuantumGhost): Temporary workaround for structured data handling.
# By this point, `output` has been converted to dict by
# `WorkflowEntry.handle_special_values`, so we need to
@ -1133,15 +1353,17 @@ class WorkflowDraftVariable(Base):
# `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
if isinstance(value, dict):
if not maybe_file_object(value):
return value
return cast(Any, value)
return File.model_validate(value)
elif isinstance(value, list) and value:
first = value[0]
value_list = cast(list[Any], value)
first: Any = value_list[0]
if not maybe_file_object(first):
return value
return [File.model_validate(i) for i in value]
return cast(Any, value)
file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list]
return cast(Any, file_list)
else:
return value
return cast(Any, value)
@classmethod
def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment:
@ -1218,6 +1440,9 @@ class WorkflowDraftVariable(Base):
case _:
return DraftVariableType.NODE
def is_truncated(self) -> bool:
return self.file_id is not None
@classmethod
def _new(
cls,
@ -1228,6 +1453,7 @@ class WorkflowDraftVariable(Base):
value: Segment,
node_execution_id: str | None,
description: str = "",
file_id: str | None = None,
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.created_at = _naive_utc_datetime()
@ -1237,6 +1463,7 @@ class WorkflowDraftVariable(Base):
variable.node_id = node_id
variable.name = name
variable.set_value(value)
variable.file_id = file_id
variable._set_selector(list(variable_utils.to_selector(node_id, name)))
variable.node_execution_id = node_execution_id
return variable
@ -1292,6 +1519,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str,
visible: bool = True,
editable: bool = True,
file_id: str | None = None,
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
@ -1299,6 +1527,7 @@ class WorkflowDraftVariable(Base):
name=name,
node_execution_id=node_execution_id,
value=value,
file_id=file_id,
)
variable.visible = visible
variable.editable = editable
@ -1309,6 +1538,93 @@ class WorkflowDraftVariable(Base):
return self.last_edited_at is not None
class WorkflowDraftVariableFile(Base):
"""Stores metadata about files associated with large workflow draft variables.
This model acts as an intermediary between WorkflowDraftVariable and UploadFile,
allowing for proper cleanup of orphaned files when variables are updated or deleted.
The MIME type of the stored content is recorded in `UploadFile.mime_type`.
Possible values are 'application/json' for JSON types other than plain text,
and 'text/plain' for JSON strings.
"""
__tablename__ = "workflow_draft_variable_files"
# Primary key
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
default=uuidv7,
server_default=sa.text("uuidv7()"),
)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default=_naive_utc_datetime,
server_default=func.current_timestamp(),
)
tenant_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
comment="The tenant to which the WorkflowDraftVariableFile belongs, referencing Tenant.id",
)
app_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
comment="The application to which the WorkflowDraftVariableFile belongs, referencing App.id",
)
user_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
comment="The owner to of the WorkflowDraftVariableFile, referencing Account.id",
)
# Reference to the `UploadFile.id` field
upload_file_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
comment="Reference to UploadFile containing the large variable data",
)
# -------------- metadata about the variable content --------------
# The `size` is already recorded in UploadFiles. It is duplicated here to avoid an additional database lookup.
size: Mapped[int | None] = mapped_column(
sa.BigInteger,
nullable=False,
comment="Size of the original variable content in bytes",
)
length: Mapped[int | None] = mapped_column(
sa.Integer,
nullable=True,
comment=(
"Length of the original variable content. For array and array-like types, "
"this represents the number of elements. For object types, it indicates the number of keys. "
"For other types, the value is NULL."
),
)
# The `value_type` field records the type of the original value.
value_type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=20),
nullable=False,
)
# Relationship to UploadFile
upload_file: Mapped["UploadFile"] = orm.relationship(
foreign_keys=[upload_file_id],
lazy="raise",
uselist=False,
primaryjoin="WorkflowDraftVariableFile.upload_file_id == UploadFile.id",
)
def is_system_variable_editable(name: str) -> bool:
return name in _EDITABLE_SYSTEM_VARIABLE