mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
feat:json metadat filter adapt (#65)
* config adapt revert * ci test * fix mysql migration test * fix * fix * lint fix * fix ob config * fix * fix * fix * test over * test * fix * fix * fix style * test over * retain gin for pg * gin for pg * uuid defalut in versions * ci test * ci test * fix * fix * fix * fix * pg josnb * fix
This commit is contained in:
@ -29,7 +29,7 @@ from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import BinaryData, LongText, StringUUID
|
||||
from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -45,6 +45,7 @@ class Dataset(Base):
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_pkey"),
|
||||
sa.Index("dataset_tenant_idx", "tenant_id"),
|
||||
adjusted_json_index("retrieval_model_idx", "retrieval_model"),
|
||||
)
|
||||
|
||||
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
|
||||
@ -69,9 +70,9 @@ class Dataset(Base):
|
||||
embedding_model_provider = mapped_column(sa.String(255), nullable=True)
|
||||
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
||||
retrieval_model = mapped_column(sa.JSON, nullable=True)
|
||||
retrieval_model = mapped_column(AdjustedJSON, nullable=True)
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
icon_info = mapped_column(sa.JSON, nullable=True)
|
||||
icon_info = mapped_column(AdjustedJSON, nullable=True)
|
||||
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
|
||||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||
@ -347,6 +348,7 @@ class Document(Base):
|
||||
sa.Index("document_dataset_id_idx", "dataset_id"),
|
||||
sa.Index("document_is_paused_idx", "is_paused"),
|
||||
sa.Index("document_tenant_idx", "tenant_id"),
|
||||
adjusted_json_index("document_metadata_idx", "doc_metadata"),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
@ -405,7 +407,7 @@ class Document(Base):
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
doc_type = mapped_column(String(40), nullable=True)
|
||||
doc_metadata = mapped_column(sa.JSON, nullable=True)
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import Base
|
||||
from .types import LongText, StringUUID
|
||||
from .types import AdjustedJSON, LongText, StringUUID
|
||||
|
||||
|
||||
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
|
||||
@ -20,7 +20,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
system_credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
|
||||
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
|
||||
|
||||
|
||||
class DatasourceProvider(Base):
|
||||
@ -36,7 +36,7 @@ class DatasourceProvider(Base):
|
||||
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
|
||||
avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default")
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
|
||||
@ -58,7 +58,7 @@ class DatasourceOauthTenantParamConfig(Base):
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
client_params: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default={})
|
||||
client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default={})
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@ -8,7 +8,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from models.base import TypeBase
|
||||
|
||||
from .types import LongText, StringUUID
|
||||
from .types import AdjustedJSON, LongText, StringUUID, adjusted_json_index
|
||||
|
||||
|
||||
class DataSourceOauthBinding(TypeBase):
|
||||
@ -16,13 +16,14 @@ class DataSourceOauthBinding(TypeBase):
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||
sa.Index("source_binding_tenant_id_idx", "tenant_id"),
|
||||
adjusted_json_index("source_info_idx", "source_info"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source_info: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
|
||||
source_info: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@ -2,12 +2,15 @@ import enum
|
||||
import uuid
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
|
||||
from sqlalchemy.dialects.mysql import LONGBLOB, LONGTEXT
|
||||
from sqlalchemy.dialects.postgresql import BYTEA, UUID
|
||||
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, UUID
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class StringUUID(TypeDecorator[uuid.UUID | str | None]):
|
||||
impl = CHAR
|
||||
@ -81,6 +84,32 @@ class BinaryData(TypeDecorator[bytes | None]):
|
||||
return value
|
||||
|
||||
|
||||
class AdjustedJSON(TypeDecorator[dict | list | None]):
|
||||
impl = sa.JSON
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, astext_type=None):
|
||||
self.astext_type = astext_type
|
||||
super().__init__()
|
||||
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "postgresql":
|
||||
if self.astext_type:
|
||||
return dialect.type_descriptor(JSONB(astext_type=self.astext_type))
|
||||
else:
|
||||
return dialect.type_descriptor(JSONB())
|
||||
elif dialect.name == "mysql":
|
||||
return dialect.type_descriptor(sa.JSON())
|
||||
else:
|
||||
return dialect.type_descriptor(sa.JSON())
|
||||
|
||||
def process_bind_param(self, value: dict | list | None, dialect: Dialect) -> dict | list | None:
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: dict | list | None, dialect: Dialect) -> dict | list | None:
|
||||
return value
|
||||
|
||||
|
||||
_E = TypeVar("_E", bound=enum.StrEnum)
|
||||
|
||||
|
||||
@ -124,3 +153,11 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
|
||||
if x is None or y is None:
|
||||
return x is y
|
||||
return x == y
|
||||
|
||||
|
||||
def adjusted_json_index(index_name, column_name):
|
||||
index_name = index_name or f"{column_name}_idx"
|
||||
if dify_config.DB_TYPE == "postgresql":
|
||||
return sa.Index(index_name, column_name, postgresql_using="gin")
|
||||
else:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user