refactor: migrate session.query to select API in small task files (#34617)

This commit is contained in:
Renzo
2026-04-06 23:13:22 -05:00
committed by GitHub
parent b55bef4438
commit ac8bd12609
6 changed files with 28 additions and 12 deletions

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from core.db.session_factory import session_factory
@ -35,7 +36,9 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
with session_factory.create_session() as session:
# get app info
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
app = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").limit(1)
)
if app:
try:
@ -53,8 +56,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
)
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
app_annotation_setting = session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if app_annotation_setting:

View File

@ -24,14 +24,16 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
start_at = time.perf_counter()
# get app info
with session_factory.create_session() as session:
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
app = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").limit(1)
)
annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
return
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
app_annotation_setting = session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if not app_annotation_setting:

View File

@ -36,7 +36,9 @@ def enable_annotation_reply_task(
start_at = time.perf_counter()
# get app info
with session_factory.create_session() as session:
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
app = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").limit(1)
)
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
@ -51,8 +53,8 @@ def enable_annotation_reply_task(
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION
)
annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
annotation_setting = session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if annotation_setting:
if dataset_collection_binding.id != annotation_setting.collection_binding_id:

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
@ -29,7 +30,7 @@ def enable_segment_to_index_task(segment_id: str):
start_at = time.perf_counter()
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1))
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
@ -24,7 +25,9 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
start_at = time.perf_counter()
with session_factory.create_session() as session:
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping
from typing import Any
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@ -22,7 +23,11 @@ def _now_ts() -> int:
def _load_subscription(session: Session, tenant_id: str, subscription_id: str) -> TriggerSubscription | None:
return session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
return session.scalar(
select(TriggerSubscription)
.where(TriggerSubscription.tenant_id == tenant_id, TriggerSubscription.id == subscription_id)
.limit(1)
)
def _refresh_oauth_if_expired(tenant_id: str, subscription: TriggerSubscription, now: int) -> None: