feat:json metadat filter adapt (#65)

* config adapt revert

* ci test

* fix mysql migration test

* fix

* fix

* lint fix

* fix ob config

* fix

* fix

* fix

* test over

* test

* fix

* fix

* fix style

* test over

* retain gin for pg

* gin for pg

* uuid defalut in versions

* ci test

* ci test

* fix

* fix

* fix

* fix

* pg josnb

* fix
This commit is contained in:
longbingljw
2025-11-15 22:29:59 +08:00
committed by GitHub
parent 84935b9169
commit 6433ac8209
57 changed files with 386 additions and 487 deletions

View File

@ -29,7 +29,7 @@ from .account import Account
from .base import Base
from .engine import db
from .model import App, Tag, TagBinding, UploadFile
from .types import BinaryData, LongText, StringUUID
from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index
logger = logging.getLogger(__name__)
@ -45,6 +45,7 @@ class Dataset(Base):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="dataset_pkey"),
sa.Index("dataset_tenant_idx", "tenant_id"),
adjusted_json_index("retrieval_model_idx", "retrieval_model"),
)
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
@ -69,9 +70,9 @@ class Dataset(Base):
embedding_model_provider = mapped_column(sa.String(255), nullable=True)
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(sa.JSON, nullable=True)
retrieval_model = mapped_column(AdjustedJSON, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
icon_info = mapped_column(sa.JSON, nullable=True)
icon_info = mapped_column(AdjustedJSON, nullable=True)
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(sa.String(255), nullable=True)
@ -347,6 +348,7 @@ class Document(Base):
sa.Index("document_dataset_id_idx", "dataset_id"),
sa.Index("document_is_paused_idx", "is_paused"),
sa.Index("document_tenant_idx", "tenant_id"),
adjusted_json_index("document_metadata_idx", "doc_metadata"),
)
# initial fields
@ -405,7 +407,7 @@ class Document(Base):
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
doc_type = mapped_column(String(40), nullable=True)
doc_metadata = mapped_column(sa.JSON, nullable=True)
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_language = mapped_column(String(255), nullable=True)

View File

@ -7,7 +7,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .base import Base
from .types import LongText, StringUUID
from .types import AdjustedJSON, LongText, StringUUID
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
@ -20,7 +20,7 @@ class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
system_credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
class DatasourceProvider(Base):
@ -36,7 +36,7 @@ class DatasourceProvider(Base):
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default")
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
@ -58,7 +58,7 @@ class DatasourceOauthTenantParamConfig(Base):
tenant_id = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
client_params: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default={})
client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default={})
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -8,7 +8,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from models.base import TypeBase
from .types import LongText, StringUUID
from .types import AdjustedJSON, LongText, StringUUID, adjusted_json_index
class DataSourceOauthBinding(TypeBase):
@ -16,13 +16,14 @@ class DataSourceOauthBinding(TypeBase):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="source_binding_pkey"),
sa.Index("source_binding_tenant_id_idx", "tenant_id"),
adjusted_json_index("source_info_idx", "source_info"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
source_info: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
source_info: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -2,12 +2,15 @@ import enum
import uuid
from typing import Any, Generic, TypeVar
import sqlalchemy as sa
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
from sqlalchemy.dialects.mysql import LONGBLOB, LONGTEXT
from sqlalchemy.dialects.postgresql import BYTEA, UUID
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, UUID
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.type_api import TypeEngine
from configs import dify_config
class StringUUID(TypeDecorator[uuid.UUID | str | None]):
impl = CHAR
@ -81,6 +84,32 @@ class BinaryData(TypeDecorator[bytes | None]):
return value
class AdjustedJSON(TypeDecorator[dict | list | None]):
impl = sa.JSON
cache_ok = True
def __init__(self, astext_type=None):
self.astext_type = astext_type
super().__init__()
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
if self.astext_type:
return dialect.type_descriptor(JSONB(astext_type=self.astext_type))
else:
return dialect.type_descriptor(JSONB())
elif dialect.name == "mysql":
return dialect.type_descriptor(sa.JSON())
else:
return dialect.type_descriptor(sa.JSON())
def process_bind_param(self, value: dict | list | None, dialect: Dialect) -> dict | list | None:
return value
def process_result_value(self, value: dict | list | None, dialect: Dialect) -> dict | list | None:
return value
_E = TypeVar("_E", bound=enum.StrEnum)
@ -124,3 +153,11 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
if x is None or y is None:
return x is y
return x == y
def adjusted_json_index(index_name, column_name):
index_name = index_name or f"{column_name}_idx"
if dify_config.DB_TYPE == "postgresql":
return sa.Index(index_name, column_name, postgresql_using="gin")
else:
return None