mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
Migrate SQLAlchemy from 1.x to 2.0 with automated and manual adjustments (#23224)
Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -3,6 +3,7 @@ from typing import Any, Optional
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
@ -211,11 +212,10 @@ class Jieba(BaseKeyword):
|
||||
return sorted_chunk_indices[:k]
|
||||
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
|
||||
.first()
|
||||
stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
|
||||
)
|
||||
document_segment = db.session.scalar(stmt)
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.add(document_segment)
|
||||
|
||||
@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
from configs import dify_config
|
||||
@ -127,7 +128,8 @@ class RetrievalService:
|
||||
external_retrieval_model: Optional[dict] = None,
|
||||
metadata_filtering_conditions: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
@ -316,10 +318,8 @@ class RetrievalService:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
)
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
@ -378,17 +378,13 @@ class RetrievalService:
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
.first()
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = db.session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
@ -18,6 +18,7 @@ from qdrant_client.http.models import (
|
||||
TokenizerType,
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -445,11 +446,8 @@ class QdrantVector(BaseVector):
|
||||
class QdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
dataset_collection_binding = db.session.scalars(stmt).one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
|
||||
@ -20,6 +20,7 @@ from qdrant_client.http.models import (
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from requests.auth import HTTPDigestAuth
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector):
|
||||
|
||||
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if not tidb_auth_binding:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -45,11 +47,10 @@ class Vector:
|
||||
vector_type = self._dataset.index_struct_dict["type"]
|
||||
else:
|
||||
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
|
||||
whitelist = (
|
||||
db.session.query(Whitelist)
|
||||
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
|
||||
.one_or_none()
|
||||
stmt = select(Whitelist).where(
|
||||
Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db"
|
||||
)
|
||||
whitelist = db.session.scalars(stmt).one_or_none()
|
||||
if whitelist:
|
||||
vector_type = VectorType.TIDB_ON_QDRANT
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -41,9 +41,8 @@ class DatasetDocumentStore:
|
||||
|
||||
@property
|
||||
def docs(self) -> dict[str, Document]:
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
|
||||
)
|
||||
stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id)
|
||||
document_segments = db.session.scalars(stmt).all()
|
||||
|
||||
output = {}
|
||||
for document_segment in document_segments:
|
||||
@ -228,10 +227,9 @@ class DatasetDocumentStore:
|
||||
return data
|
||||
|
||||
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
|
||||
.first()
|
||||
stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id
|
||||
)
|
||||
document_segment = db.session.scalar(stmt)
|
||||
|
||||
return document_segment
|
||||
|
||||
@ -4,6 +4,7 @@ import operator
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@ -367,18 +368,13 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
stmt = select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
)
|
||||
data_source_binding = db.session.scalar(stmt)
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(
|
||||
|
||||
@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import Float, and_, or_, text
|
||||
from sqlalchemy import Float, and_, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -135,7 +135,8 @@ class DatasetRetrieval:
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
@ -240,15 +241,12 @@ class DatasetRetrieval:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
@ -327,7 +325,8 @@ class DatasetRetrieval:
|
||||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
if dataset:
|
||||
results = []
|
||||
if dataset.provider == "external":
|
||||
@ -514,22 +513,18 @@ class DatasetRetrieval:
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(DatasetDocument.id == document.metadata["document_id"])
|
||||
.first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
)
|
||||
dataset_document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
.first()
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
@ -600,7 +595,8 @@ class DatasetRetrieval:
|
||||
):
|
||||
with flask_app.app_context():
|
||||
with Session(db.engine) as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
@ -685,7 +681,8 @@ class DatasetRetrieval:
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
@ -958,7 +955,8 @@ class DatasetRetrieval:
|
||||
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(metadata_stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
# get metadata model config
|
||||
if metadata_model_config is None:
|
||||
|
||||
Reference in New Issue
Block a user