refactor: migrate session.query to select API in retrieval_service (#34638)

This commit is contained in:
Renzo
2026-04-06 23:46:30 -05:00
committed by GitHub
parent 1194957fde
commit 72adb5468c
2 changed files with 41 additions and 47 deletions

View File

@ -240,7 +240,7 @@ class RetrievalService:
@classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
with Session(db.engine) as session:
return session.query(Dataset).where(Dataset.id == dataset_id).first()
return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
@classmethod
def keyword_search(
@ -573,15 +573,13 @@ class RetrievalService:
# Batch query summaries for segments retrieved via summary (only enabled summaries)
if summary_segment_ids:
summaries = (
session.query(DocumentSegmentSummary)
.filter(
summaries = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only retrieve enabled summaries
)
.all()
)
).all()
for summary in summaries:
if summary.summary_content:
segment_summary_map[summary.chunk_id] = summary.summary_content
@ -851,12 +849,12 @@ class RetrievalService:
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> SegmentAttachmentResult | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == attachment_id).limit(1))
if upload_file:
attachment_binding = (
session.query(SegmentAttachmentBinding)
attachment_binding = session.scalar(
select(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
.first()
.limit(1)
)
if attachment_binding:
attachment_info: AttachmentInfoDict = {
@ -875,14 +873,12 @@ class RetrievalService:
cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
.all()
)
attachment_bindings = session.scalars(
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
).all()
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
if attachment_bindings: