mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@ -2,37 +2,35 @@ import re
|
||||
|
||||
|
||||
class CleanProcessor:
|
||||
|
||||
@classmethod
|
||||
def clean(cls, text: str, process_rule: dict) -> str:
|
||||
# default clean
|
||||
# remove invalid symbol
|
||||
text = re.sub(r'<\|', '<', text)
|
||||
text = re.sub(r'\|>', '>', text)
|
||||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
|
||||
text = re.sub(r"<\|", "<", text)
|
||||
text = re.sub(r"\|>", ">", text)
|
||||
text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text)
|
||||
# Unicode U+FFFE
|
||||
text = re.sub('\uFFFE', '', text)
|
||||
text = re.sub("\ufffe", "", text)
|
||||
|
||||
rules = process_rule['rules'] if process_rule else None
|
||||
if 'pre_processing_rules' in rules:
|
||||
rules = process_rule["rules"] if process_rule else None
|
||||
if "pre_processing_rules" in rules:
|
||||
pre_processing_rules = rules["pre_processing_rules"]
|
||||
for pre_processing_rule in pre_processing_rules:
|
||||
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
|
||||
# Remove extra spaces
|
||||
pattern = r'\n{3,}'
|
||||
text = re.sub(pattern, '\n\n', text)
|
||||
pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
|
||||
text = re.sub(pattern, ' ', text)
|
||||
pattern = r"\n{3,}"
|
||||
text = re.sub(pattern, "\n\n", text)
|
||||
pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}"
|
||||
text = re.sub(pattern, " ", text)
|
||||
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
|
||||
# Remove email
|
||||
pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
|
||||
text = re.sub(pattern, '', text)
|
||||
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
|
||||
text = re.sub(pattern, "", text)
|
||||
|
||||
# Remove URL
|
||||
pattern = r'https?://[^\s]+'
|
||||
text = re.sub(pattern, '', text)
|
||||
pattern = r"https?://[^\s]+"
|
||||
text = re.sub(pattern, "", text)
|
||||
return text
|
||||
|
||||
def filter_string(self, text):
|
||||
|
||||
return text
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
"""Abstract interface for document cleaner implementations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseCleaner(ABC):
|
||||
"""Interface for clean chunk content.
|
||||
"""
|
||||
"""Interface for clean chunk content."""
|
||||
|
||||
@abstractmethod
|
||||
def clean(self, content: str):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
from unstructured.cleaners.core import clean_extra_whitespace
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
import re
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
from unstructured.cleaners.core import clean_non_ascii_chars
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""Replaces unicode quote characters, such as the \x91 character in a string."""
|
||||
|
||||
from unstructured.cleaners.core import replace_unicode_quotes
|
||||
|
||||
return replace_unicode_quotes(content)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""Abstract interface for document clean implementations."""
|
||||
|
||||
from core.rag.cleaner.cleaner_base import BaseCleaner
|
||||
|
||||
|
||||
class UnstructuredTranslateTextCleaner(BaseCleaner):
|
||||
|
||||
def clean(self, content) -> str:
|
||||
"""clean document content."""
|
||||
from unstructured.cleaners.translate import translate_text
|
||||
|
||||
@ -12,17 +12,27 @@ from core.rag.rerank.weight_rerank import WeightRerankRunner
|
||||
|
||||
|
||||
class DataPostProcessor:
|
||||
"""Interface for data post-processing document.
|
||||
"""
|
||||
"""Interface for data post-processing document."""
|
||||
|
||||
def __init__(self, tenant_id: str, reranking_mode: str,
|
||||
reranking_model: Optional[dict] = None, weights: Optional[dict] = None,
|
||||
reorder_enabled: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
reranking_mode: str,
|
||||
reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None,
|
||||
reorder_enabled: bool = False,
|
||||
):
|
||||
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
|
||||
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
|
||||
|
||||
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
|
||||
def invoke(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[Document],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> list[Document]:
|
||||
if self.rerank_runner:
|
||||
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
|
||||
|
||||
@ -31,21 +41,26 @@ class DataPostProcessor:
|
||||
|
||||
return documents
|
||||
|
||||
def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]:
|
||||
def _get_rerank_runner(
|
||||
self,
|
||||
reranking_mode: str,
|
||||
tenant_id: str,
|
||||
reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None,
|
||||
) -> Optional[RerankModelRunner | WeightRerankRunner]:
|
||||
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
|
||||
return WeightRerankRunner(
|
||||
tenant_id,
|
||||
Weights(
|
||||
vector_setting=VectorSetting(
|
||||
vector_weight=weights['vector_setting']['vector_weight'],
|
||||
embedding_provider_name=weights['vector_setting']['embedding_provider_name'],
|
||||
embedding_model_name=weights['vector_setting']['embedding_model_name'],
|
||||
vector_weight=weights["vector_setting"]["vector_weight"],
|
||||
embedding_provider_name=weights["vector_setting"]["embedding_provider_name"],
|
||||
embedding_model_name=weights["vector_setting"]["embedding_model_name"],
|
||||
),
|
||||
keyword_setting=KeywordSetting(
|
||||
keyword_weight=weights['keyword_setting']['keyword_weight'],
|
||||
)
|
||||
)
|
||||
keyword_weight=weights["keyword_setting"]["keyword_weight"],
|
||||
),
|
||||
),
|
||||
)
|
||||
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
|
||||
if reranking_model:
|
||||
@ -53,9 +68,9 @@ class DataPostProcessor:
|
||||
model_manager = ModelManager()
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model['reranking_provider_name'],
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model['reranking_model_name']
|
||||
model=reranking_model["reranking_model_name"],
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return None
|
||||
@ -67,5 +82,3 @@ class DataPostProcessor:
|
||||
if reorder_enabled:
|
||||
return ReorderRunner()
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@ from core.rag.models.document import Document
|
||||
|
||||
|
||||
class ReorderRunner:
|
||||
|
||||
def run(self, documents: list[Document]) -> list[Document]:
|
||||
# Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list
|
||||
odd_elements = documents[::2]
|
||||
|
||||
@ -24,37 +24,42 @@ class Jieba(BaseKeyword):
|
||||
self._config = KeywordTableConfig()
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
return self
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keywords_list = kwargs.get('keywords_list', None)
|
||||
keywords_list = kwargs.get("keywords_list", None)
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
keywords = keywords_list[i]
|
||||
if not keywords:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content,
|
||||
self._config.max_keywords_per_chunk)
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@ -63,97 +68,91 @@ class Jieba(BaseKeyword):
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
k = kwargs.get('top_k', 4)
|
||||
k = kwargs.get("top_k", 4)
|
||||
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id,
|
||||
DocumentSegment.index_node_id == chunk_index
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
|
||||
.first()
|
||||
)
|
||||
|
||||
if segment:
|
||||
|
||||
documents.append(Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": chunk_index,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
))
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": chunk_index,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
if dataset_keyword_table.data_source_type != 'database':
|
||||
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
|
||||
if dataset_keyword_table.data_source_type != "database":
|
||||
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
|
||||
storage.delete(file_key)
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
keyword_table_dict = {
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": keyword_table
|
||||
}
|
||||
"__type__": "keyword_table",
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
||||
}
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type
|
||||
if keyword_data_source_type == 'database':
|
||||
if keyword_data_source_type == "database":
|
||||
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
|
||||
db.session.commit()
|
||||
else:
|
||||
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
|
||||
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))
|
||||
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
|
||||
|
||||
def _get_dataset_keyword_table(self) -> Optional[dict]:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
return keyword_table_dict['__data__']['table']
|
||||
return keyword_table_dict["__data__"]["table"]
|
||||
else:
|
||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table='',
|
||||
keyword_table="",
|
||||
data_source_type=keyword_data_source_type,
|
||||
)
|
||||
if keyword_data_source_type == 'database':
|
||||
dataset_keyword_table.keyword_table = json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
if keyword_data_source_type == "database":
|
||||
dataset_keyword_table.keyword_table = json.dumps(
|
||||
{
|
||||
"__type__": "keyword_table",
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
|
||||
},
|
||||
cls=SetEncoder,
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
@ -174,9 +173,7 @@ class Jieba(BaseKeyword):
|
||||
keywords_to_delete = set()
|
||||
for keyword, node_idxs in keyword_table.items():
|
||||
if node_idxs_to_delete.intersection(node_idxs):
|
||||
keyword_table[keyword] = node_idxs.difference(
|
||||
node_idxs_to_delete
|
||||
)
|
||||
keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete)
|
||||
if not keyword_table[keyword]:
|
||||
keywords_to_delete.add(keyword)
|
||||
|
||||
@ -202,13 +199,14 @@ class Jieba(BaseKeyword):
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return sorted_chunk_indices[: k]
|
||||
return sorted_chunk_indices[:k]
|
||||
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id == node_id
|
||||
).first()
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
|
||||
.first()
|
||||
)
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.add(document_segment)
|
||||
@ -224,14 +222,14 @@ class Jieba(BaseKeyword):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for pre_segment_data in pre_segment_data_list:
|
||||
segment = pre_segment_data['segment']
|
||||
if pre_segment_data['keywords']:
|
||||
segment.keywords = pre_segment_data['keywords']
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
|
||||
pre_segment_data['keywords'])
|
||||
segment = pre_segment_data["segment"]
|
||||
if pre_segment_data["keywords"]:
|
||||
segment.keywords = pre_segment_data["keywords"]
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table, segment.index_node_id, pre_segment_data["keywords"]
|
||||
)
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content,
|
||||
self._config.max_keywords_per_chunk)
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
|
||||
segment.keywords = list(keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@ -8,7 +8,6 @@ from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
|
||||
def __init__(self):
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
@ -30,4 +29,4 @@ class JiebaKeywordTableHandler:
|
||||
if len(sub_tokens) > 1:
|
||||
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
|
||||
|
||||
return results
|
||||
return results
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -8,7 +8,6 @@ from models.dataset import Dataset
|
||||
|
||||
|
||||
class BaseKeyword(ABC):
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
@ -31,15 +30,12 @@ class BaseKeyword(ABC):
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
doc_id = text.metadata['doc_id']
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
@ -47,4 +43,4 @@ class BaseKeyword(ABC):
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata['doc_id'] for text in texts]
|
||||
return [text.metadata["doc_id"] for text in texts]
|
||||
|
||||
@ -20,9 +20,7 @@ class Keyword:
|
||||
raise ValueError("Keyword store must be specified.")
|
||||
|
||||
if keyword_type == "jieba":
|
||||
return Jieba(
|
||||
dataset=self._dataset
|
||||
)
|
||||
return Jieba(dataset=self._dataset)
|
||||
else:
|
||||
raise ValueError(f"Keyword store {keyword_type} is not supported.")
|
||||
|
||||
@ -41,10 +39,7 @@ class Keyword:
|
||||
def delete(self) -> None:
|
||||
self._keyword_processor.delete()
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._keyword_processor.search(query, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
||||
@ -12,73 +12,83 @@ from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def retrieve(cls, retrieval_method: str, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float] = .0,
|
||||
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model',
|
||||
weights: Optional[dict] = None):
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
def retrieve(
|
||||
cls,
|
||||
retrieval_method: str,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float] = 0.0,
|
||||
reranking_model: Optional[dict] = None,
|
||||
reranking_mode: Optional[str] = "reranking_model",
|
||||
weights: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
all_documents = []
|
||||
threads = []
|
||||
exceptions = []
|
||||
# retrieval_model source with keyword
|
||||
if retrieval_method == 'keyword_search':
|
||||
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'all_documents': all_documents,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
if retrieval_method == "keyword_search":
|
||||
keyword_thread = threading.Thread(
|
||||
target=RetrievalService.keyword_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(keyword_thread)
|
||||
keyword_thread.start()
|
||||
# retrieval_model source with semantic
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'score_threshold': score_threshold,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents,
|
||||
'retrieval_method': retrieval_method,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
embedding_thread = threading.Thread(
|
||||
target=RetrievalService.embedding_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"reranking_model": reranking_model,
|
||||
"all_documents": all_documents,
|
||||
"retrieval_method": retrieval_method,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval source with full text
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'retrieval_method': retrieval_method,
|
||||
'score_threshold': score_threshold,
|
||||
'top_k': top_k,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
full_text_index_thread = threading.Thread(
|
||||
target=RetrievalService.full_text_index_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"retrieval_method": retrieval_method,
|
||||
"score_threshold": score_threshold,
|
||||
"top_k": top_k,
|
||||
"reranking_model": reranking_model,
|
||||
"all_documents": all_documents,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
@ -86,110 +96,117 @@ class RetrievalService:
|
||||
thread.join()
|
||||
|
||||
if exceptions:
|
||||
exception_message = ';\n'.join(exceptions)
|
||||
exception_message = ";\n".join(exceptions)
|
||||
raise Exception(exception_message)
|
||||
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
|
||||
reranking_model, weights, False)
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||
)
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k
|
||||
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, all_documents: list, exceptions: list):
|
||||
def keyword_search(
|
||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
keyword = Keyword(
|
||||
dataset=dataset
|
||||
)
|
||||
keyword = Keyword(dataset=dataset)
|
||||
|
||||
documents = keyword.search(
|
||||
cls.escape_query_for_search(query),
|
||||
top_k=top_k
|
||||
)
|
||||
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
exceptions.append(str(e))
|
||||
|
||||
@classmethod
|
||||
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, retrieval_method: str, exceptions: list):
|
||||
def embedding_search(
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float],
|
||||
reranking_model: Optional[dict],
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
vector = Vector(
|
||||
dataset=dataset
|
||||
)
|
||||
vector = Vector(dataset=dataset)
|
||||
|
||||
documents = vector.search_by_vector(
|
||||
cls.escape_query_for_search(query),
|
||||
search_type='similarity_score_threshold',
|
||||
search_type="similarity_score_threshold",
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
filter={"group_id": [dataset.id]},
|
||||
)
|
||||
|
||||
if documents:
|
||||
if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
|
||||
RerankMode.RERANKING_MODEL.value,
|
||||
reranking_model, None, False)
|
||||
all_documents.extend(data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
|
||||
)
|
||||
)
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
exceptions.append(str(e))
|
||||
|
||||
@classmethod
|
||||
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, retrieval_method: str, exceptions: list):
|
||||
def full_text_index_search(
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float],
|
||||
reranking_model: Optional[dict],
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
vector_processor = Vector(
|
||||
dataset=dataset,
|
||||
)
|
||||
|
||||
documents = vector_processor.search_by_full_text(
|
||||
cls.escape_query_for_search(query),
|
||||
top_k=top_k
|
||||
)
|
||||
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
|
||||
if documents:
|
||||
if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
|
||||
RerankMode.RERANKING_MODEL.value,
|
||||
reranking_model, None, False)
|
||||
all_documents.extend(data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=len(documents)
|
||||
))
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
|
||||
)
|
||||
)
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
@ -197,4 +214,4 @@ class RetrievalService:
|
||||
|
||||
@staticmethod
|
||||
def escape_query_for_search(query: str) -> str:
|
||||
return query.replace('"', '\\"')
|
||||
return query.replace('"', '\\"')
|
||||
|
||||
@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel):
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = ("cosine",)
|
||||
read_timeout: int = 60000
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
@ -37,6 +38,7 @@ class AnalyticdbConfig(BaseModel):
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
_instance = None
|
||||
_init = False
|
||||
@ -57,9 +59,7 @@ class AnalyticdbVector(BaseVector):
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(
|
||||
user_agent="dify", **config.to_analyticdb_client_params()
|
||||
)
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
AnalyticdbVector._init = True
|
||||
@ -77,6 +77,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@ -88,6 +89,7 @@ class AnalyticdbVector(BaseVector):
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@ -109,13 +111,12 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create namespace {self.config.namespace}: {e}"
|
||||
)
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
@ -149,9 +150,7 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create collection {self._collection_name}: {e}"
|
||||
)
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def get_type(self) -> str:
|
||||
@ -162,10 +161,9 @@ class AnalyticdbVector(BaseVector):
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(
|
||||
self, documents: list[Document], embeddings: list[list[float]], **kwargs
|
||||
):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
@ -191,6 +189,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@ -202,13 +201,14 @@ class AnalyticdbVector(BaseVector):
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'"
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
@ -224,6 +224,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@ -235,15 +236,10 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(
|
||||
self, query_vector: list[float], **kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@ -270,11 +266,8 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@ -304,6 +297,7 @@ class AnalyticdbVector(BaseVector):
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@ -315,19 +309,16 @@ class AnalyticdbVector(BaseVector):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"][
|
||||
"class_prefix"
|
||||
]
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
|
||||
)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
|
||||
|
||||
# handle optional params
|
||||
if dify_config.ANALYTICDB_KEY_ID is None:
|
||||
|
||||
@ -27,21 +27,20 @@ class ChromaConfig(BaseModel):
|
||||
settings = Settings(
|
||||
# auth
|
||||
chroma_client_auth_provider=self.auth_provider,
|
||||
chroma_client_auth_credentials=self.auth_credentials
|
||||
chroma_client_auth_credentials=self.auth_credentials,
|
||||
)
|
||||
|
||||
return {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'ssl': False,
|
||||
'tenant': self.tenant,
|
||||
'database': self.database,
|
||||
'settings': settings,
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"ssl": False,
|
||||
"tenant": self.tenant,
|
||||
"database": self.database,
|
||||
"settings": settings,
|
||||
}
|
||||
|
||||
|
||||
class ChromaVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: ChromaConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
@ -58,9 +57,9 @@ class ChromaVector(BaseVector):
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(self, collection_name: str):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
self._client.get_or_create_collection(collection_name)
|
||||
@ -76,7 +75,7 @@ class ChromaVector(BaseVector):
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(where={key: {'$eq': value}})
|
||||
collection.delete(where={key: {"$eq": value}})
|
||||
|
||||
def delete(self):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
@ -93,26 +92,26 @@ class ChromaVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
|
||||
ids: list[str] = results['ids'][0]
|
||||
documents: list[str] = results['documents'][0]
|
||||
metadatas: dict[str, Any] = results['metadatas'][0]
|
||||
distances: list[float] = results['distances'][0]
|
||||
ids: list[str] = results["ids"][0]
|
||||
documents: list[str] = results["documents"][0]
|
||||
metadatas: dict[str, Any] = results["metadatas"][0]
|
||||
distances: list[float] = results["distances"][0]
|
||||
|
||||
docs = []
|
||||
for index in range(len(ids)):
|
||||
distance = distances[index]
|
||||
metadata = metadatas[index]
|
||||
if distance >= score_threshold:
|
||||
metadata['score'] = distance
|
||||
metadata["score"] = distance
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@ -123,15 +122,12 @@ class ChromaVector(BaseVector):
|
||||
class ChromaVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
index_struct_dict = {
|
||||
"type": VectorType.CHROMA,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
return ChromaVector(
|
||||
|
||||
@ -26,15 +26,15 @@ class ElasticSearchConfig(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values["host"]:
|
||||
raise ValueError("config HOST is required")
|
||||
if not values['port']:
|
||||
if not values["port"]:
|
||||
raise ValueError("config PORT is required")
|
||||
if not values['username']:
|
||||
if not values["username"]:
|
||||
raise ValueError("config USERNAME is required")
|
||||
if not values['password']:
|
||||
if not values["password"]:
|
||||
raise ValueError("config PASSWORD is required")
|
||||
return values
|
||||
|
||||
@ -50,10 +50,10 @@ class ElasticSearchVector(BaseVector):
|
||||
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
||||
try:
|
||||
parsed_url = urlparse(config.host)
|
||||
if parsed_url.scheme in ['http', 'https']:
|
||||
hosts = f'{config.host}:{config.port}'
|
||||
if parsed_url.scheme in ["http", "https"]:
|
||||
hosts = f"{config.host}:{config.port}"
|
||||
else:
|
||||
hosts = f'http://{config.host}:{config.port}'
|
||||
hosts = f"http://{config.host}:{config.port}"
|
||||
client = Elasticsearch(
|
||||
hosts=hosts,
|
||||
basic_auth=(config.username, config.password),
|
||||
@ -68,25 +68,27 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def _get_version(self) -> str:
|
||||
info = self._client.info()
|
||||
return info['version']['number']
|
||||
return info["version"]["number"]
|
||||
|
||||
def _check_version(self):
|
||||
if self._version < '8.0.0':
|
||||
if self._version < "8.0.0":
|
||||
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'elasticsearch'
|
||||
return "elasticsearch"
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
self._client.index(index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
|
||||
})
|
||||
self._client.index(
|
||||
index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
|
||||
},
|
||||
)
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
return uuids
|
||||
|
||||
@ -98,15 +100,9 @@ class ElasticSearchVector(BaseVector):
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
query_str = {
|
||||
'query': {
|
||||
'match': {
|
||||
f'metadata.{key}': f'{value}'
|
||||
}
|
||||
}
|
||||
}
|
||||
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str)
|
||||
ids = [hit['_id'] for hit in results['hits']['hits']]
|
||||
ids = [hit["_id"] for hit in results["hits"]["hits"]]
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
@ -115,44 +111,44 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
knn = {
|
||||
"field": Field.VECTOR.value,
|
||||
"query_vector": query_vector,
|
||||
"k": top_k
|
||||
}
|
||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k}
|
||||
|
||||
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
||||
|
||||
docs_and_scores = []
|
||||
for hit in results['hits']['hits']:
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs_and_scores.append(
|
||||
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
|
||||
vector=hit['_source'][Field.VECTOR.value],
|
||||
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
)
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
if score > score_threshold:
|
||||
doc.metadata['score'] = score
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {
|
||||
"match": {
|
||||
Field.CONTENT_KEY.value: query
|
||||
}
|
||||
}
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str)
|
||||
docs = []
|
||||
for hit in results['hits']['hits']:
|
||||
docs.append(Document(
|
||||
page_content=hit['_source'][Field.CONTENT_KEY.value],
|
||||
vector=hit['_source'][Field.VECTOR.value],
|
||||
metadata=hit['_source'][Field.METADATA_KEY.value],
|
||||
))
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
)
|
||||
)
|
||||
|
||||
return docs
|
||||
|
||||
@ -162,11 +158,11 @@ class ElasticSearchVector(BaseVector):
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
lock_name = f'vector_indexing_lock_{self._collection_name}'
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
@ -179,14 +175,14 @@ class ElasticSearchVector(BaseVector):
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"similarity": "cosine"
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
self._client.indices.create(index=self._collection_name, mappings=mappings)
|
||||
@ -197,22 +193,21 @@ class ElasticSearchVector(BaseVector):
|
||||
class ElasticSearchVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return ElasticSearchVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get('ELASTICSEARCH_HOST'),
|
||||
port=config.get('ELASTICSEARCH_PORT'),
|
||||
username=config.get('ELASTICSEARCH_USERNAME'),
|
||||
password=config.get('ELASTICSEARCH_PASSWORD'),
|
||||
host=config.get("ELASTICSEARCH_HOST"),
|
||||
port=config.get("ELASTICSEARCH_PORT"),
|
||||
username=config.get("ELASTICSEARCH_USERNAME"),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD"),
|
||||
),
|
||||
attributes=[]
|
||||
attributes=[],
|
||||
)
|
||||
|
||||
@ -27,44 +27,39 @@ class MilvusConfig(BaseModel):
|
||||
batch_size: int = 100
|
||||
database: str = "default"
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get('uri'):
|
||||
if not values.get("uri"):
|
||||
raise ValueError("config MILVUS_URI is required")
|
||||
if not values.get('user'):
|
||||
if not values.get("user"):
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values.get('password'):
|
||||
if not values.get("password"):
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
return {
|
||||
'uri': self.uri,
|
||||
'token': self.token,
|
||||
'user': self.user,
|
||||
'password': self.password,
|
||||
'db_name': self.database,
|
||||
"uri": self.uri,
|
||||
"token": self.token,
|
||||
"user": self.user,
|
||||
"password": self.password,
|
||||
"db_name": self.database,
|
||||
}
|
||||
|
||||
|
||||
class MilvusVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: MilvusConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._consistency_level = 'Session'
|
||||
self._consistency_level = "Session"
|
||||
self._fields = []
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MILVUS
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {
|
||||
'metric_type': 'IP',
|
||||
'index_type': "HNSW",
|
||||
'params': {"M": 8, "efConstruction": 64}
|
||||
}
|
||||
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
||||
metadatas = [d.metadata for d in texts]
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
@ -75,7 +70,7 @@ class MilvusVector(BaseVector):
|
||||
insert_dict = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
insert_dict_list.append(insert_dict)
|
||||
# Total insert count
|
||||
@ -84,22 +79,20 @@ class MilvusVector(BaseVector):
|
||||
pks: list[str] = []
|
||||
|
||||
for i in range(0, total_count, 1000):
|
||||
batch_insert_list = insert_dict_list[i:i + 1000]
|
||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||
# Insert into the collection.
|
||||
try:
|
||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||
pks.extend(ids)
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to insert batch starting at entity: %s/%s", i, total_count
|
||||
)
|
||||
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count)
|
||||
raise e
|
||||
return pks
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["{key}"] == "{value}"',
|
||||
output_fields=["id"])
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
|
||||
)
|
||||
if result:
|
||||
return [item["id"] for item in result]
|
||||
else:
|
||||
@ -107,17 +100,15 @@ class MilvusVector(BaseVector):
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
if self._client.has_collection(self._collection_name):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if self._client.has_collection(self._collection_name):
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] in {ids}',
|
||||
output_fields=["id"])
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
|
||||
)
|
||||
if result:
|
||||
ids = [item["id"] for item in result]
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
@ -130,29 +121,28 @@ class MilvusVector(BaseVector):
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
return False
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] == "{id}"',
|
||||
output_fields=["id"])
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"]
|
||||
)
|
||||
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
||||
# Set search parameters.
|
||||
results = self._client.search(collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
limit=kwargs.get('top_k', 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result['entity'].get(Field.METADATA_KEY.value)
|
||||
metadata['score'] = result['distance']
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
if result['distance'] > score_threshold:
|
||||
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
|
||||
metadata=metadata)
|
||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||
metadata["score"] = result["distance"]
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -161,11 +151,11 @@ class MilvusVector(BaseVector):
|
||||
return []
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
# Grab the existing collection if it exists
|
||||
@ -180,19 +170,11 @@ class MilvusVector(BaseVector):
|
||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field
|
||||
fields.append(
|
||||
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
||||
)
|
||||
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
|
||||
# Create the primary key field
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
|
||||
)
|
||||
)
|
||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(
|
||||
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
|
||||
)
|
||||
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
@ -208,9 +190,12 @@ class MilvusVector(BaseVector):
|
||||
|
||||
# Create the collection
|
||||
collection_name = self._collection_name
|
||||
self._client.create_collection(collection_name=collection_name,
|
||||
schema=schema, index_params=index_params_obj,
|
||||
consistency_level=self._consistency_level)
|
||||
self._client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params_obj,
|
||||
consistency_level=self._consistency_level,
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
@ -221,13 +206,12 @@ class MilvusVector(BaseVector):
|
||||
class MilvusVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
|
||||
|
||||
return MilvusVector(
|
||||
collection_name=collection_name,
|
||||
@ -237,5 +221,5 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||
user=dify_config.MILVUS_USER,
|
||||
password=dify_config.MILVUS_PASSWORD,
|
||||
database=dify_config.MILVUS_DATABASE,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@ -31,7 +31,6 @@ class SortOrder(Enum):
|
||||
|
||||
|
||||
class MyScaleVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
@ -80,7 +79,7 @@ class MyScaleVector(BaseVector):
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {}
|
||||
json.dumps(doc.metadata) if doc.metadata else {},
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
@ -101,7 +100,8 @@ class MyScaleVector(BaseVector):
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
rows = self._client.query(
|
||||
@ -122,9 +122,12 @@ class MyScaleVector(BaseVector):
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get('score_threshold') or 0.0
|
||||
where_str = f"WHERE dist < {1 - score_threshold}" if \
|
||||
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
else ""
|
||||
)
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
@ -133,7 +136,7 @@ class MyScaleVector(BaseVector):
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
vector=r['vector'],
|
||||
vector=r["vector"],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
@ -149,13 +152,12 @@ class MyScaleVector(BaseVector):
|
||||
class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
|
||||
return MyScaleVector(
|
||||
collection_name=collection_name,
|
||||
|
||||
@ -28,11 +28,11 @@ class OpenSearchConfig(BaseModel):
|
||||
password: Optional[str] = None
|
||||
secure: bool = False
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get('host'):
|
||||
if not values.get("host"):
|
||||
raise ValueError("config OPENSEARCH_HOST is required")
|
||||
if not values.get('port'):
|
||||
if not values.get("port"):
|
||||
raise ValueError("config OPENSEARCH_PORT is required")
|
||||
return values
|
||||
|
||||
@ -44,19 +44,18 @@ class OpenSearchConfig(BaseModel):
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
'hosts': [{'host': self.host, 'port': self.port}],
|
||||
'use_ssl': self.secure,
|
||||
'verify_certs': self.secure,
|
||||
"hosts": [{"host": self.host, "port": self.port}],
|
||||
"use_ssl": self.secure,
|
||||
"verify_certs": self.secure,
|
||||
}
|
||||
if self.user and self.password:
|
||||
params['http_auth'] = (self.user, self.password)
|
||||
params["http_auth"] = (self.user, self.password)
|
||||
if self.secure:
|
||||
params['ssl_context'] = self.create_ssl_context()
|
||||
params["ssl_context"] = self.create_ssl_context()
|
||||
return params
|
||||
|
||||
|
||||
class OpenSearchVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: OpenSearchConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
@ -81,7 +80,7 @@ class OpenSearchVector(BaseVector):
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
},
|
||||
}
|
||||
actions.append(action)
|
||||
|
||||
@ -90,8 +89,8 @@ class OpenSearchVector(BaseVector):
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
if response['hits']['hits']:
|
||||
return [hit['_id'] for hit in response['hits']['hits']]
|
||||
if response["hits"]["hits"]:
|
||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -110,7 +109,7 @@ class OpenSearchVector(BaseVector):
|
||||
actual_ids = []
|
||||
|
||||
for doc_id in ids:
|
||||
es_ids = self.get_ids_by_metadata_field('doc_id', doc_id)
|
||||
es_ids = self.get_ids_by_metadata_field("doc_id", doc_id)
|
||||
if es_ids:
|
||||
actual_ids.extend(es_ids)
|
||||
else:
|
||||
@ -122,9 +121,9 @@ class OpenSearchVector(BaseVector):
|
||||
helpers.bulk(self._client, actions)
|
||||
except BulkIndexError as e:
|
||||
for error in e.errors:
|
||||
delete_error = error.get('delete', {})
|
||||
status = delete_error.get('status')
|
||||
doc_id = delete_error.get('_id')
|
||||
delete_error = error.get("delete", {})
|
||||
status = delete_error.get("status")
|
||||
doc_id = delete_error.get("_id")
|
||||
|
||||
if status == 404:
|
||||
logger.warning(f"Document not found for deletion: {doc_id}")
|
||||
@ -151,15 +150,8 @@ class OpenSearchVector(BaseVector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
query = {
|
||||
"size": kwargs.get('top_k', 4),
|
||||
"query": {
|
||||
"knn": {
|
||||
Field.VECTOR.value: {
|
||||
Field.VECTOR.value: query_vector,
|
||||
"k": kwargs.get('top_k', 4)
|
||||
}
|
||||
}
|
||||
}
|
||||
"size": kwargs.get("top_k", 4),
|
||||
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
}
|
||||
|
||||
try:
|
||||
@ -169,17 +161,17 @@ class OpenSearchVector(BaseVector):
|
||||
raise
|
||||
|
||||
docs = []
|
||||
for hit in response['hits']['hits']:
|
||||
metadata = hit['_source'].get(Field.METADATA_KEY.value, {})
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
|
||||
|
||||
# Make sure metadata is a dictionary
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
metadata['score'] = hit['_score']
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
if hit['_score'] > score_threshold:
|
||||
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if hit["_score"] > score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@ -190,32 +182,28 @@ class OpenSearchVector(BaseVector):
|
||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||
|
||||
docs = []
|
||||
for hit in response['hits']['hits']:
|
||||
metadata = hit['_source'].get(Field.METADATA_KEY.value)
|
||||
vector = hit['_source'].get(Field.VECTOR.value)
|
||||
page_content = hit['_source'].get(Field.CONTENT_KEY.value)
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value)
|
||||
vector = hit["_source"].get(Field.VECTOR.value)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
lock_name = f'vector_indexing_lock_{self._collection_name.lower()}'
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}'
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name.lower()} already exists.")
|
||||
return
|
||||
|
||||
if not self._client.indices.exists(index=self._collection_name.lower()):
|
||||
index_body = {
|
||||
"settings": {
|
||||
"index": {
|
||||
"knn": True
|
||||
}
|
||||
},
|
||||
"settings": {"index": {"knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
@ -226,20 +214,17 @@ class OpenSearchVector(BaseVector):
|
||||
"name": "hnsw",
|
||||
"space_type": "l2",
|
||||
"engine": "faiss",
|
||||
"parameters": {
|
||||
"ef_construction": 64,
|
||||
"m": 8
|
||||
}
|
||||
}
|
||||
"parameters": {"ef_construction": 64, "m": 8},
|
||||
},
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
|
||||
@ -248,17 +233,14 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
|
||||
class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
|
||||
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
|
||||
|
||||
open_search_config = OpenSearchConfig(
|
||||
host=dify_config.OPENSEARCH_HOST,
|
||||
@ -268,7 +250,4 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
)
|
||||
|
||||
return OpenSearchVector(
|
||||
collection_name=collection_name,
|
||||
config=open_search_config
|
||||
)
|
||||
return OpenSearchVector(collection_name=collection_name, config=open_search_config)
|
||||
|
||||
@ -31,7 +31,7 @@ class OracleVectorConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config ORACLE_HOST is required")
|
||||
@ -103,9 +103,16 @@ class OracleVector(BaseVector):
|
||||
arraysize=cursor.arraysize,
|
||||
outconverter=self.numpy_converter_out,
|
||||
)
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1)
|
||||
|
||||
def _create_connection_pool(self, config: OracleVectorConfig):
|
||||
return oracledb.create_pool(
|
||||
user=config.user,
|
||||
password=config.password,
|
||||
dsn="{}:{}/{}".format(config.host, config.port, config.database),
|
||||
min=1,
|
||||
max=50,
|
||||
increment=1,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
@ -136,13 +143,15 @@ class OracleVector(BaseVector):
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
#array.array("f", embeddings[i]),
|
||||
# array.array("f", embeddings[i]),
|
||||
numpy.array(embeddings[i]),
|
||||
)
|
||||
)
|
||||
#print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
|
||||
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
|
||||
with self._get_cursor() as cur:
|
||||
cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values)
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
|
||||
)
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
@ -157,7 +166,8 @@ class OracleVector(BaseVector):
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
return docs
|
||||
#def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
# def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
# with self._get_cursor() as cur:
|
||||
# cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" )
|
||||
# idss = []
|
||||
@ -184,7 +194,8 @@ class OracleVector(BaseVector):
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)]
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
@ -202,7 +213,7 @@ class OracleVector(BaseVector):
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if len(query) > 0:
|
||||
# Check which language the query is in
|
||||
zh_pattern = re.compile('[\u4e00-\u9fa5]+')
|
||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||
match = zh_pattern.search(query)
|
||||
entities = []
|
||||
# match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split
|
||||
@ -210,7 +221,15 @@ class OracleVector(BaseVector):
|
||||
words = pseg.cut(query)
|
||||
current_entity = ""
|
||||
for word, pos in words:
|
||||
if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名
|
||||
if (
|
||||
pos == "nr"
|
||||
or pos == "Ng"
|
||||
or pos == "eng"
|
||||
or pos == "nz"
|
||||
or pos == "n"
|
||||
or pos == "ORG"
|
||||
or pos == "v"
|
||||
): # nr: 人名, ns: 地名, nt: 机构名
|
||||
current_entity += word
|
||||
else:
|
||||
if current_entity:
|
||||
@ -220,22 +239,22 @@ class OracleVector(BaseVector):
|
||||
entities.append(current_entity)
|
||||
else:
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
nltk.data.find('corpora/stopwords')
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
nltk.data.find("corpora/stopwords")
|
||||
except LookupError:
|
||||
nltk.download('punkt')
|
||||
nltk.download('stopwords')
|
||||
nltk.download("punkt")
|
||||
nltk.download("stopwords")
|
||||
print("run download")
|
||||
e_str = re.sub(r'[^\w ]', '', query)
|
||||
e_str = re.sub(r"[^\w ]", "", query)
|
||||
all_tokens = nltk.word_tokenize(e_str)
|
||||
stop_words = stopwords.words('english')
|
||||
stop_words = stopwords.words("english")
|
||||
for token in all_tokens:
|
||||
if token not in stop_words:
|
||||
entities.append(token)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
|
||||
[" ACCUM ".join(entities)]
|
||||
[" ACCUM ".join(entities)],
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
@ -273,8 +292,7 @@ class OracleVectorFactory(AbstractVectorFactory):
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
|
||||
|
||||
return OracleVector(
|
||||
collection_name=collection_name,
|
||||
|
||||
@ -31,27 +31,28 @@ class PgvectoRSConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTO_RS_HOST is required")
|
||||
if not values['port']:
|
||||
if not values["port"]:
|
||||
raise ValueError("config PGVECTO_RS_PORT is required")
|
||||
if not values['user']:
|
||||
if not values["user"]:
|
||||
raise ValueError("config PGVECTO_RS_USER is required")
|
||||
if not values['password']:
|
||||
if not values["password"]:
|
||||
raise ValueError("config PGVECTO_RS_PASSWORD is required")
|
||||
if not values['database']:
|
||||
if not values["database"]:
|
||||
raise ValueError("config PGVECTO_RS_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
class PGVectoRS(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
self._url = (
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self._client = create_engine(self._url)
|
||||
with Session(self._client) as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||
@ -80,9 +81,9 @@ class PGVectoRS(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def create_collection(self, dimension: int):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
@ -133,9 +134,7 @@ class PGVectoRS(BaseVector):
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = None
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; "
|
||||
)
|
||||
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ")
|
||||
result = session.execute(select_statement).fetchall()
|
||||
if result:
|
||||
return [item[0] for item in result]
|
||||
@ -143,12 +142,11 @@ class PGVectoRS(BaseVector):
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {'ids': ids})
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
@ -156,13 +154,13 @@ class PGVectoRS(BaseVector):
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
|
||||
)
|
||||
result = session.execute(select_statement, {'doc_ids': ids}).fetchall()
|
||||
result = session.execute(select_statement, {"doc_ids": ids}).fetchall()
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
if ids:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
|
||||
session.execute(select_statement, {'ids': ids})
|
||||
session.execute(select_statement, {"ids": ids})
|
||||
session.commit()
|
||||
|
||||
def delete(self) -> None:
|
||||
@ -187,7 +185,7 @@ class PGVectoRS(BaseVector):
|
||||
query_vector,
|
||||
).label("distance"),
|
||||
)
|
||||
.limit(kwargs.get('top_k', 2))
|
||||
.limit(kwargs.get("top_k", 2))
|
||||
.order_by("distance")
|
||||
)
|
||||
res = session.execute(stmt)
|
||||
@ -198,11 +196,10 @@ class PGVectoRS(BaseVector):
|
||||
for record, dis in results:
|
||||
metadata = record.meta
|
||||
score = 1 - dis
|
||||
metadata['score'] = score
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
metadata["score"] = score
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if score > score_threshold:
|
||||
doc = Document(page_content=record.text,
|
||||
metadata=metadata)
|
||||
doc = Document(page_content=record.text, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -225,13 +222,12 @@ class PGVectoRS(BaseVector):
|
||||
class PGVectoRSFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
dim = len(embeddings.embed_query("pgvecto_rs"))
|
||||
|
||||
return PGVectoRS(
|
||||
@ -243,5 +239,5 @@ class PGVectoRSFactory(AbstractVectorFactory):
|
||||
password=dify_config.PGVECTO_RS_PASSWORD,
|
||||
database=dify_config.PGVECTO_RS_DATABASE,
|
||||
),
|
||||
dim=dim
|
||||
dim=dim,
|
||||
)
|
||||
|
||||
@ -24,7 +24,7 @@ class PGVectorConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTOR_HOST is required")
|
||||
@ -201,8 +201,7 @@ class PGVectorFactory(AbstractVectorFactory):
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
|
||||
|
||||
return PGVector(
|
||||
collection_name=collection_name,
|
||||
|
||||
@ -48,28 +48,25 @@ class QdrantConfig(BaseModel):
|
||||
prefer_grpc: bool = False
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {
|
||||
'path': path
|
||||
}
|
||||
return {"path": path}
|
||||
else:
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
'timeout': self.timeout,
|
||||
'verify': self.endpoint.startswith('https'),
|
||||
'grpc_port': self.grpc_port,
|
||||
'prefer_grpc': self.prefer_grpc
|
||||
"url": self.endpoint,
|
||||
"api_key": self.api_key,
|
||||
"timeout": self.timeout,
|
||||
"verify": self.endpoint.startswith("https"),
|
||||
"grpc_port": self.grpc_port,
|
||||
"prefer_grpc": self.prefer_grpc,
|
||||
}
|
||||
|
||||
|
||||
class QdrantVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
|
||||
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
||||
@ -80,10 +77,7 @@ class QdrantVector(BaseVector):
|
||||
return VectorType.QDRANT
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if texts:
|
||||
@ -97,9 +91,9 @@ class QdrantVector(BaseVector):
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(self, collection_name: str, vector_size: int):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
@ -110,12 +104,19 @@ class QdrantVector(BaseVector):
|
||||
all_collection_name.append(collection.name)
|
||||
if collection_name not in all_collection_name:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[self._distance_func],
|
||||
)
|
||||
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
|
||||
max_indexing_threads=0, on_disk=False)
|
||||
hnsw_config = HnswConfigDiff(
|
||||
m=0,
|
||||
payload_m=16,
|
||||
ef_construct=100,
|
||||
full_scan_threshold=10000,
|
||||
max_indexing_threads=0,
|
||||
on_disk=False,
|
||||
)
|
||||
self._client.recreate_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
@ -124,21 +125,24 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
|
||||
# create group_id payload index
|
||||
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
|
||||
field_schema=PayloadSchemaType.KEYWORD)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create doc_id payload index
|
||||
self._client.create_payload_index(collection_name, Field.DOC_ID.value,
|
||||
field_schema=PayloadSchemaType.KEYWORD)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create full text index
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
tokenizer=TokenizerType.MULTILINGUAL,
|
||||
min_token_len=2,
|
||||
max_token_len=20,
|
||||
lowercase=True
|
||||
lowercase=True,
|
||||
)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
|
||||
field_schema=text_index_params)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -147,26 +151,23 @@ class QdrantVector(BaseVector):
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, embeddings, metadatas, uuids, 64, self._group_id
|
||||
):
|
||||
self._client.upsert(
|
||||
collection_name=self._collection_name, points=points
|
||||
)
|
||||
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
|
||||
def _generate_rest_batches(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
texts_iterator = iter(texts)
|
||||
embeddings_iterator = iter(embeddings)
|
||||
metadatas_iterator = iter(metadatas or [])
|
||||
@ -203,13 +204,13 @@ class QdrantVector(BaseVector):
|
||||
|
||||
@classmethod
|
||||
def _build_payloads(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str,
|
||||
) -> list[dict]:
|
||||
payloads = []
|
||||
for i, text in enumerate(texts):
|
||||
@ -219,18 +220,11 @@ class QdrantVector(BaseVector):
|
||||
"calling .from_texts or .add_texts on Qdrant instance."
|
||||
)
|
||||
metadata = metadatas[i] if metadatas is not None else None
|
||||
payloads.append(
|
||||
{
|
||||
content_payload_key: text,
|
||||
metadata_payload_key: metadata,
|
||||
group_payload_key: group_id
|
||||
}
|
||||
)
|
||||
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
|
||||
|
||||
return payloads
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
@ -248,9 +242,7 @@ class QdrantVector(BaseVector):
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
@ -275,9 +267,7 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
@ -288,7 +278,6 @@ class QdrantVector(BaseVector):
|
||||
raise e
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
@ -304,9 +293,7 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
@ -324,15 +311,13 @@ class QdrantVector(BaseVector):
|
||||
all_collection_name.append(collection.name)
|
||||
if self._collection_name not in all_collection_name:
|
||||
return False
|
||||
response = self._client.retrieve(
|
||||
collection_name=self._collection_name,
|
||||
ids=[id]
|
||||
)
|
||||
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from qdrant_client.http import models
|
||||
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
@ -348,22 +333,22 @@ class QdrantVector(BaseVector):
|
||||
limit=kwargs.get("top_k", 4),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
score_threshold=kwargs.get("score_threshold", .0)
|
||||
score_threshold=kwargs.get("score_threshold", 0.0),
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
if result.score > score_threshold:
|
||||
metadata['score'] = result.score
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@ -372,6 +357,7 @@ class QdrantVector(BaseVector):
|
||||
List of documents most similar to the query text and distance for each.
|
||||
"""
|
||||
from qdrant_client.http import models
|
||||
|
||||
scroll_filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
@ -381,24 +367,21 @@ class QdrantVector(BaseVector):
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=query),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=kwargs.get('top_k', 2),
|
||||
limit=kwargs.get("top_k", 2),
|
||||
with_payload=True,
|
||||
with_vectors=True
|
||||
|
||||
with_vectors=True,
|
||||
)
|
||||
results = response[0]
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
document = self._document_from_scored_point(
|
||||
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
|
||||
)
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
@ -410,10 +393,10 @@ class QdrantVector(BaseVector):
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
) -> Document:
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
@ -425,24 +408,25 @@ class QdrantVector(BaseVector):
|
||||
class QdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
raise ValueError("Dataset Collection Bindings is not exist!")
|
||||
else:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
if not dataset.index_struct_dict:
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return QdrantVector(
|
||||
@ -454,6 +438,6 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||
root_path=config.root_path,
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
|
||||
)
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
),
|
||||
)
|
||||
|
||||
@ -33,28 +33,29 @@ class RelytConfig(BaseModel):
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values["host"]:
|
||||
raise ValueError("config RELYT_HOST is required")
|
||||
if not values['port']:
|
||||
if not values["port"]:
|
||||
raise ValueError("config RELYT_PORT is required")
|
||||
if not values['user']:
|
||||
if not values["user"]:
|
||||
raise ValueError("config RELYT_USER is required")
|
||||
if not values['password']:
|
||||
if not values["password"]:
|
||||
raise ValueError("config RELYT_PASSWORD is required")
|
||||
if not values['database']:
|
||||
if not values["database"]:
|
||||
raise ValueError("config RELYT_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
class RelytVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
|
||||
super().__init__(collection_name)
|
||||
self.embedding_dimension = 1536
|
||||
self._client_config = config
|
||||
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
self._url = (
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self.client = create_engine(self._url)
|
||||
self._fields = []
|
||||
self._group_id = group_id
|
||||
@ -70,9 +71,9 @@ class RelytVector(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def create_collection(self, dimension: int):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
@ -110,7 +111,7 @@ class RelytVector(BaseVector):
|
||||
ids = [str(uuid.uuid1()) for _ in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
for metadata in metadatas:
|
||||
metadata['group_id'] = self._group_id
|
||||
metadata["group_id"] = self._group_id
|
||||
texts = [d.page_content for d in documents]
|
||||
|
||||
# Define the table schema
|
||||
@ -127,9 +128,7 @@ class RelytVector(BaseVector):
|
||||
chunks_table_data = []
|
||||
with self.client.connect() as conn:
|
||||
with conn.begin():
|
||||
for document, metadata, chunk_id, embedding in zip(
|
||||
texts, metadatas, ids, embeddings
|
||||
):
|
||||
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
||||
chunks_table_data.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
@ -196,15 +195,13 @@ class RelytVector(BaseVector):
|
||||
return False
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self.delete_by_uuids(ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
with Session(self.client) as session:
|
||||
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
|
||||
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
|
||||
)
|
||||
@ -228,38 +225,34 @@ class RelytVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
results = self.similarity_search_with_score_by_vector(
|
||||
k=int(kwargs.get('top_k')),
|
||||
embedding=query_vector,
|
||||
filter=kwargs.get('filter')
|
||||
k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter")
|
||||
)
|
||||
|
||||
# Organize results.
|
||||
docs = []
|
||||
for document, score in results:
|
||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
if 1 - score > score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
) -> list[tuple[Document, float]]:
|
||||
# Add the filter if provided
|
||||
try:
|
||||
from sqlalchemy.engine import Row
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Row from sqlalchemy.engine. "
|
||||
"Please 'pip install sqlalchemy>=1.4'."
|
||||
)
|
||||
raise ImportError("Could not import Row from sqlalchemy.engine. " "Please 'pip install sqlalchemy>=1.4'.")
|
||||
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
conditions = [
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
|
||||
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
|
||||
if len(value) > 1
|
||||
else f"metadata->>{key!r} = {value[0]!r}"
|
||||
for key, value in filter.items()
|
||||
]
|
||||
@ -305,13 +298,12 @@ class RelytVector(BaseVector):
|
||||
class RelytVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name))
|
||||
|
||||
return RelytVector(
|
||||
collection_name=collection_name,
|
||||
@ -322,5 +314,5 @@ class RelytVectorFactory(AbstractVectorFactory):
|
||||
password=dify_config.RELYT_PASSWORD,
|
||||
database=dify_config.RELYT_DATABASE,
|
||||
),
|
||||
group_id=dataset.id
|
||||
group_id=dataset.id,
|
||||
)
|
||||
|
||||
@ -25,16 +25,11 @@ class TencentConfig(BaseModel):
|
||||
database: Optional[str]
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "L2"
|
||||
shard: int = 1,
|
||||
replicas: int = 2,
|
||||
shard: int = (1,)
|
||||
replicas: int = (2,)
|
||||
|
||||
def to_tencent_params(self):
|
||||
return {
|
||||
'url': self.url,
|
||||
'username': self.username,
|
||||
'key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
}
|
||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
|
||||
|
||||
class TencentVector(BaseVector):
|
||||
@ -61,13 +56,10 @@ class TencentVector(BaseVector):
|
||||
return self._client.create_database(database_name=self._client_config.database)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'tencent'
|
||||
return "tencent"
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def _has_collection(self) -> bool:
|
||||
collections = self._db.list_collections()
|
||||
@ -77,9 +69,9 @@ class TencentVector(BaseVector):
|
||||
return False
|
||||
|
||||
def _create_collection(self, dimension: int) -> None:
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
@ -101,9 +93,7 @@ class TencentVector(BaseVector):
|
||||
raise ValueError("unsupported metric_type")
|
||||
params = vdb_index.HNSWParams(m=16, efconstruction=200)
|
||||
index = vdb_index.Index(
|
||||
vdb_index.FilterIndex(
|
||||
self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
|
||||
),
|
||||
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
vdb_index.VectorIndex(
|
||||
self.field_vector,
|
||||
dimension,
|
||||
@ -111,12 +101,8 @@ class TencentVector(BaseVector):
|
||||
metric_type,
|
||||
params,
|
||||
),
|
||||
vdb_index.FilterIndex(
|
||||
self.field_text, enum.FieldType.String, enum.IndexType.FILTER
|
||||
),
|
||||
vdb_index.FilterIndex(
|
||||
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
|
||||
),
|
||||
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
|
||||
)
|
||||
|
||||
self._db.create_collection(
|
||||
@ -163,15 +149,14 @@ class TencentVector(BaseVector):
|
||||
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
||||
res = self._db.collection(self._collection_name).search(vectors=[query_vector],
|
||||
params=document.HNSWSearchParams(
|
||||
ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get('top_k', 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
res = self._db.collection(self._collection_name).search(
|
||||
vectors=[query_vector],
|
||||
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@ -200,15 +185,13 @@ class TencentVector(BaseVector):
|
||||
|
||||
class TencentVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
|
||||
|
||||
return TencentVector(
|
||||
collection_name=collection_name,
|
||||
@ -220,5 +203,5 @@ class TencentVectorFactory(AbstractVectorFactory):
|
||||
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
|
||||
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@ -28,47 +28,56 @@ class TiDBVectorConfig(BaseModel):
|
||||
database: str
|
||||
program_name: str
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['host']:
|
||||
if not values["host"]:
|
||||
raise ValueError("config TIDB_VECTOR_HOST is required")
|
||||
if not values['port']:
|
||||
if not values["port"]:
|
||||
raise ValueError("config TIDB_VECTOR_PORT is required")
|
||||
if not values['user']:
|
||||
if not values["user"]:
|
||||
raise ValueError("config TIDB_VECTOR_USER is required")
|
||||
if not values['password']:
|
||||
if not values["password"]:
|
||||
raise ValueError("config TIDB_VECTOR_PASSWORD is required")
|
||||
if not values['database']:
|
||||
if not values["database"]:
|
||||
raise ValueError("config TIDB_VECTOR_DATABASE is required")
|
||||
if not values['program_name']:
|
||||
if not values["program_name"]:
|
||||
raise ValueError("config APPLICATION_NAME is required")
|
||||
return values
|
||||
|
||||
|
||||
class TiDBVector(BaseVector):
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.TIDB_VECTOR
|
||||
|
||||
def _table(self, dim: int) -> Table:
|
||||
from tidb_vector.sqlalchemy import VectorType
|
||||
|
||||
return Table(
|
||||
self._collection_name,
|
||||
self._orm_base.metadata,
|
||||
Column('id', String(36), primary_key=True, nullable=False),
|
||||
Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
|
||||
Column("id", String(36), primary_key=True, nullable=False),
|
||||
Column(
|
||||
"vector",
|
||||
VectorType(dim),
|
||||
nullable=False,
|
||||
comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
|
||||
),
|
||||
Column("text", TEXT, nullable=False),
|
||||
Column("meta", JSON, nullable=False),
|
||||
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
||||
Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
|
||||
extend_existing=True
|
||||
Column(
|
||||
"update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
|
||||
),
|
||||
extend_existing=True,
|
||||
)
|
||||
|
||||
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
|
||||
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
|
||||
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}")
|
||||
self._url = (
|
||||
f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
|
||||
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}"
|
||||
)
|
||||
self._distance_func = distance_func.lower()
|
||||
self._engine = create_engine(self._url)
|
||||
self._orm_base = declarative_base()
|
||||
@ -83,9 +92,9 @@ class TiDBVector(BaseVector):
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
logger.info("_create_collection, collection_name " + self._collection_name)
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
with Session(self._engine) as session:
|
||||
@ -116,9 +125,7 @@ class TiDBVector(BaseVector):
|
||||
chunks_table_data = []
|
||||
with self._engine.connect() as conn:
|
||||
with conn.begin():
|
||||
for id, text, meta, embedding in zip(
|
||||
ids, texts, metas, embeddings
|
||||
):
|
||||
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
|
||||
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
||||
|
||||
# Execute the batch insert when the batch size is reached
|
||||
@ -133,12 +140,12 @@ class TiDBVector(BaseVector):
|
||||
return ids
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
result = self.get_ids_by_metadata_field('doc_id', id)
|
||||
result = self.get_ids_by_metadata_field("doc_id", id)
|
||||
return bool(result)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
with Session(self._engine) as session:
|
||||
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
|
||||
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
|
||||
)
|
||||
@ -180,20 +187,22 @@ class TiDBVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
filter = kwargs.get('filter')
|
||||
filter = kwargs.get("filter")
|
||||
distance = 1 - score_threshold
|
||||
|
||||
query_vector_str = ", ".join(format(x) for x in query_vector)
|
||||
query_vector_str = "[" + query_vector_str + "]"
|
||||
logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
|
||||
logger.debug(
|
||||
f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}"
|
||||
)
|
||||
|
||||
docs = []
|
||||
if self._distance_func == 'l2':
|
||||
tidb_func = 'Vec_l2_distance'
|
||||
elif self._distance_func == 'cosine':
|
||||
tidb_func = 'Vec_Cosine_distance'
|
||||
if self._distance_func == "l2":
|
||||
tidb_func = "Vec_l2_distance"
|
||||
elif self._distance_func == "cosine":
|
||||
tidb_func = "Vec_Cosine_distance"
|
||||
else:
|
||||
tidb_func = 'Vec_Cosine_distance'
|
||||
tidb_func = "Vec_Cosine_distance"
|
||||
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(
|
||||
@ -208,7 +217,7 @@ class TiDBVector(BaseVector):
|
||||
results = [(row[0], row[1], row[2]) for row in res]
|
||||
for meta, text, distance in results:
|
||||
metadata = json.loads(meta)
|
||||
metadata['score'] = 1 - distance
|
||||
metadata["score"] = 1 - distance
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
@ -224,15 +233,13 @@ class TiDBVector(BaseVector):
|
||||
|
||||
class TiDBVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
|
||||
|
||||
return TiDBVector(
|
||||
collection_name=collection_name,
|
||||
|
||||
@ -7,7 +7,6 @@ from core.rag.models.document import Document
|
||||
|
||||
|
||||
class BaseVector(ABC):
|
||||
|
||||
def __init__(self, collection_name: str):
|
||||
self._collection_name = collection_name
|
||||
|
||||
@ -39,18 +38,11 @@ class BaseVector(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_vector(
|
||||
self,
|
||||
query_vector: list[float],
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search_by_full_text(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self) -> None:
|
||||
@ -58,7 +50,7 @@ class BaseVector(ABC):
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
doc_id = text.metadata['doc_id']
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
@ -66,7 +58,7 @@ class BaseVector(ABC):
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata['doc_id'] for text in texts]
|
||||
return [text.metadata["doc_id"] for text in texts]
|
||||
|
||||
@property
|
||||
def collection_name(self):
|
||||
|
||||
@ -20,17 +20,14 @@ class AbstractVectorFactory(ABC):
|
||||
|
||||
@staticmethod
|
||||
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
|
||||
index_struct_dict = {
|
||||
"type": vector_type,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
|
||||
return index_struct_dict
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list = None):
|
||||
if attributes is None:
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash', 'page']
|
||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "page"]
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
self._attributes = attributes
|
||||
@ -39,7 +36,7 @@ class Vector:
|
||||
def _init_vector(self) -> BaseVector:
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
vector_type = self._dataset.index_struct_dict["type"]
|
||||
|
||||
if not vector_type:
|
||||
raise ValueError("Vector store must be specified.")
|
||||
@ -52,45 +49,59 @@ class Vector:
|
||||
match vector_type:
|
||||
case VectorType.CHROMA:
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
|
||||
|
||||
return ChromaVectorFactory
|
||||
case VectorType.MILVUS:
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||
|
||||
return MilvusVectorFactory
|
||||
case VectorType.MYSCALE:
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
||||
|
||||
return MyScaleVectorFactory
|
||||
case VectorType.PGVECTOR:
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
||||
|
||||
return PGVectorFactory
|
||||
case VectorType.PGVECTO_RS:
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
|
||||
|
||||
return PGVectoRSFactory
|
||||
case VectorType.QDRANT:
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
|
||||
|
||||
return QdrantVectorFactory
|
||||
case VectorType.RELYT:
|
||||
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
|
||||
|
||||
return RelytVectorFactory
|
||||
case VectorType.ELASTICSEARCH:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
return ElasticSearchVectorFactory
|
||||
case VectorType.TIDB_VECTOR:
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||
|
||||
return TiDBVectorFactory
|
||||
case VectorType.WEAVIATE:
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
|
||||
|
||||
return WeaviateVectorFactory
|
||||
case VectorType.TENCENT:
|
||||
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
|
||||
|
||||
return TencentVectorFactory
|
||||
case VectorType.ORACLE:
|
||||
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
|
||||
|
||||
return OracleVectorFactory
|
||||
case VectorType.OPENSEARCH:
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
|
||||
|
||||
return OpenSearchVectorFactory
|
||||
case VectorType.ANALYTICDB:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
|
||||
return AnalyticdbVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
@ -98,22 +109,14 @@ class Vector:
|
||||
def create(self, texts: list = None, **kwargs):
|
||||
if texts:
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
|
||||
self._vector_processor.create(
|
||||
texts=texts,
|
||||
embeddings=embeddings,
|
||||
**kwargs
|
||||
)
|
||||
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
if kwargs.get('duplicate_check', False):
|
||||
if kwargs.get("duplicate_check", False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
|
||||
self._vector_processor.create(
|
||||
texts=documents,
|
||||
embeddings=embeddings,
|
||||
**kwargs
|
||||
)
|
||||
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self._vector_processor.text_exists(id)
|
||||
@ -124,24 +127,18 @@ class Vector:
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._vector_processor.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_vector = self._embeddings.embed_query(query)
|
||||
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._vector_processor.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._vector_processor.delete()
|
||||
# delete collection redis cache
|
||||
if self._vector_processor.collection_name:
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name)
|
||||
redis_client.delete(collection_exist_cache_key)
|
||||
|
||||
def _get_embeddings(self) -> Embeddings:
|
||||
@ -151,14 +148,13 @@ class Vector:
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model
|
||||
|
||||
model=self._dataset.embedding_model,
|
||||
)
|
||||
return CacheEmbedding(embedding_model)
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
doc_id = text.metadata['doc_id']
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
@ -2,17 +2,17 @@ from enum import Enum
|
||||
|
||||
|
||||
class VectorType(str, Enum):
|
||||
ANALYTICDB = 'analyticdb'
|
||||
CHROMA = 'chroma'
|
||||
MILVUS = 'milvus'
|
||||
MYSCALE = 'myscale'
|
||||
PGVECTOR = 'pgvector'
|
||||
PGVECTO_RS = 'pgvecto-rs'
|
||||
QDRANT = 'qdrant'
|
||||
RELYT = 'relyt'
|
||||
TIDB_VECTOR = 'tidb_vector'
|
||||
WEAVIATE = 'weaviate'
|
||||
OPENSEARCH = 'opensearch'
|
||||
TENCENT = 'tencent'
|
||||
ORACLE = 'oracle'
|
||||
ELASTICSEARCH = 'elasticsearch'
|
||||
ANALYTICDB = "analyticdb"
|
||||
CHROMA = "chroma"
|
||||
MILVUS = "milvus"
|
||||
MYSCALE = "myscale"
|
||||
PGVECTOR = "pgvector"
|
||||
PGVECTO_RS = "pgvecto-rs"
|
||||
QDRANT = "qdrant"
|
||||
RELYT = "relyt"
|
||||
TIDB_VECTOR = "tidb_vector"
|
||||
WEAVIATE = "weaviate"
|
||||
OPENSEARCH = "opensearch"
|
||||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
|
||||
@ -22,15 +22,14 @@ class WeaviateConfig(BaseModel):
|
||||
api_key: Optional[str] = None
|
||||
batch_size: int = 100
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
if not values["endpoint"]:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class WeaviateVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
|
||||
super().__init__(collection_name)
|
||||
self._client = self._init_client(config)
|
||||
@ -43,10 +42,7 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 60),
|
||||
startup_period=None
|
||||
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
@ -68,10 +64,10 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
def get_collection_name(self, dataset: Dataset) -> str:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
if not class_prefix.endswith("_Node"):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
class_prefix += "_Node"
|
||||
|
||||
return class_prefix
|
||||
|
||||
@ -79,10 +75,7 @@ class WeaviateVector(BaseVector):
|
||||
return Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self._collection_name}
|
||||
}
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
# create collection
|
||||
@ -91,9 +84,9 @@ class WeaviateVector(BaseVector):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def _create_collection(self):
|
||||
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
schema = self._default_schema(self._collection_name)
|
||||
@ -129,17 +122,9 @@ class WeaviateVector(BaseVector):
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal")
|
||||
|
||||
def delete(self):
|
||||
# check whether the index already exists
|
||||
@ -154,11 +139,19 @@ class WeaviateVector(BaseVector):
|
||||
# check whether the index already exists
|
||||
if not self._client.schema.contains(schema):
|
||||
return False
|
||||
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": id,
|
||||
}).with_limit(1).do()
|
||||
result = (
|
||||
self._client.query.get(collection_name)
|
||||
.with_additional(["id"])
|
||||
.with_where(
|
||||
{
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": id,
|
||||
}
|
||||
)
|
||||
.with_limit(1)
|
||||
.do()
|
||||
)
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
@ -211,13 +204,13 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
doc.metadata['score'] = score
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@ -240,15 +233,15 @@ class WeaviateVector(BaseVector):
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
query_obj = query_obj.with_additional(["vector"])
|
||||
properties = ['text']
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
|
||||
properties = ["text"]
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
additional = res.pop('_additional')
|
||||
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
|
||||
additional = res.pop("_additional")
|
||||
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
|
||||
return docs
|
||||
|
||||
def _default_schema(self, index_name: str) -> dict:
|
||||
@ -271,20 +264,19 @@ class WeaviateVector(BaseVector):
|
||||
class WeaviateVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
|
||||
return WeaviateVector(
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=dify_config.WEAVIATE_ENDPOINT,
|
||||
api_key=dify_config.WEAVIATE_API_KEY,
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
|
||||
),
|
||||
attributes=attributes
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
@ -12,10 +12,10 @@ from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
class DatasetDocumentStore:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
document_id: Optional[str] = None,
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
document_id: Optional[str] = None,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._user_id = user_id
|
||||
@ -41,9 +41,9 @@ class DatasetDocumentStore:
|
||||
|
||||
@property
|
||||
def docs(self) -> dict[str, Document]:
|
||||
document_segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id
|
||||
).all()
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all()
|
||||
)
|
||||
|
||||
output = {}
|
||||
for document_segment in document_segments:
|
||||
@ -55,48 +55,45 @@ class DatasetDocumentStore:
|
||||
"doc_hash": document_segment.index_node_hash,
|
||||
"document_id": document_segment.document_id,
|
||||
"dataset_id": document_segment.dataset_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def add_documents(
|
||||
self, docs: Sequence[Document], allow_update: bool = True
|
||||
) -> None:
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == self._document_id
|
||||
).scalar()
|
||||
def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None:
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.filter(DocumentSegment.document_id == self._document_id)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
if max_position is None:
|
||||
max_position = 0
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == 'high_quality':
|
||||
if self._dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model
|
||||
model=self._dataset.embedding_model,
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("doc must be a Document")
|
||||
|
||||
segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id'])
|
||||
segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"])
|
||||
|
||||
# NOTE: doc could already exist in the store, but we overwrite it
|
||||
if not allow_update and segment_document:
|
||||
raise ValueError(
|
||||
f"doc_id {doc.metadata['doc_id']} already exists. "
|
||||
"Set allow_update to True to overwrite."
|
||||
f"doc_id {doc.metadata['doc_id']} already exists. " "Set allow_update to True to overwrite."
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
if embedding_model:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[doc.page_content]
|
||||
)
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content])
|
||||
else:
|
||||
tokens = 0
|
||||
|
||||
@ -107,8 +104,8 @@ class DatasetDocumentStore:
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
document_id=self._document_id,
|
||||
index_node_id=doc.metadata['doc_id'],
|
||||
index_node_hash=doc.metadata['doc_hash'],
|
||||
index_node_id=doc.metadata["doc_id"],
|
||||
index_node_hash=doc.metadata["doc_hash"],
|
||||
position=max_position,
|
||||
content=doc.page_content,
|
||||
word_count=len(doc.page_content),
|
||||
@ -116,15 +113,15 @@ class DatasetDocumentStore:
|
||||
enabled=False,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
if doc.metadata.get('answer'):
|
||||
segment_document.answer = doc.metadata.pop('answer', '')
|
||||
if doc.metadata.get("answer"):
|
||||
segment_document.answer = doc.metadata.pop("answer", "")
|
||||
|
||||
db.session.add(segment_document)
|
||||
else:
|
||||
segment_document.content = doc.page_content
|
||||
if doc.metadata.get('answer'):
|
||||
segment_document.answer = doc.metadata.pop('answer', '')
|
||||
segment_document.index_node_hash = doc.metadata['doc_hash']
|
||||
if doc.metadata.get("answer"):
|
||||
segment_document.answer = doc.metadata.pop("answer", "")
|
||||
segment_document.index_node_hash = doc.metadata["doc_hash"]
|
||||
segment_document.word_count = len(doc.page_content)
|
||||
segment_document.tokens = tokens
|
||||
|
||||
@ -135,9 +132,7 @@ class DatasetDocumentStore:
|
||||
result = self.get_document_segment(doc_id)
|
||||
return result is not None
|
||||
|
||||
def get_document(
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
) -> Optional[Document]:
|
||||
def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]:
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
@ -153,7 +148,7 @@ class DatasetDocumentStore:
|
||||
"doc_hash": document_segment.index_node_hash,
|
||||
"document_id": document_segment.document_id,
|
||||
"dataset_id": document_segment.dataset_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
|
||||
@ -188,9 +183,10 @@ class DatasetDocumentStore:
|
||||
return document_segment.index_node_hash
|
||||
|
||||
def get_document_segment(self, doc_id: str) -> DocumentSegment:
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id,
|
||||
DocumentSegment.index_node_id == doc_id
|
||||
).first()
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return document_segment
|
||||
|
||||
@ -4,6 +4,7 @@ The goal is to facilitate decoupling of content loading from content parsing cod
|
||||
|
||||
In addition, content loading code should provide a lazy loading interface by default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import csv
|
||||
from typing import Optional
|
||||
|
||||
@ -18,12 +19,12 @@ class CSVExtractor(BaseExtractor):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
@ -57,7 +58,7 @@ class CSVExtractor(BaseExtractor):
|
||||
docs = []
|
||||
try:
|
||||
# load csv file into pandas dataframe
|
||||
df = pd.read_csv(csvfile, on_bad_lines='skip', **self.csv_args)
|
||||
df = pd.read_csv(csvfile, on_bad_lines="skip", **self.csv_args)
|
||||
|
||||
# check source column exists
|
||||
if self.source_column and self.source_column not in df.columns:
|
||||
@ -67,7 +68,7 @@ class CSVExtractor(BaseExtractor):
|
||||
|
||||
for i, row in df.iterrows():
|
||||
content = ";".join(f"{col.strip()}: {str(row[col]).strip()}" for col in df.columns)
|
||||
source = row[self.source_column] if self.source_column else ''
|
||||
source = row[self.source_column] if self.source_column else ""
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@ -10,6 +10,7 @@ class NotionInfo(BaseModel):
|
||||
"""
|
||||
Notion import info.
|
||||
"""
|
||||
|
||||
notion_workspace_id: str
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
@ -25,6 +26,7 @@ class WebsiteInfo(BaseModel):
|
||||
"""
|
||||
website import info.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
job_id: str
|
||||
url: str
|
||||
@ -43,6 +45,7 @@ class ExtractSetting(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
|
||||
datasource_type: str
|
||||
upload_file: Optional[UploadFile] = None
|
||||
notion_info: Optional[NotionInfo] = None
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
@ -17,23 +18,18 @@ class ExcelExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False
|
||||
):
|
||||
def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
""" Load from Excel file in xls or xlsx format using Pandas and openpyxl."""
|
||||
"""Load from Excel file in xls or xlsx format using Pandas and openpyxl."""
|
||||
documents = []
|
||||
file_extension = os.path.splitext(self._file_path)[-1].lower()
|
||||
|
||||
if file_extension == '.xlsx':
|
||||
if file_extension == ".xlsx":
|
||||
wb = load_workbook(self._file_path, data_only=True)
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
@ -44,35 +40,38 @@ class ExcelExtractor(BaseExtractor):
|
||||
continue
|
||||
df = pd.DataFrame(data, columns=cols)
|
||||
|
||||
df.dropna(how='all', inplace=True)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for index, row in df.iterrows():
|
||||
page_content = []
|
||||
for col_index, (k, v) in enumerate(row.items()):
|
||||
if pd.notna(v):
|
||||
cell = sheet.cell(row=index + 2,
|
||||
column=col_index + 1) # +2 to account for header and 1-based index
|
||||
cell = sheet.cell(
|
||||
row=index + 2, column=col_index + 1
|
||||
) # +2 to account for header and 1-based index
|
||||
if cell.hyperlink:
|
||||
value = f"[{v}]({cell.hyperlink.target})"
|
||||
page_content.append(f'"{k}":"{value}"')
|
||||
else:
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(Document(page_content=';'.join(page_content),
|
||||
metadata={'source': self._file_path}))
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
|
||||
elif file_extension == '.xls':
|
||||
excel_file = pd.ExcelFile(self._file_path, engine='xlrd')
|
||||
elif file_extension == ".xls":
|
||||
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
|
||||
for sheet_name in excel_file.sheet_names:
|
||||
df = excel_file.parse(sheet_name=sheet_name)
|
||||
df.dropna(how='all', inplace=True)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
page_content = []
|
||||
for k, v in row.items():
|
||||
if pd.notna(v):
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(Document(page_content=';'.join(page_content),
|
||||
metadata={'source': self._file_path}))
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file extension: {file_extension}")
|
||||
|
||||
|
||||
@ -29,61 +29,60 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain', 'application/json']
|
||||
SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"]
|
||||
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
|
||||
|
||||
class ExtractProcessor:
|
||||
@classmethod
|
||||
def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
|
||||
-> Union[list[Document], str]:
|
||||
def load_from_upload_file(
|
||||
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
|
||||
) -> Union[list[Document], str]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=upload_file,
|
||||
document_model='text_model'
|
||||
datasource_type="upload_file", upload_file=upload_file, document_model="text_model"
|
||||
)
|
||||
if return_text:
|
||||
delimiter = '\n'
|
||||
delimiter = "\n"
|
||||
return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
|
||||
else:
|
||||
return cls.extract(extract_setting, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
response = ssrf_proxy.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(url).suffix
|
||||
if not suffix and suffix != '.':
|
||||
if not suffix and suffix != ".":
|
||||
# get content-type
|
||||
if response.headers.get('Content-Type'):
|
||||
suffix = '.' + response.headers.get('Content-Type').split('/')[-1]
|
||||
if response.headers.get("Content-Type"):
|
||||
suffix = "." + response.headers.get("Content-Type").split("/")[-1]
|
||||
else:
|
||||
content_disposition = response.headers.get('Content-Disposition')
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
suffix = '.' + re.search(r'\.(\w+)$', filename).group(1)
|
||||
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
|
||||
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, 'wb') as file:
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(response.content)
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
document_model='text_model'
|
||||
)
|
||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
||||
if return_text:
|
||||
delimiter = '\n'
|
||||
return delimiter.join([document.page_content for document in cls.extract(
|
||||
extract_setting=extract_setting, file_path=file_path)])
|
||||
delimiter = "\n"
|
||||
return delimiter.join(
|
||||
[
|
||||
document.page_content
|
||||
for document in cls.extract(extract_setting=extract_setting, file_path=file_path)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return cls.extract(extract_setting=extract_setting, file_path=file_path)
|
||||
|
||||
@classmethod
|
||||
def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
|
||||
file_path: str = None) -> list[Document]:
|
||||
def extract(
|
||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str = None
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE.value:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if not file_path:
|
||||
@ -96,50 +95,56 @@ class ExtractProcessor:
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
|
||||
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx' or file_extension == '.xls':
|
||||
if etl_type == "Unstructured":
|
||||
if file_extension == ".xlsx" or file_extension == ".xls":
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
|
||||
elif file_extension in [".md", ".markdown"]:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
|
||||
if is_automatic
|
||||
else MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
)
|
||||
elif file_extension in [".htm", ".html"]:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
elif file_extension in [".docx"]:
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == '.csv':
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == '.msg':
|
||||
elif file_extension == ".msg":
|
||||
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.eml':
|
||||
elif file_extension == ".eml":
|
||||
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.ppt':
|
||||
elif file_extension == ".ppt":
|
||||
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == '.pptx':
|
||||
elif file_extension == ".pptx":
|
||||
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == '.xml':
|
||||
elif file_extension == ".xml":
|
||||
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
|
||||
elif file_extension == 'epub':
|
||||
elif file_extension == "epub":
|
||||
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
|
||||
extractor = (
|
||||
UnstructuredTextExtractor(file_path, unstructured_api_url)
|
||||
if is_automatic
|
||||
else TextExtractor(file_path, autodetect_encoding=True)
|
||||
)
|
||||
else:
|
||||
if file_extension == '.xlsx' or file_extension == '.xls':
|
||||
if file_extension == ".xlsx" or file_extension == ".xls":
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == '.pdf':
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
elif file_extension in [".md", ".markdown"]:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
elif file_extension in [".htm", ".html"]:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in ['.docx']:
|
||||
elif file_extension in [".docx"]:
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == '.csv':
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == 'epub':
|
||||
elif file_extension == "epub":
|
||||
extractor = UnstructuredEpubExtractor(file_path)
|
||||
else:
|
||||
# txt
|
||||
@ -155,13 +160,13 @@ class ExtractProcessor:
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
|
||||
if extract_setting.website_info.provider == 'firecrawl':
|
||||
if extract_setting.website_info.provider == "firecrawl":
|
||||
extractor = FirecrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
tenant_id=extract_setting.website_info.tenant_id,
|
||||
mode=extract_setting.website_info.mode,
|
||||
only_main_content=extract_setting.website_info.only_main_content
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
else:
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseExtractor(ABC):
|
||||
"""Interface for extract files.
|
||||
"""
|
||||
"""Interface for extract files."""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -9,108 +9,98 @@ from extensions.ext_storage import storage
|
||||
class FirecrawlApp:
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.firecrawl.dev'
|
||||
if self.api_key is None and self.base_url == 'https://api.firecrawl.dev':
|
||||
raise ValueError('No API key provided')
|
||||
self.base_url = base_url or "https://api.firecrawl.dev"
|
||||
if self.api_key is None and self.base_url == "https://api.firecrawl.dev":
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def scrape_url(self, url, params=None) -> dict:
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
json_data = {'url': url}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
json_data = {"url": url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = requests.post(
|
||||
f'{self.base_url}/v0/scrape',
|
||||
headers=headers,
|
||||
json=json_data
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
|
||||
if response.status_code == 200:
|
||||
response = response.json()
|
||||
if response['success'] == True:
|
||||
data = response['data']
|
||||
if response["success"] == True:
|
||||
data = response["data"]
|
||||
return {
|
||||
'title': data.get('metadata').get('title'),
|
||||
'description': data.get('metadata').get('description'),
|
||||
'source_url': data.get('metadata').get('sourceURL'),
|
||||
'markdown': data.get('markdown')
|
||||
"title": data.get("metadata").get("title"),
|
||||
"description": data.get("metadata").get("description"),
|
||||
"source_url": data.get("metadata").get("sourceURL"),
|
||||
"markdown": data.get("markdown"),
|
||||
}
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
|
||||
|
||||
elif response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}')
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
headers = self._prepare_headers()
|
||||
json_data = {'url': url}
|
||||
json_data = {"url": url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers)
|
||||
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
job_id = response.json().get('jobId')
|
||||
job_id = response.json().get("jobId")
|
||||
return job_id
|
||||
else:
|
||||
self._handle_error(response, 'start crawl job')
|
||||
self._handle_error(response, "start crawl job")
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers)
|
||||
response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers)
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get('status') == 'completed':
|
||||
total = crawl_status_response.get('total', 0)
|
||||
if crawl_status_response.get("status") == "completed":
|
||||
total = crawl_status_response.get("total", 0)
|
||||
if total == 0:
|
||||
raise Exception('Failed to check crawl status. Error: No page found')
|
||||
data = crawl_status_response.get('data', [])
|
||||
raise Exception("Failed to check crawl status. Error: No page found")
|
||||
data = crawl_status_response.get("data", [])
|
||||
url_data_list = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and 'metadata' in item and 'markdown' in item:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = {
|
||||
'title': item.get('metadata').get('title'),
|
||||
'description': item.get('metadata').get('description'),
|
||||
'source_url': item.get('metadata').get('sourceURL'),
|
||||
'markdown': item.get('markdown')
|
||||
"title": item.get("metadata").get("title"),
|
||||
"description": item.get("metadata").get("description"),
|
||||
"source_url": item.get("metadata").get("sourceURL"),
|
||||
"markdown": item.get("markdown"),
|
||||
}
|
||||
url_data_list.append(url_data)
|
||||
if url_data_list:
|
||||
file_key = 'website_files/' + job_id + '.txt'
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(url_data_list).encode('utf-8'))
|
||||
storage.save(file_key, json.dumps(url_data_list).encode("utf-8"))
|
||||
return {
|
||||
'status': 'completed',
|
||||
'total': crawl_status_response.get('total'),
|
||||
'current': crawl_status_response.get('current'),
|
||||
'data': url_data_list
|
||||
"status": "completed",
|
||||
"total": crawl_status_response.get("total"),
|
||||
"current": crawl_status_response.get("current"),
|
||||
"data": url_data_list,
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
'status': crawl_status_response.get('status'),
|
||||
'total': crawl_status_response.get('total'),
|
||||
'current': crawl_status_response.get('current'),
|
||||
'data': []
|
||||
"status": crawl_status_response.get("status"),
|
||||
"total": crawl_status_response.get("total"),
|
||||
"current": crawl_status_response.get("current"),
|
||||
"data": [],
|
||||
}
|
||||
|
||||
else:
|
||||
self._handle_error(response, 'check crawl status')
|
||||
self._handle_error(response, "check crawl status")
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
|
||||
for attempt in range(retries):
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2 ** attempt))
|
||||
time.sleep(backoff_factor * (2**attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
@ -119,13 +109,11 @@ class FirecrawlApp:
|
||||
for attempt in range(retries):
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2 ** attempt))
|
||||
time.sleep(backoff_factor * (2**attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}')
|
||||
|
||||
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
|
||||
|
||||
@ -5,7 +5,7 @@ from services.website_service import WebsiteService
|
||||
|
||||
class FirecrawlWebExtractor(BaseExtractor):
|
||||
"""
|
||||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
|
||||
|
||||
Args:
|
||||
@ -15,14 +15,7 @@ class FirecrawlWebExtractor(BaseExtractor):
|
||||
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
mode: str = 'crawl',
|
||||
only_main_content: bool = False
|
||||
):
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
@ -33,28 +26,31 @@ class FirecrawlWebExtractor(BaseExtractor):
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == 'crawl':
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id)
|
||||
if self.mode == "crawl":
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(page_content=crawl_data.get('markdown', ''),
|
||||
metadata={
|
||||
'source_url': crawl_data.get('source_url'),
|
||||
'description': crawl_data.get('description'),
|
||||
'title': crawl_data.get('title')
|
||||
}
|
||||
)
|
||||
document = Document(
|
||||
page_content=crawl_data.get("markdown", ""),
|
||||
metadata={
|
||||
"source_url": crawl_data.get("source_url"),
|
||||
"description": crawl_data.get("description"),
|
||||
"title": crawl_data.get("title"),
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
elif self.mode == 'scrape':
|
||||
scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id,
|
||||
self.only_main_content)
|
||||
elif self.mode == "scrape":
|
||||
scrape_data = WebsiteService.get_scrape_url_data(
|
||||
"firecrawl", self._url, self.tenant_id, self.only_main_content
|
||||
)
|
||||
|
||||
document = Document(page_content=scrape_data.get('markdown', ''),
|
||||
metadata={
|
||||
'source_url': scrape_data.get('source_url'),
|
||||
'description': scrape_data.get('description'),
|
||||
'title': scrape_data.get('title')
|
||||
}
|
||||
)
|
||||
document = Document(
|
||||
page_content=scrape_data.get("markdown", ""),
|
||||
metadata={
|
||||
"source_url": scrape_data.get("source_url"),
|
||||
"description": scrape_data.get("description"),
|
||||
"title": scrape_data.get("title"),
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
return documents
|
||||
|
||||
@ -37,9 +37,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
|
||||
try:
|
||||
encodings = future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Timeout reached while detecting encoding for {file_path}"
|
||||
)
|
||||
raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}")
|
||||
|
||||
if all(encoding["encoding"] is None for encoding in encodings):
|
||||
raise RuntimeError(f"Could not detect encoding for {file_path}")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@ -6,7 +7,6 @@ from core.rag.models.document import Document
|
||||
|
||||
|
||||
class HtmlExtractor(BaseExtractor):
|
||||
|
||||
"""
|
||||
Load html files.
|
||||
|
||||
@ -15,10 +15,7 @@ class HtmlExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str
|
||||
):
|
||||
def __init__(self, file_path: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
@ -27,8 +24,8 @@ class HtmlExtractor(BaseExtractor):
|
||||
|
||||
def _load_as_text(self) -> str:
|
||||
with open(self._file_path, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, 'html.parser')
|
||||
soup = BeautifulSoup(fp, "html.parser")
|
||||
text = soup.get_text()
|
||||
text = text.strip() if text else ''
|
||||
text = text.strip() if text else ""
|
||||
|
||||
return text
|
||||
return text
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import re
|
||||
from typing import Optional, cast
|
||||
|
||||
@ -16,12 +17,12 @@ class MarkdownExtractor(BaseExtractor):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = False,
|
||||
remove_images: bool = False,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = False,
|
||||
remove_images: bool = False,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
@ -78,13 +79,10 @@ class MarkdownExtractor(BaseExtractor):
|
||||
if current_header is not None:
|
||||
# pass linting, assert keys are defined
|
||||
markdown_tups = [
|
||||
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
|
||||
for key, value in markdown_tups
|
||||
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups
|
||||
]
|
||||
else:
|
||||
markdown_tups = [
|
||||
(key, re.sub("\n", "", value)) for key, value in markdown_tups
|
||||
]
|
||||
markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups]
|
||||
|
||||
return markdown_tups
|
||||
|
||||
|
||||
@ -21,22 +21,21 @@ RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
||||
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
||||
# if user want split by headings, use the corresponding splitter
|
||||
HEADING_SPLITTER = {
|
||||
'heading_1': '# ',
|
||||
'heading_2': '## ',
|
||||
'heading_3': '### ',
|
||||
"heading_1": "# ",
|
||||
"heading_2": "## ",
|
||||
"heading_3": "### ",
|
||||
}
|
||||
|
||||
|
||||
class NotionExtractor(BaseExtractor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notion_workspace_id: str,
|
||||
notion_obj_id: str,
|
||||
notion_page_type: str,
|
||||
tenant_id: str,
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None,
|
||||
|
||||
self,
|
||||
notion_workspace_id: str,
|
||||
notion_obj_id: str,
|
||||
notion_page_type: str,
|
||||
tenant_id: str,
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None,
|
||||
):
|
||||
self._notion_access_token = None
|
||||
self._document_model = document_model
|
||||
@ -46,46 +45,38 @@ class NotionExtractor(BaseExtractor):
|
||||
if notion_access_token:
|
||||
self._notion_access_token = notion_access_token
|
||||
else:
|
||||
self._notion_access_token = self._get_access_token(tenant_id,
|
||||
self._notion_workspace_id)
|
||||
self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id)
|
||||
if not self._notion_access_token:
|
||||
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment "
|
||||
"variable `NOTION_INTEGRATION_TOKEN`."
|
||||
"Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`."
|
||||
)
|
||||
|
||||
self._notion_access_token = integration_token
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
self.update_last_edited_time(
|
||||
self._document_model
|
||||
)
|
||||
self.update_last_edited_time(self._document_model)
|
||||
|
||||
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
|
||||
|
||||
return text_docs
|
||||
|
||||
def _load_data_as_documents(
|
||||
self, notion_obj_id: str, notion_page_type: str
|
||||
) -> list[Document]:
|
||||
def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> list[Document]:
|
||||
docs = []
|
||||
if notion_page_type == 'database':
|
||||
if notion_page_type == "database":
|
||||
# get all the pages in the database
|
||||
page_text_documents = self._get_notion_database_data(notion_obj_id)
|
||||
docs.extend(page_text_documents)
|
||||
elif notion_page_type == 'page':
|
||||
elif notion_page_type == "page":
|
||||
page_text_list = self._get_notion_block_data(notion_obj_id)
|
||||
docs.append(Document(page_content='\n'.join(page_text_list)))
|
||||
docs.append(Document(page_content="\n".join(page_text_list)))
|
||||
else:
|
||||
raise ValueError("notion page type not supported")
|
||||
|
||||
return docs
|
||||
|
||||
def _get_notion_database_data(
|
||||
self, database_id: str, query_dict: dict[str, Any] = {}
|
||||
) -> list[Document]:
|
||||
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
@ -100,50 +91,50 @@ class NotionExtractor(BaseExtractor):
|
||||
data = res.json()
|
||||
|
||||
database_content = []
|
||||
if 'results' not in data or data["results"] is None:
|
||||
if "results" not in data or data["results"] is None:
|
||||
return []
|
||||
for result in data["results"]:
|
||||
properties = result['properties']
|
||||
properties = result["properties"]
|
||||
data = {}
|
||||
for property_name, property_value in properties.items():
|
||||
type = property_value['type']
|
||||
if type == 'multi_select':
|
||||
type = property_value["type"]
|
||||
if type == "multi_select":
|
||||
value = []
|
||||
multi_select_list = property_value[type]
|
||||
for multi_select in multi_select_list:
|
||||
value.append(multi_select['name'])
|
||||
elif type == 'rich_text' or type == 'title':
|
||||
value.append(multi_select["name"])
|
||||
elif type == "rich_text" or type == "title":
|
||||
if len(property_value[type]) > 0:
|
||||
value = property_value[type][0]['plain_text']
|
||||
value = property_value[type][0]["plain_text"]
|
||||
else:
|
||||
value = ''
|
||||
elif type == 'select' or type == 'status':
|
||||
value = ""
|
||||
elif type == "select" or type == "status":
|
||||
if property_value[type]:
|
||||
value = property_value[type]['name']
|
||||
value = property_value[type]["name"]
|
||||
else:
|
||||
value = ''
|
||||
value = ""
|
||||
else:
|
||||
value = property_value[type]
|
||||
data[property_name] = value
|
||||
row_dict = {k: v for k, v in data.items() if v}
|
||||
row_content = ''
|
||||
row_content = ""
|
||||
for key, value in row_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value_dict = {k: v for k, v in value.items() if v}
|
||||
value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items())
|
||||
row_content = row_content + f'{key}:{value_content}\n'
|
||||
value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
|
||||
row_content = row_content + f"{key}:{value_content}\n"
|
||||
else:
|
||||
row_content = row_content + f'{key}:{value}\n'
|
||||
row_content = row_content + f"{key}:{value}\n"
|
||||
database_content.append(row_content)
|
||||
|
||||
return [Document(page_content='\n'.join(database_content))]
|
||||
return [Document(page_content="\n".join(database_content))]
|
||||
|
||||
def _get_notion_block_data(self, page_id: str) -> list[str]:
|
||||
result_lines_arr = []
|
||||
start_cursor = None
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id)
|
||||
while True:
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor}
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
|
||||
res = requests.request(
|
||||
"GET",
|
||||
block_url,
|
||||
@ -152,14 +143,14 @@ class NotionExtractor(BaseExtractor):
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
params=query_dict
|
||||
params=query_dict,
|
||||
)
|
||||
data = res.json()
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
cur_result_text_arr = []
|
||||
if result_type == 'table':
|
||||
if result_type == "table":
|
||||
result_block_id = result["id"]
|
||||
text = self._read_table_rows(result_block_id)
|
||||
text += "\n\n"
|
||||
@ -175,17 +166,15 @@ class NotionExtractor(BaseExtractor):
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
block_type = result["type"]
|
||||
if has_children and block_type != 'child_page':
|
||||
children_text = self._read_block(
|
||||
result_block_id, num_tabs=1
|
||||
)
|
||||
if has_children and block_type != "child_page":
|
||||
children_text = self._read_block(result_block_id, num_tabs=1)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
if result_type in HEADING_SPLITTER:
|
||||
result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}")
|
||||
else:
|
||||
result_lines_arr.append(cur_result_text + '\n\n')
|
||||
result_lines_arr.append(cur_result_text + "\n\n")
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
@ -199,7 +188,7 @@ class NotionExtractor(BaseExtractor):
|
||||
start_cursor = None
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
|
||||
while True:
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor}
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
@ -209,16 +198,16 @@ class NotionExtractor(BaseExtractor):
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
params=query_dict
|
||||
params=query_dict,
|
||||
)
|
||||
data = res.json()
|
||||
if 'results' not in data or data["results"] is None:
|
||||
if "results" not in data or data["results"] is None:
|
||||
break
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
cur_result_text_arr = []
|
||||
if result_type == 'table':
|
||||
if result_type == "table":
|
||||
result_block_id = result["id"]
|
||||
text = self._read_table_rows(result_block_id)
|
||||
result_lines_arr.append(text)
|
||||
@ -233,17 +222,15 @@ class NotionExtractor(BaseExtractor):
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
block_type = result["type"]
|
||||
if has_children and block_type != 'child_page':
|
||||
children_text = self._read_block(
|
||||
result_block_id, num_tabs=num_tabs + 1
|
||||
)
|
||||
if has_children and block_type != "child_page":
|
||||
children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
if result_type in HEADING_SPLITTER:
|
||||
result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}')
|
||||
result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}")
|
||||
else:
|
||||
result_lines_arr.append(cur_result_text + '\n\n')
|
||||
result_lines_arr.append(cur_result_text + "\n\n")
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
@ -260,7 +247,7 @@ class NotionExtractor(BaseExtractor):
|
||||
start_cursor = None
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
|
||||
while not done:
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor}
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
@ -270,28 +257,28 @@ class NotionExtractor(BaseExtractor):
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
params=query_dict
|
||||
params=query_dict,
|
||||
)
|
||||
data = res.json()
|
||||
# get table headers text
|
||||
table_header_cell_texts = []
|
||||
table_header_cells = data["results"][0]['table_row']['cells']
|
||||
table_header_cells = data["results"][0]["table_row"]["cells"]
|
||||
for table_header_cell in table_header_cells:
|
||||
if table_header_cell:
|
||||
for table_header_cell_text in table_header_cell:
|
||||
text = table_header_cell_text["text"]["content"]
|
||||
table_header_cell_texts.append(text)
|
||||
else:
|
||||
table_header_cell_texts.append('')
|
||||
table_header_cell_texts.append("")
|
||||
# Initialize Markdown table with headers
|
||||
markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n"
|
||||
markdown_table += "| " + " | ".join(['---'] * len(table_header_cell_texts)) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(table_header_cell_texts)) + " |\n"
|
||||
|
||||
# Process data to format each row in Markdown table format
|
||||
results = data["results"]
|
||||
for i in range(len(results) - 1):
|
||||
column_texts = []
|
||||
table_column_cells = data["results"][i + 1]['table_row']['cells']
|
||||
table_column_cells = data["results"][i + 1]["table_row"]["cells"]
|
||||
for j in range(len(table_column_cells)):
|
||||
if table_column_cells[j]:
|
||||
for table_column_cell_text in table_column_cells[j]:
|
||||
@ -315,10 +302,8 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
last_edited_time = self.get_notion_last_edited_time()
|
||||
data_source_info = document_model.data_source_info_dict
|
||||
data_source_info['last_edited_time'] = last_edited_time
|
||||
update_params = {
|
||||
DocumentModel.data_source_info: json.dumps(data_source_info)
|
||||
}
|
||||
data_source_info["last_edited_time"] = last_edited_time
|
||||
update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}
|
||||
|
||||
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
|
||||
db.session.commit()
|
||||
@ -326,7 +311,7 @@ class NotionExtractor(BaseExtractor):
|
||||
def get_notion_last_edited_time(self) -> str:
|
||||
obj_id = self._notion_obj_id
|
||||
page_type = self._notion_page_type
|
||||
if page_type == 'database':
|
||||
if page_type == "database":
|
||||
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
|
||||
else:
|
||||
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
|
||||
@ -341,7 +326,7 @@ class NotionExtractor(BaseExtractor):
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
},
|
||||
json=query_dict
|
||||
json=query_dict,
|
||||
)
|
||||
|
||||
data = res.json()
|
||||
@ -352,14 +337,16 @@ class NotionExtractor(BaseExtractor):
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
)
|
||||
).first()
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
|
||||
f'and notion workspace {notion_workspace_id}')
|
||||
raise Exception(
|
||||
f"No notion data source binding found for tenant {tenant_id} "
|
||||
f"and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return data_source_binding.access_token
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
|
||||
@ -16,21 +17,17 @@ class PdfExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
file_cache_key: Optional[str] = None
|
||||
):
|
||||
def __init__(self, file_path: str, file_cache_key: Optional[str] = None):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_key = ""
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
text = storage.load(self._file_cache_key).decode('utf-8')
|
||||
text = storage.load(self._file_cache_key).decode("utf-8")
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
@ -43,12 +40,12 @@ class PdfExtractor(BaseExtractor):
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
storage.save(plaintext_file_key, text.encode("utf-8"))
|
||||
|
||||
return documents
|
||||
|
||||
def load(
|
||||
self,
|
||||
self,
|
||||
) -> Iterator[Document]:
|
||||
"""Lazy load given path as pages."""
|
||||
blob = Blob.from_path(self._file_path)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@ -14,12 +15,7 @@ class TextExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False
|
||||
):
|
||||
def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
|
||||
@ -8,13 +8,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredWordExtractor(BaseExtractor):
|
||||
"""Loader that uses unstructured to load word documents.
|
||||
"""
|
||||
"""Loader that uses unstructured to load word documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
@ -24,9 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor):
|
||||
from unstructured.__version__ import __version__ as __unstructured_version__
|
||||
from unstructured.file_utils.filetype import FileType, detect_filetype
|
||||
|
||||
unstructured_version = tuple(
|
||||
int(x) for x in __unstructured_version__.split(".")
|
||||
)
|
||||
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
|
||||
# check the file extension
|
||||
try:
|
||||
import magic # noqa: F401
|
||||
@ -53,6 +50,7 @@ class UnstructuredWordExtractor(BaseExtractor):
|
||||
elements = partition_docx(filename=self._file_path)
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@ -26,6 +26,7 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
elements = partition_email(filename=self._file_path)
|
||||
|
||||
# noinspection PyBroadException
|
||||
@ -34,15 +35,16 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
||||
element_text = element.text.strip()
|
||||
|
||||
padding_needed = 4 - len(element_text) % 4
|
||||
element_text += '=' * padding_needed
|
||||
element_text += "=" * padding_needed
|
||||
|
||||
element_decode = base64.b64decode(element_text)
|
||||
soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser')
|
||||
soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser")
|
||||
element.text = soup.get_text()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@ -28,6 +28,7 @@ class UnstructuredEpubExtractor(BaseExtractor):
|
||||
|
||||
elements = partition_epub(filename=self._file_path, xml_keep_tags=True)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@ -38,6 +38,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
|
||||
|
||||
elements = partition_md(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@ -14,11 +14,7 @@ class UnstructuredMsgExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
@ -28,6 +24,7 @@ class UnstructuredMsgExtractor(BaseExtractor):
|
||||
|
||||
elements = partition_msg(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@ -14,12 +14,7 @@ class UnstructuredPPTExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
api_key: str
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
@ -14,11 +14,7 @@ class UnstructuredPPTXExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
@ -14,11 +14,7 @@ class UnstructuredTextExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
@ -28,6 +24,7 @@ class UnstructuredTextExtractor(BaseExtractor):
|
||||
|
||||
elements = partition_text(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@ -14,11 +14,7 @@ class UnstructuredXmlExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
@ -28,6 +24,7 @@ class UnstructuredXmlExtractor(BaseExtractor):
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import mimetypes
|
||||
@ -21,6 +22,7 @@ from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WordExtractor(BaseExtractor):
|
||||
"""Load docx files.
|
||||
|
||||
@ -43,9 +45,7 @@ class WordExtractor(BaseExtractor):
|
||||
r = requests.get(self.file_path)
|
||||
|
||||
if r.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Check the url of your file; returned status code {r.status_code}"
|
||||
)
|
||||
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
|
||||
|
||||
self.web_path = self.file_path
|
||||
self.temp_file = tempfile.NamedTemporaryFile()
|
||||
@ -60,11 +60,13 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
content = self.parse_docx(self.file_path, 'storage')
|
||||
return [Document(
|
||||
page_content=content,
|
||||
metadata={"source": self.file_path},
|
||||
)]
|
||||
content = self.parse_docx(self.file_path, "storage")
|
||||
return [
|
||||
Document(
|
||||
page_content=content,
|
||||
metadata={"source": self.file_path},
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_url(url: str) -> bool:
|
||||
@ -84,18 +86,18 @@ class WordExtractor(BaseExtractor):
|
||||
url = rel.reltype
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers['Content-Type'])
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
storage.save(file_key, response.content)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
image_ext = rel.target_ref.split('.')[-1]
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
|
||||
storage.save(file_key, rel.target_part.blob)
|
||||
@ -112,12 +114,14 @@ class WordExtractor(BaseExtractor):
|
||||
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
image_map[rel.target_part] = f""
|
||||
image_map[rel.target_part] = (
|
||||
f""
|
||||
)
|
||||
|
||||
return image_map
|
||||
|
||||
@ -167,8 +171,8 @@ class WordExtractor(BaseExtractor):
|
||||
def _parse_cell_paragraph(self, paragraph, image_map):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if run.element.xpath('.//a:blip'):
|
||||
for blip in run.element.xpath('.//a:blip'):
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if not image_id:
|
||||
continue
|
||||
@ -184,16 +188,16 @@ class WordExtractor(BaseExtractor):
|
||||
def _parse_paragraph(self, paragraph, image_map):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if run.element.xpath('.//a:blip'):
|
||||
for blip in run.element.xpath('.//a:blip'):
|
||||
embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if embed_id:
|
||||
rel_target = run.part.rels[embed_id].target_ref
|
||||
if rel_target in image_map:
|
||||
paragraph_content.append(image_map[rel_target])
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return ' '.join(paragraph_content) if paragraph_content else ''
|
||||
return " ".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
def parse_docx(self, docx_path, image_folder):
|
||||
doc = DocxDocument(docx_path)
|
||||
@ -204,60 +208,59 @@ class WordExtractor(BaseExtractor):
|
||||
image_map = self._extract_images_from_docx(doc, image_folder)
|
||||
|
||||
hyperlinks_url = None
|
||||
url_pattern = re.compile(r'http://[^\s+]+//|https://[^\s+]+')
|
||||
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
|
||||
for para in doc.paragraphs:
|
||||
for run in para.runs:
|
||||
if run.text and hyperlinks_url:
|
||||
result = f' [{run.text}]({hyperlinks_url}) '
|
||||
result = f" [{run.text}]({hyperlinks_url}) "
|
||||
run.text = result
|
||||
hyperlinks_url = None
|
||||
if 'HYPERLINK' in run.element.xml:
|
||||
if "HYPERLINK" in run.element.xml:
|
||||
try:
|
||||
xml = ET.XML(run.element.xml)
|
||||
x_child = [c for c in xml.iter() if c is not None]
|
||||
for x in x_child:
|
||||
if x_child is None:
|
||||
continue
|
||||
if x.tag.endswith('instrText'):
|
||||
if x.tag.endswith("instrText"):
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
|
||||
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if hasattr(run.element, 'tag') and isinstance(element.tag, str) and run.element.tag.endswith('r'):
|
||||
if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"):
|
||||
drawing_elements = run.element.findall(
|
||||
'.//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing')
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
|
||||
)
|
||||
for drawing in drawing_elements:
|
||||
blip_elements = drawing.findall(
|
||||
'.//{http://schemas.openxmlformats.org/drawingml/2006/main}blip')
|
||||
".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip"
|
||||
)
|
||||
for blip in blip_elements:
|
||||
embed_id = blip.get(
|
||||
'{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
|
||||
)
|
||||
if embed_id:
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return ''.join(paragraph_content) if paragraph_content else ''
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
paragraphs = doc.paragraphs.copy()
|
||||
tables = doc.tables.copy()
|
||||
for element in doc.element.body:
|
||||
if hasattr(element, 'tag'):
|
||||
if isinstance(element.tag, str) and element.tag.endswith('p'): # paragraph
|
||||
if hasattr(element, "tag"):
|
||||
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
|
||||
para = paragraphs.pop(0)
|
||||
parsed_paragraph = parse_paragraph(para)
|
||||
if parsed_paragraph:
|
||||
content.append(parsed_paragraph)
|
||||
elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table
|
||||
elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table
|
||||
table = tables.pop(0)
|
||||
content.append(self._table_to_markdown(table, image_map))
|
||||
return '\n'.join(content)
|
||||
|
||||
return "\n".join(content)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
@ -15,8 +16,7 @@ from models.dataset import Dataset, DatasetProcessRule
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
"""Interface for extract files.
|
||||
"""
|
||||
"""Interface for extract files."""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
@ -34,18 +34,24 @@ class BaseIndexProcessor(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_splitter(self, processing_rule: dict,
|
||||
embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
|
||||
def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
"""
|
||||
if processing_rule['mode'] == "custom":
|
||||
if processing_rule["mode"] == "custom":
|
||||
# The user-defined segmentation rule
|
||||
rules = processing_rule['rules']
|
||||
rules = processing_rule["rules"]
|
||||
segmentation = rules["segmentation"]
|
||||
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
|
||||
@ -53,22 +59,22 @@ class BaseIndexProcessor(ABC):
|
||||
|
||||
separator = segmentation["separator"]
|
||||
if separator:
|
||||
separator = separator.replace('\\n', '\n')
|
||||
separator = separator.replace("\\n", "\n")
|
||||
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=segmentation["max_tokens"],
|
||||
chunk_overlap=segmentation.get('chunk_overlap', 0) or 0,
|
||||
chunk_overlap=segmentation.get("chunk_overlap", 0) or 0,
|
||||
fixed_separator=separator,
|
||||
separators=["\n\n", "。", ". ", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
)
|
||||
else:
|
||||
# Automatic segmentation
|
||||
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
||||
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
|
||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"],
|
||||
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"],
|
||||
separators=["\n\n", "。", ". ", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
)
|
||||
|
||||
return character_splitter
|
||||
|
||||
@ -7,8 +7,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess
|
||||
|
||||
|
||||
class IndexProcessorFactory:
|
||||
"""IndexProcessorInit.
|
||||
"""
|
||||
"""IndexProcessorInit."""
|
||||
|
||||
def __init__(self, index_type: str):
|
||||
self._index_type = index_type
|
||||
@ -22,7 +21,6 @@ class IndexProcessorFactory:
|
||||
if self._index_type == IndexType.PARAGRAPH_INDEX.value:
|
||||
return ParagraphIndexProcessor()
|
||||
elif self._index_type == IndexType.QA_INDEX.value:
|
||||
|
||||
return QAIndexProcessor()
|
||||
else:
|
||||
raise ValueError(f"Index type {self._index_type} is not supported.")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
@ -15,33 +16,32 @@ from models.dataset import Dataset
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
|
||||
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||
is_automatic=kwargs.get('process_rule_mode') == "automatic")
|
||||
text_docs = ExtractProcessor.extract(
|
||||
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic"
|
||||
)
|
||||
|
||||
return text_docs
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
# Split the text documents into nodes.
|
||||
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||
embedding_model_instance=kwargs.get('embedding_model_instance'))
|
||||
splitter = self._get_splitter(
|
||||
processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
|
||||
)
|
||||
all_documents = []
|
||||
for document in documents:
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
|
||||
document.page_content = document_text
|
||||
# parse document to nodes
|
||||
document_nodes = splitter.split_documents([document])
|
||||
split_documents = []
|
||||
for document_node in document_nodes:
|
||||
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
@ -55,7 +55,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
return all_documents
|
||||
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if with_keywords:
|
||||
@ -63,7 +63,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword.create(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
@ -76,17 +76,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
keyword.delete()
|
||||
|
||||
def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(retrieval_method=retrieval_method, dataset_id=dataset.id, query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata['score'] = result.score
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
@ -23,33 +24,33 @@ from models.dataset import Dataset
|
||||
|
||||
class QAIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
|
||||
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||
is_automatic=kwargs.get('process_rule_mode') == "automatic")
|
||||
text_docs = ExtractProcessor.extract(
|
||||
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic"
|
||||
)
|
||||
return text_docs
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||
embedding_model_instance=kwargs.get('embedding_model_instance'))
|
||||
splitter = self._get_splitter(
|
||||
processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
|
||||
)
|
||||
|
||||
# Split the text documents into nodes.
|
||||
all_documents = []
|
||||
all_qa_documents = []
|
||||
for document in documents:
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
|
||||
document.page_content = document_text
|
||||
|
||||
# parse document to nodes
|
||||
document_nodes = splitter.split_documents([document])
|
||||
split_documents = []
|
||||
for document_node in document_nodes:
|
||||
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
@ -61,14 +62,18 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
all_documents.extend(split_documents)
|
||||
for i in range(0, len(all_documents), 10):
|
||||
threads = []
|
||||
sub_documents = all_documents[i:i + 10]
|
||||
sub_documents = all_documents[i : i + 10]
|
||||
for doc in sub_documents:
|
||||
document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'tenant_id': kwargs.get('tenant_id'),
|
||||
'document_node': doc,
|
||||
'all_qa_documents': all_qa_documents,
|
||||
'document_language': kwargs.get('doc_language', 'English')})
|
||||
document_format_thread = threading.Thread(
|
||||
target=self._format_qa_document,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"tenant_id": kwargs.get("tenant_id"),
|
||||
"document_node": doc,
|
||||
"all_qa_documents": all_qa_documents,
|
||||
"document_language": kwargs.get("doc_language", "English"),
|
||||
},
|
||||
)
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
@ -76,9 +81,8 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
return all_qa_documents
|
||||
|
||||
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
|
||||
|
||||
# check file type
|
||||
if not file.filename.endswith('.csv'):
|
||||
if not file.filename.endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
@ -86,7 +90,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
df = pd.read_csv(file)
|
||||
text_docs = []
|
||||
for index, row in df.iterrows():
|
||||
data = Document(page_content=row[0], metadata={'answer': row[1]})
|
||||
data = Document(page_content=row[0], metadata={"answer": row[1]})
|
||||
text_docs.append(data)
|
||||
if len(text_docs) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
@ -96,7 +100,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
return text_docs
|
||||
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
|
||||
@ -107,17 +111,29 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
vector.delete()
|
||||
|
||||
def retrieve(self, retrieval_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict):
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_model: dict,
|
||||
):
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(retrieval_method=retrieval_method, dataset_id=dataset.id, query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata['score'] = result.score
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
@ -134,12 +150,12 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
document_qa_list = self._format_split_text(response)
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
|
||||
qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy())
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result['question'])
|
||||
qa_document.metadata['answer'] = result['answer']
|
||||
qa_document.metadata['doc_id'] = doc_id
|
||||
qa_document.metadata['doc_hash'] = hash
|
||||
hash = helper.generate_text_hash(result["question"])
|
||||
qa_document.metadata["answer"] = result["answer"]
|
||||
qa_document.metadata["doc_id"] = doc_id
|
||||
qa_document.metadata["doc_hash"] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
@ -151,10 +167,4 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
|
||||
matches = re.findall(regex, text, re.UNICODE)
|
||||
|
||||
return [
|
||||
{
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
}
|
||||
for q, a in matches if q and a
|
||||
]
|
||||
return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a]
|
||||
|
||||
@ -55,9 +55,7 @@ class BaseDocumentTransformer(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform a list of documents.
|
||||
|
||||
Args:
|
||||
@ -68,9 +66,7 @@ class BaseDocumentTransformer(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Asynchronously transform a list of documents.
|
||||
|
||||
Args:
|
||||
|
||||
@ -2,7 +2,5 @@ from enum import Enum
|
||||
|
||||
|
||||
class RerankMode(Enum):
|
||||
|
||||
RERANKING_MODEL = 'reranking_model'
|
||||
WEIGHTED_SCORE = 'weighted_score'
|
||||
|
||||
RERANKING_MODEL = "reranking_model"
|
||||
WEIGHTED_SCORE = "weighted_score"
|
||||
|
||||
@ -8,8 +8,14 @@ class RerankModelRunner:
|
||||
def __init__(self, rerank_model_instance: ModelInstance) -> None:
|
||||
self.rerank_model_instance = rerank_model_instance
|
||||
|
||||
def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[Document],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Run rerank model
|
||||
:param query: search query
|
||||
@ -23,19 +29,15 @@ class RerankModelRunner:
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.metadata['doc_id'] not in doc_id:
|
||||
doc_id.append(document.metadata['doc_id'])
|
||||
if document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.append(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
rerank_result = self.rerank_model_instance.invoke_rerank(
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user
|
||||
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
@ -45,12 +47,12 @@ class RerankModelRunner:
|
||||
rerank_document = Document(
|
||||
page_content=result.text,
|
||||
metadata={
|
||||
"doc_id": documents[result.index].metadata['doc_id'],
|
||||
"doc_hash": documents[result.index].metadata['doc_hash'],
|
||||
"document_id": documents[result.index].metadata['document_id'],
|
||||
"dataset_id": documents[result.index].metadata['dataset_id'],
|
||||
'score': result.score
|
||||
}
|
||||
"doc_id": documents[result.index].metadata["doc_id"],
|
||||
"doc_hash": documents[result.index].metadata["doc_hash"],
|
||||
"document_id": documents[result.index].metadata["document_id"],
|
||||
"dataset_id": documents[result.index].metadata["dataset_id"],
|
||||
"score": result.score,
|
||||
},
|
||||
)
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
|
||||
@ -13,13 +13,18 @@ from core.rag.rerank.entity.weight import VectorSetting, Weights
|
||||
|
||||
|
||||
class WeightRerankRunner:
|
||||
|
||||
def __init__(self, tenant_id: str, weights: Weights) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.weights = weights
|
||||
|
||||
def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[Document],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Run rerank model
|
||||
:param query: search query
|
||||
@ -34,8 +39,8 @@ class WeightRerankRunner:
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.metadata['doc_id'] not in doc_id:
|
||||
doc_id.append(document.metadata['doc_id'])
|
||||
if document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.append(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
@ -47,13 +52,15 @@ class WeightRerankRunner:
|
||||
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
|
||||
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
|
||||
# format document
|
||||
score = self.weights.vector_setting.vector_weight * query_vector_score + \
|
||||
self.weights.keyword_setting.keyword_weight * query_score
|
||||
score = (
|
||||
self.weights.vector_setting.vector_weight * query_vector_score
|
||||
+ self.weights.keyword_setting.keyword_weight * query_score
|
||||
)
|
||||
if score_threshold and score < score_threshold:
|
||||
continue
|
||||
document.metadata['score'] = score
|
||||
document.metadata["score"] = score
|
||||
rerank_documents.append(document)
|
||||
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True)
|
||||
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return rerank_documents[:top_n] if top_n else rerank_documents
|
||||
|
||||
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
|
||||
@ -70,7 +77,7 @@ class WeightRerankRunner:
|
||||
for document in documents:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata['keywords'] = document_keywords
|
||||
document.metadata["keywords"] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
|
||||
# Counter query keywords(TF)
|
||||
@ -132,8 +139,9 @@ class WeightRerankRunner:
|
||||
|
||||
return similarities
|
||||
|
||||
def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document],
|
||||
vector_setting: VectorSetting) -> list[float]:
|
||||
def _calculate_cosine(
|
||||
self, tenant_id: str, query: str, documents: list[Document], vector_setting: VectorSetting
|
||||
) -> list[float]:
|
||||
"""
|
||||
Calculate Cosine scores
|
||||
:param query: search query
|
||||
@ -149,15 +157,14 @@ class WeightRerankRunner:
|
||||
tenant_id=tenant_id,
|
||||
provider=vector_setting.embedding_provider_name,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=vector_setting.embedding_model_name
|
||||
|
||||
model=vector_setting.embedding_model_name,
|
||||
)
|
||||
cache_embedding = CacheEmbedding(embedding_model)
|
||||
query_vector = cache_embedding.embed_query(query)
|
||||
for document in documents:
|
||||
# calculate cosine similarity
|
||||
if 'score' in document.metadata:
|
||||
query_vector_scores.append(document.metadata['score'])
|
||||
if "score" in document.metadata:
|
||||
query_vector_scores.append(document.metadata["score"])
|
||||
else:
|
||||
# transform to NumPy
|
||||
vec1 = np.array(query_vector)
|
||||
|
||||
@ -32,14 +32,11 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
@ -48,15 +45,18 @@ class DatasetRetrieval:
|
||||
self.application_generate_entity = application_generate_entity
|
||||
|
||||
def retrieve(
|
||||
self, app_id: str, user_id: str, tenant_id: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
invoke_from: InvokeFrom,
|
||||
show_retrieve_source: bool,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
message_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
self,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: DatasetEntity,
|
||||
query: str,
|
||||
invoke_from: InvokeFrom,
|
||||
show_retrieve_source: bool,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
message_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dataset.
|
||||
@ -84,16 +84,12 @@ class DatasetRetrieval:
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
|
||||
)
|
||||
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials
|
||||
model=model_config.model, credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
@ -102,39 +98,46 @@ class DatasetRetrieval:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
all_documents = []
|
||||
user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
|
||||
user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
all_documents = self.single_retrieve(
|
||||
app_id, tenant_id, user_id, user_from, available_datasets, query,
|
||||
app_id,
|
||||
tenant_id,
|
||||
user_id,
|
||||
user_from,
|
||||
available_datasets,
|
||||
query,
|
||||
model_instance,
|
||||
model_config, planning_strategy, message_id
|
||||
model_config,
|
||||
planning_strategy,
|
||||
message_id,
|
||||
)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
all_documents = self.multiple_retrieve(
|
||||
app_id, tenant_id, user_id, user_from,
|
||||
available_datasets, query, retrieve_config.top_k,
|
||||
app_id,
|
||||
tenant_id,
|
||||
user_id,
|
||||
user_from,
|
||||
available_datasets,
|
||||
query,
|
||||
retrieve_config.top_k,
|
||||
retrieve_config.score_threshold,
|
||||
retrieve_config.rerank_mode,
|
||||
retrieve_config.reranking_model,
|
||||
@ -145,89 +148,89 @@ class DatasetRetrieval:
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if item.metadata.get('score'):
|
||||
document_score_list[item.metadata['doc_id']] = item.metadata['score']
|
||||
if item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
sorted_segments = sorted(segments,
|
||||
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
||||
float('inf')))
|
||||
sorted_segments = sorted(
|
||||
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
|
||||
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
|
||||
else:
|
||||
document_context_list.append(segment.get_sign_content())
|
||||
if show_retrieve_source:
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=segment.dataset_id
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
'position': resource_number,
|
||||
'dataset_id': dataset.id,
|
||||
'dataset_name': dataset.name,
|
||||
'document_id': document.id,
|
||||
'document_name': document.name,
|
||||
'data_source_type': document.data_source_type,
|
||||
'segment_id': segment.id,
|
||||
'retriever_from': invoke_from.to_source(),
|
||||
'score': document_score_list.get(segment.index_node_id, None)
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
}
|
||||
|
||||
if invoke_from.to_source() == 'dev':
|
||||
source['hit_count'] = segment.hit_count
|
||||
source['word_count'] = segment.word_count
|
||||
source['segment_position'] = segment.position
|
||||
source['index_node_hash'] = segment.index_node_hash
|
||||
if invoke_from.to_source() == "dev":
|
||||
source["hit_count"] = segment.hit_count
|
||||
source["word_count"] = segment.word_count
|
||||
source["segment_position"] = segment.position
|
||||
source["index_node_hash"] = segment.index_node_hash
|
||||
if segment.answer:
|
||||
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source['content'] = segment.content
|
||||
source["content"] = segment.content
|
||||
context_list.append(source)
|
||||
resource_number += 1
|
||||
if hit_callback:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
return ''
|
||||
return ""
|
||||
|
||||
def single_retrieve(
|
||||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: Optional[str] = None,
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
tools = []
|
||||
for dataset in available_datasets:
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
description = "useful for when you want to answer queries about the " + dataset.name
|
||||
|
||||
description = description.replace('\n', '').replace('\r', '')
|
||||
description = description.replace("\n", "").replace("\r", "")
|
||||
message_tool = PromptMessageTool(
|
||||
name=dataset.id,
|
||||
description=description,
|
||||
@ -235,14 +238,15 @@ class DatasetRetrieval:
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
tools.append(message_tool)
|
||||
dataset_id = None
|
||||
if planning_strategy == PlanningStrategy.REACT_ROUTER:
|
||||
react_multi_dataset_router = ReactMultiDatasetRouter()
|
||||
dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance,
|
||||
user_id, tenant_id)
|
||||
dataset_id = react_multi_dataset_router.invoke(
|
||||
query, tools, model_config, model_instance, user_id, tenant_id
|
||||
)
|
||||
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
function_call_router = FunctionCallMultiDatasetRouter()
|
||||
@ -250,37 +254,37 @@ class DatasetRetrieval:
|
||||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
retrieval_method = 'keyword_search'
|
||||
retrieval_method = "keyword_search"
|
||||
else:
|
||||
retrieval_method = retrieval_model_config['search_method']
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
# get reranking model
|
||||
reranking_model = retrieval_model_config['reranking_model'] \
|
||||
if retrieval_model_config['reranking_enable'] else None
|
||||
reranking_model = (
|
||||
retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None
|
||||
)
|
||||
# get score threshold
|
||||
score_threshold = .0
|
||||
score_threshold = 0.0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
with measure_time() as timer:
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method, dataset_id=dataset.id,
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
reranking_model=reranking_model,
|
||||
reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'),
|
||||
weights=retrieval_model_config.get('weights', None),
|
||||
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
|
||||
weights=retrieval_model_config.get("weights", None),
|
||||
)
|
||||
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
||||
|
||||
@ -291,20 +295,20 @@ class DatasetRetrieval:
|
||||
return []
|
||||
|
||||
def multiple_retrieve(
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
threads = []
|
||||
all_documents = []
|
||||
@ -312,13 +316,16 @@ class DatasetRetrieval:
|
||||
index_type = None
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset.id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'all_documents': all_documents,
|
||||
})
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
@ -327,16 +334,10 @@ class DatasetRetrieval:
|
||||
with measure_time() as timer:
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(
|
||||
tenant_id, reranking_mode,
|
||||
reranking_model, weights, False
|
||||
)
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k
|
||||
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
|
||||
)
|
||||
else:
|
||||
if index_type == "economy":
|
||||
@ -357,30 +358,26 @@ class DatasetRetrieval:
|
||||
"""Handle retrieval end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
trace_manager: TraceQueueManager = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
message_id=message_id,
|
||||
documents=documents,
|
||||
timer=timer
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
@ -395,10 +392,10 @@ class DatasetRetrieval:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source="app",
|
||||
source_app_id=app_id,
|
||||
created_by_role=user_from,
|
||||
created_by=user_id
|
||||
created_by=user_id,
|
||||
)
|
||||
dataset_queries.append(dataset_query)
|
||||
if dataset_queries:
|
||||
@ -407,9 +404,7 @@ class DatasetRetrieval:
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
@ -419,38 +414,42 @@ class DatasetRetrieval:
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrieval_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k
|
||||
)
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
else:
|
||||
if top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model.get('reranking_model', None)
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
reranking_mode=retrieval_model.get('reranking_mode')
|
||||
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||
weights=retrieval_model.get('weights', None),
|
||||
)
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else None,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode")
|
||||
if retrieval_model.get("reranking_mode")
|
||||
else "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
|
||||
def to_dataset_retriever_tool(self, tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler) \
|
||||
-> Optional[list[DatasetRetrieverBaseTool]]:
|
||||
def to_dataset_retriever_tool(
|
||||
self,
|
||||
tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
) -> Optional[list[DatasetRetrieverBaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tenant_id: tenant id
|
||||
@ -464,18 +463,14 @@ class DatasetRetrieval:
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
@ -483,22 +478,18 @@ class DatasetRetrieval:
|
||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# get retrieval model config
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
for dataset in available_datasets:
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
|
||||
# get score threshold
|
||||
score_threshold = None
|
||||
@ -512,7 +503,7 @@ class DatasetRetrieval:
|
||||
score_threshold=score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source()
|
||||
retriever_from=invoke_from.to_source(),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
@ -525,8 +516,8 @@ class DatasetRetrieval:
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
|
||||
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
@ -547,7 +538,7 @@ class DatasetRetrieval:
|
||||
for document in documents:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata['keywords'] = document_keywords
|
||||
document.metadata["keywords"] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
|
||||
# Counter query keywords(TF)
|
||||
@ -606,21 +597,19 @@ class DatasetRetrieval:
|
||||
|
||||
for document, score in zip(documents, similarities):
|
||||
# format document
|
||||
document.metadata['score'] = score
|
||||
documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
|
||||
document.metadata["score"] = score
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents[:top_k] if top_k else documents
|
||||
|
||||
def calculate_vector_score(self, all_documents: list[Document],
|
||||
top_k: int, score_threshold: float) -> list[Document]:
|
||||
def calculate_vector_score(
|
||||
self, all_documents: list[Document], top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
filter_documents = []
|
||||
for document in all_documents:
|
||||
if score_threshold is None or document.metadata['score'] >= score_threshold:
|
||||
if score_threshold is None or document.metadata["score"] >= score_threshold:
|
||||
filter_documents.append(document)
|
||||
|
||||
if not filter_documents:
|
||||
return []
|
||||
filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
|
||||
filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return filter_documents[:top_k] if top_k else filter_documents
|
||||
|
||||
|
||||
|
||||
|
||||
@ -16,9 +16,7 @@ class StructuredChatOutputParser:
|
||||
if response["action"] == "Final Answer":
|
||||
return ReactFinish({"output": response["action_input"]}, text)
|
||||
else:
|
||||
return ReactAction(
|
||||
response["action"], response.get("action_input", {}), text
|
||||
)
|
||||
return ReactAction(response["action"], response.get("action_input", {}), text)
|
||||
else:
|
||||
return ReactFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
|
||||
@ -2,9 +2,9 @@ from enum import Enum
|
||||
|
||||
|
||||
class RetrievalMethod(Enum):
|
||||
SEMANTIC_SEARCH = 'semantic_search'
|
||||
FULL_TEXT_SEARCH = 'full_text_search'
|
||||
HYBRID_SEARCH = 'hybrid_search'
|
||||
SEMANTIC_SEARCH = "semantic_search"
|
||||
FULL_TEXT_SEARCH = "full_text_search"
|
||||
HYBRID_SEARCH = "hybrid_search"
|
||||
|
||||
@staticmethod
|
||||
def is_support_semantic_search(retrieval_method: str) -> bool:
|
||||
|
||||
@ -6,14 +6,12 @@ from core.model_runtime.entities.message_entities import PromptMessageTool, Syst
|
||||
|
||||
|
||||
class FunctionCallMultiDatasetRouter:
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
query: str,
|
||||
dataset_tools: list[PromptMessageTool],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
|
||||
self,
|
||||
query: str,
|
||||
dataset_tools: list[PromptMessageTool],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
) -> Union[str, None]:
|
||||
"""Given input, decided what to do.
|
||||
Returns:
|
||||
@ -26,22 +24,18 @@ class FunctionCallMultiDatasetRouter:
|
||||
|
||||
try:
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content='You are a helpful AI assistant.'),
|
||||
UserPromptMessage(content=query)
|
||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
if result.message.tool_calls:
|
||||
# get retrieval model config
|
||||
return result.message.tool_calls[0].function.name
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
return None
|
||||
|
||||
@ -50,16 +50,14 @@ Action:
|
||||
|
||||
|
||||
class ReactMultiDatasetRouter:
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
query: str,
|
||||
dataset_tools: list[PromptMessageTool],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
|
||||
self,
|
||||
query: str,
|
||||
dataset_tools: list[PromptMessageTool],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
) -> Union[str, None]:
|
||||
"""Given input, decided what to do.
|
||||
Returns:
|
||||
@ -71,23 +69,28 @@ class ReactMultiDatasetRouter:
|
||||
return dataset_tools[0].name
|
||||
|
||||
try:
|
||||
return self._react_invoke(query=query, model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
tools=dataset_tools, user_id=user_id, tenant_id=tenant_id)
|
||||
return self._react_invoke(
|
||||
query=query,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
tools=dataset_tools,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def _react_invoke(
|
||||
self,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
self,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> Union[str, None]:
|
||||
if model_config.mode == "chat":
|
||||
prompt = self.create_chat_prompt(
|
||||
@ -103,18 +106,18 @@ class ReactMultiDatasetRouter:
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
)
|
||||
stop = ['Observation:']
|
||||
stop = ["Observation:"]
|
||||
# handle invoke result
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt,
|
||||
inputs={},
|
||||
query='',
|
||||
query="",
|
||||
files=[],
|
||||
context='',
|
||||
context="",
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
result_text, usage = self._invoke_llm(
|
||||
completion_param=model_config.parameters,
|
||||
@ -122,7 +125,7 @@ class ReactMultiDatasetRouter:
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
output_parser = StructuredChatOutputParser()
|
||||
react_decision = output_parser.parse(result_text)
|
||||
@ -130,17 +133,21 @@ class ReactMultiDatasetRouter:
|
||||
return react_decision.tool
|
||||
return None
|
||||
|
||||
def _invoke_llm(self, completion_param: dict,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str], user_id: str, tenant_id: str
|
||||
) -> tuple[str, LLMUsage]:
|
||||
def _invoke_llm(
|
||||
self,
|
||||
completion_param: dict,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
Invoke large language model
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
@ -151,9 +158,7 @@ class ReactMultiDatasetRouter:
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
@ -168,7 +173,7 @@ class ReactMultiDatasetRouter:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
full_text = ''
|
||||
full_text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
@ -189,40 +194,35 @@ class ReactMultiDatasetRouter:
|
||||
return full_text, usage
|
||||
|
||||
def create_chat_prompt(
|
||||
self,
|
||||
query: str,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
self,
|
||||
query: str,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> list[ChatModelMessage]:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
tool_strings.append(
|
||||
f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}")
|
||||
f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}"
|
||||
)
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
unique_tool_names = {tool.name for tool in tools}
|
||||
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
prompt_messages = []
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=template
|
||||
)
|
||||
system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=template)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=query
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=query)
|
||||
prompt_messages.append(user_prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
def create_completion_prompt(
|
||||
self,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
self,
|
||||
tools: Sequence[PromptMessageTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> CompletionModelPromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Functionality for splitting text."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
@ -18,31 +19,29 @@ from core.rag.splitter.text_splitter import (
|
||||
|
||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
"""
|
||||
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
|
||||
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_encoder(
|
||||
cls: type[TS],
|
||||
embedding_model_instance: Optional[ModelInstance],
|
||||
allowed_special: Union[Literal[all], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal[all], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
cls: type[TS],
|
||||
embedding_model_instance: Optional[ModelInstance],
|
||||
allowed_special: Union[Literal[all], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal[all], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
def _token_encoder(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
if embedding_model_instance:
|
||||
return embedding_model_instance.get_text_embedding_num_tokens(
|
||||
texts=[text]
|
||||
)
|
||||
return embedding_model_instance.get_text_embedding_num_tokens(texts=[text])
|
||||
else:
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',
|
||||
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
|
||||
@ -22,9 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
TS = TypeVar("TS", bound="TextSplitter")
|
||||
|
||||
|
||||
def _split_text_with_regex(
|
||||
text: str, separator: str, keep_separator: bool
|
||||
) -> list[str]:
|
||||
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if keep_separator:
|
||||
@ -37,19 +35,19 @@ def _split_text_with_regex(
|
||||
splits = re.split(separator, text)
|
||||
else:
|
||||
splits = list(text)
|
||||
return [s for s in splits if (s != "" and s != '\n')]
|
||||
return [s for s in splits if (s != "" and s != "\n")]
|
||||
|
||||
|
||||
class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""Interface for splitting text into chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter.
|
||||
|
||||
@ -62,8 +60,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||
f"({chunk_size}), should be smaller."
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller."
|
||||
)
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
@ -75,9 +72,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split text into multiple components."""
|
||||
|
||||
def create_documents(
|
||||
self, texts: list[str], metadatas: Optional[list[dict]] = None
|
||||
) -> list[Document]:
|
||||
def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
documents = []
|
||||
@ -119,14 +114,10 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
index = 0
|
||||
for d in splits:
|
||||
_len = lengths[index]
|
||||
if (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> self._chunk_size
|
||||
):
|
||||
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
|
||||
if total > self._chunk_size:
|
||||
logger.warning(
|
||||
f"Created a chunk of size {total}, "
|
||||
f"which is longer than the specified {self._chunk_size}"
|
||||
f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}"
|
||||
)
|
||||
if len(current_doc) > 0:
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
@ -136,13 +127,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
# - we have a larger chunk than in the chunk overlap
|
||||
# - or if we still have any chunks and the length is long
|
||||
while total > self._chunk_overlap or (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> self._chunk_size
|
||||
and total > 0
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
|
||||
):
|
||||
total -= self._length_function(current_doc[0]) + (
|
||||
separator_len if len(current_doc) > 1 else 0
|
||||
)
|
||||
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||
@ -159,28 +146,25 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise ValueError(
|
||||
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
|
||||
)
|
||||
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
|
||||
|
||||
def _huggingface_tokenizer_length(text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"Please install it with `pip install transformers`."
|
||||
"Could not import transformers python package. " "Please install it with `pip install transformers`."
|
||||
)
|
||||
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_tiktoken_encoder(
|
||||
cls: type[TS],
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
cls: type[TS],
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> TS:
|
||||
"""Text splitter that uses tiktoken encoder to count length."""
|
||||
try:
|
||||
@ -217,15 +201,11 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
|
||||
return cls(length_function=_tiktoken_encoder, **kwargs)
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -267,9 +247,7 @@ class HeaderType(TypedDict):
|
||||
class MarkdownHeaderTextSplitter:
|
||||
"""Splitting markdown files based on specified headers."""
|
||||
|
||||
def __init__(
|
||||
self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
|
||||
):
|
||||
def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False):
|
||||
"""Create a new MarkdownHeaderTextSplitter.
|
||||
|
||||
Args:
|
||||
@ -280,9 +258,7 @@ class MarkdownHeaderTextSplitter:
|
||||
self.return_each_line = return_each_line
|
||||
# Given the headers we want to split on,
|
||||
# (e.g., "#, ##, etc") order by length
|
||||
self.headers_to_split_on = sorted(
|
||||
headers_to_split_on, key=lambda split: len(split[0]), reverse=True
|
||||
)
|
||||
self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True)
|
||||
|
||||
def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
|
||||
"""Combine lines with common metadata into chunks
|
||||
@ -292,10 +268,7 @@ class MarkdownHeaderTextSplitter:
|
||||
aggregated_chunks: list[LineType] = []
|
||||
|
||||
for line in lines:
|
||||
if (
|
||||
aggregated_chunks
|
||||
and aggregated_chunks[-1]["metadata"] == line["metadata"]
|
||||
):
|
||||
if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]:
|
||||
# If the last line in the aggregated list
|
||||
# has the same metadata as the current line,
|
||||
# append the current content to the last lines's content
|
||||
@ -304,10 +277,7 @@ class MarkdownHeaderTextSplitter:
|
||||
# Otherwise, append the current line to the aggregated list
|
||||
aggregated_chunks.append(line)
|
||||
|
||||
return [
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in aggregated_chunks
|
||||
]
|
||||
return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks]
|
||||
|
||||
def split_text(self, text: str) -> list[Document]:
|
||||
"""Split markdown file
|
||||
@ -332,10 +302,9 @@ class MarkdownHeaderTextSplitter:
|
||||
for sep, name in self.headers_to_split_on:
|
||||
# Check if line starts with a header that we intend to split on
|
||||
if stripped_line.startswith(sep) and (
|
||||
# Header with no text OR header is followed by space
|
||||
# Both are valid conditions that sep is being used a header
|
||||
len(stripped_line) == len(sep)
|
||||
or stripped_line[len(sep)] == " "
|
||||
# Header with no text OR header is followed by space
|
||||
# Both are valid conditions that sep is being used a header
|
||||
len(stripped_line) == len(sep) or stripped_line[len(sep)] == " "
|
||||
):
|
||||
# Ensure we are tracking the header as metadata
|
||||
if name is not None:
|
||||
@ -343,10 +312,7 @@ class MarkdownHeaderTextSplitter:
|
||||
current_header_level = sep.count("#")
|
||||
|
||||
# Pop out headers of lower or same level from the stack
|
||||
while (
|
||||
header_stack
|
||||
and header_stack[-1]["level"] >= current_header_level
|
||||
):
|
||||
while header_stack and header_stack[-1]["level"] >= current_header_level:
|
||||
# We have encountered a new header
|
||||
# at the same or higher level
|
||||
popped_header = header_stack.pop()
|
||||
@ -359,7 +325,7 @@ class MarkdownHeaderTextSplitter:
|
||||
header: HeaderType = {
|
||||
"level": current_header_level,
|
||||
"name": name,
|
||||
"data": stripped_line[len(sep):].strip(),
|
||||
"data": stripped_line[len(sep) :].strip(),
|
||||
}
|
||||
header_stack.append(header)
|
||||
# Update initial_metadata with the current header
|
||||
@ -392,9 +358,7 @@ class MarkdownHeaderTextSplitter:
|
||||
current_metadata = initial_metadata.copy()
|
||||
|
||||
if current_content:
|
||||
lines_with_metadata.append(
|
||||
{"content": "\n".join(current_content), "metadata": current_metadata}
|
||||
)
|
||||
lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata})
|
||||
|
||||
# lines_with_metadata has each line with associated header metadata
|
||||
# aggregate these into chunks based on common metadata
|
||||
@ -402,8 +366,7 @@ class MarkdownHeaderTextSplitter:
|
||||
return self.aggregate_lines_to_chunks(lines_with_metadata)
|
||||
else:
|
||||
return [
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in lines_with_metadata
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata
|
||||
]
|
||||
|
||||
|
||||
@ -436,12 +399,12 @@ class TokenTextSplitter(TextSplitter):
|
||||
"""Splitting text to tokens using model tokenizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
@ -488,10 +451,10 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
separators: Optional[list[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
separators: Optional[list[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||
@ -508,7 +471,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
break
|
||||
if re.search(_s, text):
|
||||
separator = _s
|
||||
new_separators = separators[i + 1:]
|
||||
new_separators = separators[i + 1 :]
|
||||
break
|
||||
|
||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
||||
|
||||
Reference in New Issue
Block a user