diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 52b000c171..8201162c0c 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -41,7 +41,7 @@ from fields.document_fields import ( from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile -from models.dataset import DocumentPipelineExecutionLog, DocumentSegmentSummary +from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService @@ -329,66 +329,14 @@ class DatasetDocumentListApi(Resource): # Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled) summary_status_map: dict[str, str | None] = {} if has_summary_index and document_ids_need_summary: - # Get all segments for these documents (excluding qa_model and re_segment) - segments = ( - db.session.query(DocumentSegment.id, DocumentSegment.document_id) - .where( - DocumentSegment.document_id.in_(document_ids_need_summary), - DocumentSegment.status != "re_segment", - DocumentSegment.tenant_id == current_tenant_id, - ) - .all() + from services.summary_index_service import SummaryIndexService + + summary_status_map = SummaryIndexService.get_documents_summary_index_status( + document_ids=document_ids_need_summary, + dataset_id=dataset_id, + tenant_id=current_tenant_id, ) - # Group segments by document_id - document_segments_map: dict[str, list[str]] = {} - for segment in segments: - doc_id = str(segment.document_id) - if doc_id not in document_segments_map: - document_segments_map[doc_id] = [] - document_segments_map[doc_id].append(segment.id) - - # Get all summary records for these segments - all_segment_ids = [seg.id for seg in segments] - summaries = {} - if all_segment_ids: - summary_records = ( - db.session.query(DocumentSegmentSummary) - .where( - DocumentSegmentSummary.chunk_id.in_(all_segment_ids), - DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only count enabled summaries - ) - .all() - ) - summaries = {summary.chunk_id: summary.status for summary in summary_records} - - # Calculate summary_index_status for each document - for doc_id in document_ids_need_summary: - segment_ids = document_segments_map.get(doc_id, []) - if not segment_ids: - # No segments, status is None (not started) - summary_status_map[doc_id] = None - continue - - # Check if there are any "not_started" or "generating" status summaries - # Only check enabled=True summaries (already filtered in query) - # If segment has no summary record (summaries.get returns None), - # it means the summary is disabled (enabled=False) or not created yet, ignore it - has_pending_summaries = any( - summaries.get(segment_id) is not None # Ensure summary exists (enabled=True) - and summaries[segment_id] in ("not_started", "generating") - for segment_id in segment_ids - ) - - if has_pending_summaries: - # Task is still running (not started or generating) - summary_status_map[doc_id] = "SUMMARIZING" - else: - # All enabled=True summaries are "completed" or "error", task finished - # Or no enabled=True summaries exist (all disabled) - summary_status_map[doc_id] = None - # Add summary_index_status to each document for document in documents: if has_summary_index and document.need_summary is True: @@ -1491,15 +1439,12 @@ class DocumentSummaryStatusApi(DocumentResource): segment_ids = [segment.id for segment in segments] summaries = [] if segment_ids: - summaries = ( - db.session.query(DocumentSegmentSummary) - .filter( - DocumentSegmentSummary.document_id == document_id, - DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.chunk_id.in_(segment_ids), - DocumentSegmentSummary.enabled == True, # Only return enabled summaries - ) - .all() + from services.summary_index_service import SummaryIndexService + + summaries = SummaryIndexService.get_document_summaries( + document_id=document_id, + dataset_id=dataset_id, + segment_ids=segment_ids, ) # Create a mapping of chunk_id to summary diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index c88cf1f71d..228de0914f 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required -from models.dataset import ChildChunk, DocumentSegment, DocumentSegmentSummary +from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile from services.dataset_service import DatasetService, DocumentService, SegmentService from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs @@ -43,17 +43,11 @@ from tasks.batch_create_segment_to_index_task import batch_create_segment_to_ind def _get_segment_with_summary(segment, dataset_id): """Helper function to marshal segment and add summary information.""" + from services.summary_index_service import SummaryIndexService + segment_dict = dict(marshal(segment, segment_fields)) # Query summary for this segment (only enabled summaries) - summary = ( - db.session.query(DocumentSegmentSummary) - .where( - DocumentSegmentSummary.chunk_id == segment.id, - DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries - ) - .first() - ) + summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) segment_dict["summary"] = summary.summary_content if summary else None return segment_dict @@ -203,17 +197,12 @@ class DatasetDocumentSegmentListApi(Resource): segment_ids = [segment.id for segment in segments.items] summaries = {} if segment_ids: - summary_records = ( - db.session.query(DocumentSegmentSummary) - .where( - DocumentSegmentSummary.chunk_id.in_(segment_ids), - DocumentSegmentSummary.dataset_id == dataset_id, - ) - .all() - ) - # Only include enabled summaries + from services.summary_index_service import SummaryIndexService + + summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) + # Only include enabled summaries (already filtered by service) summaries = { - summary.chunk_id: summary.summary_content for summary in summary_records if summary.enabled is True + chunk_id: summary.summary_content for chunk_id, summary in summary_records.items() } # Add summary to each segment diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 6cda62149c..b327c0d49e 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -45,6 +45,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( Segmentation, ) from services.file_service import FileService +from services.summary_index_service import SummaryIndexService class DocumentTextCreatePayload(BaseModel): @@ -518,68 +519,12 @@ class DocumentListApi(DatasetApiResource): # Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled) summary_status_map: dict[str, str | None] = {} if has_summary_index and document_ids_need_summary: - # Get all segments for these documents (excluding qa_model and re_segment) - segments = ( - db.session.query(DocumentSegment.id, DocumentSegment.document_id) - .where( - DocumentSegment.document_id.in_(document_ids_need_summary), - DocumentSegment.status != "re_segment", - DocumentSegment.tenant_id == tenant_id, - ) - .all() + summary_status_map = SummaryIndexService.get_documents_summary_index_status( + document_ids=document_ids_need_summary, + dataset_id=dataset_id, + tenant_id=tenant_id, ) - # Group segments by document_id - document_segments_map: dict[str, list[str]] = {} - for segment in segments: - doc_id = str(segment.document_id) - if doc_id not in document_segments_map: - document_segments_map[doc_id] = [] - document_segments_map[doc_id].append(segment.id) - - # Get all summary records for these segments - all_segment_ids = [seg.id for seg in segments] - summaries = {} - if all_segment_ids: - from models.dataset import DocumentSegmentSummary - - summary_records = ( - db.session.query(DocumentSegmentSummary) - .where( - DocumentSegmentSummary.chunk_id.in_(all_segment_ids), - DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only count enabled summaries - ) - .all() - ) - summaries = {summary.chunk_id: summary.status for summary in summary_records} - - # Calculate summary_index_status for each document - for doc_id in document_ids_need_summary: - segment_ids = document_segments_map.get(doc_id, []) - if not segment_ids: - # No segments, status is None (not started) - summary_status_map[doc_id] = None # type: ignore[assignment] - continue - - # Check if there are any "not_started" or "generating" status summaries - # Only check enabled=True summaries (already filtered in query) - # If segment has no summary record (summaries.get returns None), - # it means the summary is disabled (enabled=False) or not created yet, ignore it - has_pending_summaries = any( - summaries.get(segment_id) is not None # Ensure summary exists (enabled=True) - and summaries[segment_id] in ("not_started", "generating") - for segment_id in segment_ids - ) - - if has_pending_summaries: - # Task is still running (not started or generating) - summary_status_map[doc_id] = "SUMMARIZING" - else: - # All enabled=True summaries are "completed" or "error", task finished - # Or no enabled=True summaries exist (all disabled) - summary_status_map[doc_id] = None # type: ignore[assignment] - # Add summary_index_status to each document for document in documents: if has_summary_index and document.need_summary is True: @@ -697,46 +642,11 @@ class DocumentApi(DatasetApiResource): summary_index_status = None has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True if has_summary_index and document.need_summary is True: - # Get all segments for this document (excluding qa_model and re_segment) - segments = ( - db.session.query(DocumentSegment.id) - .where( - DocumentSegment.document_id == document_id, - DocumentSegment.status != "re_segment", - DocumentSegment.tenant_id == tenant_id, - ) - .all() + summary_index_status = SummaryIndexService.get_document_summary_index_status( + document_id=document_id, + dataset_id=dataset_id, + tenant_id=tenant_id, ) - segment_ids = [seg.id for seg in segments] - - if segment_ids: - from models.dataset import DocumentSegmentSummary - - # Get all summary records for these segments - summary_records = ( - db.session.query(DocumentSegmentSummary) - .where( - DocumentSegmentSummary.chunk_id.in_(segment_ids), - DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only count enabled summaries - ) - .all() - ) - summaries = {summary.chunk_id: summary.status for summary in summary_records} - - # Check if there are any "not_started" or "generating" status summaries - has_pending_summaries = any( - summaries.get(segment_id) is not None # Ensure summary exists (enabled=True) - and summaries[segment_id] in ("not_started", "generating") - for segment_id in segment_ids - ) - - if has_pending_summaries: - summary_index_status = "SUMMARIZING" - else: - summary_index_status = None - else: - summary_index_status = None if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 1701cf2b9c..f423fdf7ef 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -5,13 +5,13 @@ import time import uuid from datetime import UTC, datetime +from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document -from extensions.ext_database import db from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument @@ -78,37 +78,38 @@ class SummaryIndexService: Returns: Created or updated DocumentSegmentSummary instance """ - # Check if summary record already exists - existing_summary = ( - db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() - ) - - if existing_summary: - # Update existing record - existing_summary.summary_content = summary_content - existing_summary.status = status - existing_summary.error = None # type: ignore[assignment] # Clear any previous errors - # Re-enable if it was disabled - if not existing_summary.enabled: - existing_summary.enabled = True - existing_summary.disabled_at = None - existing_summary.disabled_by = None - db.session.add(existing_summary) - db.session.flush() - return existing_summary - else: - # Create new record (enabled by default) - summary_record = DocumentSegmentSummary( - dataset_id=dataset.id, - document_id=segment.document_id, - chunk_id=segment.id, - summary_content=summary_content, - status=status, - enabled=True, # Explicitly set enabled to True + with session_factory.create_session() as session: + # Check if summary record already exists + existing_summary = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() ) - db.session.add(summary_record) - db.session.flush() - return summary_record + + if existing_summary: + # Update existing record + existing_summary.summary_content = summary_content + existing_summary.status = status + existing_summary.error = None # type: ignore[assignment] # Clear any previous errors + # Re-enable if it was disabled + if not existing_summary.enabled: + existing_summary.enabled = True + existing_summary.disabled_at = None + existing_summary.disabled_by = None + session.add(existing_summary) + session.flush() + return existing_summary + else: + # Create new record (enabled by default) + summary_record = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=summary_content, + status=status, + enabled=True, # Explicitly set enabled to True + ) + session.add(summary_record) + session.flush() + return summary_record @staticmethod def vectorize_summary( @@ -131,6 +132,9 @@ class SummaryIndexService: ) return + # Get summary_record_id for later session queries + summary_record_id = summary_record.id + # Reuse existing index_node_id if available (like segment does), otherwise generate new one old_summary_node_id = summary_record.summary_index_node_id if old_summary_node_id: @@ -141,7 +145,8 @@ class SummaryIndexService: summary_index_node_id = str(uuid.uuid4()) # Always regenerate hash (in case summary content changed) - summary_hash = helper.generate_text_hash(summary_record.summary_content) + summary_content = summary_record.summary_content + summary_hash = helper.generate_text_hash(summary_content) # Delete old vector only if we're reusing the same index_node_id (to overwrite) # If index_node_id changed, the old vector should have been deleted elsewhere @@ -167,14 +172,14 @@ class SummaryIndexService: model=dataset.embedding_model, ) if embedding_model: - tokens_list = embedding_model.get_text_embedding_num_tokens([summary_record.summary_content]) + tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content]) embedding_tokens = tokens_list[0] if tokens_list else 0 except Exception as e: logger.warning("Failed to calculate embedding tokens for summary: %s", str(e)) # Create document with summary content and metadata summary_document = Document( - page_content=summary_record.summary_content, + page_content=summary_content, metadata={ "doc_id": summary_index_node_id, "doc_hash": summary_hash, @@ -207,14 +212,28 @@ class SummaryIndexService: ) # Success - update summary record with index node info - summary_record.summary_index_node_id = summary_index_node_id - summary_record.summary_index_node_hash = summary_hash - summary_record.tokens = embedding_tokens # Save embedding tokens - summary_record.status = "completed" - # Explicitly update updated_at to ensure it's refreshed even if other fields haven't changed - summary_record.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.add(summary_record) - db.session.flush() + with session_factory.create_session() as session: + # Refresh the summary record in the new session + summary_record_in_session = ( + session.query(DocumentSegmentSummary) + .filter_by(id=summary_record_id) + .first() + ) + if summary_record_in_session: + summary_record_in_session.summary_index_node_id = summary_index_node_id + summary_record_in_session.summary_index_node_hash = summary_hash + summary_record_in_session.tokens = embedding_tokens # Save embedding tokens + summary_record_in_session.status = "completed" + # Explicitly update updated_at to ensure it's refreshed even if other fields haven't changed + summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) + session.add(summary_record_in_session) + session.commit() + # Update the original object for consistency + summary_record.summary_index_node_id = summary_index_node_id + summary_record.summary_index_node_hash = summary_hash + summary_record.tokens = embedding_tokens + summary_record.status = "completed" + summary_record.updated_at = summary_record_in_session.updated_at # Success, exit function return @@ -256,12 +275,23 @@ class SummaryIndexService: str(e), exc_info=True, ) - summary_record.status = "error" - summary_record.error = f"Vectorization failed: {str(e)}" - # Explicitly update updated_at to ensure it's refreshed - summary_record.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.add(summary_record) - db.session.flush() + # Update error status in session + with session_factory.create_session() as session: + summary_record_in_session = ( + session.query(DocumentSegmentSummary) + .filter_by(id=summary_record_id) + .first() + ) + if summary_record_in_session: + summary_record_in_session.status = "error" + summary_record_in_session.error = f"Vectorization failed: {str(e)}" + summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) + session.add(summary_record_in_session) + session.commit() + # Update the original object for consistency + summary_record.status = "error" + summary_record.error = summary_record_in_session.error + summary_record.updated_at = summary_record_in_session.updated_at raise @staticmethod @@ -283,40 +313,41 @@ class SummaryIndexService: if not segment_ids: return - # Query existing summary records - existing_summaries = ( - db.session.query(DocumentSegmentSummary) - .filter( - DocumentSegmentSummary.chunk_id.in_(segment_ids), - DocumentSegmentSummary.dataset_id == dataset.id, - ) - .all() - ) - existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries} - - # Create or update records - for segment in segments: - existing_summary = existing_summary_map.get(segment.id) - if existing_summary: - # Update existing record - existing_summary.status = status - existing_summary.error = None # type: ignore[assignment] # Clear any previous errors - if not existing_summary.enabled: - existing_summary.enabled = True - existing_summary.disabled_at = None - existing_summary.disabled_by = None - db.session.add(existing_summary) - else: - # Create new record - summary_record = DocumentSegmentSummary( - dataset_id=dataset.id, - document_id=segment.document_id, - chunk_id=segment.id, - summary_content=None, # Will be filled later - status=status, - enabled=True, + with session_factory.create_session() as session: + # Query existing summary records + existing_summaries = ( + session.query(DocumentSegmentSummary) + .filter( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset.id, ) - db.session.add(summary_record) + .all() + ) + existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries} + + # Create or update records + for segment in segments: + existing_summary = existing_summary_map.get(segment.id) + if existing_summary: + # Update existing record + existing_summary.status = status + existing_summary.error = None # type: ignore[assignment] # Clear any previous errors + if not existing_summary.enabled: + existing_summary.enabled = True + existing_summary.disabled_at = None + existing_summary.disabled_by = None + session.add(existing_summary) + else: + # Create new record + summary_record = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=None, # Will be filled later + status=status, + enabled=True, + ) + session.add(summary_record) @staticmethod def update_summary_record_error( @@ -332,17 +363,18 @@ class SummaryIndexService: dataset: Dataset containing the segment error: Error message """ - summary_record = ( - db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() - ) + with session_factory.create_session() as session: + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) - if summary_record: - summary_record.status = "error" - summary_record.error = error - db.session.add(summary_record) - db.session.flush() - else: - logger.warning("Summary record not found for segment %s when updating error", segment.id) + if summary_record: + summary_record.status = "error" + summary_record.error = error + session.add(summary_record) + session.commit() + else: + logger.warning("Summary record not found for segment %s when updating error", segment.id) @staticmethod def generate_and_vectorize_summary( @@ -365,61 +397,76 @@ class SummaryIndexService: Raises: ValueError: If summary generation fails """ - # Get existing summary record (should have been created by batch_create_summary_records) - summary_record = ( - db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() - ) - - if not summary_record: - # If not found (shouldn't happen), create one - logger.warning("Summary record not found for segment %s, creating one", segment.id) - summary_record = SummaryIndexService.create_summary_record( - segment, dataset, summary_content="", status="generating" - ) - if summary_record: - summary_record.error = None # type: ignore[assignment] - - try: - # Update status to "generating" - summary_record.status = "generating" - summary_record.error = None # type: ignore[assignment] - db.session.add(summary_record) - db.session.flush() - - # Generate summary (returns summary_content and llm_usage) - summary_content, llm_usage = SummaryIndexService.generate_summary_for_segment( - segment, dataset, summary_index_setting - ) - - # Update summary content - summary_record.summary_content = summary_content - - # Log LLM usage for summary generation - if llm_usage and llm_usage.total_tokens > 0: - logger.info( - "Summary generation for segment %s used %s tokens (prompt: %s, completion: %s)", - segment.id, - llm_usage.total_tokens, - llm_usage.prompt_tokens, - llm_usage.completion_tokens, + with session_factory.create_session() as session: + try: + # Get or refresh summary record in this session + summary_record_in_session = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() ) - # Vectorize summary (will delete old vector if exists before creating new one) - SummaryIndexService.vectorize_summary(summary_record, segment, dataset) + if not summary_record_in_session: + # If not found, create one + logger.warning("Summary record not found for segment %s, creating one", segment.id) + summary_record_in_session = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content="", + status="generating", + enabled=True, + ) + session.add(summary_record_in_session) + session.flush() - # Status will be updated to "completed" by vectorize_summary on success - db.session.commit() - logger.info("Successfully generated and vectorized summary for segment %s", segment.id) - return summary_record + # Update status to "generating" + summary_record_in_session.status = "generating" + summary_record_in_session.error = None # type: ignore[assignment] + session.add(summary_record_in_session) + session.flush() - except Exception as e: - logger.exception("Failed to generate summary for segment %s", segment.id) - # Update summary record with error status - summary_record.status = "error" - summary_record.error = str(e) - db.session.add(summary_record) - db.session.commit() - raise + # Generate summary (returns summary_content and llm_usage) + summary_content, llm_usage = SummaryIndexService.generate_summary_for_segment( + segment, dataset, summary_index_setting + ) + + # Update summary content + summary_record_in_session.summary_content = summary_content + + # Log LLM usage for summary generation + if llm_usage and llm_usage.total_tokens > 0: + logger.info( + "Summary generation for segment %s used %s tokens (prompt: %s, completion: %s)", + segment.id, + llm_usage.total_tokens, + llm_usage.prompt_tokens, + llm_usage.completion_tokens, + ) + + # Vectorize summary (will delete old vector if exists before creating new one) + # Pass the session-managed record to vectorize_summary + SummaryIndexService.vectorize_summary(summary_record_in_session, segment, dataset) + + # Status will be updated to "completed" by vectorize_summary on success + session.commit() + logger.info("Successfully generated and vectorized summary for segment %s", segment.id) + return summary_record_in_session + + except Exception as e: + logger.exception("Failed to generate summary for segment %s", segment.id) + # Update summary record with error status + summary_record_in_session = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + if summary_record_in_session: + summary_record_in_session.status = "error" + summary_record_in_session.error = str(e) + session.add(summary_record_in_session) + session.commit() + raise @staticmethod def generate_summaries_for_document( @@ -468,31 +515,32 @@ class SummaryIndexService: only_parent_chunks, ) - # Query segments (only enabled segments) - query = db.session.query(DocumentSegment).filter_by( - dataset_id=dataset.id, - document_id=document.id, - status="completed", - enabled=True, # Only generate summaries for enabled segments - ) + with session_factory.create_session() as session: + # Query segments (only enabled segments) + query = session.query(DocumentSegment).filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, # Only generate summaries for enabled segments + ) - if segment_ids: - query = query.filter(DocumentSegment.id.in_(segment_ids)) + if segment_ids: + query = query.filter(DocumentSegment.id.in_(segment_ids)) - segments = query.all() + segments = query.all() - if not segments: - logger.info("No segments found for document %s", document.id) - return [] + if not segments: + logger.info("No segments found for document %s", document.id) + return [] - # Batch create summary records with "not_started" status before processing - # This ensures all records exist upfront, allowing status tracking - SummaryIndexService.batch_create_summary_records( - segments=segments, - dataset=dataset, - status="not_started", - ) - db.session.commit() # Commit initial records + # Batch create summary records with "not_started" status before processing + # This ensures all records exist upfront, allowing status tracking + SummaryIndexService.batch_create_summary_records( + segments=segments, + dataset=dataset, + status="not_started", + ) + session.commit() # Commit initial records summary_records = [] @@ -524,8 +572,6 @@ class SummaryIndexService: # Continue with other segments continue - db.session.commit() # Commit any remaining changes - logger.info( "Completed summary generation for document %s: %s summaries generated and vectorized", document.id, @@ -550,46 +596,47 @@ class SummaryIndexService: """ from libs.datetime_utils import naive_utc_now - query = db.session.query(DocumentSegmentSummary).filter_by( - dataset_id=dataset.id, - enabled=True, # Only disable enabled summaries - ) + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by( + dataset_id=dataset.id, + enabled=True, # Only disable enabled summaries + ) - if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = query.all() - if not summaries: - return + if not summaries: + return - logger.info( - "Disabling %s summary records for dataset %s, segment_ids: %s", - len(summaries), - dataset.id, - len(segment_ids) if segment_ids else "all", - ) + logger.info( + "Disabling %s summary records for dataset %s, segment_ids: %s", + len(summaries), + dataset.id, + len(segment_ids) if segment_ids else "all", + ) - # Remove from vector database (but keep records) - if dataset.indexing_technique == "high_quality": - summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] - if summary_node_ids: - try: - vector = Vector(dataset) - vector.delete_by_ids(summary_node_ids) - except Exception as e: - logger.warning("Failed to remove summary vectors: %s", str(e)) + # Remove from vector database (but keep records) + if dataset.indexing_technique == "high_quality": + summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] + if summary_node_ids: + try: + vector = Vector(dataset) + vector.delete_by_ids(summary_node_ids) + except Exception as e: + logger.warning("Failed to remove summary vectors: %s", str(e)) - # Disable summary records (don't delete) - now = naive_utc_now() - for summary in summaries: - summary.enabled = False - summary.disabled_at = now - summary.disabled_by = disabled_by - db.session.add(summary) + # Disable summary records (don't delete) + now = naive_utc_now() + for summary in summaries: + summary.enabled = False + summary.disabled_at = now + summary.disabled_by = disabled_by + session.add(summary) - db.session.commit() - logger.info("Disabled %s summary records for dataset %s", len(summaries), dataset.id) + session.commit() + logger.info("Disabled %s summary records for dataset %s", len(summaries), dataset.id) @staticmethod def enable_summaries_for_segments( @@ -612,63 +659,65 @@ class SummaryIndexService: if dataset.indexing_technique != "high_quality": return - query = db.session.query(DocumentSegmentSummary).filter_by( - dataset_id=dataset.id, - enabled=False, # Only enable disabled summaries - ) - - if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - - summaries = query.all() - - if not summaries: - return - - logger.info( - "Enabling %s summary records for dataset %s, segment_ids: %s", - len(summaries), - dataset.id, - len(segment_ids) if segment_ids else "all", - ) - - # Re-vectorize and re-add to vector database - enabled_count = 0 - for summary in summaries: - # Get the original segment - segment = ( - db.session.query(DocumentSegment) - .filter_by( - id=summary.chunk_id, - dataset_id=dataset.id, - ) - .first() + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by( + dataset_id=dataset.id, + enabled=False, # Only enable disabled summaries ) - # Summary.enabled stays in sync with chunk.enabled, only enable summary if the associated chunk is enabled. - if not segment or not segment.enabled or segment.status != "completed": - continue + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - if not summary.summary_content: - continue + summaries = query.all() - try: - # Re-vectorize summary - SummaryIndexService.vectorize_summary(summary, segment, dataset) + if not summaries: + return - # Enable summary record - summary.enabled = True - summary.disabled_at = None - summary.disabled_by = None - db.session.add(summary) - enabled_count += 1 - except Exception: - logger.exception("Failed to re-vectorize summary %s", summary.id) - # Keep it disabled if vectorization fails - continue + logger.info( + "Enabling %s summary records for dataset %s, segment_ids: %s", + len(summaries), + dataset.id, + len(segment_ids) if segment_ids else "all", + ) - db.session.commit() - logger.info("Enabled %s summary records for dataset %s", enabled_count, dataset.id) + # Re-vectorize and re-add to vector database + enabled_count = 0 + for summary in summaries: + # Get the original segment + segment = ( + session.query(DocumentSegment) + .filter_by( + id=summary.chunk_id, + dataset_id=dataset.id, + ) + .first() + ) + + # Summary.enabled stays in sync with chunk.enabled, + # only enable summary if the associated chunk is enabled. + if not segment or not segment.enabled or segment.status != "completed": + continue + + if not summary.summary_content: + continue + + try: + # Re-vectorize summary + SummaryIndexService.vectorize_summary(summary, segment, dataset) + + # Enable summary record + summary.enabled = True + summary.disabled_at = None + summary.disabled_by = None + session.add(summary) + enabled_count += 1 + except Exception: + logger.exception("Failed to re-vectorize summary %s", summary.id) + # Keep it disabled if vectorization fails + continue + + session.commit() + logger.info("Enabled %s summary records for dataset %s", enabled_count, dataset.id) @staticmethod def delete_summaries_for_segments( @@ -683,29 +732,30 @@ class SummaryIndexService: dataset: Dataset containing the segments segment_ids: List of segment IDs to delete summaries for. If None, delete all. """ - query = db.session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id) + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id) - if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = query.all() - if not summaries: - return + if not summaries: + return - # Delete from vector database - if dataset.indexing_technique == "high_quality": - summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] - if summary_node_ids: - vector = Vector(dataset) - vector.delete_by_ids(summary_node_ids) + # Delete from vector database + if dataset.indexing_technique == "high_quality": + summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] + if summary_node_ids: + vector = Vector(dataset) + vector.delete_by_ids(summary_node_ids) - # Delete summary records - for summary in summaries: - db.session.delete(summary) + # Delete summary records + for summary in summaries: + session.delete(summary) - db.session.commit() - logger.info("Deleted %s summary records for dataset %s", len(summaries), dataset.id) + session.commit() + logger.info("Deleted %s summary records for dataset %s", len(summaries), dataset.id) @staticmethod def update_summary_for_segment( @@ -736,85 +786,280 @@ class SummaryIndexService: if segment.document and segment.document.doc_form == "qa_model": return None - try: - # Check if summary_content is empty (whitespace-only strings are considered empty) - if not summary_content or not summary_content.strip(): - # If summary is empty, only delete existing summary vector and record + with session_factory.create_session() as session: + try: + # Check if summary_content is empty (whitespace-only strings are considered empty) + if not summary_content or not summary_content.strip(): + # If summary is empty, only delete existing summary vector and record + summary_record = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if summary_record: + # Delete old vector if exists + old_summary_node_id = summary_record.summary_index_node_id + if old_summary_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Delete summary record since summary is empty + session.delete(summary_record) + session.commit() + logger.info("Deleted summary for segment %s (empty content provided)", segment.id) + return None + else: + # No existing summary record, nothing to do + logger.info("No summary record found for segment %s, nothing to delete", segment.id) + return None + + # Find existing summary record summary_record = ( - db.session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() ) if summary_record: - # Delete old vector if exists + # Update existing summary old_summary_node_id = summary_record.summary_index_node_id + + # Update summary content + summary_record.summary_content = summary_content + summary_record.status = "generating" + session.add(summary_record) + session.flush() + + # Delete old vector if exists if old_summary_node_id: - try: - vector = Vector(dataset) - vector.delete_by_ids([old_summary_node_id]) - except Exception as e: - logger.warning( - "Failed to delete old summary vector for segment %s: %s", - segment.id, - str(e), - ) + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) - # Delete summary record since summary is empty - db.session.delete(summary_record) - db.session.commit() - logger.info("Deleted summary for segment %s (empty content provided)", segment.id) - return None + # Re-vectorize summary + SummaryIndexService.vectorize_summary(summary_record, segment, dataset) + + session.commit() + logger.info("Successfully updated and re-vectorized summary for segment %s", segment.id) + return summary_record else: - # No existing summary record, nothing to do - logger.info("No summary record found for segment %s, nothing to delete", segment.id) - return None + # Create new summary record if doesn't exist + summary_record = SummaryIndexService.create_summary_record( + segment, dataset, summary_content, status="generating" + ) + SummaryIndexService.vectorize_summary(summary_record, segment, dataset) + session.commit() + logger.info("Successfully created and vectorized summary for segment %s", segment.id) + return summary_record - # Find existing summary record - summary_record = ( - db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() - ) - - if summary_record: - # Update existing summary - old_summary_node_id = summary_record.summary_index_node_id - - # Update summary content - summary_record.summary_content = summary_content - summary_record.status = "generating" - db.session.add(summary_record) - db.session.flush() - - # Delete old vector if exists - if old_summary_node_id: - vector = Vector(dataset) - vector.delete_by_ids([old_summary_node_id]) - - # Re-vectorize summary - SummaryIndexService.vectorize_summary(summary_record, segment, dataset) - - db.session.commit() - logger.info("Successfully updated and re-vectorized summary for segment %s", segment.id) - return summary_record - else: - # Create new summary record if doesn't exist - summary_record = SummaryIndexService.create_summary_record( - segment, dataset, summary_content, status="generating" + except Exception as e: + logger.exception("Failed to update summary for segment %s", segment.id) + # Update summary record with error status if it exists + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() ) - SummaryIndexService.vectorize_summary(summary_record, segment, dataset) - db.session.commit() - logger.info("Successfully created and vectorized summary for segment %s", segment.id) - return summary_record + if summary_record: + summary_record.status = "error" + summary_record.error = str(e) + session.add(summary_record) + session.commit() + raise - except Exception as e: - logger.exception("Failed to update summary for segment %s", segment.id) - # Update summary record with error status if it exists - summary_record = ( - db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + @staticmethod + def get_segment_summary(segment_id: str, dataset_id: str) -> DocumentSegmentSummary | None: + """ + Get summary for a single segment. + + Args: + segment_id: Segment ID (chunk_id) + dataset_id: Dataset ID + + Returns: + DocumentSegmentSummary instance if found, None otherwise + """ + with session_factory.create_session() as session: + return ( + session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment_id, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + .first() ) - if summary_record: - summary_record.status = "error" - summary_record.error = str(e) - db.session.add(summary_record) - db.session.commit() - raise + + @staticmethod + def get_segments_summaries(segment_ids: list[str], dataset_id: str) -> dict[str, DocumentSegmentSummary]: + """ + Get summaries for multiple segments. + + Args: + segment_ids: List of segment IDs (chunk_ids) + dataset_id: Dataset ID + + Returns: + Dictionary mapping segment_id to DocumentSegmentSummary (only enabled summaries) + """ + if not segment_ids: + return {} + + with session_factory.create_session() as session: + summary_records = ( + session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + .all() + ) + + return {summary.chunk_id: summary for summary in summary_records} + + @staticmethod + def get_document_summaries( + document_id: str, dataset_id: str, segment_ids: list[str] | None = None + ) -> list[DocumentSegmentSummary]: + """ + Get all summary records for a document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + segment_ids: Optional list of segment IDs to filter by + + Returns: + List of DocumentSegmentSummary instances (only enabled summaries) + """ + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter( + DocumentSegmentSummary.document_id == document_id, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + return query.all() + + @staticmethod + def get_document_summary_index_status( + document_id: str, dataset_id: str, tenant_id: str + ) -> str | None: + """ + Get summary_index_status for a single document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + tenant_id: Tenant ID + + Returns: + "SUMMARIZING" if there are pending summaries, None otherwise + """ + # Get all segments for this document (excluding qa_model and re_segment) + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment.id) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + .all() + ) + segment_ids = [seg.id for seg in segments] + + if not segment_ids: + return None + + # Get all summary records for these segments + summaries = SummaryIndexService.get_segments_summaries(segment_ids, dataset_id) + summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()} + + # Check if there are any "not_started" or "generating" status summaries + has_pending_summaries = any( + summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) + and summary_status_map[segment_id] in ("not_started", "generating") + for segment_id in segment_ids + ) + + return "SUMMARIZING" if has_pending_summaries else None + + @staticmethod + def get_documents_summary_index_status( + document_ids: list[str], dataset_id: str, tenant_id: str + ) -> dict[str, str | None]: + """ + Get summary_index_status for multiple documents. + + Args: + document_ids: List of document IDs + dataset_id: Dataset ID + tenant_id: Tenant ID + + Returns: + Dictionary mapping document_id to summary_index_status ("SUMMARIZING" or None) + """ + if not document_ids: + return {} + + # Get all segments for these documents (excluding qa_model and re_segment) + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment.id, DocumentSegment.document_id) + .where( + DocumentSegment.document_id.in_(document_ids), + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + .all() + ) + + # Group segments by document_id + document_segments_map: dict[str, list[str]] = {} + for segment in segments: + doc_id = str(segment.document_id) + if doc_id not in document_segments_map: + document_segments_map[doc_id] = [] + document_segments_map[doc_id].append(segment.id) + + # Get all summary records for these segments + all_segment_ids = [seg.id for seg in segments] + summaries = SummaryIndexService.get_segments_summaries(all_segment_ids, dataset_id) + summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()} + + # Calculate summary_index_status for each document + result: dict[str, str | None] = {} + for doc_id in document_ids: + segment_ids = document_segments_map.get(doc_id, []) + if not segment_ids: + # No segments, status is None (not started) + result[doc_id] = None + continue + + # Check if there are any "not_started" or "generating" status summaries + # Only check enabled=True summaries (already filtered in query) + # If segment has no summary record (summary_status_map.get returns None), + # it means the summary is disabled (enabled=False) or not created yet, ignore it + has_pending_summaries = any( + summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) + and summary_status_map[segment_id] in ("not_started", "generating") + for segment_id in segment_ids + ) + + if has_pending_summaries: + # Task is still running (not started or generating) + result[doc_id] = "SUMMARIZING" + else: + # All enabled=True summaries are "completed" or "error", task finished + # Or no enabled=True summaries exist (all disabled) + result[doc_id] = None + + return result