Feat/support multimodal embedding (#29115)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jyong
2025-12-09 14:41:46 +08:00
committed by GitHub
parent 77cf8f6c27
commit 9affc546c6
78 changed files with 3230 additions and 713 deletions

View File

@ -2,6 +2,7 @@ from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
@ -30,9 +31,10 @@ class DataPostProcessor:
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type)
if self.reorder_runner:
documents = self.reorder_runner.run(documents)

View File

@ -1,23 +1,30 @@
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
@ -37,14 +44,15 @@ class RetrievalService:
retrieval_method: RetrievalMethod,
dataset_id: str,
query: str,
top_k: int,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
):
if not query:
if not query and not attachment_ids:
return []
dataset = cls._get_dataset(dataset_id)
if not dataset:
@ -56,69 +64,52 @@ class RetrievalService:
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
retrieval_service = RetrievalService()
if query:
futures.append(
executor.submit(
cls.keyword_search,
retrieval_service._retrieve,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
all_documents=all_documents,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
futures.append(
executor.submit(
cls.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
retrieval_method=retrieval_method,
dataset=dataset,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
reranking_mode=reranking_mode,
weights=weights,
document_ids_filter=document_ids_filter,
attachment_id=None,
all_documents=all_documents,
exceptions=exceptions,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
futures.append(
executor.submit(
cls.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
if attachment_ids:
for attachment_id in attachment_ids:
futures.append(
executor.submit(
retrieval_service._retrieve,
flask_app=current_app._get_current_object(), # type: ignore
retrieval_method=retrieval_method,
dataset=dataset,
query=None,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
reranking_mode=reranking_mode,
weights=weights,
document_ids_filter=document_ids_filter,
attachment_id=attachment_id,
all_documents=all_documents,
exceptions=exceptions,
)
)
)
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
all_documents = cls._deduplicate_documents(all_documents)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k,
)
return all_documents
@classmethod
@ -223,6 +214,7 @@ class RetrievalService:
retrieval_method: RetrievalMethod,
exceptions: list,
document_ids_filter: list[str] | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
):
with flask_app.app_context():
try:
@ -231,14 +223,30 @@ class RetrievalService:
raise ValueError("dataset not found")
vector = Vector(dataset=dataset)
documents = vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
documents = []
if query_type == QueryType.TEXT_QUERY:
documents.extend(
vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
)
if query_type == QueryType.IMAGE_QUERY:
if not dataset.is_multimodal:
return
documents.extend(
vector.search_by_file(
file_id=query,
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
)
if documents:
if (
@ -250,14 +258,37 @@ class RetrievalService:
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
if dataset.is_multimodal:
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model.get("reranking_provider_name") or "",
model=reranking_model.get("reranking_model_name") or "",
model_type=ModelType.RERANK,
)
if is_support_vision:
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
query_type=query_type,
)
)
else:
# not effective, return original documents
all_documents.extend(documents)
else:
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
query_type=query_type,
)
)
)
else:
all_documents.extend(documents)
except Exception as e:
@ -339,103 +370,159 @@ class RetrievalService:
records = []
include_segment_ids = set()
segment_child_map = {}
# Process documents
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)
if not child_chunk:
segment_file_map = {}
with Session(db.engine) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == child_chunk.segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# Handle parent-child documents
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attchment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
if not child_chunk:
continue
segment_id = child_chunk.segment_id
if not segment_id:
continue
segment = (
session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
)
)
.first()
)
.first()
)
if not segment:
continue
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if child_chunk:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
if attachment_info:
segment_file_map[segment.id].append(attachment_info)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
# Handle normal documents
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
# Handle normal documents
segment = None
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attchment_info"]
segment_id = attachment_info_dict["segment_id"]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
segment = db.session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
if not segment:
continue
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
records.append(record)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
records.append(record)
else:
if attachment_info:
attachment_infos = segment_file_map.get(segment.id, [])
if attachment_info not in attachment_infos:
attachment_infos.append(attachment_info)
segment_file_map[segment.id] = attachment_infos
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
if record["segment"].id in segment_file_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
result = []
for record in records:
@ -447,6 +534,11 @@ class RetrievalService:
if not isinstance(child_chunks, list):
child_chunks = None
# Extract files, ensuring it's a list or None
files = record.get("files")
if not isinstance(files, list):
files = None
# Extract score, ensuring it's a float or None
score_value = record.get("score")
score = (
@ -456,10 +548,149 @@ class RetrievalService:
)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks, score=score, files=files
)
result.append(retrieval_segment)
return result
except Exception as e:
db.session.rollback()
raise e
def _retrieve(
self,
flask_app: Flask,
retrieval_method: RetrievalMethod,
dataset: Dataset,
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
all_documents: list[Document] = [],
exceptions: list[str] = [],
):
if not query and not attachment_id:
return
with flask_app.app_context():
all_documents_item: list[Document] = []
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query:
futures.append(
executor.submit(
self.keyword_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
all_documents=all_documents_item,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
if query:
futures.append(
executor.submit(
self.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
query_type=QueryType.TEXT_QUERY,
)
)
if attachment_id:
futures.append(
executor.submit(
self.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=attachment_id,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
query_type=QueryType.IMAGE_QUERY,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query:
futures.append(
executor.submit(
self.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents_item,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE:
all_documents.extend(all_documents_item)
all_documents_item = self._deduplicate_documents(all_documents_item)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
query = query or attachment_id
if not query:
return
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY,
)
all_documents.extend(all_documents_item)
@classmethod
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> dict[str, Any] | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
.first()
)
if attachment_binding:
attchment_info = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
"size": upload_file.size,
}
return {"attchment_info": attchment_info, "segment_id": attachment_binding.segment_id}
return None

View File

@ -1,3 +1,4 @@
import base64
import logging
import time
from abc import ABC, abstractmethod
@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Dataset, Whitelist
from models.model import UploadFile
logger = logging.getLogger(__name__)
@ -203,6 +207,47 @@ class Vector:
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
def create_multimodal(self, file_documents: list | None = None, **kwargs):
if file_documents:
start = time.time()
logger.info("start embedding %s files %s", len(file_documents), start)
batch_size = 1000
total_batches = len(file_documents) + batch_size - 1
for i in range(0, len(file_documents), batch_size):
batch = file_documents[i : i + batch_size]
batch_start = time.time()
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
# Batch query all upload files to avoid N+1 queries
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
upload_files = db.session.scalars(stmt).all()
upload_file_map = {str(f.id): f for f in upload_files}
file_base64_list = []
real_batch = []
for document in batch:
attachment_id = document.metadata["doc_id"]
doc_type = document.metadata["doc_type"]
upload_file = upload_file_map.get(attachment_id)
if upload_file:
blob = storage.load_once(upload_file.key)
file_base64_str = base64.b64encode(blob).decode()
file_base64_list.append(
{
"content": file_base64_str,
"content_type": doc_type,
"file_id": attachment_id,
}
)
real_batch.append(document)
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
@ -223,6 +268,22 @@ class Vector:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
return []
blob = storage.load_once(upload_file.key)
file_base64_str = base64.b64encode(blob).decode()
multimodal_vector = self._embeddings.embed_multimodal_query(
{
"content": file_base64_str,
"content_type": DocType.IMAGE,
"file_id": file_id,
}
)
return self._vector_processor.search_by_vector(multimodal_vector, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)

View File

@ -5,9 +5,9 @@ from sqlalchemy import func, select
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
class DatasetDocumentStore:
@ -120,6 +120,9 @@ class DatasetDocumentStore:
db.session.add(segment_document)
db.session.flush()
self.add_multimodel_documents_binding(
segment_id=segment_document.id, multimodel_documents=doc.attachments
)
if save_child:
if doc.children:
for position, child in enumerate(doc.children, start=1):
@ -144,6 +147,9 @@ class DatasetDocumentStore:
segment_document.index_node_hash = doc.metadata.get("doc_hash")
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
self.add_multimodel_documents_binding(
segment_id=segment_document.id, multimodel_documents=doc.attachments
)
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).where(
@ -233,3 +239,15 @@ class DatasetDocumentStore:
document_segment = db.session.scalar(stmt)
return document_segment
def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None):
if multimodel_documents:
for multimodel_document in multimodel_documents:
binding = SegmentAttachmentBinding(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_id,
attachment_id=multimodel_document.metadata["doc_id"],
)
db.session.add(binding)

View File

@ -104,6 +104,88 @@ class CacheEmbedding(Embeddings):
return text_embeddings
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
"""Embed file documents."""
# use doc embedding cache or store if not exists
multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))]
embedding_queue_indices = []
for i, multimodel_document in enumerate(multimodel_documents):
file_id = multimodel_document["file_id"]
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
)
.first()
)
if embedding:
multimodel_embeddings[i] = embedding.get_embedding()
else:
embedding_queue_indices.append(i)
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
if embedding_queue_indices:
embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices]
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
else 1
)
for i in range(0, len(embedding_queue_multimodel_documents), max_chunks):
batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks]
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=batch_multimodel_documents,
user=self._user,
input_type=EmbeddingInputType.DOCUMENT,
)
for vector in embedding_result.embeddings:
try:
# FIXME: type ignore for numpy here
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
if np.isnan(normalized_embedding).any():
# for issue #11827 float values are not json compliant
logger.warning("Normalized embedding is nan: %s", normalized_embedding)
continue
embedding_queue_embeddings.append(normalized_embedding)
except IntegrityError:
db.session.rollback()
except Exception:
logger.exception("Failed transform embedding")
cache_embeddings = []
try:
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
multimodel_embeddings[i] = n_embedding
file_id = multimodel_documents[i]["file_id"]
if file_id not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model,
hash=file_id,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
)
embedding_cache.set_embedding(n_embedding)
db.session.add(embedding_cache)
cache_embeddings.append(file_id)
db.session.commit()
except IntegrityError:
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.exception("Failed to embed documents")
raise ex
return multimodel_embeddings
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
@ -146,3 +228,46 @@ class CacheEmbedding(Embeddings):
raise ex
return embedding_results # type: ignore
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists
file_id = multimodel_document["file_id"]
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float")
return [float(x) for x in decoded_embedding]
try:
embedding_result = self._model_instance.invoke_multimodal_embedding(
multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY
)
embedding_results = embedding_result.embeddings[0]
# FIXME: type ignore for numpy here
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"])
raise ex
try:
# encode embedding to base64
embedding_vector = np.array(embedding_results)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex:
if dify_config.DEBUG:
logger.exception(
"Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"]
)
raise ex
return embedding_results # type: ignore

View File

@ -9,11 +9,21 @@ class Embeddings(ABC):
"""Embed search docs."""
raise NotImplementedError
@abstractmethod
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
"""Embed file documents."""
raise NotImplementedError
@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
raise NotImplementedError
@abstractmethod
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
"""Embed multimodal query."""
raise NotImplementedError
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError

View File

@ -19,3 +19,4 @@ class RetrievalSegments(BaseModel):
segment: DocumentSegment
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None

View File

@ -21,3 +21,4 @@ class RetrievalSourceMetadata(BaseModel):
page: int | None = None
doc_metadata: dict[str, Any] | None = None
title: str | None = None
files: list[dict[str, Any]] | None = None

View File

@ -0,0 +1,6 @@
from enum import StrEnum
class DocType(StrEnum):
TEXT = "text"
IMAGE = "image"

View File

@ -1,7 +1,12 @@
from enum import StrEnum
class IndexType(StrEnum):
class IndexStructureType(StrEnum):
PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "hierarchical_model"
class IndexTechniqueType(StrEnum):
ECONOMY = "economy"
HIGH_QUALITY = "high_quality"

View File

@ -0,0 +1,6 @@
from enum import StrEnum
class QueryType(StrEnum):
TEXT_QUERY = "text_query"
IMAGE_QUERY = "image_query"

View File

@ -1,20 +1,34 @@
"""Abstract interface for document loader implementations."""
import cgi
import logging
import mimetypes
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional
from urllib.parse import unquote, urlparse
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import Account, ToolFile
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
if TYPE_CHECKING:
from core.model_manager import ModelInstance
@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC):
raise NotImplementedError
@abstractmethod
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
raise NotImplementedError
@abstractmethod
@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC):
)
return character_splitter # type: ignore
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
"""
Get the content files from the document.
"""
multi_model_documents: list[AttachmentDocument] = []
text = document.page_content
images = self._extract_markdown_images(text)
if not images:
return multi_model_documents
upload_file_id_list = []
for image in images:
# Collect all upload_file_ids including duplicates to preserve occurrence count
# For data before v0.10.0
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For data after v0.10.0
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
# Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes)
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
match = re.search(pattern, image)
if match:
if current_user:
tool_file_id = match.group(1)
upload_file_id = self._download_tool_file(tool_file_id, current_user)
if upload_file_id:
upload_file_id_list.append(upload_file_id)
continue
if current_user:
upload_file_id = self._download_image(image.split(" ")[0], current_user)
if upload_file_id:
upload_file_id_list.append(upload_file_id)
if not upload_file_id_list:
return multi_model_documents
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
# Create a mapping from ID to UploadFile for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
# Create a Document for each occurrence (including duplicates)
for upload_file_id in upload_file_id_list:
upload_file = upload_file_map.get(upload_file_id)
if upload_file:
multi_model_documents.append(
AttachmentDocument(
page_content=upload_file.name,
metadata={
"doc_id": upload_file.id,
"doc_hash": "",
"document_id": document.metadata.get("document_id"),
"dataset_id": document.metadata.get("dataset_id"),
"doc_type": DocType.IMAGE,
},
)
)
return multi_model_documents
def _extract_markdown_images(self, text: str) -> list[str]:
"""
Extract the markdown images from the text.
"""
pattern = r"!\[.*?\]\((.*?)\)"
return re.findall(pattern, text)
def _download_image(self, image_url: str, current_user: Account) -> str | None:
"""
Download the image from the URL.
Image size must not exceed 2MB.
"""
from services.file_service import FileService
MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024
DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT
try:
# Download with timeout
response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT)
response.raise_for_status()
# Check Content-Length header if available
content_length = response.headers.get("Content-Length")
if content_length and int(content_length) > MAX_IMAGE_SIZE:
logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length)
return None
filename = None
content_disposition = response.headers.get("content-disposition")
if content_disposition:
_, params = cgi.parse_header(content_disposition)
if "filename" in params:
filename = params["filename"]
filename = unquote(filename)
if not filename:
parsed_url = urlparse(image_url)
# unquote 处理 URL 中的中文
path = unquote(parsed_url.path)
filename = os.path.basename(path)
if not filename:
filename = "downloaded_image_file"
name, current_ext = os.path.splitext(filename)
content_type = response.headers.get("content-type", "").split(";")[0].strip()
real_ext = mimetypes.guess_extension(content_type)
if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext:
filename = f"{name}{real_ext}"
# Download content with size limit
blob = b""
for chunk in response.iter_bytes(chunk_size=8192):
blob += chunk
if len(blob) > MAX_IMAGE_SIZE:
logging.warning("Image from %s exceeds 2MB limit during download", image_url)
return None
if not blob:
logging.warning("Image from %s is empty", image_url)
return None
upload_file = FileService(db.engine).upload_file(
filename=filename,
content=blob,
mimetype=content_type,
user=current_user,
)
return upload_file.id
except httpx.TimeoutException:
logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT)
return None
except httpx.RequestError as e:
logging.warning("Error downloading image from %s: %s", image_url, str(e))
return None
except Exception:
logging.exception("Unexpected error downloading image from %s", image_url)
return None
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:
"""
Download the tool file from the ID.
"""
from services.file_service import FileService
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
upload_file = FileService(db.engine).upload_file(
filename=tool_file.name,
content=blob,
mimetype=tool_file.mimetype,
user=current_user,
)
return upload_file.id

View File

@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
@ -19,11 +19,11 @@ class IndexProcessorFactory:
if not self._index_type:
raise ValueError("Index type must be specified.")
if self._index_type == IndexType.PARAGRAPH_INDEX:
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX:
elif self._index_type == IndexStructureType.QA_INDEX:
return QAIndexProcessor()
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if document_node.metadata is not None:
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
multimodal_documents = (
self._get_content_files(document_node, current_user) if document_node.metadata else None
)
if multimodal_documents:
document_node.attachments = multimodal_documents
# delete Splitter character
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:
@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
with_keywords = False
if with_keywords:
keywords_list = kwargs.get("keywords_list")
@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
documents: list[Any] = []
all_multimodal_documents: list[Any] = []
if isinstance(chunks, list):
documents = []
for content in chunks:
metadata = {
"dataset_id": dataset.id,
@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
"doc_hash": helper.generate_text_hash(content),
}
doc = Document(page_content=content, metadata=metadata)
attachments = self._get_content_files(doc)
if attachments:
doc.attachments = attachments
all_multimodal_documents.extend(attachments)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
else:
raise ValueError("Chunks is not a list")
multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks)
for general_chunk in multimodal_general_structure.general_chunks:
metadata = {
"dataset_id": dataset.id,
"document_id": document.id,
"doc_id": str(uuid.uuid4()),
"doc_hash": helper.generate_text_hash(general_chunk.content),
}
doc = Document(page_content=general_chunk.content, metadata=metadata)
if general_chunk.files:
attachments = []
for file in general_chunk.files:
file_metadata = {
"doc_id": file.id,
"doc_hash": "",
"document_id": document.id,
"dataset_id": dataset.id,
"doc_type": DocType.IMAGE,
}
file_document = AttachmentDocument(
page_content=file.filename or "image_file", metadata=file_metadata
)
attachments.append(file_document)
all_multimodal_documents.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)
if doc.attachments:
all_multimodal_documents.extend(doc.attachments)
documents.append(doc)
if documents:
# save node to document segment
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
# add document segments
doc_store.add_documents(docs=documents, save_child=False)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if all_multimodal_documents:
vector.create_multimodal(all_multimodal_documents)
elif dataset.indexing_technique == "economy":
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
if isinstance(chunks, list):
preview = []
for content in chunks:
preview.append({"content": content})
return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)}
return {
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
"preview": preview,
"total_segments": len(chunks),
}
else:
raise ValueError("Chunks is not a list")

View File

@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from libs import helper
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
page_content = page_content
if len(page_content) > 0:
document_node.page_content = page_content
multimodel_documents = self._get_content_files(document_node, current_user)
if multimodel_documents:
document_node.attachments = multimodel_documents
# parse document to child nodes
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
elif rules.parent_mode == ParentMode.FULL_DOC:
page_content = "\n".join([document.page_content for document in documents])
document = Document(page_content=page_content, metadata=documents[0].metadata)
multimodel_documents = self._get_content_files(document)
if multimodel_documents:
document.attachments = multimodel_documents
# parse document to child nodes
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
Document.model_validate(child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
}
child_documents.append(ChildDocument(page_content=child, metadata=child_metadata))
doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents)
if parent_child.files and len(parent_child.files) > 0:
attachments = []
for file in parent_child.files:
file_metadata = {
"doc_id": file.id,
"doc_hash": "",
"document_id": document.id,
"dataset_id": dataset.id,
"doc_type": DocType.IMAGE,
}
file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata)
attachments.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)
documents.append(doc)
if documents:
# update document parent mode
@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
doc_store.add_documents(docs=documents, save_child=True)
if dataset.indexing_technique == "high_quality":
all_child_documents = []
all_multimodal_documents = []
for doc in documents:
if doc.children:
all_child_documents.extend(doc.children)
if doc.attachments:
all_multimodal_documents.extend(doc.attachments)
vector = Vector(dataset)
if all_child_documents:
vector = Vector(dataset)
vector.create(all_child_documents)
if all_multimodal_documents:
vector.create_multimodal(all_multimodal_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {
"chunk_structure": IndexType.PARENT_CHILD_INDEX,
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
"parent_mode": parent_childs.parent_mode,
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),

View File

@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, QAStructureChunk
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor):
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
if not process_rule:
@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor):
try:
# Skip the first row
df = pd.read_csv(file)
df = pd.read_csv(file) # type: ignore
text_docs = []
for _, row in df.iterrows():
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError(str(e))
return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
def load(
self,
dataset: Dataset,
documents: list[Document],
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
vector = Vector(dataset)
@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor):
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return {
"chunk_structure": IndexType.QA_INDEX,
"chunk_structure": IndexStructureType.QA_INDEX,
"qa_preview": preview,
"total_segments": len(qa_chunks.qa_chunks),
}

View File

@ -4,6 +4,8 @@ from typing import Any
from pydantic import BaseModel, Field
from core.file import File
class ChildDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
@ -15,7 +17,19 @@ class ChildDocument(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
class AttachmentDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
provider: str | None = "dify"
vector: list[float] | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
class Document(BaseModel):
@ -28,12 +42,31 @@ class Document(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
provider: str | None = "dify"
children: list[ChildDocument] | None = None
attachments: list[AttachmentDocument] | None = None
class GeneralChunk(BaseModel):
"""
General Chunk.
"""
content: str
files: list[File] | None = None
class MultimodalGeneralStructureChunk(BaseModel):
"""
Multimodal General Structure Chunk.
"""
general_chunks: list[GeneralChunk]
class GeneralStructureChunk(BaseModel):
"""
@ -50,6 +83,7 @@ class ParentChildChunk(BaseModel):
parent_content: str
child_contents: list[str]
files: list[File] | None = None
class ParentChildStructureChunk(BaseModel):

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
@ -12,6 +13,7 @@ class BaseRerankRunner(ABC):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model

View File

@ -1,6 +1,15 @@
from core.model_manager import ModelInstance
import base64
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_base import BaseRerankRunner
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import UploadFile
class RerankModelRunner(BaseRerankRunner):
@ -14,6 +23,7 @@ class RerankModelRunner(BaseRerankRunner):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model
@ -24,6 +34,56 @@ class RerankModelRunner(BaseRerankRunner):
:param user: unique user id if needed
:return:
"""
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
model=self.rerank_model_instance.model,
model_type=ModelType.RERANK,
)
if not is_support_vision:
if query_type == QueryType.TEXT_QUERY:
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
else:
return documents
else:
rerank_result, unique_documents = self.fetch_multimodal_rerank(
query, documents, score_threshold, top_n, user, query_type
)
rerank_documents = []
for result in rerank_result.docs:
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
metadata=unique_documents[result.index].metadata,
provider=unique_documents[result.index].provider,
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def fetch_text_rerank(
self,
query: str,
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> tuple[RerankResult, list[Document]]:
"""
Fetch text rerank
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:return:
"""
docs = []
doc_ids = set()
unique_documents = []
@ -33,33 +93,99 @@ class RerankModelRunner(BaseRerankRunner):
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
rerank_result = self.rerank_model_instance.invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
)
return rerank_result, unique_documents
rerank_documents = []
def fetch_multimodal_rerank(
self,
query: str,
documents: list[Document],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> tuple[RerankResult, list[Document]]:
"""
Fetch multimodal rerank
:param query: search query
:param documents: documents for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id if needed
:param query_type: query type
:return: rerank result
"""
docs = []
doc_ids = set()
unique_documents = []
for document in documents:
if (
document.provider == "dify"
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
if document.metadata.get("doc_type") == DocType.IMAGE:
# Query file info within db.session context to ensure thread-safe access
upload_file = (
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
)
if upload_file:
blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode()
document_file_dict = {
"content": document_file_base64,
"content_type": document.metadata["doc_type"],
}
docs.append(document_file_dict)
else:
document_text_dict = {
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
docs.append(document_text_dict)
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(
{
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
)
unique_documents.append(document)
for result in rerank_result.docs:
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
documents = unique_documents
if query_type == QueryType.TEXT_QUERY:
rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user)
return rerank_result, unique_documents
elif query_type == QueryType.IMAGE_QUERY:
# Query file info within db.session context to ensure thread-safe access
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
if upload_file:
blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode()
file_query_dict = {
"content": file_query,
"content_type": DocType.IMAGE,
}
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
return rerank_result, unique_documents
else:
raise ValueError(f"Upload file not found for query: {query}")
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
else:
raise ValueError(f"Query type {query_type} is not supported")

View File

@ -7,6 +7,8 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.entity.weight import VectorSetting, Weights
from core.rag.rerank.rerank_base import BaseRerankRunner
@ -24,6 +26,7 @@ class WeightRerankRunner(BaseRerankRunner):
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
) -> list[Document]:
"""
Run rerank model
@ -43,8 +46,10 @@ class WeightRerankRunner(BaseRerankRunner):
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
# weight rerank only support text documents
if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT:
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
else:
if document not in unique_documents:
unique_documents.append(document)

View File

@ -8,6 +8,7 @@ from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
DatasetEntity,
@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.file import File, FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
@ -99,7 +105,8 @@ class DatasetRetrieval:
message_id: str,
memory: TokenBufferMemory | None = None,
inputs: Mapping[str, Any] | None = None,
) -> str | None:
vision_enabled: bool = False,
) -> tuple[str | None, list[File] | None]:
"""
Retrieve dataset.
:param app_id: app_id
@ -118,7 +125,7 @@ class DatasetRetrieval:
"""
dataset_ids = config.dataset_ids
if len(dataset_ids) == 0:
return None
return None, []
retrieve_config = config.retrieve_config
# check model is support tool calling
@ -136,7 +143,7 @@ class DatasetRetrieval:
)
if not model_schema:
return None
return None, []
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
@ -182,8 +189,8 @@ class DatasetRetrieval:
tenant_id,
user_id,
user_from,
available_datasets,
query,
available_datasets,
model_instance,
model_config,
planning_strategy,
@ -213,6 +220,7 @@ class DatasetRetrieval:
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list: list[DocumentContext] = []
context_files: list[File] = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
# deal with external documents
for item in external_documents:
@ -248,6 +256,31 @@ class DatasetRetrieval:
score=record.score,
)
)
if vision_enabled:
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == segment.id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attchment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=segment.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attchment_info)
if show_retrieve_source:
for record in records:
segment = record.segment
@ -288,8 +321,10 @@ class DatasetRetrieval:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
return str("\n".join([document_context.content for document_context in document_context_list]))
return ""
return str(
"\n".join([document_context.content for document_context in document_context_list])
), context_files
return "", context_files
def single_retrieve(
self,
@ -297,8 +332,8 @@ class DatasetRetrieval:
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
available_datasets: list,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
@ -336,7 +371,7 @@ class DatasetRetrieval:
dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance)
self._record_usage(router_usage)
timer = None
if dataset_id:
# get retrieval model config
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
@ -406,10 +441,19 @@ class DatasetRetrieval:
weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter,
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
if results:
self._on_retrieval_end(results, message_id, timer)
thread = threading.Thread(
target=self._on_retrieval_end,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"documents": results,
"message_id": message_id,
"timer": timer,
},
)
thread.start()
return results
return []
@ -421,7 +465,7 @@ class DatasetRetrieval:
user_id: str,
user_from: str,
available_datasets: list,
query: str,
query: str | None,
top_k: int,
score_threshold: float,
reranking_mode: str,
@ -431,10 +475,11 @@ class DatasetRetrieval:
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
):
if not available_datasets:
return []
threads = []
all_threads = []
all_documents: list[Document] = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type_check = all(
@ -467,131 +512,226 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
with measure_time() as timer:
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"available_datasets": available_datasets,
"metadata_condition": metadata_condition,
"metadata_filter_document_ids": metadata_filter_document_ids,
"all_documents": all_documents,
"tenant_id": tenant_id,
"reranking_enable": reranking_enable,
"reranking_mode": reranking_mode,
"reranking_model": reranking_model,
"weights": weights,
"top_k": top_k,
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
},
)
else:
if index_type == "economy":
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
elif index_type == "high_quality":
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
else:
all_documents = all_documents[:top_k] if top_k else all_documents
self._on_query(query, dataset_ids, app_id, user_from, user_id)
all_threads.append(query_thread)
query_thread.start()
if attachment_ids:
for attachment_id in attachment_ids:
attachment_thread = threading.Thread(
target=self._multiple_retrieve_thread,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"available_datasets": available_datasets,
"metadata_condition": metadata_condition,
"metadata_filter_document_ids": metadata_filter_document_ids,
"all_documents": all_documents,
"tenant_id": tenant_id,
"reranking_enable": reranking_enable,
"reranking_mode": reranking_mode,
"reranking_model": reranking_model,
"weights": weights,
"top_k": top_k,
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
for thread in all_threads:
thread.join()
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
self._on_retrieval_end(all_documents, message_id, timer)
return all_documents
def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None):
"""Handle retrieval end."""
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = db.session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
_ = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
else:
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
# add thread to call _on_retrieval_end
retrieval_end_thread = threading.Thread(
target=self._on_retrieval_end,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"documents": all_documents,
"message_id": message_id,
"timer": timer,
},
)
retrieval_end_thread.start()
retrieval_resource_list = []
doc_ids_filter = []
for document in all_documents:
if document.provider == "dify":
doc_id = document.metadata.get("doc_id")
if doc_id and doc_id not in doc_ids_filter:
doc_ids_filter.append(doc_id)
retrieval_resource_list.append(document)
elif document.provider == "external":
retrieval_resource_list.append(document)
return retrieval_resource_list
def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str):
def _on_retrieval_end(
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
):
"""Handle retrieval end."""
with flask_app.app_context():
dify_documents = [document for document in documents if document.provider == "dify"]
segment_ids = []
segment_index_node_ids = []
with Session(db.engine) as session:
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
segment_id = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = session.scalar(child_chunk_stmt)
if child_chunk:
segment_id = child_chunk.segment_id
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
if segment_id:
if segment_id not in segment_ids:
segment_ids.append(segment_id)
_ = (
session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
else:
query = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
if document.metadata["doc_id"] not in segment_index_node_ids:
segment = (
session.query(DocumentSegment)
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
.first()
)
if segment:
segment_index_node_ids.append(document.metadata["doc_id"])
segment_ids.append(segment.id)
query = session.query(DocumentSegment).where(
DocumentSegment.id == segment.id
)
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
if segment_id not in segment_ids:
segment_ids.append(segment_id)
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
if query:
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(
DocumentSegment.dataset_id == document.metadata["dataset_id"]
)
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
)
def _on_query(
self,
query: str | None,
attachment_ids: list[str] | None,
dataset_ids: list[str],
app_id: str,
user_from: str,
user_id: str,
):
"""
Handle query.
"""
if not query:
if not query and not attachment_ids:
return
dataset_queries = []
for dataset_id in dataset_ids:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
contents = []
if query:
contents.append({"content_type": QueryType.TEXT_QUERY, "content": query})
if attachment_ids:
for attachment_id in attachment_ids:
contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id})
if contents:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=json.dumps(contents),
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
db.session.add_all(dataset_queries)
db.session.commit()
def _retriever(
@ -603,6 +743,7 @@ class DatasetRetrieval:
all_documents: list,
document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
):
with flask_app.app_context():
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
@ -611,7 +752,7 @@ class DatasetRetrieval:
if not dataset:
return []
if dataset.provider == "external":
if dataset.provider == "external" and query:
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
@ -663,6 +804,7 @@ class DatasetRetrieval:
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
attachment_ids=attachment_ids,
)
all_documents.extend(documents)
@ -1222,3 +1364,86 @@ class DatasetRetrieval:
usage = LLMUsage.empty_usage()
return full_text, usage
def _multiple_retrieve_thread(
self,
flask_app: Flask,
available_datasets: list,
metadata_condition: MetadataCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document],
tenant_id: str,
reranking_enable: bool,
reranking_mode: str,
reranking_model: dict | None,
weights: dict[str, Any] | None,
top_k: int,
score_threshold: float,
query: str | None,
attachment_id: str | None,
):
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)