Feat/support multimodal embedding (#29115)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jyong
2025-12-09 14:41:46 +08:00
committed by GitHub
parent 77cf8f6c27
commit 9affc546c6
78 changed files with 3230 additions and 713 deletions

View File

@ -0,0 +1,31 @@
import base64
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from extensions.ext_storage import storage
from models.model import UploadFile
PREVIEW_WORDS_LIMIT = 3000
class AttachmentService:
_session_maker: sessionmaker
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
)
if not upload_file:
raise NotFound("File not found")
blob = storage.load_once(upload_file.key)
return base64.b64encode(blob).decode()

View File

@ -7,7 +7,7 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal
from typing import Any, Literal, cast
import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
@ -19,9 +19,10 @@ from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from enums.cloud_plan import CloudPlan
from events.dataset_event import dataset_was_deleted
@ -46,6 +47,7 @@ from models.dataset import (
DocumentSegment,
ExternalKnowledgeBindings,
Pipeline,
SegmentAttachmentBinding,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
@ -363,6 +365,27 @@ class DatasetService:
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
@staticmethod
def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str):
try:
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=model,
)
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema:
raise ValueError("Model schema not found")
if model_schema.features and ModelFeature.VISION in model_schema.features:
return True
else:
return False
except LLMBadRequestError:
raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.")
@staticmethod
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
try:
@ -402,13 +425,13 @@ class DatasetService:
if not dataset:
raise ValueError("Dataset not found")
# check if dataset name is exists
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
if data.get("name") and data.get("name") != dataset.name:
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
# Verify user has permission to update this dataset
DatasetService.check_dataset_permission(dataset, user)
@ -844,6 +867,12 @@ class DatasetService:
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model or "",
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
@ -880,6 +909,12 @@ class DatasetService:
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.collection_binding_id = dataset_collection_binding.id
dataset.indexing_technique = knowledge_configuration.indexing_technique
except LLMBadRequestError:
@ -937,6 +972,12 @@ class DatasetService:
)
)
dataset.collection_binding_id = dataset_collection_binding.id
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
knowledge_configuration.embedding_model_provider,
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
@ -2305,6 +2346,7 @@ class DocumentService:
embedding_model_provider=knowledge_config.embedding_model_provider,
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
is_multimodal=knowledge_config.is_multimodal,
)
db.session.add(dataset)
@ -2685,6 +2727,13 @@ class SegmentService:
if "content" not in args or not args["content"] or not args["content"].strip():
raise ValueError("Content is empty")
if args.get("attachment_ids"):
if not isinstance(args["attachment_ids"], list):
raise ValueError("Attachment IDs is invalid")
single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT
if len(args["attachment_ids"]) > single_chunk_attachment_limit:
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
@ -2731,11 +2780,23 @@ class SegmentService:
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()
if args["attachment_ids"]:
for attachment_id in args["attachment_ids"]:
binding = SegmentAttachmentBinding(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
segment_id=segment_document.id,
attachment_id=attachment_id,
)
db.session.add(binding)
db.session.commit()
# save vector index
@ -2899,7 +2960,7 @@ class SegmentService:
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# update segment index task
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks
# get embedding model instance
if dataset.indexing_technique == "high_quality":
@ -2926,12 +2987,11 @@ class SegmentService:
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
if args.enabled or keyword_changed:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
@ -2976,7 +3036,7 @@ class SegmentService:
db.session.add(document)
db.session.add(segment)
db.session.commit()
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
@ -3002,15 +3062,15 @@ class SegmentService:
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
)
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# update multimodel vector index
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
except Exception as e:
logger.exception("update segment index failed")
segment.enabled = False
@ -3048,7 +3108,9 @@ class SegmentService:
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
delete_segment_from_index_task.delay(
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
)
db.session.delete(segment)
# update document word count
@ -3097,7 +3159,9 @@ class SegmentService:
# Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids:
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
delete_segment_from_index_task.delay(
index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids
)
if document.word_count is None:
document.word_count = 0

View File

@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
name: str | None = None
is_multimodal: bool = False
class SegmentCreateArgs(BaseModel):
content: str | None = None
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdateArgs(BaseModel):
@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
enabled: bool | None = None
attachment_ids: list[str] | None = None
class ChildChunkUpdateArgs(BaseModel):

View File

@ -1,3 +1,4 @@
import base64
import hashlib
import os
import uuid
@ -123,6 +124,15 @@ class FileService:
return file_size <= file_size_limit
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
)
if not upload_file:
raise NotFound("File not found")
blob = storage.load_once(upload_file.key)
return base64.b64encode(blob).decode()
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
if len(text_name) > 200:
text_name = text_name[:200]

View File

@ -1,3 +1,4 @@
import json
import logging
import time
from typing import Any
@ -5,6 +6,7 @@ from typing import Any
from core.app.app_config.entities import ModelConfig
from core.model_runtime.entities import LLMMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -32,6 +34,7 @@ class HitTestingService:
account: Account,
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
attachment_ids: list | None = None,
limit: int = 10,
):
start = time.perf_counter()
@ -41,7 +44,7 @@ class HitTestingService:
retrieval_model = dataset.retrieval_model or default_retrieval_model
document_ids_filter = None
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions:
if metadata_filtering_conditions and query:
dataset_retrieval = DatasetRetrieval()
from core.app.app_config.entities import MetadataFilteringCondition
@ -66,6 +69,7 @@ class HitTestingService:
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
dataset_id=dataset.id,
query=query,
attachment_ids=attachment_ids,
top_k=retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
@ -80,17 +84,24 @@ class HitTestingService:
end = time.perf_counter()
logger.debug("Hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
dataset_queries = []
if query:
content = {"content_type": QueryType.TEXT_QUERY, "content": query}
dataset_queries.append(content)
if attachment_ids:
for attachment_id in attachment_ids:
content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}
dataset_queries.append(content)
if dataset_queries:
dataset_query = DatasetQuery(
dataset_id=dataset.id,
content=json.dumps(dataset_queries),
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
db.session.commit()
return cls.compact_retrieve_response(query, all_documents)
@ -168,9 +179,14 @@ class HitTestingService:
@classmethod
def hit_testing_args_check(cls, args):
query = args["query"]
attachment_ids = args["attachment_ids"]
if not query or len(query) > 250:
raise ValueError("Query is required and cannot exceed 250 characters")
if not attachment_ids and not query:
raise ValueError("Query or attachment_ids is required")
if query and len(query) > 250:
raise ValueError("Query cannot exceed 250 characters")
if attachment_ids and not isinstance(attachment_ids, list):
raise ValueError("Attachment_ids must be a list")
@staticmethod
def escape_query_for_search(query: str) -> str:

View File

@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
@ -21,9 +24,10 @@ class VectorService:
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents: list[Document] = []
multimodal_documents: list[AttachmentDocument] = []
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
if doc_form == IndexStructureType.PARENT_CHILD_INDEX:
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not dataset_document:
logger.warning(
@ -70,12 +74,29 @@ class VectorService:
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.TEXT,
},
)
documents.append(rag_document)
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_document: AttachmentDocument = AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
multimodal_documents.append(multimodal_document)
index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor()
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list)
if len(multimodal_documents) > 0:
index_processor.load(dataset, [], multimodal_documents, with_keywords=False)
@classmethod
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
@ -130,6 +151,7 @@ class VectorService:
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.TEXT,
},
)
# use full doc mode to generate segment's child chunk
@ -226,3 +248,92 @@ class VectorService:
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
vector.delete_by_ids([child_chunk.index_node_id])
@classmethod
def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset):
if dataset.indexing_technique != "high_quality":
return
attachments = segment.attachments
old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else []
# Check if there's any actual change needed
if set(attachment_ids) == set(old_attachment_ids):
return
try:
vector = Vector(dataset=dataset)
if dataset.is_multimodal:
# Delete old vectors if they exist
if old_attachment_ids:
vector.delete_by_ids(old_attachment_ids)
# Delete existing segment attachment bindings in one operation
db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete(
synchronize_session=False
)
if not attachment_ids:
db.session.commit()
return
# Bulk fetch upload files - only fetch needed fields
upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if not upload_file_list:
db.session.commit()
return
# Create a mapping for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list}
# Prepare batch operations
bindings = []
documents = []
# Create common metadata base to avoid repetition
base_metadata = {
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
}
# Process attachments in the order specified by attachment_ids
for attachment_id in attachment_ids:
upload_file = upload_file_map.get(attachment_id)
if not upload_file:
logger.warning("Upload file not found for attachment_id: %s", attachment_id)
continue
# Create segment attachment binding
bindings.append(
SegmentAttachmentBinding(
tenant_id=segment.tenant_id,
dataset_id=segment.dataset_id,
document_id=segment.document_id,
segment_id=segment.id,
attachment_id=upload_file.id,
)
)
# Create document for vector indexing
documents.append(
Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id})
)
# Bulk insert all bindings at once
if bindings:
db.session.add_all(bindings)
# Add documents to vector store if any
if documents and dataset.is_multimodal:
vector.add_texts(documents, duplicate_check=True)
# Single commit for all operations
db.session.commit()
except Exception:
logger.exception("Failed to update multimodal vector for segment %s", segment.id)
db.session.rollback()
raise