chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@ -2,37 +2,35 @@ import re
class CleanProcessor:
@classmethod
def clean(cls, text: str, process_rule: dict) -> str:
# default clean
# remove invalid symbol
text = re.sub(r'<\|', '<', text)
text = re.sub(r'\|>', '>', text)
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
text = re.sub(r"<\|", "<", text)
text = re.sub(r"\|>", ">", text)
text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text)
# Unicode U+FFFE
text = re.sub('\uFFFE', '', text)
text = re.sub("\ufffe", "", text)
rules = process_rule['rules'] if process_rule else None
if 'pre_processing_rules' in rules:
rules = process_rule["rules"] if process_rule else None
if "pre_processing_rules" in rules:
pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules:
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
# Remove extra spaces
pattern = r'\n{3,}'
text = re.sub(pattern, '\n\n', text)
pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
text = re.sub(pattern, ' ', text)
pattern = r"\n{3,}"
text = re.sub(pattern, "\n\n", text)
pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}"
text = re.sub(pattern, " ", text)
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
# Remove email
pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
text = re.sub(pattern, '', text)
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text)
# Remove URL
pattern = r'https?://[^\s]+'
text = re.sub(pattern, '', text)
pattern = r"https?://[^\s]+"
text = re.sub(pattern, "", text)
return text
def filter_string(self, text):
return text

View File

@ -1,12 +1,11 @@
"""Abstract interface for document cleaner implementations."""
from abc import ABC, abstractmethod
class BaseCleaner(ABC):
"""Interface for clean chunk content.
"""
"""Interface for clean chunk content."""
@abstractmethod
def clean(self, content: str):
raise NotImplementedError

View File

@ -1,9 +1,9 @@
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.core import clean_extra_whitespace

View File

@ -1,9 +1,9 @@
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
import re

View File

@ -1,9 +1,9 @@
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.core import clean_non_ascii_chars

View File

@ -1,11 +1,12 @@
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""Replaces unicode quote characters, such as the \x91 character in a string."""
from unstructured.cleaners.core import replace_unicode_quotes
return replace_unicode_quotes(content)

View File

@ -1,9 +1,9 @@
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredTranslateTextCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.translate import translate_text

View File

@ -12,17 +12,27 @@ from core.rag.rerank.weight_rerank import WeightRerankRunner
class DataPostProcessor:
"""Interface for data post-processing document.
"""
"""Interface for data post-processing document."""
def __init__(self, tenant_id: str, reranking_mode: str,
reranking_model: Optional[dict] = None, weights: Optional[dict] = None,
reorder_enabled: bool = False):
def __init__(
self,
tenant_id: str,
reranking_mode: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
reorder_enabled: bool = False,
):
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
def invoke(
self,
query: str,
documents: list[Document],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
@ -31,21 +41,26 @@ class DataPostProcessor:
return documents
def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None,
weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]:
def _get_rerank_runner(
self,
reranking_mode: str,
tenant_id: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
) -> Optional[RerankModelRunner | WeightRerankRunner]:
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
return WeightRerankRunner(
tenant_id,
Weights(
vector_setting=VectorSetting(
vector_weight=weights['vector_setting']['vector_weight'],
embedding_provider_name=weights['vector_setting']['embedding_provider_name'],
embedding_model_name=weights['vector_setting']['embedding_model_name'],
vector_weight=weights["vector_setting"]["vector_weight"],
embedding_provider_name=weights["vector_setting"]["embedding_provider_name"],
embedding_model_name=weights["vector_setting"]["embedding_model_name"],
),
keyword_setting=KeywordSetting(
keyword_weight=weights['keyword_setting']['keyword_weight'],
)
)
keyword_weight=weights["keyword_setting"]["keyword_weight"],
),
),
)
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
if reranking_model:
@ -53,9 +68,9 @@ class DataPostProcessor:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model['reranking_provider_name'],
provider=reranking_model["reranking_provider_name"],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
model=reranking_model["reranking_model_name"],
)
except InvokeAuthorizationError:
return None
@ -67,5 +82,3 @@ class DataPostProcessor:
if reorder_enabled:
return ReorderRunner()
return None

View File

@ -2,7 +2,6 @@ from core.rag.models.document import Document
class ReorderRunner:
def run(self, documents: list[Document]) -> list[Document]:
# Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list
odd_elements = documents[::2]

View File

@ -24,37 +24,42 @@ class Jieba(BaseKeyword):
self._config = KeywordTableConfig()
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None)
keywords_list = kwargs.get("keywords_list", None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = keyword_table_handler.extract_keywords(text.page_content,
self._config.max_keywords_per_chunk)
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
self._save_dataset_keyword_table(keyword_table)
@ -63,97 +68,91 @@ class Jieba(BaseKeyword):
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search(self, query: str, **kwargs: Any) -> list[Document]:
keyword_table = self._get_dataset_keyword_table()
k = kwargs.get('top_k', 4)
k = kwargs.get("top_k", 4)
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
.first()
)
if segment:
documents.append(Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
))
documents.append(
Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
)
return documents
def delete(self) -> None:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
if dataset_keyword_table.data_source_type != "database":
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
storage.delete(file_key)
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": keyword_table
}
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
}
dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type
if keyword_data_source_type == 'database':
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
else:
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict['__data__']['table']
return keyword_table_dict["__data__"]["table"]
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table='',
keyword_table="",
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == 'database':
dataset_keyword_table.keyword_table = json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(
{
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
},
cls=SetEncoder,
)
db.session.add(dataset_keyword_table)
db.session.commit()
@ -174,9 +173,7 @@ class Jieba(BaseKeyword):
keywords_to_delete = set()
for keyword, node_idxs in keyword_table.items():
if node_idxs_to_delete.intersection(node_idxs):
keyword_table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete)
if not keyword_table[keyword]:
keywords_to_delete.add(keyword)
@ -202,13 +199,14 @@ class Jieba(BaseKeyword):
reverse=True,
)
return sorted_chunk_indices[: k]
return sorted_chunk_indices[:k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id == node_id
).first()
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
)
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)
@ -224,14 +222,14 @@ class Jieba(BaseKeyword):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data['segment']
if pre_segment_data['keywords']:
segment.keywords = pre_segment_data['keywords']
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
pre_segment_data['keywords'])
segment = pre_segment_data["segment"]
if pre_segment_data["keywords"]:
segment.keywords = pre_segment_data["keywords"]
keyword_table = self._add_text_to_keyword_table(
keyword_table, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keywords = keyword_table_handler.extract_keywords(segment.content,
self._config.max_keywords_per_chunk)
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
self._save_dataset_keyword_table(keyword_table)

View File

@ -8,7 +8,6 @@ from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
@ -30,4 +29,4 @@ class JiebaKeywordTableHandler:
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
return results

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,6 @@ from models.dataset import Dataset
class BaseKeyword(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@ -31,15 +30,12 @@ class BaseKeyword(ABC):
def delete(self) -> None:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search(self, query: str, **kwargs: Any) -> list[Document]:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
doc_id = text.metadata['doc_id']
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
@ -47,4 +43,4 @@ class BaseKeyword(ABC):
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
return [text.metadata["doc_id"] for text in texts]

View File

@ -20,9 +20,7 @@ class Keyword:
raise ValueError("Keyword store must be specified.")
if keyword_type == "jieba":
return Jieba(
dataset=self._dataset
)
return Jieba(dataset=self._dataset)
else:
raise ValueError(f"Keyword store {keyword_type} is not supported.")
@ -41,10 +39,7 @@ class Keyword:
def delete(self) -> None:
self._keyword_processor.delete()
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search(self, query: str, **kwargs: Any) -> list[Document]:
return self._keyword_processor.search(query, **kwargs)
def __getattr__(self, name):

View File

@ -12,73 +12,83 @@ from extensions.ext_database import db
from models.dataset import Dataset
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class RetrievalService:
@classmethod
def retrieve(cls, retrieval_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0,
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model',
weights: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
def retrieve(
cls,
retrieval_method: str,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float] = 0.0,
reranking_model: Optional[dict] = None,
reranking_mode: Optional[str] = "reranking_model",
weights: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
threads = []
exceptions = []
# retrieval_model source with keyword
if retrieval_method == 'keyword_search':
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'all_documents': all_documents,
'exceptions': exceptions,
})
if retrieval_method == "keyword_search":
keyword_thread = threading.Thread(
target=RetrievalService.keyword_search,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"exceptions": exceptions,
},
)
threads.append(keyword_thread)
keyword_thread.start()
# retrieval_model source with semantic
if RetrievalMethod.is_support_semantic_search(retrieval_method):
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'score_threshold': score_threshold,
'reranking_model': reranking_model,
'all_documents': all_documents,
'retrieval_method': retrieval_method,
'exceptions': exceptions,
})
embedding_thread = threading.Thread(
target=RetrievalService.embedding_search,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
"score_threshold": score_threshold,
"reranking_model": reranking_model,
"all_documents": all_documents,
"retrieval_method": retrieval_method,
"exceptions": exceptions,
},
)
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'retrieval_method': retrieval_method,
'score_threshold': score_threshold,
'top_k': top_k,
'reranking_model': reranking_model,
'all_documents': all_documents,
'exceptions': exceptions,
})
full_text_index_thread = threading.Thread(
target=RetrievalService.full_text_index_search,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset_id,
"query": query,
"retrieval_method": retrieval_method,
"score_threshold": score_threshold,
"top_k": top_k,
"reranking_model": reranking_model,
"all_documents": all_documents,
"exceptions": exceptions,
},
)
threads.append(full_text_index_thread)
full_text_index_thread.start()
@ -86,110 +96,117 @@ class RetrievalService:
thread.join()
if exceptions:
exception_message = ';\n'.join(exceptions)
exception_message = ";\n".join(exceptions)
raise Exception(exception_message)
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
reranking_model, weights, False)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
)
return all_documents
@classmethod
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, all_documents: list, exceptions: list):
def keyword_search(
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
):
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
keyword = Keyword(
dataset=dataset
)
keyword = Keyword(dataset=dataset)
documents = keyword.search(
cls.escape_query_for_search(query),
top_k=top_k
)
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@classmethod
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, retrieval_method: str, exceptions: list):
def embedding_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
):
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
vector = Vector(
dataset=dataset
)
vector = Vector(dataset=dataset)
documents = vector.search_by_vector(
cls.escape_query_for_search(query),
search_type='similarity_score_threshold',
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
}
filter={"group_id": [dataset.id]},
)
if documents:
if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents)
))
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
)
)
else:
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@classmethod
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, retrieval_method: str, exceptions: list):
def full_text_index_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
):
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
vector_processor = Vector(
dataset=dataset,
)
documents = vector_processor.search_by_full_text(
cls.escape_query_for_search(query),
top_k=top_k
)
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
if documents:
if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents)
))
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
)
)
else:
all_documents.extend(documents)
except Exception as e:
@ -197,4 +214,4 @@ class RetrievalService:
@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')
return query.replace('"', '\\"')

View File

@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel):
namespace_password: str = (None,)
metrics: str = ("cosine",)
read_timeout: int = 60000
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
@ -37,6 +38,7 @@ class AnalyticdbConfig(BaseModel):
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector):
_instance = None
_init = False
@ -57,9 +59,7 @@ class AnalyticdbVector(BaseVector):
except:
raise ImportError(_import_err_msg)
self.config = config
self._client_config = open_api_models.Config(
user_agent="dify", **config.to_analyticdb_client_params()
)
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
AnalyticdbVector._init = True
@ -77,6 +77,7 @@ class AnalyticdbVector(BaseVector):
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -88,6 +89,7 @@ class AnalyticdbVector(BaseVector):
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
@ -109,13 +111,12 @@ class AnalyticdbVector(BaseVector):
)
self._client.create_namespace(request)
else:
raise ValueError(
f"failed to create namespace {self.config.namespace}: {e}"
)
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
@ -149,9 +150,7 @@ class AnalyticdbVector(BaseVector):
)
self._client.create_collection(request)
else:
raise ValueError(
f"failed to create collection {self._collection_name}: {e}"
)
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def get_type(self) -> str:
@ -162,10 +161,9 @@ class AnalyticdbVector(BaseVector):
self._create_collection_if_not_exists(dimension)
self.add_texts(texts, embeddings)
def add_texts(
self, documents: list[Document], embeddings: list[list[float]], **kwargs
):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
@ -191,6 +189,7 @@ class AnalyticdbVector(BaseVector):
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -202,13 +201,14 @@ class AnalyticdbVector(BaseVector):
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'"
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
@ -224,6 +224,7 @@ class AnalyticdbVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -235,15 +236,10 @@ class AnalyticdbVector(BaseVector):
)
self._client.delete_collection_data(request)
def search_by_vector(
self, query_vector: list[float], **kwargs: Any
) -> list[Document]:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = (
kwargs.get("score_threshold", 0.0)
if kwargs.get("score_threshold", 0.0)
else 0.0
)
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -270,11 +266,8 @@ class AnalyticdbVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = (
kwargs.get("score_threshold", 0.0)
if kwargs.get("score_threshold", 0.0)
else 0.0
)
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -304,6 +297,7 @@ class AnalyticdbVector(BaseVector):
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
@ -315,19 +309,16 @@ class AnalyticdbVector(BaseVector):
except Exception as e:
raise e
class AnalyticdbVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"][
"class_prefix"
]
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
# handle optional params
if dify_config.ANALYTICDB_KEY_ID is None:

View File

@ -27,21 +27,20 @@ class ChromaConfig(BaseModel):
settings = Settings(
# auth
chroma_client_auth_provider=self.auth_provider,
chroma_client_auth_credentials=self.auth_credentials
chroma_client_auth_credentials=self.auth_credentials,
)
return {
'host': self.host,
'port': self.port,
'ssl': False,
'tenant': self.tenant,
'database': self.database,
'settings': settings,
"host": self.host,
"port": self.port,
"ssl": False,
"tenant": self.tenant,
"database": self.database,
"settings": settings,
}
class ChromaVector(BaseVector):
def __init__(self, collection_name: str, config: ChromaConfig):
super().__init__(collection_name)
self._client_config = config
@ -58,9 +57,9 @@ class ChromaVector(BaseVector):
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str):
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
lock_name = "vector_indexing_lock_{}".format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
self._client.get_or_create_collection(collection_name)
@ -76,7 +75,7 @@ class ChromaVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(where={key: {'$eq': value}})
collection.delete(where={key: {"$eq": value}})
def delete(self):
self._client.delete_collection(self._collection_name)
@ -93,26 +92,26 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
ids: list[str] = results['ids'][0]
documents: list[str] = results['documents'][0]
metadatas: dict[str, Any] = results['metadatas'][0]
distances: list[float] = results['distances'][0]
ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0]
metadatas: dict[str, Any] = results["metadatas"][0]
distances: list[float] = results["distances"][0]
docs = []
for index in range(len(ids)):
distance = distances[index]
metadata = metadatas[index]
if distance >= score_threshold:
metadata['score'] = distance
metadata["score"] = distance
doc = Document(
page_content=documents[index],
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@ -123,15 +122,12 @@ class ChromaVector(BaseVector):
class ChromaVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {
"type": VectorType.CHROMA,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
return ChromaVector(

View File

@ -26,15 +26,15 @@ class ElasticSearchConfig(BaseModel):
username: str
password: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config PORT is required")
if not values['username']:
if not values["username"]:
raise ValueError("config USERNAME is required")
if not values['password']:
if not values["password"]:
raise ValueError("config PASSWORD is required")
return values
@ -50,10 +50,10 @@ class ElasticSearchVector(BaseVector):
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try:
parsed_url = urlparse(config.host)
if parsed_url.scheme in ['http', 'https']:
hosts = f'{config.host}:{config.port}'
if parsed_url.scheme in ["http", "https"]:
hosts = f"{config.host}:{config.port}"
else:
hosts = f'http://{config.host}:{config.port}'
hosts = f"http://{config.host}:{config.port}"
client = Elasticsearch(
hosts=hosts,
basic_auth=(config.username, config.password),
@ -68,25 +68,27 @@ class ElasticSearchVector(BaseVector):
def _get_version(self) -> str:
info = self._client.info()
return info['version']['number']
return info["version"]["number"]
def _check_version(self):
if self._version < '8.0.0':
if self._version < "8.0.0":
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
def get_type(self) -> str:
return 'elasticsearch'
return "elasticsearch"
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
for i in range(len(documents)):
self._client.index(index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
})
self._client.index(
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
},
)
self._client.indices.refresh(index=self._collection_name)
return uuids
@ -98,15 +100,9 @@ class ElasticSearchVector(BaseVector):
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str) -> None:
query_str = {
'query': {
'match': {
f'metadata.{key}': f'{value}'
}
}
}
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit['_id'] for hit in results['hits']['hits']]
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids:
self.delete_by_ids(ids)
@ -115,44 +111,44 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 10)
knn = {
"field": Field.VECTOR.value,
"query_vector": query_vector,
"k": top_k
}
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k}
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
docs_and_scores = []
for hit in results['hits']['hits']:
for hit in results["hits"]["hits"]:
docs_and_scores.append(
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
),
hit["_score"],
)
)
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
if score > score_threshold:
doc.metadata['score'] = score
doc.metadata["score"] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {
"match": {
Field.CONTENT_KEY.value: query
}
}
query_str = {"match": {Field.CONTENT_KEY.value: query}}
results = self._client.search(index=self._collection_name, query=query_str)
docs = []
for hit in results['hits']['hits']:
docs.append(Document(
page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value],
))
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
)
)
return docs
@ -162,11 +158,11 @@ class ElasticSearchVector(BaseVector):
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name}'
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
@ -179,14 +175,14 @@ class ElasticSearchVector(BaseVector):
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"similarity": "cosine"
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
},
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mappings)
@ -197,22 +193,21 @@ class ElasticSearchVector(BaseVector):
class ElasticSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get('ELASTICSEARCH_HOST'),
port=config.get('ELASTICSEARCH_PORT'),
username=config.get('ELASTICSEARCH_USERNAME'),
password=config.get('ELASTICSEARCH_PASSWORD'),
host=config.get("ELASTICSEARCH_HOST"),
port=config.get("ELASTICSEARCH_PORT"),
username=config.get("ELASTICSEARCH_USERNAME"),
password=config.get("ELASTICSEARCH_PASSWORD"),
),
attributes=[]
attributes=[],
)

View File

@ -27,44 +27,39 @@ class MilvusConfig(BaseModel):
batch_size: int = 100
database: str = "default"
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values.get('uri'):
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get('user'):
if not values.get("user"):
raise ValueError("config MILVUS_USER is required")
if not values.get('password'):
if not values.get("password"):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'uri': self.uri,
'token': self.token,
'user': self.user,
'password': self.password,
'db_name': self.database,
"uri": self.uri,
"token": self.token,
"user": self.user,
"password": self.password,
"db_name": self.database,
}
class MilvusVector(BaseVector):
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = 'Session'
self._consistency_level = "Session"
self._fields = []
def get_type(self) -> str:
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
@ -75,7 +70,7 @@ class MilvusVector(BaseVector):
insert_dict = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata
Field.METADATA_KEY.value: documents[i].metadata,
}
insert_dict_list.append(insert_dict)
# Total insert count
@ -84,22 +79,20 @@ class MilvusVector(BaseVector):
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i:i + 1000]
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection.
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count)
raise e
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["{key}"] == "{value}"',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
@ -107,17 +100,15 @@ class MilvusVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None:
if self._client.has_collection(self._collection_name):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {ids}',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
)
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
@ -130,29 +121,28 @@ class MilvusVector(BaseVector):
if not self._client.has_collection(self._collection_name):
return False
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] == "{id}"',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"]
)
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
results = self._client.search(collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get('top_k', 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result['entity'].get(Field.METADATA_KEY.value)
metadata['score'] = result['distance']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if result['distance'] > score_threshold:
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
metadata=metadata)
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
@ -161,11 +151,11 @@ class MilvusVector(BaseVector):
return []
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
# Grab the existing collection if it exists
@ -180,19 +170,11 @@ class MilvusVector(BaseVector):
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
)
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the primary key field
fields.append(
FieldSchema(
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
)
)
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
)
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create the schema for the collection
schema = CollectionSchema(fields)
@ -208,9 +190,12 @@ class MilvusVector(BaseVector):
# Create the collection
collection_name = self._collection_name
self._client.create_collection(collection_name=collection_name,
schema=schema, index_params=index_params_obj,
consistency_level=self._consistency_level)
self._client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
@ -221,13 +206,12 @@ class MilvusVector(BaseVector):
class MilvusVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
return MilvusVector(
collection_name=collection_name,
@ -237,5 +221,5 @@ class MilvusVectorFactory(AbstractVectorFactory):
user=dify_config.MILVUS_USER,
password=dify_config.MILVUS_PASSWORD,
database=dify_config.MILVUS_DATABASE,
)
),
)

View File

@ -31,7 +31,6 @@ class SortOrder(Enum):
class MyScaleVector(BaseVector):
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
super().__init__(collection_name)
self._config = config
@ -80,7 +79,7 @@ class MyScaleVector(BaseVector):
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {}
json.dumps(doc.metadata) if doc.metadata else {},
)
values.append(str(row))
ids.append(doc_id)
@ -101,7 +100,8 @@ class MyScaleVector(BaseVector):
def delete_by_ids(self, ids: list[str]) -> None:
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
)
def get_ids_by_metadata_field(self, key: str, value: str):
rows = self._client.query(
@ -122,9 +122,12 @@ class MyScaleVector(BaseVector):
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get('score_threshold') or 0.0
where_str = f"WHERE dist < {1 - score_threshold}" if \
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
score_threshold = kwargs.get("score_threshold") or 0.0
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
else ""
)
sql = f"""
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
@ -133,7 +136,7 @@ class MyScaleVector(BaseVector):
return [
Document(
page_content=r["text"],
vector=r['vector'],
vector=r["vector"],
metadata=r["metadata"],
)
for r in self._client.query(sql).named_results()
@ -149,13 +152,12 @@ class MyScaleVector(BaseVector):
class MyScaleVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
return MyScaleVector(
collection_name=collection_name,

View File

@ -28,11 +28,11 @@ class OpenSearchConfig(BaseModel):
password: Optional[str] = None
secure: bool = False
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values.get('host'):
if not values.get("host"):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get('port'):
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
return values
@ -44,19 +44,18 @@ class OpenSearchConfig(BaseModel):
def to_opensearch_params(self) -> dict[str, Any]:
params = {
'hosts': [{'host': self.host, 'port': self.port}],
'use_ssl': self.secure,
'verify_certs': self.secure,
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
}
if self.user and self.password:
params['http_auth'] = (self.user, self.password)
params["http_auth"] = (self.user, self.password)
if self.secure:
params['ssl_context'] = self.create_ssl_context()
params["ssl_context"] = self.create_ssl_context()
return params
class OpenSearchVector(BaseVector):
def __init__(self, collection_name: str, config: OpenSearchConfig):
super().__init__(collection_name)
self._client_config = config
@ -81,7 +80,7 @@ class OpenSearchVector(BaseVector):
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
},
}
actions.append(action)
@ -90,8 +89,8 @@ class OpenSearchVector(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
response = self._client.search(index=self._collection_name.lower(), body=query)
if response['hits']['hits']:
return [hit['_id'] for hit in response['hits']['hits']]
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
else:
return None
@ -110,7 +109,7 @@ class OpenSearchVector(BaseVector):
actual_ids = []
for doc_id in ids:
es_ids = self.get_ids_by_metadata_field('doc_id', doc_id)
es_ids = self.get_ids_by_metadata_field("doc_id", doc_id)
if es_ids:
actual_ids.extend(es_ids)
else:
@ -122,9 +121,9 @@ class OpenSearchVector(BaseVector):
helpers.bulk(self._client, actions)
except BulkIndexError as e:
for error in e.errors:
delete_error = error.get('delete', {})
status = delete_error.get('status')
doc_id = delete_error.get('_id')
delete_error = error.get("delete", {})
status = delete_error.get("status")
doc_id = delete_error.get("_id")
if status == 404:
logger.warning(f"Document not found for deletion: {doc_id}")
@ -151,15 +150,8 @@ class OpenSearchVector(BaseVector):
raise ValueError("All elements in query_vector should be floats")
query = {
"size": kwargs.get('top_k', 4),
"query": {
"knn": {
Field.VECTOR.value: {
Field.VECTOR.value: query_vector,
"k": kwargs.get('top_k', 4)
}
}
}
"size": kwargs.get("top_k", 4),
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
}
try:
@ -169,17 +161,17 @@ class OpenSearchVector(BaseVector):
raise
docs = []
for hit in response['hits']['hits']:
metadata = hit['_source'].get(Field.METADATA_KEY.value, {})
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
# Make sure metadata is a dictionary
if metadata is None:
metadata = {}
metadata['score'] = hit['_score']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if hit['_score'] > score_threshold:
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
metadata["score"] = hit["_score"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if hit["_score"] > score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
@ -190,32 +182,28 @@ class OpenSearchVector(BaseVector):
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
docs = []
for hit in response['hits']['hits']:
metadata = hit['_source'].get(Field.METADATA_KEY.value)
vector = hit['_source'].get(Field.VECTOR.value)
page_content = hit['_source'].get(Field.CONTENT_KEY.value)
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value)
vector = hit["_source"].get(Field.VECTOR.value)
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
return docs
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name.lower()}'
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}'
collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name.lower()} already exists.")
return
if not self._client.indices.exists(index=self._collection_name.lower()):
index_body = {
"settings": {
"index": {
"knn": True
}
},
"settings": {"index": {"knn": True}},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
@ -226,20 +214,17 @@ class OpenSearchVector(BaseVector):
"name": "hnsw",
"space_type": "l2",
"engine": "faiss",
"parameters": {
"ef_construction": 64,
"m": 8
}
}
"parameters": {"ef_construction": 64, "m": 8},
},
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
},
},
}
}
},
}
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
@ -248,17 +233,14 @@ class OpenSearchVector(BaseVector):
class OpenSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST,
@ -268,7 +250,4 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
secure=dify_config.OPENSEARCH_SECURE,
)
return OpenSearchVector(
collection_name=collection_name,
config=open_search_config
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)

View File

@ -31,7 +31,7 @@ class OracleVectorConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ORACLE_HOST is required")
@ -103,9 +103,16 @@ class OracleVector(BaseVector):
arraysize=cursor.arraysize,
outconverter=self.numpy_converter_out,
)
def _create_connection_pool(self, config: OracleVectorConfig):
return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1)
def _create_connection_pool(self, config: OracleVectorConfig):
return oracledb.create_pool(
user=config.user,
password=config.password,
dsn="{}:{}/{}".format(config.host, config.port, config.database),
min=1,
max=50,
increment=1,
)
@contextmanager
def _get_cursor(self):
@ -136,13 +143,15 @@ class OracleVector(BaseVector):
doc_id,
doc.page_content,
json.dumps(doc.metadata),
#array.array("f", embeddings[i]),
# array.array("f", embeddings[i]),
numpy.array(embeddings[i]),
)
)
#print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
with self._get_cursor() as cur:
cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values)
cur.executemany(
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
)
return pks
def text_exists(self, id: str) -> bool:
@ -157,7 +166,8 @@ class OracleVector(BaseVector):
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
#def get_ids_by_metadata_field(self, key: str, value: str):
# def get_ids_by_metadata_field(self, key: str, value: str):
# with self._get_cursor() as cur:
# cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" )
# idss = []
@ -184,7 +194,8 @@ class OracleVector(BaseVector):
top_k = kwargs.get("top_k", 5)
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)]
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only",
[numpy.array(query_vector)],
)
docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
@ -202,7 +213,7 @@ class OracleVector(BaseVector):
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile('[\u4e00-\u9fa5]+')
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
match = zh_pattern.search(query)
entities = []
# match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split
@ -210,7 +221,15 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名
if (
pos == "nr"
or pos == "Ng"
or pos == "eng"
or pos == "nz"
or pos == "n"
or pos == "ORG"
or pos == "v"
): # nr: 人名, ns: 地名, nt: 机构名
current_entity += word
else:
if current_entity:
@ -220,22 +239,22 @@ class OracleVector(BaseVector):
entities.append(current_entity)
else:
try:
nltk.data.find('tokenizers/punkt')
nltk.data.find('corpora/stopwords')
nltk.data.find("tokenizers/punkt")
nltk.data.find("corpora/stopwords")
except LookupError:
nltk.download('punkt')
nltk.download('stopwords')
nltk.download("punkt")
nltk.download("stopwords")
print("run download")
e_str = re.sub(r'[^\w ]', '', query)
e_str = re.sub(r"[^\w ]", "", query)
all_tokens = nltk.word_tokenize(e_str)
stop_words = stopwords.words('english')
stop_words = stopwords.words("english")
for token in all_tokens:
if token not in stop_words:
entities.append(token)
with self._get_cursor() as cur:
cur.execute(
f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
[" ACCUM ".join(entities)]
[" ACCUM ".join(entities)],
)
docs = []
for record in cur:
@ -273,8 +292,7 @@ class OracleVectorFactory(AbstractVectorFactory):
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
return OracleVector(
collection_name=collection_name,

View File

@ -31,27 +31,28 @@ class PgvectoRSConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config PGVECTO_RS_HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config PGVECTO_RS_PORT is required")
if not values['user']:
if not values["user"]:
raise ValueError("config PGVECTO_RS_USER is required")
if not values['password']:
if not values["password"]:
raise ValueError("config PGVECTO_RS_PASSWORD is required")
if not values['database']:
if not values["database"]:
raise ValueError("config PGVECTO_RS_DATABASE is required")
return values
class PGVectoRS(BaseVector):
def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
super().__init__(collection_name)
self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._url = (
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self._client = create_engine(self._url)
with Session(self._client) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
@ -80,9 +81,9 @@ class PGVectoRS(BaseVector):
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
@ -133,9 +134,7 @@ class PGVectoRS(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; "
)
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ")
result = session.execute(select_statement).fetchall()
if result:
return [item[0] for item in result]
@ -143,12 +142,11 @@ class PGVectoRS(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.execute(select_statement, {"ids": ids})
session.commit()
def delete_by_ids(self, ids: list[str]) -> None:
@ -156,13 +154,13 @@ class PGVectoRS(BaseVector):
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
)
result = session.execute(select_statement, {'doc_ids': ids}).fetchall()
result = session.execute(select_statement, {"doc_ids": ids}).fetchall()
if result:
ids = [item[0] for item in result]
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.execute(select_statement, {"ids": ids})
session.commit()
def delete(self) -> None:
@ -187,7 +185,7 @@ class PGVectoRS(BaseVector):
query_vector,
).label("distance"),
)
.limit(kwargs.get('top_k', 2))
.limit(kwargs.get("top_k", 2))
.order_by("distance")
)
res = session.execute(stmt)
@ -198,11 +196,10 @@ class PGVectoRS(BaseVector):
for record, dis in results:
metadata = record.meta
score = 1 - dis
metadata['score'] = score
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
metadata["score"] = score
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if score > score_threshold:
doc = Document(page_content=record.text,
metadata=metadata)
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)
return docs
@ -225,13 +222,12 @@ class PGVectoRS(BaseVector):
class PGVectoRSFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dim = len(embeddings.embed_query("pgvecto_rs"))
return PGVectoRS(
@ -243,5 +239,5 @@ class PGVectoRSFactory(AbstractVectorFactory):
password=dify_config.PGVECTO_RS_PASSWORD,
database=dify_config.PGVECTO_RS_DATABASE,
),
dim=dim
dim=dim,
)

View File

@ -24,7 +24,7 @@ class PGVectorConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required")
@ -201,8 +201,7 @@ class PGVectorFactory(AbstractVectorFactory):
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
return PGVector(
collection_name=collection_name,

View File

@ -48,28 +48,25 @@ class QdrantConfig(BaseModel):
prefer_grpc: bool = False
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
return {"path": path}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout,
'verify': self.endpoint.startswith('https'),
'grpc_port': self.grpc_port,
'prefer_grpc': self.prefer_grpc
"url": self.endpoint,
"api_key": self.api_key,
"timeout": self.timeout,
"verify": self.endpoint.startswith("https"),
"grpc_port": self.grpc_port,
"prefer_grpc": self.prefer_grpc,
}
class QdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
@ -80,10 +77,7 @@ class QdrantVector(BaseVector):
return VectorType.QDRANT
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
@ -97,9 +91,9 @@ class QdrantVector(BaseVector):
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
lock_name = "vector_indexing_lock_{}".format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
collection_name = collection_name or uuid.uuid4().hex
@ -110,12 +104,19 @@ class QdrantVector(BaseVector):
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False)
hnsw_config = HnswConfigDiff(
m=0,
payload_m=16,
ef_construct=100,
full_scan_threshold=10000,
max_indexing_threads=0,
on_disk=False,
)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
@ -124,21 +125,24 @@ class QdrantVector(BaseVector):
)
# create group_id payload index
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
field_schema=PayloadSchemaType.KEYWORD)
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(collection_name, Field.DOC_ID.value,
field_schema=PayloadSchemaType.KEYWORD)
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -147,26 +151,23 @@ class QdrantVector(BaseVector):
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, metadatas, uuids, 64, self._group_id
):
self._client.upsert(
collection_name=self._collection_name, points=points
)
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
@ -203,13 +204,13 @@ class QdrantVector(BaseVector):
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str,
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
@ -219,18 +220,11 @@ class QdrantVector(BaseVector):
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append(
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
@ -248,9 +242,7 @@ class QdrantVector(BaseVector):
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
@ -275,9 +267,7 @@ class QdrantVector(BaseVector):
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
@ -288,7 +278,6 @@ class QdrantVector(BaseVector):
raise e
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
@ -304,9 +293,7 @@ class QdrantVector(BaseVector):
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
@ -324,15 +311,13 @@ class QdrantVector(BaseVector):
all_collection_name.append(collection.name)
if self._collection_name not in all_collection_name:
return False
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[id]
)
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
@ -348,22 +333,22 @@ class QdrantVector(BaseVector):
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", .0)
score_threshold=kwargs.get("score_threshold", 0.0),
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
if result.score > score_threshold:
metadata['score'] = result.score
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@ -372,6 +357,7 @@ class QdrantVector(BaseVector):
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = models.Filter(
must=[
models.FieldCondition(
@ -381,24 +367,21 @@ class QdrantVector(BaseVector):
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
),
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get('top_k', 2),
limit=kwargs.get("top_k", 2),
with_payload=True,
with_vectors=True
with_vectors=True,
)
results = response[0]
documents = []
for result in results:
if result:
document = self._document_from_scored_point(
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
)
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
documents.append(document)
return documents
@ -410,10 +393,10 @@ class QdrantVector(BaseVector):
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
@ -425,24 +408,25 @@ class QdrantVector(BaseVector):
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
raise ValueError("Dataset Collection Bindings is not exist!")
else:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
if not dataset.index_struct_dict:
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
config = current_app.config
return QdrantVector(
@ -454,6 +438,6 @@ class QdrantVectorFactory(AbstractVectorFactory):
root_path=config.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
)
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
),
)

View File

@ -33,28 +33,29 @@ class RelytConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config RELYT_HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config RELYT_PORT is required")
if not values['user']:
if not values["user"]:
raise ValueError("config RELYT_USER is required")
if not values['password']:
if not values["password"]:
raise ValueError("config RELYT_PASSWORD is required")
if not values['database']:
if not values["database"]:
raise ValueError("config RELYT_DATABASE is required")
return values
class RelytVector(BaseVector):
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
super().__init__(collection_name)
self.embedding_dimension = 1536
self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._url = (
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self.client = create_engine(self._url)
self._fields = []
self._group_id = group_id
@ -70,9 +71,9 @@ class RelytVector(BaseVector):
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
@ -110,7 +111,7 @@ class RelytVector(BaseVector):
ids = [str(uuid.uuid1()) for _ in documents]
metadatas = [d.metadata for d in documents]
for metadata in metadatas:
metadata['group_id'] = self._group_id
metadata["group_id"] = self._group_id
texts = [d.page_content for d in documents]
# Define the table schema
@ -127,9 +128,7 @@ class RelytVector(BaseVector):
chunks_table_data = []
with self.client.connect() as conn:
with conn.begin():
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
chunks_table_data.append(
{
"id": chunk_id,
@ -196,15 +195,13 @@ class RelytVector(BaseVector):
return False
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_uuids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
with Session(self.client) as session:
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
)
@ -228,38 +225,34 @@ class RelytVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
results = self.similarity_search_with_score_by_vector(
k=int(kwargs.get('top_k')),
embedding=query_vector,
filter=kwargs.get('filter')
k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter")
)
# Organize results.
docs = []
for document, score in results:
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
if 1 - score > score_threshold:
docs.append(document)
return docs
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError(
"Could not import Row from sqlalchemy.engine. "
"Please 'pip install sqlalchemy>=1.4'."
)
raise ImportError("Could not import Row from sqlalchemy.engine. " "Please 'pip install sqlalchemy>=1.4'.")
filter_condition = ""
if filter is not None:
conditions = [
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
if len(value) > 1
else f"metadata->>{key!r} = {value[0]!r}"
for key, value in filter.items()
]
@ -305,13 +298,12 @@ class RelytVector(BaseVector):
class RelytVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name))
return RelytVector(
collection_name=collection_name,
@ -322,5 +314,5 @@ class RelytVectorFactory(AbstractVectorFactory):
password=dify_config.RELYT_PASSWORD,
database=dify_config.RELYT_DATABASE,
),
group_id=dataset.id
group_id=dataset.id,
)

View File

@ -25,16 +25,11 @@ class TencentConfig(BaseModel):
database: Optional[str]
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = 1,
replicas: int = 2,
shard: int = (1,)
replicas: int = (2,)
def to_tencent_params(self):
return {
'url': self.url,
'username': self.username,
'key': self.api_key,
'timeout': self.timeout
}
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
class TencentVector(BaseVector):
@ -61,13 +56,10 @@ class TencentVector(BaseVector):
return self._client.create_database(database_name=self._client_config.database)
def get_type(self) -> str:
return 'tencent'
return "tencent"
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def _has_collection(self) -> bool:
collections = self._db.list_collections()
@ -77,9 +69,9 @@ class TencentVector(BaseVector):
return False
def _create_collection(self, dimension: int) -> None:
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
@ -101,9 +93,7 @@ class TencentVector(BaseVector):
raise ValueError("unsupported metric_type")
params = vdb_index.HNSWParams(m=16, efconstruction=200)
index = vdb_index.Index(
vdb_index.FilterIndex(
self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
),
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
vdb_index.VectorIndex(
self.field_vector,
dimension,
@ -111,12 +101,8 @@ class TencentVector(BaseVector):
metric_type,
params,
),
vdb_index.FilterIndex(
self.field_text, enum.FieldType.String, enum.IndexType.FILTER
),
vdb_index.FilterIndex(
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
),
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
)
self._db.create_collection(
@ -163,15 +149,14 @@ class TencentVector(BaseVector):
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
res = self._db.collection(self._collection_name).search(vectors=[query_vector],
params=document.HNSWSearchParams(
ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get('top_k', 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
res = self._db.collection(self._collection_name).search(
vectors=[query_vector],
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@ -200,15 +185,13 @@ class TencentVector(BaseVector):
class TencentVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
return TencentVector(
collection_name=collection_name,
@ -220,5 +203,5 @@ class TencentVectorFactory(AbstractVectorFactory):
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
)
),
)

View File

@ -28,47 +28,56 @@ class TiDBVectorConfig(BaseModel):
database: str
program_name: str
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config TIDB_VECTOR_HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config TIDB_VECTOR_PORT is required")
if not values['user']:
if not values["user"]:
raise ValueError("config TIDB_VECTOR_USER is required")
if not values['password']:
if not values["password"]:
raise ValueError("config TIDB_VECTOR_PASSWORD is required")
if not values['database']:
if not values["database"]:
raise ValueError("config TIDB_VECTOR_DATABASE is required")
if not values['program_name']:
if not values["program_name"]:
raise ValueError("config APPLICATION_NAME is required")
return values
class TiDBVector(BaseVector):
def get_type(self) -> str:
return VectorType.TIDB_VECTOR
def _table(self, dim: int) -> Table:
from tidb_vector.sqlalchemy import VectorType
return Table(
self._collection_name,
self._orm_base.metadata,
Column('id', String(36), primary_key=True, nullable=False),
Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
Column("id", String(36), primary_key=True, nullable=False),
Column(
"vector",
VectorType(dim),
nullable=False,
comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
),
Column("text", TEXT, nullable=False),
Column("meta", JSON, nullable=False),
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
extend_existing=True
Column(
"update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
),
extend_existing=True,
)
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"):
super().__init__(collection_name)
self._client_config = config
self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}")
self._url = (
f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}"
)
self._distance_func = distance_func.lower()
self._engine = create_engine(self._url)
self._orm_base = declarative_base()
@ -83,9 +92,9 @@ class TiDBVector(BaseVector):
def _create_collection(self, dimension: int):
logger.info("_create_collection, collection_name " + self._collection_name)
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
with Session(self._engine) as session:
@ -116,9 +125,7 @@ class TiDBVector(BaseVector):
chunks_table_data = []
with self._engine.connect() as conn:
with conn.begin():
for id, text, meta, embedding in zip(
ids, texts, metas, embeddings
):
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
# Execute the batch insert when the batch size is reached
@ -133,12 +140,12 @@ class TiDBVector(BaseVector):
return ids
def text_exists(self, id: str) -> bool:
result = self.get_ids_by_metadata_field('doc_id', id)
result = self.get_ids_by_metadata_field("doc_id", id)
return bool(result)
def delete_by_ids(self, ids: list[str]) -> None:
with Session(self._engine) as session:
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
)
@ -180,20 +187,22 @@ class TiDBVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
filter = kwargs.get('filter')
filter = kwargs.get("filter")
distance = 1 - score_threshold
query_vector_str = ", ".join(format(x) for x in query_vector)
query_vector_str = "[" + query_vector_str + "]"
logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
logger.debug(
f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}"
)
docs = []
if self._distance_func == 'l2':
tidb_func = 'Vec_l2_distance'
elif self._distance_func == 'cosine':
tidb_func = 'Vec_Cosine_distance'
if self._distance_func == "l2":
tidb_func = "Vec_l2_distance"
elif self._distance_func == "cosine":
tidb_func = "Vec_Cosine_distance"
else:
tidb_func = 'Vec_Cosine_distance'
tidb_func = "Vec_Cosine_distance"
with Session(self._engine) as session:
select_statement = sql_text(
@ -208,7 +217,7 @@ class TiDBVector(BaseVector):
results = [(row[0], row[1], row[2]) for row in res]
for meta, text, distance in results:
metadata = json.loads(meta)
metadata['score'] = 1 - distance
metadata["score"] = 1 - distance
docs.append(Document(page_content=text, metadata=metadata))
return docs
@ -224,15 +233,13 @@ class TiDBVector(BaseVector):
class TiDBVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
return TiDBVector(
collection_name=collection_name,

View File

@ -7,7 +7,6 @@ from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name
@ -39,18 +38,11 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def search_by_vector(
self,
query_vector: list[float],
**kwargs: Any
) -> list[Document]:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
raise NotImplementedError
@abstractmethod
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
raise NotImplementedError
def delete(self) -> None:
@ -58,7 +50,7 @@ class BaseVector(ABC):
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
doc_id = text.metadata['doc_id']
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
@ -66,7 +58,7 @@ class BaseVector(ABC):
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
return [text.metadata["doc_id"] for text in texts]
@property
def collection_name(self):

View File

@ -20,17 +20,14 @@ class AbstractVectorFactory(ABC):
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
index_struct_dict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
return index_struct_dict
class Vector:
def __init__(self, dataset: Dataset, attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash', 'page']
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "page"]
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
@ -39,7 +36,7 @@ class Vector:
def _init_vector(self) -> BaseVector:
vector_type = dify_config.VECTOR_STORE
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
vector_type = self._dataset.index_struct_dict["type"]
if not vector_type:
raise ValueError("Vector store must be specified.")
@ -52,45 +49,59 @@ class Vector:
match vector_type:
case VectorType.CHROMA:
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
return ChromaVectorFactory
case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
return PGVectoRSFactory
case VectorType.QDRANT:
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
return QdrantVectorFactory
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.ORACLE:
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
return OracleVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
case VectorType.ANALYTICDB:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
@ -98,22 +109,14 @@ class Vector:
def create(self, texts: list = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(
texts=texts,
embeddings=embeddings,
**kwargs
)
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get('duplicate_check', False):
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.create(
texts=documents,
embeddings=embeddings,
**kwargs
)
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
@ -124,24 +127,18 @@ class Vector:
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self) -> None:
self._vector_processor.delete()
# delete collection redis cache
if self._vector_processor.collection_name:
collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name)
redis_client.delete(collection_exist_cache_key)
def _get_embeddings(self) -> Embeddings:
@ -151,14 +148,13 @@ class Vector:
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model
model=self._dataset.embedding_model,
)
return CacheEmbedding(embedding_model)
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
doc_id = text.metadata['doc_id']
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)

View File

@ -2,17 +2,17 @@ from enum import Enum
class VectorType(str, Enum):
ANALYTICDB = 'analyticdb'
CHROMA = 'chroma'
MILVUS = 'milvus'
MYSCALE = 'myscale'
PGVECTOR = 'pgvector'
PGVECTO_RS = 'pgvecto-rs'
QDRANT = 'qdrant'
RELYT = 'relyt'
TIDB_VECTOR = 'tidb_vector'
WEAVIATE = 'weaviate'
OPENSEARCH = 'opensearch'
TENCENT = 'tencent'
ORACLE = 'oracle'
ELASTICSEARCH = 'elasticsearch'
ANALYTICDB = "analyticdb"
CHROMA = "chroma"
MILVUS = "milvus"
MYSCALE = "myscale"
PGVECTOR = "pgvector"
PGVECTO_RS = "pgvecto-rs"
QDRANT = "qdrant"
RELYT = "relyt"
TIDB_VECTOR = "tidb_vector"
WEAVIATE = "weaviate"
OPENSEARCH = "opensearch"
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"

View File

@ -22,15 +22,14 @@ class WeaviateConfig(BaseModel):
api_key: Optional[str] = None
batch_size: int = 100
@model_validator(mode='before')
@model_validator(mode="before")
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
super().__init__(collection_name)
self._client = self._init_client(config)
@ -43,10 +42,7 @@ class WeaviateVector(BaseVector):
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
@ -68,10 +64,10 @@ class WeaviateVector(BaseVector):
def get_collection_name(self, dataset: Dataset) -> str:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
if not class_prefix.endswith("_Node"):
# original class_prefix
class_prefix += '_Node'
class_prefix += "_Node"
return class_prefix
@ -79,10 +75,7 @@ class WeaviateVector(BaseVector):
return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
# create collection
@ -91,9 +84,9 @@ class WeaviateVector(BaseVector):
self.add_texts(texts, embeddings)
def _create_collection(self):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
schema = self._default_schema(self._collection_name)
@ -129,17 +122,9 @@ class WeaviateVector(BaseVector):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal")
def delete(self):
# check whether the index already exists
@ -154,11 +139,19 @@ class WeaviateVector(BaseVector):
# check whether the index already exists
if not self._client.schema.contains(schema):
return False
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}).with_limit(1).do()
result = (
self._client.query.get(collection_name)
.with_additional(["id"])
.with_where(
{
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}
)
.with_limit(1)
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
@ -211,13 +204,13 @@ class WeaviateVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
# check score threshold
if score > score_threshold:
doc.metadata['score'] = score
doc.metadata["score"] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@ -240,15 +233,15 @@ class WeaviateVector(BaseVector):
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
query_obj = query_obj.with_additional(["vector"])
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
additional = res.pop('_additional')
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
additional = res.pop("_additional")
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
return docs
def _default_schema(self, index_name: str) -> dict:
@ -271,20 +264,19 @@ class WeaviateVector(BaseVector):
class WeaviateVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT,
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),
attributes=attributes
attributes=attributes,
)

View File

@ -12,10 +12,10 @@ from models.dataset import Dataset, DocumentSegment
class DatasetDocumentStore:
def __init__(
self,
dataset: Dataset,
user_id: str,
document_id: Optional[str] = None,
self,
dataset: Dataset,
user_id: str,
document_id: Optional[str] = None,
):
self._dataset = dataset
self._user_id = user_id
@ -41,9 +41,9 @@ class DatasetDocumentStore:
@property
def docs(self) -> dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id
).all()
document_segments = (
db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all()
)
output = {}
for document_segment in document_segments:
@ -55,48 +55,45 @@ class DatasetDocumentStore:
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
},
)
return output
def add_documents(
self, docs: Sequence[Document], allow_update: bool = True
) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document_id == self._document_id
).scalar()
def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None:
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id)
.scalar()
)
if max_position is None:
max_position = 0
embedding_model = None
if self._dataset.indexing_technique == 'high_quality':
if self._dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model
model=self._dataset.embedding_model,
)
for doc in docs:
if not isinstance(doc, Document):
raise ValueError("doc must be a Document")
segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id'])
segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"])
# NOTE: doc could already exist in the store, but we overwrite it
if not allow_update and segment_document:
raise ValueError(
f"doc_id {doc.metadata['doc_id']} already exists. "
"Set allow_update to True to overwrite."
f"doc_id {doc.metadata['doc_id']} already exists. " "Set allow_update to True to overwrite."
)
# calc embedding use tokens
if embedding_model:
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[doc.page_content]
)
tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content])
else:
tokens = 0
@ -107,8 +104,8 @@ class DatasetDocumentStore:
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
index_node_id=doc.metadata['doc_id'],
index_node_hash=doc.metadata['doc_hash'],
index_node_id=doc.metadata["doc_id"],
index_node_hash=doc.metadata["doc_hash"],
position=max_position,
content=doc.page_content,
word_count=len(doc.page_content),
@ -116,15 +113,15 @@ class DatasetDocumentStore:
enabled=False,
created_by=self._user_id,
)
if doc.metadata.get('answer'):
segment_document.answer = doc.metadata.pop('answer', '')
if doc.metadata.get("answer"):
segment_document.answer = doc.metadata.pop("answer", "")
db.session.add(segment_document)
else:
segment_document.content = doc.page_content
if doc.metadata.get('answer'):
segment_document.answer = doc.metadata.pop('answer', '')
segment_document.index_node_hash = doc.metadata['doc_hash']
if doc.metadata.get("answer"):
segment_document.answer = doc.metadata.pop("answer", "")
segment_document.index_node_hash = doc.metadata["doc_hash"]
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
@ -135,9 +132,7 @@ class DatasetDocumentStore:
result = self.get_document_segment(doc_id)
return result is not None
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[Document]:
def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]:
document_segment = self.get_document_segment(doc_id)
if document_segment is None:
@ -153,7 +148,7 @@ class DatasetDocumentStore:
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
},
)
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
@ -188,9 +183,10 @@ class DatasetDocumentStore:
return document_segment.index_node_hash
def get_document_segment(self, doc_id: str) -> DocumentSegment:
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id,
DocumentSegment.index_node_id == doc_id
).first()
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first()
)
return document_segment

View File

@ -4,6 +4,7 @@ The goal is to facilitate decoupling of content loading from content parsing cod
In addition, content loading code should provide a lazy loading interface by default.
"""
from __future__ import annotations
import contextlib

View File

@ -1,4 +1,5 @@
"""Abstract interface for document loader implementations."""
import csv
from typing import Optional
@ -18,12 +19,12 @@ class CSVExtractor(BaseExtractor):
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None,
csv_args: Optional[dict] = None,
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None,
csv_args: Optional[dict] = None,
):
"""Initialize with file path."""
self._file_path = file_path
@ -57,7 +58,7 @@ class CSVExtractor(BaseExtractor):
docs = []
try:
# load csv file into pandas dataframe
df = pd.read_csv(csvfile, on_bad_lines='skip', **self.csv_args)
df = pd.read_csv(csvfile, on_bad_lines="skip", **self.csv_args)
# check source column exists
if self.source_column and self.source_column not in df.columns:
@ -67,7 +68,7 @@ class CSVExtractor(BaseExtractor):
for i, row in df.iterrows():
content = ";".join(f"{col.strip()}: {str(row[col]).strip()}" for col in df.columns)
source = row[self.source_column] if self.source_column else ''
source = row[self.source_column] if self.source_column else ""
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)

View File

@ -10,6 +10,7 @@ class NotionInfo(BaseModel):
"""
Notion import info.
"""
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str
@ -25,6 +26,7 @@ class WebsiteInfo(BaseModel):
"""
website import info.
"""
provider: str
job_id: str
url: str
@ -43,6 +45,7 @@ class ExtractSetting(BaseModel):
"""
Model class for provider response.
"""
datasource_type: str
upload_file: Optional[UploadFile] = None
notion_info: Optional[NotionInfo] = None

View File

@ -1,4 +1,5 @@
"""Abstract interface for document loader implementations."""
import os
from typing import Optional
@ -17,23 +18,18 @@ class ExcelExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False
):
def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def extract(self) -> list[Document]:
""" Load from Excel file in xls or xlsx format using Pandas and openpyxl."""
"""Load from Excel file in xls or xlsx format using Pandas and openpyxl."""
documents = []
file_extension = os.path.splitext(self._file_path)[-1].lower()
if file_extension == '.xlsx':
if file_extension == ".xlsx":
wb = load_workbook(self._file_path, data_only=True)
for sheet_name in wb.sheetnames:
sheet = wb[sheet_name]
@ -44,35 +40,38 @@ class ExcelExtractor(BaseExtractor):
continue
df = pd.DataFrame(data, columns=cols)
df.dropna(how='all', inplace=True)
df.dropna(how="all", inplace=True)
for index, row in df.iterrows():
page_content = []
for col_index, (k, v) in enumerate(row.items()):
if pd.notna(v):
cell = sheet.cell(row=index + 2,
column=col_index + 1) # +2 to account for header and 1-based index
cell = sheet.cell(
row=index + 2, column=col_index + 1
) # +2 to account for header and 1-based index
if cell.hyperlink:
value = f"[{v}]({cell.hyperlink.target})"
page_content.append(f'"{k}":"{value}"')
else:
page_content.append(f'"{k}":"{v}"')
documents.append(Document(page_content=';'.join(page_content),
metadata={'source': self._file_path}))
documents.append(
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
)
elif file_extension == '.xls':
excel_file = pd.ExcelFile(self._file_path, engine='xlrd')
elif file_extension == ".xls":
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
for sheet_name in excel_file.sheet_names:
df = excel_file.parse(sheet_name=sheet_name)
df.dropna(how='all', inplace=True)
df.dropna(how="all", inplace=True)
for _, row in df.iterrows():
page_content = []
for k, v in row.items():
if pd.notna(v):
page_content.append(f'"{k}":"{v}"')
documents.append(Document(page_content=';'.join(page_content),
metadata={'source': self._file_path}))
documents.append(
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
)
else:
raise ValueError(f"Unsupported file extension: {file_extension}")

View File

@ -29,61 +29,60 @@ from core.rag.models.document import Document
from extensions.ext_storage import storage
from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain', 'application/json']
SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"]
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class ExtractProcessor:
@classmethod
def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
-> Union[list[Document], str]:
def load_from_upload_file(
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
) -> Union[list[Document], str]:
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=upload_file,
document_model='text_model'
datasource_type="upload_file", upload_file=upload_file, document_model="text_model"
)
if return_text:
delimiter = '\n'
delimiter = "\n"
return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
else:
return cls.extract(extract_setting, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = ssrf_proxy.get(url, headers={
"User-Agent": USER_AGENT
})
response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
if not suffix and suffix != '.':
if not suffix and suffix != ".":
# get content-type
if response.headers.get('Content-Type'):
suffix = '.' + response.headers.get('Content-Type').split('/')[-1]
if response.headers.get("Content-Type"):
suffix = "." + response.headers.get("Content-Type").split("/")[-1]
else:
content_disposition = response.headers.get('Content-Disposition')
content_disposition = response.headers.get("Content-Disposition")
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
suffix = '.' + re.search(r'\.(\w+)$', filename).group(1)
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
with open(file_path, "wb") as file:
file.write(response.content)
extract_setting = ExtractSetting(
datasource_type="upload_file",
document_model='text_model'
)
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
if return_text:
delimiter = '\n'
return delimiter.join([document.page_content for document in cls.extract(
extract_setting=extract_setting, file_path=file_path)])
delimiter = "\n"
return delimiter.join(
[
document.page_content
for document in cls.extract(extract_setting=extract_setting, file_path=file_path)
]
)
else:
return cls.extract(extract_setting=extract_setting, file_path=file_path)
@classmethod
def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
file_path: str = None) -> list[Document]:
def extract(
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE.value:
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
@ -96,50 +95,56 @@ class ExtractProcessor:
etl_type = dify_config.ETL_TYPE
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
if etl_type == 'Unstructured':
if file_extension == '.xlsx' or file_extension == '.xls':
if etl_type == "Unstructured":
if file_extension == ".xlsx" or file_extension == ".xls":
extractor = ExcelExtractor(file_path)
elif file_extension == '.pdf':
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
elif file_extension in ['.md', '.markdown']:
extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
elif file_extension in [".md", ".markdown"]:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
if is_automatic
else MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
)
elif file_extension in [".htm", ".html"]:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
elif file_extension in [".docx"]:
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == '.csv':
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
elif file_extension == ".msg":
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
elif file_extension == '.eml':
elif file_extension == ".eml":
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
elif file_extension == '.ppt':
elif file_extension == ".ppt":
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key)
elif file_extension == '.pptx':
elif file_extension == ".pptx":
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
elif file_extension == '.xml':
elif file_extension == ".xml":
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
elif file_extension == 'epub':
elif file_extension == "epub":
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url)
else:
# txt
extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
extractor = (
UnstructuredTextExtractor(file_path, unstructured_api_url)
if is_automatic
else TextExtractor(file_path, autodetect_encoding=True)
)
else:
if file_extension == '.xlsx' or file_extension == '.xls':
if file_extension == ".xlsx" or file_extension == ".xls":
extractor = ExcelExtractor(file_path)
elif file_extension == '.pdf':
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
elif file_extension in ['.md', '.markdown']:
elif file_extension in [".md", ".markdown"]:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
elif file_extension in [".htm", ".html"]:
extractor = HtmlExtractor(file_path)
elif file_extension in ['.docx']:
elif file_extension in [".docx"]:
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == '.csv':
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == 'epub':
elif file_extension == "epub":
extractor = UnstructuredEpubExtractor(file_path)
else:
# txt
@ -155,13 +160,13 @@ class ExtractProcessor:
)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
if extract_setting.website_info.provider == 'firecrawl':
if extract_setting.website_info.provider == "firecrawl":
extractor = FirecrawlWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
tenant_id=extract_setting.website_info.tenant_id,
mode=extract_setting.website_info.mode,
only_main_content=extract_setting.website_info.only_main_content
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
else:

View File

@ -1,12 +1,11 @@
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
class BaseExtractor(ABC):
"""Interface for extract files.
"""
"""Interface for extract files."""
@abstractmethod
def extract(self):
raise NotImplementedError

View File

@ -9,108 +9,98 @@ from extensions.ext_storage import storage
class FirecrawlApp:
def __init__(self, api_key=None, base_url=None):
self.api_key = api_key
self.base_url = base_url or 'https://api.firecrawl.dev'
if self.api_key is None and self.base_url == 'https://api.firecrawl.dev':
raise ValueError('No API key provided')
self.base_url = base_url or "https://api.firecrawl.dev"
if self.api_key is None and self.base_url == "https://api.firecrawl.dev":
raise ValueError("No API key provided")
def scrape_url(self, url, params=None) -> dict:
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
json_data = {'url': url}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
json_data = {"url": url}
if params:
json_data.update(params)
response = requests.post(
f'{self.base_url}/v0/scrape',
headers=headers,
json=json_data
)
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
if response.status_code == 200:
response = response.json()
if response['success'] == True:
data = response['data']
if response["success"] == True:
data = response["data"]
return {
'title': data.get('metadata').get('title'),
'description': data.get('metadata').get('description'),
'source_url': data.get('metadata').get('sourceURL'),
'markdown': data.get('markdown')
"title": data.get("metadata").get("title"),
"description": data.get("metadata").get("description"),
"source_url": data.get("metadata").get("sourceURL"),
"markdown": data.get("markdown"),
}
else:
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
elif response.status_code in [402, 409, 500]:
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}')
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
else:
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
def crawl_url(self, url, params=None) -> str:
headers = self._prepare_headers()
json_data = {'url': url}
json_data = {"url": url}
if params:
json_data.update(params)
response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers)
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
if response.status_code == 200:
job_id = response.json().get('jobId')
job_id = response.json().get("jobId")
return job_id
else:
self._handle_error(response, 'start crawl job')
self._handle_error(response, "start crawl job")
def check_crawl_status(self, job_id) -> dict:
headers = self._prepare_headers()
response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers)
response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers)
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get('status') == 'completed':
total = crawl_status_response.get('total', 0)
if crawl_status_response.get("status") == "completed":
total = crawl_status_response.get("total", 0)
if total == 0:
raise Exception('Failed to check crawl status. Error: No page found')
data = crawl_status_response.get('data', [])
raise Exception("Failed to check crawl status. Error: No page found")
data = crawl_status_response.get("data", [])
url_data_list = []
for item in data:
if isinstance(item, dict) and 'metadata' in item and 'markdown' in item:
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
url_data = {
'title': item.get('metadata').get('title'),
'description': item.get('metadata').get('description'),
'source_url': item.get('metadata').get('sourceURL'),
'markdown': item.get('markdown')
"title": item.get("metadata").get("title"),
"description": item.get("metadata").get("description"),
"source_url": item.get("metadata").get("sourceURL"),
"markdown": item.get("markdown"),
}
url_data_list.append(url_data)
if url_data_list:
file_key = 'website_files/' + job_id + '.txt'
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(url_data_list).encode('utf-8'))
storage.save(file_key, json.dumps(url_data_list).encode("utf-8"))
return {
'status': 'completed',
'total': crawl_status_response.get('total'),
'current': crawl_status_response.get('current'),
'data': url_data_list
"status": "completed",
"total": crawl_status_response.get("total"),
"current": crawl_status_response.get("current"),
"data": url_data_list,
}
else:
return {
'status': crawl_status_response.get('status'),
'total': crawl_status_response.get('total'),
'current': crawl_status_response.get('current'),
'data': []
"status": crawl_status_response.get("status"),
"total": crawl_status_response.get("total"),
"current": crawl_status_response.get("current"),
"data": [],
}
else:
self._handle_error(response, 'check crawl status')
self._handle_error(response, "check crawl status")
def _prepare_headers(self):
return {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
for attempt in range(retries):
response = requests.post(url, headers=headers, json=data)
if response.status_code == 502:
time.sleep(backoff_factor * (2 ** attempt))
time.sleep(backoff_factor * (2**attempt))
else:
return response
return response
@ -119,13 +109,11 @@ class FirecrawlApp:
for attempt in range(retries):
response = requests.get(url, headers=headers)
if response.status_code == 502:
time.sleep(backoff_factor * (2 ** attempt))
time.sleep(backoff_factor * (2**attempt))
else:
return response
return response
def _handle_error(self, response, action):
error_message = response.json().get('error', 'Unknown error occurred')
raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}')
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")

View File

@ -5,7 +5,7 @@ from services.website_service import WebsiteService
class FirecrawlWebExtractor(BaseExtractor):
"""
Crawl and scrape websites and return content in clean llm-ready markdown.
Crawl and scrape websites and return content in clean llm-ready markdown.
Args:
@ -15,14 +15,7 @@ class FirecrawlWebExtractor(BaseExtractor):
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
"""
def __init__(
self,
url: str,
job_id: str,
tenant_id: str,
mode: str = 'crawl',
only_main_content: bool = False
):
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False):
"""Initialize with url, api_key, base_url and mode."""
self._url = url
self.job_id = job_id
@ -33,28 +26,31 @@ class FirecrawlWebExtractor(BaseExtractor):
def extract(self) -> list[Document]:
"""Extract content from the URL."""
documents = []
if self.mode == 'crawl':
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id)
if self.mode == "crawl":
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id)
if crawl_data is None:
return []
document = Document(page_content=crawl_data.get('markdown', ''),
metadata={
'source_url': crawl_data.get('source_url'),
'description': crawl_data.get('description'),
'title': crawl_data.get('title')
}
)
document = Document(
page_content=crawl_data.get("markdown", ""),
metadata={
"source_url": crawl_data.get("source_url"),
"description": crawl_data.get("description"),
"title": crawl_data.get("title"),
},
)
documents.append(document)
elif self.mode == 'scrape':
scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id,
self.only_main_content)
elif self.mode == "scrape":
scrape_data = WebsiteService.get_scrape_url_data(
"firecrawl", self._url, self.tenant_id, self.only_main_content
)
document = Document(page_content=scrape_data.get('markdown', ''),
metadata={
'source_url': scrape_data.get('source_url'),
'description': scrape_data.get('description'),
'title': scrape_data.get('title')
}
)
document = Document(
page_content=scrape_data.get("markdown", ""),
metadata={
"source_url": scrape_data.get("source_url"),
"description": scrape_data.get("description"),
"title": scrape_data.get("title"),
},
)
documents.append(document)
return documents

View File

@ -37,9 +37,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
try:
encodings = future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
raise TimeoutError(
f"Timeout reached while detecting encoding for {file_path}"
)
raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}")
if all(encoding["encoding"] is None for encoding in encodings):
raise RuntimeError(f"Could not detect encoding for {file_path}")

View File

@ -1,4 +1,5 @@
"""Abstract interface for document loader implementations."""
from bs4 import BeautifulSoup
from core.rag.extractor.extractor_base import BaseExtractor
@ -6,7 +7,6 @@ from core.rag.models.document import Document
class HtmlExtractor(BaseExtractor):
"""
Load html files.
@ -15,10 +15,7 @@ class HtmlExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
def __init__(self, file_path: str):
"""Initialize with file path."""
self._file_path = file_path
@ -27,8 +24,8 @@ class HtmlExtractor(BaseExtractor):
def _load_as_text(self) -> str:
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
soup = BeautifulSoup(fp, "html.parser")
text = soup.get_text()
text = text.strip() if text else ''
text = text.strip() if text else ""
return text
return text

View File

@ -1,4 +1,5 @@
"""Abstract interface for document loader implementations."""
import re
from typing import Optional, cast
@ -16,12 +17,12 @@ class MarkdownExtractor(BaseExtractor):
"""
def __init__(
self,
file_path: str,
remove_hyperlinks: bool = False,
remove_images: bool = False,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
self,
file_path: str,
remove_hyperlinks: bool = False,
remove_images: bool = False,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
"""Initialize with file path."""
self._file_path = file_path
@ -78,13 +79,10 @@ class MarkdownExtractor(BaseExtractor):
if current_header is not None:
# pass linting, assert keys are defined
markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups
]
else:
markdown_tups = [
(key, re.sub("\n", "", value)) for key, value in markdown_tups
]
markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups]
return markdown_tups

View File

@ -21,22 +21,21 @@ RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
# if user want split by headings, use the corresponding splitter
HEADING_SPLITTER = {
'heading_1': '# ',
'heading_2': '## ',
'heading_3': '### ',
"heading_1": "# ",
"heading_2": "## ",
"heading_3": "### ",
}
class NotionExtractor(BaseExtractor):
def __init__(
self,
notion_workspace_id: str,
notion_obj_id: str,
notion_page_type: str,
tenant_id: str,
document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None,
self,
notion_workspace_id: str,
notion_obj_id: str,
notion_page_type: str,
tenant_id: str,
document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None,
):
self._notion_access_token = None
self._document_model = document_model
@ -46,46 +45,38 @@ class NotionExtractor(BaseExtractor):
if notion_access_token:
self._notion_access_token = notion_access_token
else:
self._notion_access_token = self._get_access_token(tenant_id,
self._notion_workspace_id)
self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id)
if not self._notion_access_token:
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment "
"variable `NOTION_INTEGRATION_TOKEN`."
"Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`."
)
self._notion_access_token = integration_token
def extract(self) -> list[Document]:
self.update_last_edited_time(
self._document_model
)
self.update_last_edited_time(self._document_model)
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
return text_docs
def _load_data_as_documents(
self, notion_obj_id: str, notion_page_type: str
) -> list[Document]:
def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> list[Document]:
docs = []
if notion_page_type == 'database':
if notion_page_type == "database":
# get all the pages in the database
page_text_documents = self._get_notion_database_data(notion_obj_id)
docs.extend(page_text_documents)
elif notion_page_type == 'page':
elif notion_page_type == "page":
page_text_list = self._get_notion_block_data(notion_obj_id)
docs.append(Document(page_content='\n'.join(page_text_list)))
docs.append(Document(page_content="\n".join(page_text_list)))
else:
raise ValueError("notion page type not supported")
return docs
def _get_notion_database_data(
self, database_id: str, query_dict: dict[str, Any] = {}
) -> list[Document]:
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
"""Get all the pages from a Notion database."""
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
@ -100,50 +91,50 @@ class NotionExtractor(BaseExtractor):
data = res.json()
database_content = []
if 'results' not in data or data["results"] is None:
if "results" not in data or data["results"] is None:
return []
for result in data["results"]:
properties = result['properties']
properties = result["properties"]
data = {}
for property_name, property_value in properties.items():
type = property_value['type']
if type == 'multi_select':
type = property_value["type"]
if type == "multi_select":
value = []
multi_select_list = property_value[type]
for multi_select in multi_select_list:
value.append(multi_select['name'])
elif type == 'rich_text' or type == 'title':
value.append(multi_select["name"])
elif type == "rich_text" or type == "title":
if len(property_value[type]) > 0:
value = property_value[type][0]['plain_text']
value = property_value[type][0]["plain_text"]
else:
value = ''
elif type == 'select' or type == 'status':
value = ""
elif type == "select" or type == "status":
if property_value[type]:
value = property_value[type]['name']
value = property_value[type]["name"]
else:
value = ''
value = ""
else:
value = property_value[type]
data[property_name] = value
row_dict = {k: v for k, v in data.items() if v}
row_content = ''
row_content = ""
for key, value in row_dict.items():
if isinstance(value, dict):
value_dict = {k: v for k, v in value.items() if v}
value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items())
row_content = row_content + f'{key}:{value_content}\n'
value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
row_content = row_content + f"{key}:{value_content}\n"
else:
row_content = row_content + f'{key}:{value}\n'
row_content = row_content + f"{key}:{value}\n"
database_content.append(row_content)
return [Document(page_content='\n'.join(database_content))]
return [Document(page_content="\n".join(database_content))]
def _get_notion_block_data(self, page_id: str) -> list[str]:
result_lines_arr = []
start_cursor = None
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id)
while True:
query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor}
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
res = requests.request(
"GET",
block_url,
@ -152,14 +143,14 @@ class NotionExtractor(BaseExtractor):
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
params=query_dict
params=query_dict,
)
data = res.json()
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
cur_result_text_arr = []
if result_type == 'table':
if result_type == "table":
result_block_id = result["id"]
text = self._read_table_rows(result_block_id)
text += "\n\n"
@ -175,17 +166,15 @@ class NotionExtractor(BaseExtractor):
result_block_id = result["id"]
has_children = result["has_children"]
block_type = result["type"]
if has_children and block_type != 'child_page':
children_text = self._read_block(
result_block_id, num_tabs=1
)
if has_children and block_type != "child_page":
children_text = self._read_block(result_block_id, num_tabs=1)
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
if result_type in HEADING_SPLITTER:
result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}")
else:
result_lines_arr.append(cur_result_text + '\n\n')
result_lines_arr.append(cur_result_text + "\n\n")
if data["next_cursor"] is None:
break
@ -199,7 +188,7 @@ class NotionExtractor(BaseExtractor):
start_cursor = None
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
while True:
query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor}
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
res = requests.request(
"GET",
@ -209,16 +198,16 @@ class NotionExtractor(BaseExtractor):
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
params=query_dict
params=query_dict,
)
data = res.json()
if 'results' not in data or data["results"] is None:
if "results" not in data or data["results"] is None:
break
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
cur_result_text_arr = []
if result_type == 'table':
if result_type == "table":
result_block_id = result["id"]
text = self._read_table_rows(result_block_id)
result_lines_arr.append(text)
@ -233,17 +222,15 @@ class NotionExtractor(BaseExtractor):
result_block_id = result["id"]
has_children = result["has_children"]
block_type = result["type"]
if has_children and block_type != 'child_page':
children_text = self._read_block(
result_block_id, num_tabs=num_tabs + 1
)
if has_children and block_type != "child_page":
children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1)
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
if result_type in HEADING_SPLITTER:
result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}')
result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}")
else:
result_lines_arr.append(cur_result_text + '\n\n')
result_lines_arr.append(cur_result_text + "\n\n")
if data["next_cursor"] is None:
break
@ -260,7 +247,7 @@ class NotionExtractor(BaseExtractor):
start_cursor = None
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
while not done:
query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor}
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
res = requests.request(
"GET",
@ -270,28 +257,28 @@ class NotionExtractor(BaseExtractor):
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
params=query_dict
params=query_dict,
)
data = res.json()
# get table headers text
table_header_cell_texts = []
table_header_cells = data["results"][0]['table_row']['cells']
table_header_cells = data["results"][0]["table_row"]["cells"]
for table_header_cell in table_header_cells:
if table_header_cell:
for table_header_cell_text in table_header_cell:
text = table_header_cell_text["text"]["content"]
table_header_cell_texts.append(text)
else:
table_header_cell_texts.append('')
table_header_cell_texts.append("")
# Initialize Markdown table with headers
markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n"
markdown_table += "| " + " | ".join(['---'] * len(table_header_cell_texts)) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(table_header_cell_texts)) + " |\n"
# Process data to format each row in Markdown table format
results = data["results"]
for i in range(len(results) - 1):
column_texts = []
table_column_cells = data["results"][i + 1]['table_row']['cells']
table_column_cells = data["results"][i + 1]["table_row"]["cells"]
for j in range(len(table_column_cells)):
if table_column_cells[j]:
for table_column_cell_text in table_column_cells[j]:
@ -315,10 +302,8 @@ class NotionExtractor(BaseExtractor):
last_edited_time = self.get_notion_last_edited_time()
data_source_info = document_model.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
DocumentModel.data_source_info: json.dumps(data_source_info)
}
data_source_info["last_edited_time"] = last_edited_time
update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
db.session.commit()
@ -326,7 +311,7 @@ class NotionExtractor(BaseExtractor):
def get_notion_last_edited_time(self) -> str:
obj_id = self._notion_obj_id
page_type = self._notion_page_type
if page_type == 'database':
if page_type == "database":
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
else:
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
@ -341,7 +326,7 @@ class NotionExtractor(BaseExtractor):
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
json=query_dict,
)
data = res.json()
@ -352,14 +337,16 @@ class NotionExtractor(BaseExtractor):
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
).first()
if not data_source_binding:
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
f'and notion workspace {notion_workspace_id}')
raise Exception(
f"No notion data source binding found for tenant {tenant_id} "
f"and notion workspace {notion_workspace_id}"
)
return data_source_binding.access_token

View File

@ -1,4 +1,5 @@
"""Abstract interface for document loader implementations."""
from collections.abc import Iterator
from typing import Optional
@ -16,21 +17,17 @@ class PdfExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
file_cache_key: Optional[str] = None
):
def __init__(self, file_path: str, file_cache_key: Optional[str] = None):
"""Initialize with file path."""
self._file_path = file_path
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
plaintext_file_key = ''
plaintext_file_key = ""
plaintext_file_exists = False
if self._file_cache_key:
try:
text = storage.load(self._file_cache_key).decode('utf-8')
text = storage.load(self._file_cache_key).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
@ -43,12 +40,12 @@ class PdfExtractor(BaseExtractor):
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
storage.save(plaintext_file_key, text.encode("utf-8"))
return documents
def load(
self,
self,
) -> Iterator[Document]:
"""Lazy load given path as pages."""
blob = Blob.from_path(self._file_path)

View File

@ -1,4 +1,5 @@
"""Abstract interface for document loader implementations."""
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
@ -14,12 +15,7 @@ class TextExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False
):
def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding

View File

@ -8,13 +8,12 @@ logger = logging.getLogger(__name__)
class UnstructuredWordExtractor(BaseExtractor):
"""Loader that uses unstructured to load word documents.
"""
"""Loader that uses unstructured to load word documents."""
def __init__(
self,
file_path: str,
api_url: str,
self,
file_path: str,
api_url: str,
):
"""Initialize with file path."""
self._file_path = file_path
@ -24,9 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor):
from unstructured.__version__ import __version__ as __unstructured_version__
from unstructured.file_utils.filetype import FileType, detect_filetype
unstructured_version = tuple(
int(x) for x in __unstructured_version__.split(".")
)
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
# check the file extension
try:
import magic # noqa: F401
@ -53,6 +50,7 @@ class UnstructuredWordExtractor(BaseExtractor):
elements = partition_docx(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
documents = []
for chunk in chunks:

View File

@ -26,6 +26,7 @@ class UnstructuredEmailExtractor(BaseExtractor):
def extract(self) -> list[Document]:
from unstructured.partition.email import partition_email
elements = partition_email(filename=self._file_path)
# noinspection PyBroadException
@ -34,15 +35,16 @@ class UnstructuredEmailExtractor(BaseExtractor):
element_text = element.text.strip()
padding_needed = 4 - len(element_text) % 4
element_text += '=' * padding_needed
element_text += "=" * padding_needed
element_decode = base64.b64decode(element_text)
soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser')
soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser")
element.text = soup.get_text()
except Exception:
pass
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
documents = []
for chunk in chunks:

View File

@ -28,6 +28,7 @@ class UnstructuredEpubExtractor(BaseExtractor):
elements = partition_epub(filename=self._file_path, xml_keep_tags=True)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
documents = []
for chunk in chunks:

View File

@ -38,6 +38,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
elements = partition_md(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
documents = []
for chunk in chunks:

View File

@ -14,11 +14,7 @@ class UnstructuredMsgExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
def __init__(self, file_path: str, api_url: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
@ -28,6 +24,7 @@ class UnstructuredMsgExtractor(BaseExtractor):
elements = partition_msg(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
documents = []
for chunk in chunks:

View File

@ -14,12 +14,7 @@ class UnstructuredPPTExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str,
api_key: str
):
def __init__(self, file_path: str, api_url: str, api_key: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url

View File

@ -14,11 +14,7 @@ class UnstructuredPPTXExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
def __init__(self, file_path: str, api_url: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url

View File

@ -14,11 +14,7 @@ class UnstructuredTextExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
def __init__(self, file_path: str, api_url: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
@ -28,6 +24,7 @@ class UnstructuredTextExtractor(BaseExtractor):
elements = partition_text(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
documents = []
for chunk in chunks:

View File

@ -14,11 +14,7 @@ class UnstructuredXmlExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str
):
def __init__(self, file_path: str, api_url: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
@ -28,6 +24,7 @@ class UnstructuredXmlExtractor(BaseExtractor):
elements = partition_xml(filename=self._file_path, xml_keep_tags=True)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
documents = []
for chunk in chunks:

View File

@ -1,4 +1,5 @@
"""Abstract interface for document loader implementations."""
import datetime
import logging
import mimetypes
@ -21,6 +22,7 @@ from models.model import UploadFile
logger = logging.getLogger(__name__)
class WordExtractor(BaseExtractor):
"""Load docx files.
@ -43,9 +45,7 @@ class WordExtractor(BaseExtractor):
r = requests.get(self.file_path)
if r.status_code != 200:
raise ValueError(
f"Check the url of your file; returned status code {r.status_code}"
)
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
self.web_path = self.file_path
self.temp_file = tempfile.NamedTemporaryFile()
@ -60,11 +60,13 @@ class WordExtractor(BaseExtractor):
def extract(self) -> list[Document]:
"""Load given path as single page."""
content = self.parse_docx(self.file_path, 'storage')
return [Document(
page_content=content,
metadata={"source": self.file_path},
)]
content = self.parse_docx(self.file_path, "storage")
return [
Document(
page_content=content,
metadata={"source": self.file_path},
)
]
@staticmethod
def _is_valid_url(url: str) -> bool:
@ -84,18 +86,18 @@ class WordExtractor(BaseExtractor):
url = rel.reltype
response = requests.get(url, stream=True)
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers['Content-Type'])
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
file_uuid = str(uuid.uuid4())
file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, response.content)
else:
continue
else:
image_ext = rel.target_ref.split('.')[-1]
image_ext = rel.target_ref.split(".")[-1]
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, rel.target_part.blob)
@ -112,12 +114,14 @@ class WordExtractor(BaseExtractor):
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
used=True,
used_by=self.user_id,
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
)
db.session.add(upload_file)
db.session.commit()
image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)"
image_map[rel.target_part] = (
f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)"
)
return image_map
@ -167,8 +171,8 @@ class WordExtractor(BaseExtractor):
def _parse_cell_paragraph(self, paragraph, image_map):
paragraph_content = []
for run in paragraph.runs:
if run.element.xpath('.//a:blip'):
for blip in run.element.xpath('.//a:blip'):
if run.element.xpath(".//a:blip"):
for blip in run.element.xpath(".//a:blip"):
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
if not image_id:
continue
@ -184,16 +188,16 @@ class WordExtractor(BaseExtractor):
def _parse_paragraph(self, paragraph, image_map):
paragraph_content = []
for run in paragraph.runs:
if run.element.xpath('.//a:blip'):
for blip in run.element.xpath('.//a:blip'):
embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
if run.element.xpath(".//a:blip"):
for blip in run.element.xpath(".//a:blip"):
embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
if embed_id:
rel_target = run.part.rels[embed_id].target_ref
if rel_target in image_map:
paragraph_content.append(image_map[rel_target])
if run.text.strip():
paragraph_content.append(run.text.strip())
return ' '.join(paragraph_content) if paragraph_content else ''
return " ".join(paragraph_content) if paragraph_content else ""
def parse_docx(self, docx_path, image_folder):
doc = DocxDocument(docx_path)
@ -204,60 +208,59 @@ class WordExtractor(BaseExtractor):
image_map = self._extract_images_from_docx(doc, image_folder)
hyperlinks_url = None
url_pattern = re.compile(r'http://[^\s+]+//|https://[^\s+]+')
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
for para in doc.paragraphs:
for run in para.runs:
if run.text and hyperlinks_url:
result = f' [{run.text}]({hyperlinks_url}) '
result = f" [{run.text}]({hyperlinks_url}) "
run.text = result
hyperlinks_url = None
if 'HYPERLINK' in run.element.xml:
if "HYPERLINK" in run.element.xml:
try:
xml = ET.XML(run.element.xml)
x_child = [c for c in xml.iter() if c is not None]
for x in x_child:
if x_child is None:
continue
if x.tag.endswith('instrText'):
if x.tag.endswith("instrText"):
for i in url_pattern.findall(x.text):
hyperlinks_url = str(i)
except Exception as e:
logger.error(e)
def parse_paragraph(paragraph):
paragraph_content = []
for run in paragraph.runs:
if hasattr(run.element, 'tag') and isinstance(element.tag, str) and run.element.tag.endswith('r'):
if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"):
drawing_elements = run.element.findall(
'.//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing')
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
)
for drawing in drawing_elements:
blip_elements = drawing.findall(
'.//{http://schemas.openxmlformats.org/drawingml/2006/main}blip')
".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip"
)
for blip in blip_elements:
embed_id = blip.get(
'{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
)
if embed_id:
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
paragraph_content.append(image_map[image_part])
if run.text.strip():
paragraph_content.append(run.text.strip())
return ''.join(paragraph_content) if paragraph_content else ''
return "".join(paragraph_content) if paragraph_content else ""
paragraphs = doc.paragraphs.copy()
tables = doc.tables.copy()
for element in doc.element.body:
if hasattr(element, 'tag'):
if isinstance(element.tag, str) and element.tag.endswith('p'): # paragraph
if hasattr(element, "tag"):
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
para = paragraphs.pop(0)
parsed_paragraph = parse_paragraph(para)
if parsed_paragraph:
content.append(parsed_paragraph)
elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table
elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table
table = tables.pop(0)
content.append(self._table_to_markdown(table, image_map))
return '\n'.join(content)
return "\n".join(content)

View File

@ -1,4 +1,5 @@
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
from typing import Optional
@ -15,8 +16,7 @@ from models.dataset import Dataset, DatasetProcessRule
class BaseIndexProcessor(ABC):
"""Interface for extract files.
"""
"""Interface for extract files."""
@abstractmethod
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
@ -34,18 +34,24 @@ class BaseIndexProcessor(ABC):
raise NotImplementedError
@abstractmethod
def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]:
def retrieve(
self,
retrieval_method: str,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
) -> list[Document]:
raise NotImplementedError
def _get_splitter(self, processing_rule: dict,
embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
if processing_rule['mode'] == "custom":
if processing_rule["mode"] == "custom":
# The user-defined segmentation rule
rules = processing_rule['rules']
rules = processing_rule["rules"]
segmentation = rules["segmentation"]
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
@ -53,22 +59,22 @@ class BaseIndexProcessor(ABC):
separator = segmentation["separator"]
if separator:
separator = separator.replace('\\n', '\n')
separator = separator.replace("\\n", "\n")
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=segmentation.get('chunk_overlap', 0) or 0,
chunk_overlap=segmentation.get("chunk_overlap", 0) or 0,
fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance
embedding_model_instance=embedding_model_instance,
)
else:
# Automatic segmentation
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"],
separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance
embedding_model_instance=embedding_model_instance,
)
return character_splitter

View File

@ -7,8 +7,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess
class IndexProcessorFactory:
"""IndexProcessorInit.
"""
"""IndexProcessorInit."""
def __init__(self, index_type: str):
self._index_type = index_type
@ -22,7 +21,6 @@ class IndexProcessorFactory:
if self._index_type == IndexType.PARAGRAPH_INDEX.value:
return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX.value:
return QAIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@ -1,4 +1,5 @@
"""Paragraph index processor."""
import uuid
from typing import Optional
@ -15,33 +16,32 @@ from models.dataset import Dataset
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=kwargs.get('process_rule_mode') == "automatic")
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic"
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes.
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=kwargs.get('embedding_model_instance'))
splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
)
all_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
@ -55,7 +55,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if with_keywords:
@ -63,7 +63,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
@ -76,17 +76,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
else:
keyword.delete()
def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]:
def retrieve(
self,
retrieval_method: str,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(retrieval_method=retrieval_method, dataset_id=dataset.id, query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata['score'] = result.score
metadata["score"] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)

View File

@ -1,4 +1,5 @@
"""Paragraph index processor."""
import logging
import re
import threading
@ -23,33 +24,33 @@ from models.dataset import Dataset
class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
is_automatic=kwargs.get('process_rule_mode') == "automatic")
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic"
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
embedding_model_instance=kwargs.get('embedding_model_instance'))
splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
)
# Split the text documents into nodes.
all_documents = []
all_qa_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata['doc_id'] = doc_id
document_node.metadata['doc_hash'] = hash
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
@ -61,14 +62,18 @@ class QAIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i:i + 10]
sub_documents = all_documents[i : i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={
'flask_app': current_app._get_current_object(),
'tenant_id': kwargs.get('tenant_id'),
'document_node': doc,
'all_qa_documents': all_qa_documents,
'document_language': kwargs.get('doc_language', 'English')})
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(),
"tenant_id": kwargs.get("tenant_id"),
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": kwargs.get("doc_language", "English"),
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
@ -76,9 +81,8 @@ class QAIndexProcessor(BaseIndexProcessor):
return all_qa_documents
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
# check file type
if not file.filename.endswith('.csv'):
if not file.filename.endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
@ -86,7 +90,7 @@ class QAIndexProcessor(BaseIndexProcessor):
df = pd.read_csv(file)
text_docs = []
for index, row in df.iterrows():
data = Document(page_content=row[0], metadata={'answer': row[1]})
data = Document(page_content=row[0], metadata={"answer": row[1]})
text_docs.append(data)
if len(text_docs) == 0:
raise ValueError("The CSV file is empty.")
@ -96,7 +100,7 @@ class QAIndexProcessor(BaseIndexProcessor):
return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
if dataset.indexing_technique == 'high_quality':
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
@ -107,17 +111,29 @@ class QAIndexProcessor(BaseIndexProcessor):
else:
vector.delete()
def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict):
def retrieve(
self,
retrieval_method: str,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
):
# Set search parameters.
results = RetrievalService.retrieve(retrieval_method=retrieval_method, dataset_id=dataset.id, query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata['score'] = result.score
metadata["score"] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
@ -134,12 +150,12 @@ class QAIndexProcessor(BaseIndexProcessor):
document_qa_list = self._format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result['question'])
qa_document.metadata['answer'] = result['answer']
qa_document.metadata['doc_id'] = doc_id
qa_document.metadata['doc_hash'] = hash
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
@ -151,10 +167,4 @@ class QAIndexProcessor(BaseIndexProcessor):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
matches = re.findall(regex, text, re.UNICODE)
return [
{
"question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip())
}
for q, a in matches if q and a
]
return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a]

View File

@ -55,9 +55,7 @@ class BaseDocumentTransformer(ABC):
"""
@abstractmethod
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform a list of documents.
Args:
@ -68,9 +66,7 @@ class BaseDocumentTransformer(ABC):
"""
@abstractmethod
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Asynchronously transform a list of documents.
Args:

View File

@ -2,7 +2,5 @@ from enum import Enum
class RerankMode(Enum):
RERANKING_MODEL = 'reranking_model'
WEIGHTED_SCORE = 'weighted_score'
RERANKING_MODEL = "reranking_model"
WEIGHTED_SCORE = "weighted_score"

View File

@ -8,8 +8,14 @@ class RerankModelRunner:
def __init__(self, rerank_model_instance: ModelInstance) -> None:
self.rerank_model_instance = rerank_model_instance
def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
def run(
self,
query: str,
documents: list[Document],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> list[Document]:
"""
Run rerank model
:param query: search query
@ -23,19 +29,15 @@ class RerankModelRunner:
doc_id = []
unique_documents = []
for document in documents:
if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id'])
if document.metadata["doc_id"] not in doc_id:
doc_id.append(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
rerank_result = self.rerank_model_instance.invoke_rerank(
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
)
rerank_documents = []
@ -45,12 +47,12 @@ class RerankModelRunner:
rerank_document = Document(
page_content=result.text,
metadata={
"doc_id": documents[result.index].metadata['doc_id'],
"doc_hash": documents[result.index].metadata['doc_hash'],
"document_id": documents[result.index].metadata['document_id'],
"dataset_id": documents[result.index].metadata['dataset_id'],
'score': result.score
}
"doc_id": documents[result.index].metadata["doc_id"],
"doc_hash": documents[result.index].metadata["doc_hash"],
"document_id": documents[result.index].metadata["document_id"],
"dataset_id": documents[result.index].metadata["dataset_id"],
"score": result.score,
},
)
rerank_documents.append(rerank_document)

View File

@ -13,13 +13,18 @@ from core.rag.rerank.entity.weight import VectorSetting, Weights
class WeightRerankRunner:
def __init__(self, tenant_id: str, weights: Weights) -> None:
self.tenant_id = tenant_id
self.weights = weights
def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
def run(
self,
query: str,
documents: list[Document],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> list[Document]:
"""
Run rerank model
:param query: search query
@ -34,8 +39,8 @@ class WeightRerankRunner:
doc_id = []
unique_documents = []
for document in documents:
if document.metadata['doc_id'] not in doc_id:
doc_id.append(document.metadata['doc_id'])
if document.metadata["doc_id"] not in doc_id:
doc_id.append(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
@ -47,13 +52,15 @@ class WeightRerankRunner:
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
# format document
score = self.weights.vector_setting.vector_weight * query_vector_score + \
self.weights.keyword_setting.keyword_weight * query_score
score = (
self.weights.vector_setting.vector_weight * query_vector_score
+ self.weights.keyword_setting.keyword_weight * query_score
)
if score_threshold and score < score_threshold:
continue
document.metadata['score'] = score
document.metadata["score"] = score
rerank_documents.append(document)
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True)
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
@ -70,7 +77,7 @@ class WeightRerankRunner:
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata['keywords'] = document_keywords
document.metadata["keywords"] = document_keywords
documents_keywords.append(document_keywords)
# Counter query keywords(TF)
@ -132,8 +139,9 @@ class WeightRerankRunner:
return similarities
def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document],
vector_setting: VectorSetting) -> list[float]:
def _calculate_cosine(
self, tenant_id: str, query: str, documents: list[Document], vector_setting: VectorSetting
) -> list[float]:
"""
Calculate Cosine scores
:param query: search query
@ -149,15 +157,14 @@ class WeightRerankRunner:
tenant_id=tenant_id,
provider=vector_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=vector_setting.embedding_model_name
model=vector_setting.embedding_model_name,
)
cache_embedding = CacheEmbedding(embedding_model)
query_vector = cache_embedding.embed_query(query)
for document in documents:
# calculate cosine similarity
if 'score' in document.metadata:
query_vector_scores.append(document.metadata['score'])
if "score" in document.metadata:
query_vector_scores.append(document.metadata["score"])
else:
# transform to NumPy
vec1 = np.array(query_vector)

View File

@ -32,14 +32,11 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
@ -48,15 +45,18 @@ class DatasetRetrieval:
self.application_generate_entity = application_generate_entity
def retrieve(
self, app_id: str, user_id: str, tenant_id: str,
model_config: ModelConfigWithCredentialsEntity,
config: DatasetEntity,
query: str,
invoke_from: InvokeFrom,
show_retrieve_source: bool,
hit_callback: DatasetIndexToolCallbackHandler,
message_id: str,
memory: Optional[TokenBufferMemory] = None,
self,
app_id: str,
user_id: str,
tenant_id: str,
model_config: ModelConfigWithCredentialsEntity,
config: DatasetEntity,
query: str,
invoke_from: InvokeFrom,
show_retrieve_source: bool,
hit_callback: DatasetIndexToolCallbackHandler,
message_id: str,
memory: Optional[TokenBufferMemory] = None,
) -> Optional[str]:
"""
Retrieve dataset.
@ -84,16 +84,12 @@ class DatasetRetrieval:
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
credentials=model_config.credentials
model=model_config.model, credentials=model_config.credentials
)
if not model_schema:
@ -102,39 +98,46 @@ class DatasetRetrieval:
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available
if not dataset:
continue
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
continue
available_datasets.append(dataset)
all_documents = []
user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
all_documents = self.single_retrieve(
app_id, tenant_id, user_id, user_from, available_datasets, query,
app_id,
tenant_id,
user_id,
user_from,
available_datasets,
query,
model_instance,
model_config, planning_strategy, message_id
model_config,
planning_strategy,
message_id,
)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
all_documents = self.multiple_retrieve(
app_id, tenant_id, user_id, user_from,
available_datasets, query, retrieve_config.top_k,
app_id,
tenant_id,
user_id,
user_from,
available_datasets,
query,
retrieve_config.top_k,
retrieve_config.score_threshold,
retrieve_config.rerank_mode,
retrieve_config.reranking_model,
@ -145,89 +148,89 @@ class DatasetRetrieval:
document_score_list = {}
for item in all_documents:
if item.metadata.get('score'):
document_score_list[item.metadata['doc_id']] = item.metadata['score']
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(segments,
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
float('inf')))
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if show_retrieve_source:
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(
id=segment.dataset_id
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
if dataset and document:
source = {
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': invoke_from.to_source(),
'score': document_score_list.get(segment.index_node_id, None)
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": invoke_from.to_source(),
"score": document_score_list.get(segment.index_node_id, None),
}
if invoke_from.to_source() == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
source['segment_position'] = segment.position
source['index_node_hash'] = segment.index_node_hash
if invoke_from.to_source() == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source['content'] = segment.content
source["content"] = segment.content
context_list.append(source)
resource_number += 1
if hit_callback:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return ''
return ""
def single_retrieve(
self, app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
message_id: Optional[str] = None,
self,
app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
message_id: Optional[str] = None,
):
tools = []
for dataset in available_datasets:
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
description = "useful for when you want to answer queries about the " + dataset.name
description = description.replace('\n', '').replace('\r', '')
description = description.replace("\n", "").replace("\r", "")
message_tool = PromptMessageTool(
name=dataset.id,
description=description,
@ -235,14 +238,15 @@ class DatasetRetrieval:
"type": "object",
"properties": {},
"required": [],
}
},
)
tools.append(message_tool)
dataset_id = None
if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance,
user_id, tenant_id)
dataset_id = react_multi_dataset_router.invoke(
query, tools, model_config, model_instance, user_id, tenant_id
)
elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter()
@ -250,37 +254,37 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
top_k = retrieval_model_config["top_k"]
# get retrieval method
if dataset.indexing_technique == "economy":
retrieval_method = 'keyword_search'
retrieval_method = "keyword_search"
else:
retrieval_method = retrieval_model_config['search_method']
retrieval_method = retrieval_model_config["search_method"]
# get reranking model
reranking_model = retrieval_model_config['reranking_model'] \
if retrieval_model_config['reranking_enable'] else None
reranking_model = (
retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None
)
# get score threshold
score_threshold = .0
score_threshold = 0.0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
with measure_time() as timer:
results = RetrievalService.retrieve(
retrieval_method=retrieval_method, dataset_id=dataset.id,
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k, score_threshold=score_threshold,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'),
weights=retrieval_model_config.get('weights', None),
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
weights=retrieval_model_config.get("weights", None),
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
@ -291,20 +295,20 @@ class DatasetRetrieval:
return []
def multiple_retrieve(
self,
app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
top_k: int,
score_threshold: float,
reranking_mode: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
self,
app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
top_k: int,
score_threshold: float,
reranking_mode: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
):
threads = []
all_documents = []
@ -312,13 +316,16 @@ class DatasetRetrieval:
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset.id,
'query': query,
'top_k': top_k,
'all_documents': all_documents,
})
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
@ -327,16 +334,10 @@ class DatasetRetrieval:
with measure_time() as timer:
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(
tenant_id, reranking_mode,
reranking_model, weights, False
)
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
)
else:
if index_type == "economy":
@ -357,30 +358,26 @@ class DatasetRetrieval:
"""Handle retrieval end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id']
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
trace_manager: TraceQueueManager = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE,
message_id=message_id,
documents=documents,
timer=timer
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
)
@ -395,10 +392,10 @@ class DatasetRetrieval:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
@ -407,9 +404,7 @@ class DatasetRetrieval:
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
@ -419,38 +414,42 @@ class DatasetRetrieval:
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(retrieval_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=top_k
)
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
)
if documents:
all_documents.extend(documents)
else:
if top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=retrieval_model.get('score_threshold', .0)
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None)
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None),
)
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else None,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode")
if retrieval_model.get("reranking_mode")
else "reranking_model",
weights=retrieval_model.get("weights", None),
)
all_documents.extend(documents)
def to_dataset_retriever_tool(self, tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[list[DatasetRetrieverBaseTool]]:
def to_dataset_retriever_tool(
self,
tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
) -> Optional[list[DatasetRetrieverBaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id
@ -464,18 +463,14 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available
if not dataset:
continue
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
continue
available_datasets.append(dataset)
@ -483,22 +478,18 @@ class DatasetRetrieval:
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
top_k = retrieval_model_config["top_k"]
# get score threshold
score_threshold = None
@ -512,7 +503,7 @@ class DatasetRetrieval:
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source()
retriever_from=invoke_from.to_source(),
)
tools.append(tool)
@ -525,8 +516,8 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
)
tools.append(tool)
@ -547,7 +538,7 @@ class DatasetRetrieval:
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata['keywords'] = document_keywords
document.metadata["keywords"] = document_keywords
documents_keywords.append(document_keywords)
# Counter query keywords(TF)
@ -606,21 +597,19 @@ class DatasetRetrieval:
for document, score in zip(documents, similarities):
# format document
document.metadata['score'] = score
documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
document.metadata["score"] = score
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents[:top_k] if top_k else documents
def calculate_vector_score(self, all_documents: list[Document],
top_k: int, score_threshold: float) -> list[Document]:
def calculate_vector_score(
self, all_documents: list[Document], top_k: int, score_threshold: float
) -> list[Document]:
filter_documents = []
for document in all_documents:
if score_threshold is None or document.metadata['score'] >= score_threshold:
if score_threshold is None or document.metadata["score"] >= score_threshold:
filter_documents.append(document)
if not filter_documents:
return []
filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
return filter_documents[:top_k] if top_k else filter_documents

View File

@ -16,9 +16,7 @@ class StructuredChatOutputParser:
if response["action"] == "Final Answer":
return ReactFinish({"output": response["action_input"]}, text)
else:
return ReactAction(
response["action"], response.get("action_input", {}), text
)
return ReactAction(response["action"], response.get("action_input", {}), text)
else:
return ReactFinish({"output": text}, text)
except Exception as e:

View File

@ -2,9 +2,9 @@ from enum import Enum
class RetrievalMethod(Enum):
SEMANTIC_SEARCH = 'semantic_search'
FULL_TEXT_SEARCH = 'full_text_search'
HYBRID_SEARCH = 'hybrid_search'
SEMANTIC_SEARCH = "semantic_search"
FULL_TEXT_SEARCH = "full_text_search"
HYBRID_SEARCH = "hybrid_search"
@staticmethod
def is_support_semantic_search(retrieval_method: str) -> bool:

View File

@ -6,14 +6,12 @@ from core.model_runtime.entities.message_entities import PromptMessageTool, Syst
class FunctionCallMultiDatasetRouter:
def invoke(
self,
query: str,
dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
self,
query: str,
dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
) -> Union[str, None]:
"""Given input, decided what to do.
Returns:
@ -26,22 +24,18 @@ class FunctionCallMultiDatasetRouter:
try:
prompt_messages = [
SystemPromptMessage(content='You are a helpful AI assistant.'),
UserPromptMessage(content=query)
SystemPromptMessage(content="You are a helpful AI assistant."),
UserPromptMessage(content=query),
]
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
if result.message.tool_calls:
# get retrieval model config
return result.message.tool_calls[0].function.name
return None
except Exception as e:
return None
return None

View File

@ -50,16 +50,14 @@ Action:
class ReactMultiDatasetRouter:
def invoke(
self,
query: str,
dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
user_id: str,
tenant_id: str
self,
query: str,
dataset_tools: list[PromptMessageTool],
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
user_id: str,
tenant_id: str,
) -> Union[str, None]:
"""Given input, decided what to do.
Returns:
@ -71,23 +69,28 @@ class ReactMultiDatasetRouter:
return dataset_tools[0].name
try:
return self._react_invoke(query=query, model_config=model_config,
model_instance=model_instance,
tools=dataset_tools, user_id=user_id, tenant_id=tenant_id)
return self._react_invoke(
query=query,
model_config=model_config,
model_instance=model_instance,
tools=dataset_tools,
user_id=user_id,
tenant_id=tenant_id,
)
except Exception as e:
return None
def _react_invoke(
self,
query: str,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
tools: Sequence[PromptMessageTool],
user_id: str,
tenant_id: str,
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
self,
query: str,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
tools: Sequence[PromptMessageTool],
user_id: str,
tenant_id: str,
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]:
if model_config.mode == "chat":
prompt = self.create_chat_prompt(
@ -103,18 +106,18 @@ class ReactMultiDatasetRouter:
prefix=prefix,
format_instructions=format_instructions,
)
stop = ['Observation:']
stop = ["Observation:"]
# handle invoke result
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt,
inputs={},
query='',
query="",
files=[],
context='',
context="",
memory_config=None,
memory=None,
model_config=model_config
model_config=model_config,
)
result_text, usage = self._invoke_llm(
completion_param=model_config.parameters,
@ -122,7 +125,7 @@ class ReactMultiDatasetRouter:
prompt_messages=prompt_messages,
stop=stop,
user_id=user_id,
tenant_id=tenant_id
tenant_id=tenant_id,
)
output_parser = StructuredChatOutputParser()
react_decision = output_parser.parse(result_text)
@ -130,17 +133,21 @@ class ReactMultiDatasetRouter:
return react_decision.tool
return None
def _invoke_llm(self, completion_param: dict,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str], user_id: str, tenant_id: str
) -> tuple[str, LLMUsage]:
def _invoke_llm(
self,
completion_param: dict,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str],
user_id: str,
tenant_id: str,
) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
Invoke large language model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
@ -151,9 +158,7 @@ class ReactMultiDatasetRouter:
)
# handle invoke result
text, usage = self._handle_invoke_result(
invoke_result=invoke_result
)
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
@ -168,7 +173,7 @@ class ReactMultiDatasetRouter:
"""
model = None
prompt_messages = []
full_text = ''
full_text = ""
usage = None
for result in invoke_result:
text = result.delta.message.content
@ -189,40 +194,35 @@ class ReactMultiDatasetRouter:
return full_text, usage
def create_chat_prompt(
self,
query: str,
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
self,
query: str,
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> list[ChatModelMessage]:
tool_strings = []
for tool in tools:
tool_strings.append(
f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}"
)
formatted_tools = "\n".join(tool_strings)
unique_tool_names = {tool.name for tool in tools}
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
prompt_messages = []
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=template
)
system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=template)
prompt_messages.append(system_prompt_messages)
user_prompt_message = ChatModelMessage(
role=PromptMessageRole.USER,
text=query
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=query)
prompt_messages.append(user_prompt_message)
return prompt_messages
def create_completion_prompt(
self,
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
self,
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> CompletionModelPromptTemplate:
"""Create prompt in the style of the zero shot agent.

View File

@ -1,4 +1,5 @@
"""Functionality for splitting text."""
from __future__ import annotations
from typing import Any, Optional
@ -18,31 +19,29 @@ from core.rag.splitter.text_splitter import (
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
"""
@classmethod
def from_encoder(
cls: type[TS],
embedding_model_instance: Optional[ModelInstance],
allowed_special: Union[Literal[all], Set[str]] = set(),
disallowed_special: Union[Literal[all], Collection[str]] = "all",
**kwargs: Any,
cls: type[TS],
embedding_model_instance: Optional[ModelInstance],
allowed_special: Union[Literal[all], Set[str]] = set(),
disallowed_special: Union[Literal[all], Collection[str]] = "all",
**kwargs: Any,
):
def _token_encoder(text: str) -> int:
if not text:
return 0
if embedding_model_instance:
return embedding_model_instance.get_text_embedding_num_tokens(
texts=[text]
)
return embedding_model_instance.get_text_embedding_num_tokens(texts=[text])
else:
return GPT2Tokenizer.get_num_tokens(text)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}

View File

@ -22,9 +22,7 @@ logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> list[str]:
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
@ -37,19 +35,19 @@ def _split_text_with_regex(
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if (s != "" and s != '\n')]
return [s for s in splits if (s != "" and s != "\n")]
class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks."""
def __init__(
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
) -> None:
"""Create a new TextSplitter.
@ -62,8 +60,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
"""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller."
)
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
@ -75,9 +72,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
def split_text(self, text: str) -> list[str]:
"""Split text into multiple components."""
def create_documents(
self, texts: list[str], metadatas: Optional[list[dict]] = None
) -> list[Document]:
def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
@ -119,14 +114,10 @@ class TextSplitter(BaseDocumentTransformer, ABC):
index = 0
for d in splits:
_len = lengths[index]
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
):
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
if total > self._chunk_size:
logger.warning(
f"Created a chunk of size {total}, "
f"which is longer than the specified {self._chunk_size}"
f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
@ -136,13 +127,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
and total > 0
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
@ -159,28 +146,25 @@ class TextSplitter(BaseDocumentTransformer, ABC):
from transformers import PreTrainedTokenizerBase
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError(
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
)
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
def _huggingface_tokenizer_length(text: str) -> int:
return len(tokenizer.encode(text))
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please install it with `pip install transformers`."
"Could not import transformers python package. " "Please install it with `pip install transformers`."
)
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
@classmethod
def from_tiktoken_encoder(
cls: type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
cls: type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TS:
"""Text splitter that uses tiktoken encoder to count length."""
try:
@ -217,15 +201,11 @@ class TextSplitter(BaseDocumentTransformer, ABC):
return cls(length_function=_tiktoken_encoder, **kwargs)
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
raise NotImplementedError
@ -267,9 +247,7 @@ class HeaderType(TypedDict):
class MarkdownHeaderTextSplitter:
"""Splitting markdown files based on specified headers."""
def __init__(
self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
):
def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False):
"""Create a new MarkdownHeaderTextSplitter.
Args:
@ -280,9 +258,7 @@ class MarkdownHeaderTextSplitter:
self.return_each_line = return_each_line
# Given the headers we want to split on,
# (e.g., "#, ##, etc") order by length
self.headers_to_split_on = sorted(
headers_to_split_on, key=lambda split: len(split[0]), reverse=True
)
self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True)
def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
"""Combine lines with common metadata into chunks
@ -292,10 +268,7 @@ class MarkdownHeaderTextSplitter:
aggregated_chunks: list[LineType] = []
for line in lines:
if (
aggregated_chunks
and aggregated_chunks[-1]["metadata"] == line["metadata"]
):
if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]:
# If the last line in the aggregated list
# has the same metadata as the current line,
# append the current content to the last lines's content
@ -304,10 +277,7 @@ class MarkdownHeaderTextSplitter:
# Otherwise, append the current line to the aggregated list
aggregated_chunks.append(line)
return [
Document(page_content=chunk["content"], metadata=chunk["metadata"])
for chunk in aggregated_chunks
]
return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks]
def split_text(self, text: str) -> list[Document]:
"""Split markdown file
@ -332,10 +302,9 @@ class MarkdownHeaderTextSplitter:
for sep, name in self.headers_to_split_on:
# Check if line starts with a header that we intend to split on
if stripped_line.startswith(sep) and (
# Header with no text OR header is followed by space
# Both are valid conditions that sep is being used a header
len(stripped_line) == len(sep)
or stripped_line[len(sep)] == " "
# Header with no text OR header is followed by space
# Both are valid conditions that sep is being used a header
len(stripped_line) == len(sep) or stripped_line[len(sep)] == " "
):
# Ensure we are tracking the header as metadata
if name is not None:
@ -343,10 +312,7 @@ class MarkdownHeaderTextSplitter:
current_header_level = sep.count("#")
# Pop out headers of lower or same level from the stack
while (
header_stack
and header_stack[-1]["level"] >= current_header_level
):
while header_stack and header_stack[-1]["level"] >= current_header_level:
# We have encountered a new header
# at the same or higher level
popped_header = header_stack.pop()
@ -359,7 +325,7 @@ class MarkdownHeaderTextSplitter:
header: HeaderType = {
"level": current_header_level,
"name": name,
"data": stripped_line[len(sep):].strip(),
"data": stripped_line[len(sep) :].strip(),
}
header_stack.append(header)
# Update initial_metadata with the current header
@ -392,9 +358,7 @@ class MarkdownHeaderTextSplitter:
current_metadata = initial_metadata.copy()
if current_content:
lines_with_metadata.append(
{"content": "\n".join(current_content), "metadata": current_metadata}
)
lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata})
# lines_with_metadata has each line with associated header metadata
# aggregate these into chunks based on common metadata
@ -402,8 +366,7 @@ class MarkdownHeaderTextSplitter:
return self.aggregate_lines_to_chunks(lines_with_metadata)
else:
return [
Document(page_content=chunk["content"], metadata=chunk["metadata"])
for chunk in lines_with_metadata
Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata
]
@ -436,12 +399,12 @@ class TokenTextSplitter(TextSplitter):
"""Splitting text to tokens using model tokenizer."""
def __init__(
self,
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
self,
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
@ -488,10 +451,10 @@ class RecursiveCharacterTextSplitter(TextSplitter):
"""
def __init__(
self,
separators: Optional[list[str]] = None,
keep_separator: bool = True,
**kwargs: Any,
self,
separators: Optional[list[str]] = None,
keep_separator: bool = True,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
@ -508,7 +471,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
break
if re.search(_s, text):
separator = _s
new_separators = separators[i + 1:]
new_separators = separators[i + 1 :]
break
splits = _split_text_with_regex(text, separator, self._keep_separator)