mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
Feat/support multimodal embedding (#29115)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
31
api/services/attachment_service.py
Normal file
31
api/services/attachment_service.py
Normal file
@ -0,0 +1,31 @@
|
||||
import base64
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class AttachmentService:
|
||||
_session_maker: sessionmaker
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None):
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_maker = session_factory
|
||||
else:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
|
||||
def get_file_base64(self, file_id: str) -> str:
|
||||
upload_file = (
|
||||
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
blob = storage.load_once(upload_file.key)
|
||||
return base64.b64encode(blob).decode()
|
||||
@ -7,7 +7,7 @@ import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
@ -19,9 +19,10 @@ from configs import dify_config
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
@ -46,6 +47,7 @@ from models.dataset import (
|
||||
DocumentSegment,
|
||||
ExternalKnowledgeBindings,
|
||||
Pipeline,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
@ -363,6 +365,27 @@ class DatasetService:
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
|
||||
@staticmethod
|
||||
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=model,
|
||||
)
|
||||
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
|
||||
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
if model_schema.features and ModelFeature.VISION in model_schema.features:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except LLMBadRequestError:
|
||||
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
|
||||
|
||||
@staticmethod
|
||||
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
|
||||
try:
|
||||
@ -402,13 +425,13 @@ class DatasetService:
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
# check if dataset name is exists
|
||||
|
||||
if DatasetService._has_dataset_same_name(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
name=data.get("name", dataset.name),
|
||||
):
|
||||
raise ValueError("Dataset name already exists")
|
||||
if data.get("name") and data.get("name") != dataset.name:
|
||||
if DatasetService._has_dataset_same_name(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
name=data.get("name", dataset.name),
|
||||
):
|
||||
raise ValueError("Dataset name already exists")
|
||||
|
||||
# Verify user has permission to update this dataset
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
@ -844,6 +867,12 @@ class DatasetService:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_configuration.embedding_model or "",
|
||||
)
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_user.current_tenant_id,
|
||||
knowledge_configuration.embedding_model_provider,
|
||||
knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
@ -880,6 +909,12 @@ class DatasetService:
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_user.current_tenant_id,
|
||||
knowledge_configuration.embedding_model_provider,
|
||||
knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
dataset.indexing_technique = knowledge_configuration.indexing_technique
|
||||
except LLMBadRequestError:
|
||||
@ -937,6 +972,12 @@ class DatasetService:
|
||||
)
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_user.current_tenant_id,
|
||||
knowledge_configuration.embedding_model_provider,
|
||||
knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
@ -2305,6 +2346,7 @@ class DocumentService:
|
||||
embedding_model_provider=knowledge_config.embedding_model_provider,
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||
is_multimodal=knowledge_config.is_multimodal,
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@ -2685,6 +2727,13 @@ class SegmentService:
|
||||
if "content" not in args or not args["content"] or not args["content"].strip():
|
||||
raise ValueError("Content is empty")
|
||||
|
||||
if args.get("attachment_ids"):
|
||||
if not isinstance(args["attachment_ids"], list):
|
||||
raise ValueError("Attachment IDs is invalid")
|
||||
single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT
|
||||
if len(args["attachment_ids"]) > single_chunk_attachment_limit:
|
||||
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
|
||||
|
||||
@classmethod
|
||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
@ -2731,11 +2780,23 @@ class SegmentService:
|
||||
segment_document.word_count += len(args["answer"])
|
||||
segment_document.answer = args["answer"]
|
||||
|
||||
db.session.add(segment_document)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += segment_document.word_count
|
||||
db.session.add(document)
|
||||
db.session.add(segment_document)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += segment_document.word_count
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
if args["attachment_ids"]:
|
||||
for attachment_id in args["attachment_ids"]:
|
||||
binding = SegmentAttachmentBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
segment_id=segment_document.id,
|
||||
attachment_id=attachment_id,
|
||||
)
|
||||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
|
||||
# save vector index
|
||||
@ -2899,7 +2960,7 @@ class SegmentService:
|
||||
document.word_count = max(0, document.word_count + word_count_change)
|
||||
db.session.add(document)
|
||||
# update segment index task
|
||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# regenerate child chunks
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
@ -2926,12 +2987,11 @@ class SegmentService:
|
||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
||||
if processing_rule:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||
if args.enabled or keyword_changed:
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
@ -2976,7 +3036,7 @@ class SegmentService:
|
||||
db.session.add(document)
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# get embedding model instance
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
@ -3002,15 +3062,15 @@ class SegmentService:
|
||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
||||
if processing_rule:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
|
||||
# update multimodel vector index
|
||||
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
|
||||
except Exception as e:
|
||||
logger.exception("update segment index failed")
|
||||
segment.enabled = False
|
||||
@ -3048,7 +3108,9 @@ class SegmentService:
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
|
||||
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
|
||||
delete_segment_from_index_task.delay(
|
||||
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
|
||||
)
|
||||
|
||||
db.session.delete(segment)
|
||||
# update document word count
|
||||
@ -3097,7 +3159,9 @@ class SegmentService:
|
||||
|
||||
# Start async cleanup with both parent and child node IDs
|
||||
if index_node_ids or child_node_ids:
|
||||
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
|
||||
delete_segment_from_index_task.delay(
|
||||
index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids
|
||||
)
|
||||
|
||||
if document.word_count is None:
|
||||
document.word_count = 0
|
||||
|
||||
@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
name: str | None = None
|
||||
is_multimodal: bool = False
|
||||
|
||||
|
||||
class SegmentCreateArgs(BaseModel):
|
||||
content: str | None = None
|
||||
answer: str | None = None
|
||||
keywords: list[str] | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class SegmentUpdateArgs(BaseModel):
|
||||
@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel):
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
enabled: bool | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
|
||||
|
||||
class ChildChunkUpdateArgs(BaseModel):
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
@ -123,6 +124,15 @@ class FileService:
|
||||
|
||||
return file_size <= file_size_limit
|
||||
|
||||
def get_file_base64(self, file_id: str) -> str:
|
||||
upload_file = (
|
||||
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
blob = storage.load_once(upload_file.key)
|
||||
return base64.b64encode(blob).decode()
|
||||
|
||||
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
|
||||
if len(text_name) > 200:
|
||||
text_name = text_name[:200]
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
@ -5,6 +6,7 @@ from typing import Any
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.model_runtime.entities import LLMMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -32,6 +34,7 @@ class HitTestingService:
|
||||
account: Account,
|
||||
retrieval_model: Any, # FIXME drop this any
|
||||
external_retrieval_model: dict,
|
||||
attachment_ids: list | None = None,
|
||||
limit: int = 10,
|
||||
):
|
||||
start = time.perf_counter()
|
||||
@ -41,7 +44,7 @@ class HitTestingService:
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
document_ids_filter = None
|
||||
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
||||
if metadata_filtering_conditions:
|
||||
if metadata_filtering_conditions and query:
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
from core.app.app_config.entities import MetadataFilteringCondition
|
||||
@ -66,6 +69,7 @@ class HitTestingService:
|
||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
attachment_ids=attachment_ids,
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
@ -80,17 +84,24 @@ class HitTestingService:
|
||||
|
||||
end = time.perf_counter()
|
||||
logger.debug("Hit testing retrieve in %s seconds", end - start)
|
||||
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=query,
|
||||
source="hit_testing",
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
dataset_queries = []
|
||||
if query:
|
||||
content = {"content_type": QueryType.TEXT_QUERY, "content": query}
|
||||
dataset_queries.append(content)
|
||||
if attachment_ids:
|
||||
for attachment_id in attachment_ids:
|
||||
content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}
|
||||
dataset_queries.append(content)
|
||||
if dataset_queries:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=json.dumps(dataset_queries),
|
||||
source="hit_testing",
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(query, all_documents)
|
||||
@ -168,9 +179,14 @@ class HitTestingService:
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
query = args["query"]
|
||||
attachment_ids = args["attachment_ids"]
|
||||
|
||||
if not query or len(query) > 250:
|
||||
raise ValueError("Query is required and cannot exceed 250 characters")
|
||||
if not attachment_ids and not query:
|
||||
raise ValueError("Query or attachment_ids is required")
|
||||
if query and len(query) > 250:
|
||||
raise ValueError("Query cannot exceed 250 characters")
|
||||
if attachment_ids and not isinstance(attachment_ids, list):
|
||||
raise ValueError("Attachment_ids must be a list")
|
||||
|
||||
@staticmethod
|
||||
def escape_query_for_search(query: str) -> str:
|
||||
|
||||
@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models import UploadFile
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode
|
||||
|
||||
@ -21,9 +24,10 @@ class VectorService:
|
||||
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
|
||||
):
|
||||
documents: list[Document] = []
|
||||
multimodal_documents: list[AttachmentDocument] = []
|
||||
|
||||
for segment in segments:
|
||||
if doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
||||
if not dataset_document:
|
||||
logger.warning(
|
||||
@ -70,12 +74,29 @@ class VectorService:
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.TEXT,
|
||||
},
|
||||
)
|
||||
documents.append(rag_document)
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_document: AttachmentDocument = AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
multimodal_documents.append(multimodal_document)
|
||||
index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
|
||||
if len(documents) > 0:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
|
||||
index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list)
|
||||
if len(multimodal_documents) > 0:
|
||||
index_processor.load(dataset, [], multimodal_documents, with_keywords=False)
|
||||
|
||||
@classmethod
|
||||
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
|
||||
@ -130,6 +151,7 @@ class VectorService:
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.TEXT,
|
||||
},
|
||||
)
|
||||
# use full doc mode to generate segment's child chunk
|
||||
@ -226,3 +248,92 @@ class VectorService:
|
||||
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.delete_by_ids([child_chunk.index_node_id])
|
||||
|
||||
@classmethod
|
||||
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
attachments = segment.attachments
|
||||
old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else []
|
||||
|
||||
# Check if there's any actual change needed
|
||||
if set(attachment_ids) == set(old_attachment_ids):
|
||||
return
|
||||
|
||||
try:
|
||||
vector = Vector(dataset=dataset)
|
||||
if dataset.is_multimodal:
|
||||
# Delete old vectors if they exist
|
||||
if old_attachment_ids:
|
||||
vector.delete_by_ids(old_attachment_ids)
|
||||
|
||||
# Delete existing segment attachment bindings in one operation
|
||||
db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
if not attachment_ids:
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
# Bulk fetch upload files - only fetch needed fields
|
||||
upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
|
||||
if not upload_file_list:
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
# Create a mapping for quick lookup
|
||||
upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list}
|
||||
|
||||
# Prepare batch operations
|
||||
bindings = []
|
||||
documents = []
|
||||
|
||||
# Create common metadata base to avoid repetition
|
||||
base_metadata = {
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
}
|
||||
|
||||
# Process attachments in the order specified by attachment_ids
|
||||
for attachment_id in attachment_ids:
|
||||
upload_file = upload_file_map.get(attachment_id)
|
||||
if not upload_file:
|
||||
logger.warning("Upload file not found for attachment_id: %s", attachment_id)
|
||||
continue
|
||||
|
||||
# Create segment attachment binding
|
||||
bindings.append(
|
||||
SegmentAttachmentBinding(
|
||||
tenant_id=segment.tenant_id,
|
||||
dataset_id=segment.dataset_id,
|
||||
document_id=segment.document_id,
|
||||
segment_id=segment.id,
|
||||
attachment_id=upload_file.id,
|
||||
)
|
||||
)
|
||||
|
||||
# Create document for vector indexing
|
||||
documents.append(
|
||||
Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id})
|
||||
)
|
||||
|
||||
# Bulk insert all bindings at once
|
||||
if bindings:
|
||||
db.session.add_all(bindings)
|
||||
|
||||
# Add documents to vector store if any
|
||||
if documents and dataset.is_multimodal:
|
||||
vector.add_texts(documents, duplicate_check=True)
|
||||
|
||||
# Single commit for all operations
|
||||
db.session.commit()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to update multimodal vector for segment %s", segment.id)
|
||||
db.session.rollback()
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user