mirror of
https://github.com/langgenius/dify.git
synced 2026-06-08 09:27:39 +08:00
refactor: migrate session.query to select API in retrieval_service (#34638)
This commit is contained in:
@ -240,7 +240,7 @@ class RetrievalService:
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
||||
with Session(db.engine) as session:
|
||||
return session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
@ -573,15 +573,13 @@ class RetrievalService:
|
||||
|
||||
# Batch query summaries for segments retrieved via summary (only enabled summaries)
|
||||
if summary_segment_ids:
|
||||
summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
summaries = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only retrieve enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
for summary in summaries:
|
||||
if summary.summary_content:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
@ -851,12 +849,12 @@ class RetrievalService:
|
||||
def get_segment_attachment_info(
|
||||
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
||||
) -> SegmentAttachmentResult | None:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == attachment_id).limit(1))
|
||||
if upload_file:
|
||||
attachment_binding = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
attachment_binding = session.scalar(
|
||||
select(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if attachment_binding:
|
||||
attachment_info: AttachmentInfoDict = {
|
||||
@ -875,14 +873,12 @@ class RetrievalService:
|
||||
cls, attachment_ids: list[str], session: Session
|
||||
) -> list[SegmentAttachmentInfoResult]:
|
||||
attachment_infos: list[SegmentAttachmentInfoResult] = []
|
||||
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
|
||||
if upload_files:
|
||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||
attachment_bindings = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
|
||||
.all()
|
||||
)
|
||||
attachment_bindings = session.scalars(
|
||||
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
|
||||
).all()
|
||||
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
|
||||
|
||||
if attachment_bindings:
|
||||
|
||||
Reference in New Issue
Block a user