Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

This commit is contained in:
-LAN-
2025-09-11 15:13:31 +08:00
105 changed files with 3132 additions and 568 deletions

View File

@ -246,6 +246,8 @@ class AccountService:
account.name = name
if password:
valid_password(password)
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()

View File

@ -263,11 +263,9 @@ class AppAnnotationService:
db.session.delete(annotation)
annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.where(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
annotation_hit_histories = db.session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
).all()
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)

View File

@ -1,5 +1,7 @@
import json
from sqlalchemy import select
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
@ -9,11 +11,11 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
data_source_api_key_bindings = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
).all()
return data_source_api_key_bindings
@staticmethod

View File

@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs:
@classmethod
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
with flask_app.app_context():
apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
app_ids = [app.id for app in apps]
while True:
with Session(db.engine).no_autoflush as session:

View File

@ -6,6 +6,7 @@ import secrets
import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, Optional
import sqlalchemy as sa
@ -741,14 +742,12 @@ class DatasetService:
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.where(
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
.all()
)
).all()
if dataset_auto_disable_logs:
return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
@ -885,69 +884,58 @@ class DocumentService:
return document
@staticmethod
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
).all()
return documents
@staticmethod
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
)
.all()
)
).all()
return documents
@staticmethod
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
).all()
return documents
@staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.all()
)
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
).all()
return documents
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]:
assert isinstance(current_user, Account)
documents = (
db.session.query(Document)
.where(
documents = db.session.scalars(
select(Document).where(
Document.batch == batch,
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id,
)
.all()
)
).all()
return documents
@ -984,7 +972,7 @@ class DocumentService:
# Check if document_ids is not empty to avoid WHERE false condition
if not document_ids or len(document_ids) == 0:
return
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all()
file_ids = [
document.data_source_info_dict["upload_file_id"]
for document in documents
@ -2424,16 +2412,14 @@ class SegmentService:
if not segment_ids or len(segment_ids) == 0:
return
if action == "enable":
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False,
)
.all()
)
).all()
if not segments:
return
real_deal_segment_ids = []
@ -2451,16 +2437,14 @@ class SegmentService:
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
elif action == "disable":
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
)
.all()
)
).all()
if not segments:
return
real_deal_segment_ids = []
@ -2532,16 +2516,13 @@ class SegmentService:
dataset: Dataset,
) -> list[ChildChunk]:
assert isinstance(current_user, Account)
child_chunks = (
db.session.query(ChildChunk)
.where(
child_chunks = db.session.scalars(
select(ChildChunk).where(
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
)
.all()
)
).all()
child_chunks_map = {chunk.id: chunk for chunk in child_chunks}
new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
@ -2751,19 +2732,13 @@ class DatasetCollectionBindingService:
class DatasetPermissionService:
@classmethod
def get_dataset_partial_member_list(cls, dataset_id):
user_list_query = (
db.session.query(
user_list_query = db.session.scalars(
select(
DatasetPermission.account_id,
)
.where(DatasetPermission.dataset_id == dataset_id)
.all()
)
).where(DatasetPermission.dataset_id == dataset_id)
).all()
user_list = []
for user in user_list_query:
user_list.append(user.account_id)
return user_list
return user_list_query
@classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):

View File

@ -3,7 +3,7 @@ import logging
from json import JSONDecodeError
from typing import Optional, Union
from sqlalchemy import or_
from sqlalchemy import or_, select
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
@ -322,16 +322,14 @@ class ModelLoadBalancingService:
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.where(
current_load_balancing_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
)
).all()
# id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}

View File

@ -1,5 +1,7 @@
from typing import Optional
from sqlalchemy import select
from constants.languages import languages
from extensions.ext_database import db
from models.model import App, RecommendedApp
@ -31,18 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
:param language: language
:return:
"""
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
).all()
if len(recommended_apps) == 0:
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
).all()
categories = set()
recommended_apps_result = []

View File

@ -2,7 +2,7 @@ import uuid
from typing import Optional
from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
@ -29,35 +29,30 @@ class TagService:
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tags = (
db.session.query(Tag)
.where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
tags = db.session.scalars(
select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
if not tags:
return []
tag_ids = [tag.id for tag in tags]
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tag_bindings = (
db.session.query(TagBinding.target_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all()
)
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
tag_bindings = db.session.scalars(
select(TagBinding.target_id).where(
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
)
).all()
return tag_bindings
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
if not tag_type or not tag_name:
return []
tags = (
db.session.query(Tag)
.where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
tags = list(
db.session.scalars(
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
)
if not tags:
return []
@ -117,7 +112,7 @@ class TagService:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all()
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping
from typing import Any, cast
from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
@ -443,9 +444,7 @@ class ApiToolManageService:
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
)
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
result: list[ToolProviderApiEntity] = []

View File

@ -3,7 +3,7 @@ from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_
from sqlalchemy import or_, select
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
@ -186,7 +186,9 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
tools: list[WorkflowToolProviderController] = []
for provider in db_tools: