Merge branch 'main' into feat/external-knowledge

# Conflicts:
#	api/core/rag/datasource/retrieval_service.py
#	api/models/dataset.py
#	api/services/dataset_service.py
This commit is contained in:
jyong
2024-09-18 14:40:43 +08:00
1428 changed files with 44957 additions and 30983 deletions

View File

@ -24,37 +24,42 @@ class Jieba(BaseKeyword):
self._config = KeywordTableConfig()
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None)
keywords_list = kwargs.get("keywords_list", None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = keyword_table_handler.extract_keywords(text.page_content,
self._config.max_keywords_per_chunk)
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
self._save_dataset_keyword_table(keyword_table)
@ -63,97 +68,91 @@ class Jieba(BaseKeyword):
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search(self, query: str, **kwargs: Any) -> list[Document]:
keyword_table = self._get_dataset_keyword_table()
k = kwargs.get('top_k', 4)
k = kwargs.get("top_k", 4)
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
.first()
)
if segment:
documents.append(Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
))
documents.append(
Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
)
return documents
def delete(self) -> None:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
if dataset_keyword_table.data_source_type != "database":
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
storage.delete(file_key)
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": keyword_table
}
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
}
dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type
if keyword_data_source_type == 'database':
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
else:
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict['__data__']['table']
return keyword_table_dict["__data__"]["table"]
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table='',
keyword_table="",
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == 'database':
dataset_keyword_table.keyword_table = json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(
{
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
},
cls=SetEncoder,
)
db.session.add(dataset_keyword_table)
db.session.commit()
@ -174,9 +173,7 @@ class Jieba(BaseKeyword):
keywords_to_delete = set()
for keyword, node_idxs in keyword_table.items():
if node_idxs_to_delete.intersection(node_idxs):
keyword_table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete)
if not keyword_table[keyword]:
keywords_to_delete.add(keyword)
@ -202,13 +199,14 @@ class Jieba(BaseKeyword):
reverse=True,
)
return sorted_chunk_indices[: k]
return sorted_chunk_indices[:k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id == node_id
).first()
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
)
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)
@ -224,14 +222,14 @@ class Jieba(BaseKeyword):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data['segment']
if pre_segment_data['keywords']:
segment.keywords = pre_segment_data['keywords']
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
pre_segment_data['keywords'])
segment = pre_segment_data["segment"]
if pre_segment_data["keywords"]:
segment.keywords = pre_segment_data["keywords"]
keyword_table = self._add_text_to_keyword_table(
keyword_table, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keywords = keyword_table_handler.extract_keywords(segment.content,
self._config.max_keywords_per_chunk)
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
self._save_dataset_keyword_table(keyword_table)

View File

@ -8,7 +8,6 @@ from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
@ -30,4 +29,4 @@ class JiebaKeywordTableHandler:
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
return results

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,6 @@ from models.dataset import Dataset
class BaseKeyword(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@ -31,15 +30,12 @@ class BaseKeyword(ABC):
def delete(self) -> None:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search(self, query: str, **kwargs: Any) -> list[Document]:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
doc_id = text.metadata['doc_id']
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
@ -47,4 +43,4 @@ class BaseKeyword(ABC):
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
return [text.metadata["doc_id"] for text in texts]

View File

@ -20,9 +20,7 @@ class Keyword:
raise ValueError("Keyword store must be specified.")
if keyword_type == "jieba":
return Jieba(
dataset=self._dataset
)
return Jieba(dataset=self._dataset)
else:
raise ValueError(f"Keyword store {keyword_type} is not supported.")
@ -41,10 +39,7 @@ class Keyword:
def delete(self) -> None:
self._keyword_processor.delete()
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search(self, query: str, **kwargs: Any) -> list[Document]:
return self._keyword_processor.search(query, **kwargs)
def __getattr__(self, name):

View File

@ -7,25 +7,21 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import Dataset
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0,
@ -113,95 +109,104 @@ class RetrievalService:
return all_documents
@classmethod
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, all_documents: list, exceptions: list):
def keyword_search(
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
):
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
keyword = Keyword(
dataset=dataset
)
keyword = Keyword(dataset=dataset)
documents = keyword.search(
cls.escape_query_for_search(query),
top_k=top_k
)
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@classmethod
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, retrival_method: str, exceptions: list):
def embedding_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
):
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
vector = Vector(
dataset=dataset
)
vector = Vector(dataset=dataset)
documents = vector.search_by_vector(
cls.escape_query_for_search(query),
search_type='similarity_score_threshold',
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
}
filter={"group_id": [dataset.id]},
)
if documents:
if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents)
))
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
)
)
else:
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@classmethod
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, retrival_method: str, exceptions: list):
def full_text_index_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
):
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
vector_processor = Vector(
dataset=dataset,
)
documents = vector_processor.search_by_full_text(
cls.escape_query_for_search(query),
top_k=top_k
)
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
if documents:
if reranking_model and reranking_model.get('reranking_model_name') and reranking_model.get('reranking_provider_name') and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id),
RerankMode.RERANKING_MODEL.value,
reranking_model, None, False)
all_documents.extend(data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents)
))
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
)
)
else:
all_documents.extend(documents)
except Exception as e:
@ -209,4 +214,4 @@ class RetrievalService:
@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')
return query.replace('"', '\\"')

View File

@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel):
namespace_password: str = (None,)
metrics: str = ("cosine",)
read_timeout: int = 60000
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
@ -37,6 +38,7 @@ class AnalyticdbConfig(BaseModel):
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector):
_instance = None
_init = False
@ -57,9 +59,7 @@ class AnalyticdbVector(BaseVector):
except:
raise ImportError(_import_err_msg)
self.config = config
self._client_config = open_api_models.Config(
user_agent="dify", **config.to_analyticdb_client_params()
)
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
AnalyticdbVector._init = True
@ -77,6 +77,7 @@ class AnalyticdbVector(BaseVector):
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -88,6 +89,7 @@ class AnalyticdbVector(BaseVector):
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
@ -109,13 +111,12 @@ class AnalyticdbVector(BaseVector):
)
self._client.create_namespace(request)
else:
raise ValueError(
f"failed to create namespace {self.config.namespace}: {e}"
)
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
@ -149,9 +150,7 @@ class AnalyticdbVector(BaseVector):
)
self._client.create_collection(request)
else:
raise ValueError(
f"failed to create collection {self._collection_name}: {e}"
)
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def get_type(self) -> str:
@ -162,10 +161,9 @@ class AnalyticdbVector(BaseVector):
self._create_collection_if_not_exists(dimension)
self.add_texts(texts, embeddings)
def add_texts(
self, documents: list[Document], embeddings: list[list[float]], **kwargs
):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
@ -191,6 +189,7 @@ class AnalyticdbVector(BaseVector):
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -202,13 +201,14 @@ class AnalyticdbVector(BaseVector):
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'"
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
@ -224,6 +224,7 @@ class AnalyticdbVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -235,15 +236,10 @@ class AnalyticdbVector(BaseVector):
)
self._client.delete_collection_data(request)
def search_by_vector(
self, query_vector: list[float], **kwargs: Any
) -> list[Document]:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = (
kwargs.get("score_threshold", 0.0)
if kwargs.get("score_threshold", 0.0)
else 0.0
)
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -270,11 +266,8 @@ class AnalyticdbVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = (
kwargs.get("score_threshold", 0.0)
if kwargs.get("score_threshold", 0.0)
else 0.0
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@ -304,6 +297,7 @@ class AnalyticdbVector(BaseVector):
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
@ -315,19 +309,16 @@ class AnalyticdbVector(BaseVector):
except Exception as e:
raise e
class AnalyticdbVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"][
"class_prefix"
]
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
# handle optional params
if dify_config.ANALYTICDB_KEY_ID is None:

View File

@ -27,21 +27,20 @@ class ChromaConfig(BaseModel):
settings = Settings(
# auth
chroma_client_auth_provider=self.auth_provider,
chroma_client_auth_credentials=self.auth_credentials
chroma_client_auth_credentials=self.auth_credentials,
)
return {
'host': self.host,
'port': self.port,
'ssl': False,
'tenant': self.tenant,
'database': self.database,
'settings': settings,
"host": self.host,
"port": self.port,
"ssl": False,
"tenant": self.tenant,
"database": self.database,
"settings": settings,
}
class ChromaVector(BaseVector):
def __init__(self, collection_name: str, config: ChromaConfig):
super().__init__(collection_name)
self._client_config = config
@ -58,9 +57,9 @@ class ChromaVector(BaseVector):
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str):
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
lock_name = "vector_indexing_lock_{}".format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
self._client.get_or_create_collection(collection_name)
@ -76,7 +75,7 @@ class ChromaVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(where={key: {'$eq': value}})
collection.delete(where={key: {"$eq": value}})
def delete(self):
self._client.delete_collection(self._collection_name)
@ -93,26 +92,26 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = float(kwargs.get("score_threshold") or 0.0)
ids: list[str] = results['ids'][0]
documents: list[str] = results['documents'][0]
metadatas: dict[str, Any] = results['metadatas'][0]
distances: list[float] = results['distances'][0]
ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0]
metadatas: dict[str, Any] = results["metadatas"][0]
distances: list[float] = results["distances"][0]
docs = []
for index in range(len(ids)):
distance = distances[index]
metadata = metadatas[index]
if distance >= score_threshold:
metadata['score'] = distance
metadata["score"] = distance
doc = Document(
page_content=documents[index],
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@ -123,15 +122,12 @@ class ChromaVector(BaseVector):
class ChromaVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {
"type": VectorType.CHROMA,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
return ChromaVector(

View File

@ -26,15 +26,16 @@ class ElasticSearchConfig(BaseModel):
username: str
password: str
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config PORT is required")
if not values['username']:
if not values["username"]:
raise ValueError("config USERNAME is required")
if not values['password']:
if not values["password"]:
raise ValueError("config PASSWORD is required")
return values
@ -50,10 +51,10 @@ class ElasticSearchVector(BaseVector):
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try:
parsed_url = urlparse(config.host)
if parsed_url.scheme in ['http', 'https']:
hosts = f'{config.host}:{config.port}'
if parsed_url.scheme in {"http", "https"}:
hosts = f"{config.host}:{config.port}"
else:
hosts = f'http://{config.host}:{config.port}'
hosts = f"http://{config.host}:{config.port}"
client = Elasticsearch(
hosts=hosts,
basic_auth=(config.username, config.password),
@ -68,45 +69,41 @@ class ElasticSearchVector(BaseVector):
def _get_version(self) -> str:
info = self._client.info()
return info['version']['number']
return info["version"]["number"]
def _check_version(self):
if self._version < '8.0.0':
if self._version < "8.0.0":
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
def get_type(self) -> str:
return 'elasticsearch'
return "elasticsearch"
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
for i in range(len(documents)):
self._client.index(index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
})
self._client.index(
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] or None,
Field.METADATA_KEY.value: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
return uuids
def text_exists(self, id: str) -> bool:
return self._client.exists(index=self._collection_name, id=id).__bool__()
return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None:
for id in ids:
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str) -> None:
query_str = {
'query': {
'match': {
f'metadata.{key}': f'{value}'
}
}
}
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit['_id'] for hit in results['hits']['hits']]
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids:
self.delete_by_ids(ids)
@ -115,44 +112,44 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 10)
knn = {
"field": Field.VECTOR.value,
"query_vector": query_vector,
"k": top_k
}
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k}
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
docs_and_scores = []
for hit in results['hits']['hits']:
for hit in results["hits"]["hits"]:
docs_and_scores.append(
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
),
hit["_score"],
)
)
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
doc.metadata['score'] = score
doc.metadata["score"] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {
"match": {
Field.CONTENT_KEY.value: query
}
}
query_str = {"match": {Field.CONTENT_KEY.value: query}}
results = self._client.search(index=self._collection_name, query=query_str)
docs = []
for hit in results['hits']['hits']:
docs.append(Document(
page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value],
))
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
)
)
return docs
@ -162,11 +159,11 @@ class ElasticSearchVector(BaseVector):
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name}'
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
@ -179,14 +176,14 @@ class ElasticSearchVector(BaseVector):
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"similarity": "cosine"
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
},
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mappings)
@ -197,22 +194,21 @@ class ElasticSearchVector(BaseVector):
class ElasticSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get('ELASTICSEARCH_HOST'),
port=config.get('ELASTICSEARCH_PORT'),
username=config.get('ELASTICSEARCH_USERNAME'),
password=config.get('ELASTICSEARCH_PASSWORD'),
host=config.get("ELASTICSEARCH_HOST"),
port=config.get("ELASTICSEARCH_PORT"),
username=config.get("ELASTICSEARCH_USERNAME"),
password=config.get("ELASTICSEARCH_PASSWORD"),
),
attributes=[]
attributes=[],
)

View File

@ -1,10 +1,9 @@
import json
import logging
from typing import Any, Optional
from uuid import uuid4
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException, connections
from pymilvus import MilvusClient, MilvusException
from pymilvus.milvus_client import IndexParams
from configs import dify_config
@ -21,55 +20,47 @@ logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
host: str
port: int
uri: str
token: Optional[str] = None
user: str
password: str
secure: bool = False
batch_size: int = 100
database: str = "default"
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values.get('host'):
raise ValueError("config MILVUS_HOST is required")
if not values.get('port'):
raise ValueError("config MILVUS_PORT is required")
if not values.get('user'):
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("user"):
raise ValueError("config MILVUS_USER is required")
if not values.get('password'):
if not values.get("password"):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure,
'db_name': self.database,
"uri": self.uri,
"token": self.token,
"user": self.user,
"password": self.password,
"db_name": self.database,
}
class MilvusVector(BaseVector):
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = 'Session'
self._consistency_level = "Session"
self._fields = []
def get_type(self) -> str:
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
@ -80,7 +71,7 @@ class MilvusVector(BaseVector):
insert_dict = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata
Field.METADATA_KEY.value: documents[i].metadata,
}
insert_dict_list.append(insert_dict)
# Total insert count
@ -89,111 +80,70 @@ class MilvusVector(BaseVector):
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i:i + 1000]
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection.
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count)
raise e
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["{key}"] == "{value}"',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
db_name=self._client_config.database)
from pymilvus import utility
if utility.has_collection(self._collection_name, using=alias):
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]) -> None:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
db_name=self._client_config.database)
from pymilvus import utility
if utility.has_collection(self._collection_name, using=alias):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {ids}',
output_fields=["id"])
if self._client.has_collection(self._collection_name):
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
)
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
db_name=self._client_config.database)
from pymilvus import utility
if utility.has_collection(self._collection_name, using=alias):
utility.drop_collection(self._collection_name, None, using=alias)
if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
db_name=self._client_config.database)
from pymilvus import utility
if not utility.has_collection(self._collection_name, using=alias):
if not self._client.has_collection(self._collection_name):
return False
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] == "{id}"',
output_fields=["id"])
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"]
)
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
results = self._client.search(collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get('top_k', 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result['entity'].get(Field.METADATA_KEY.value)
metadata['score'] = result['distance']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if result['distance'] > score_threshold:
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
metadata=metadata)
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
@ -202,23 +152,15 @@ class MilvusVector(BaseVector):
return []
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
# Grab the existing collection if it exists
from pymilvus import utility
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user,
password=self._client_config.password, db_name=self._client_config.database)
if not utility.has_collection(self._collection_name, using=alias):
if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata
@ -229,19 +171,11 @@ class MilvusVector(BaseVector):
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
)
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the primary key field
fields.append(
FieldSchema(
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
)
)
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
)
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create the schema for the collection
schema = CollectionSchema(fields)
@ -257,39 +191,36 @@ class MilvusVector(BaseVector):
# Create the collection
collection_name = self._collection_name
self._client.create_collection(collection_name=collection_name,
schema=schema, index_params=index_params_obj,
consistency_level=self._consistency_level)
self._client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
if config.secure:
uri = "https://" + str(config.host) + ":" + str(config.port)
else:
uri = "http://" + str(config.host) + ":" + str(config.port)
client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database)
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client
class MilvusVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
host=dify_config.MILVUS_HOST,
port=dify_config.MILVUS_PORT,
uri=dify_config.MILVUS_URI,
token=dify_config.MILVUS_TOKEN,
user=dify_config.MILVUS_USER,
password=dify_config.MILVUS_PASSWORD,
secure=dify_config.MILVUS_SECURE,
database=dify_config.MILVUS_DATABASE,
)
),
)

View File

@ -31,12 +31,11 @@ class SortOrder(Enum):
class MyScaleVector(BaseVector):
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
super().__init__(collection_name)
self._config = config
self._metric = metric
self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC
self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC
self._client = get_client(
host=config.host,
port=config.port,
@ -80,7 +79,7 @@ class MyScaleVector(BaseVector):
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {}
json.dumps(doc.metadata) if doc.metadata else {},
)
values.append(str(row))
ids.append(doc_id)
@ -93,7 +92,7 @@ class MyScaleVector(BaseVector):
@staticmethod
def escape_str(value: Any) -> str:
return "".join(" " if c in ("\\", "'") else c for c in str(value))
return "".join(" " if c in {"\\", "'"} else c for c in str(value))
def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
@ -101,7 +100,8 @@ class MyScaleVector(BaseVector):
def delete_by_ids(self, ids: list[str]) -> None:
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
)
def get_ids_by_metadata_field(self, key: str, value: str):
rows = self._client.query(
@ -122,9 +122,12 @@ class MyScaleVector(BaseVector):
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get('score_threshold') or 0.0
where_str = f"WHERE dist < {1 - score_threshold}" if \
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
score_threshold = float(kwargs.get("score_threshold") or 0.0)
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
else ""
)
sql = f"""
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
@ -133,7 +136,7 @@ class MyScaleVector(BaseVector):
return [
Document(
page_content=r["text"],
vector=r['vector'],
vector=r["vector"],
metadata=r["metadata"],
)
for r in self._client.query(sql).named_results()
@ -149,13 +152,12 @@ class MyScaleVector(BaseVector):
class MyScaleVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
return MyScaleVector(
collection_name=collection_name,

View File

@ -28,11 +28,12 @@ class OpenSearchConfig(BaseModel):
password: Optional[str] = None
secure: bool = False
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values.get('host'):
if not values.get("host"):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get('port'):
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
return values
@ -44,19 +45,18 @@ class OpenSearchConfig(BaseModel):
def to_opensearch_params(self) -> dict[str, Any]:
params = {
'hosts': [{'host': self.host, 'port': self.port}],
'use_ssl': self.secure,
'verify_certs': self.secure,
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
}
if self.user and self.password:
params['http_auth'] = (self.user, self.password)
params["http_auth"] = (self.user, self.password)
if self.secure:
params['ssl_context'] = self.create_ssl_context()
params["ssl_context"] = self.create_ssl_context()
return params
class OpenSearchVector(BaseVector):
def __init__(self, collection_name: str, config: OpenSearchConfig):
super().__init__(collection_name)
self._client_config = config
@ -81,7 +81,7 @@ class OpenSearchVector(BaseVector):
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
},
}
actions.append(action)
@ -90,8 +90,8 @@ class OpenSearchVector(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
response = self._client.search(index=self._collection_name.lower(), body=query)
if response['hits']['hits']:
return [hit['_id'] for hit in response['hits']['hits']]
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
else:
return None
@ -110,7 +110,7 @@ class OpenSearchVector(BaseVector):
actual_ids = []
for doc_id in ids:
es_ids = self.get_ids_by_metadata_field('doc_id', doc_id)
es_ids = self.get_ids_by_metadata_field("doc_id", doc_id)
if es_ids:
actual_ids.extend(es_ids)
else:
@ -122,9 +122,9 @@ class OpenSearchVector(BaseVector):
helpers.bulk(self._client, actions)
except BulkIndexError as e:
for error in e.errors:
delete_error = error.get('delete', {})
status = delete_error.get('status')
doc_id = delete_error.get('_id')
delete_error = error.get("delete", {})
status = delete_error.get("status")
doc_id = delete_error.get("_id")
if status == 404:
logger.warning(f"Document not found for deletion: {doc_id}")
@ -151,15 +151,8 @@ class OpenSearchVector(BaseVector):
raise ValueError("All elements in query_vector should be floats")
query = {
"size": kwargs.get('top_k', 4),
"query": {
"knn": {
Field.VECTOR.value: {
Field.VECTOR.value: query_vector,
"k": kwargs.get('top_k', 4)
}
}
}
"size": kwargs.get("top_k", 4),
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
}
try:
@ -169,17 +162,17 @@ class OpenSearchVector(BaseVector):
raise
docs = []
for hit in response['hits']['hits']:
metadata = hit['_source'].get(Field.METADATA_KEY.value, {})
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
# Make sure metadata is a dictionary
if metadata is None:
metadata = {}
metadata['score'] = hit['_score']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if hit['_score'] > score_threshold:
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] > score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs
@ -190,32 +183,28 @@ class OpenSearchVector(BaseVector):
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
docs = []
for hit in response['hits']['hits']:
metadata = hit['_source'].get(Field.METADATA_KEY.value)
vector = hit['_source'].get(Field.VECTOR.value)
page_content = hit['_source'].get(Field.CONTENT_KEY.value)
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value)
vector = hit["_source"].get(Field.VECTOR.value)
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
return docs
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name.lower()}'
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}'
collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name.lower()} already exists.")
return
if not self._client.indices.exists(index=self._collection_name.lower()):
index_body = {
"settings": {
"index": {
"knn": True
}
},
"settings": {"index": {"knn": True}},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
@ -226,20 +215,17 @@ class OpenSearchVector(BaseVector):
"name": "hnsw",
"space_type": "l2",
"engine": "faiss",
"parameters": {
"ef_construction": 64,
"m": 8
}
}
"parameters": {"ef_construction": 64, "m": 8},
},
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
},
},
}
}
},
}
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
@ -248,17 +234,14 @@ class OpenSearchVector(BaseVector):
class OpenSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST,
@ -268,7 +251,4 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
secure=dify_config.OPENSEARCH_SECURE,
)
return OpenSearchVector(
collection_name=collection_name,
config=open_search_config
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)

View File

@ -31,7 +31,8 @@ class OracleVectorConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ORACLE_HOST is required")
@ -103,9 +104,16 @@ class OracleVector(BaseVector):
arraysize=cursor.arraysize,
outconverter=self.numpy_converter_out,
)
def _create_connection_pool(self, config: OracleVectorConfig):
return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1)
def _create_connection_pool(self, config: OracleVectorConfig):
return oracledb.create_pool(
user=config.user,
password=config.password,
dsn="{}:{}/{}".format(config.host, config.port, config.database),
min=1,
max=50,
increment=1,
)
@contextmanager
def _get_cursor(self):
@ -136,13 +144,15 @@ class OracleVector(BaseVector):
doc_id,
doc.page_content,
json.dumps(doc.metadata),
#array.array("f", embeddings[i]),
# array.array("f", embeddings[i]),
numpy.array(embeddings[i]),
)
)
#print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
with self._get_cursor() as cur:
cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values)
cur.executemany(
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
)
return pks
def text_exists(self, id: str) -> bool:
@ -157,7 +167,8 @@ class OracleVector(BaseVector):
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
#def get_ids_by_metadata_field(self, key: str, value: str):
# def get_ids_by_metadata_field(self, key: str, value: str):
# with self._get_cursor() as cur:
# cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" )
# idss = []
@ -184,10 +195,12 @@ class OracleVector(BaseVector):
top_k = kwargs.get("top_k", 5)
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)]
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
f" ORDER BY distance fetch first {top_k} rows only",
[numpy.array(query_vector)],
)
docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
@ -199,10 +212,10 @@ class OracleVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile('[\u4e00-\u9fa5]+')
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
match = zh_pattern.search(query)
entities = []
# match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split
@ -210,7 +223,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
current_entity += word
else:
if current_entity:
@ -220,22 +233,23 @@ class OracleVector(BaseVector):
entities.append(current_entity)
else:
try:
nltk.data.find('tokenizers/punkt')
nltk.data.find('corpora/stopwords')
nltk.data.find("tokenizers/punkt")
nltk.data.find("corpora/stopwords")
except LookupError:
nltk.download('punkt')
nltk.download('stopwords')
nltk.download("punkt")
nltk.download("stopwords")
print("run download")
e_str = re.sub(r'[^\w ]', '', query)
e_str = re.sub(r"[^\w ]", "", query)
all_tokens = nltk.word_tokenize(e_str)
stop_words = stopwords.words('english')
stop_words = stopwords.words("english")
for token in all_tokens:
if token not in stop_words:
entities.append(token)
with self._get_cursor() as cur:
cur.execute(
f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
[" ACCUM ".join(entities)]
f"select meta, text, embedding FROM {self.table_name}"
f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
[" ACCUM ".join(entities)],
)
docs = []
for record in cur:
@ -273,8 +287,7 @@ class OracleVectorFactory(AbstractVectorFactory):
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
return OracleVector(
collection_name=collection_name,

View File

@ -31,27 +31,29 @@ class PgvectoRSConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config PGVECTO_RS_HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config PGVECTO_RS_PORT is required")
if not values['user']:
if not values["user"]:
raise ValueError("config PGVECTO_RS_USER is required")
if not values['password']:
if not values["password"]:
raise ValueError("config PGVECTO_RS_PASSWORD is required")
if not values['database']:
if not values["database"]:
raise ValueError("config PGVECTO_RS_DATABASE is required")
return values
class PGVectoRS(BaseVector):
def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
super().__init__(collection_name)
self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._url = (
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self._client = create_engine(self._url)
with Session(self._client) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
@ -80,9 +82,9 @@ class PGVectoRS(BaseVector):
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
@ -133,9 +135,7 @@ class PGVectoRS(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; "
)
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ")
result = session.execute(select_statement).fetchall()
if result:
return [item[0] for item in result]
@ -143,12 +143,11 @@ class PGVectoRS(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.execute(select_statement, {"ids": ids})
session.commit()
def delete_by_ids(self, ids: list[str]) -> None:
@ -156,13 +155,13 @@ class PGVectoRS(BaseVector):
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
)
result = session.execute(select_statement, {'doc_ids': ids}).fetchall()
result = session.execute(select_statement, {"doc_ids": ids}).fetchall()
if result:
ids = [item[0] for item in result]
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {'ids': ids})
session.execute(select_statement, {"ids": ids})
session.commit()
def delete(self) -> None:
@ -187,7 +186,7 @@ class PGVectoRS(BaseVector):
query_vector,
).label("distance"),
)
.limit(kwargs.get('top_k', 2))
.limit(kwargs.get("top_k", 2))
.order_by("distance")
)
res = session.execute(stmt)
@ -198,11 +197,10 @@ class PGVectoRS(BaseVector):
for record, dis in results:
metadata = record.meta
score = 1 - dis
metadata['score'] = score
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
metadata["score"] = score
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
doc = Document(page_content=record.text,
metadata=metadata)
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)
return docs
@ -225,13 +223,12 @@ class PGVectoRS(BaseVector):
class PGVectoRSFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dim = len(embeddings.embed_query("pgvecto_rs"))
return PGVectoRS(
@ -243,5 +240,5 @@ class PGVectoRSFactory(AbstractVectorFactory):
password=dify_config.PGVECTO_RS_PASSWORD,
database=dify_config.PGVECTO_RS_DATABASE,
),
dim=dim
dim=dim,
)

View File

@ -24,7 +24,8 @@ class PGVectorConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required")
@ -138,11 +139,12 @@ class PGVector(BaseVector):
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name} ORDER BY distance LIMIT {top_k}",
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
f" ORDER BY distance LIMIT {top_k}",
(json.dumps(query_vector),),
)
docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
@ -201,8 +203,7 @@ class PGVectorFactory(AbstractVectorFactory):
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
return PGVector(
collection_name=collection_name,

View File

@ -48,28 +48,25 @@ class QdrantConfig(BaseModel):
prefer_grpc: bool = False
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
return {"path": path}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout,
'verify': self.endpoint.startswith('https'),
'grpc_port': self.grpc_port,
'prefer_grpc': self.prefer_grpc
"url": self.endpoint,
"api_key": self.api_key,
"timeout": self.timeout,
"verify": self.endpoint.startswith("https"),
"grpc_port": self.grpc_port,
"prefer_grpc": self.prefer_grpc,
}
class QdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
@ -80,10 +77,7 @@ class QdrantVector(BaseVector):
return VectorType.QDRANT
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
@ -97,9 +91,9 @@ class QdrantVector(BaseVector):
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
lock_name = "vector_indexing_lock_{}".format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
collection_name = collection_name or uuid.uuid4().hex
@ -110,12 +104,19 @@ class QdrantVector(BaseVector):
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False)
hnsw_config = HnswConfigDiff(
m=0,
payload_m=16,
ef_construct=100,
full_scan_threshold=10000,
max_indexing_threads=0,
on_disk=False,
)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
@ -124,21 +125,24 @@ class QdrantVector(BaseVector):
)
# create group_id payload index
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
field_schema=PayloadSchemaType.KEYWORD)
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(collection_name, Field.DOC_ID.value,
field_schema=PayloadSchemaType.KEYWORD)
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -147,26 +151,23 @@ class QdrantVector(BaseVector):
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, metadatas, uuids, 64, self._group_id
):
self._client.upsert(
collection_name=self._collection_name, points=points
)
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
@ -203,13 +204,13 @@ class QdrantVector(BaseVector):
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str,
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
@ -219,18 +220,11 @@ class QdrantVector(BaseVector):
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append(
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
@ -248,9 +242,7 @@ class QdrantVector(BaseVector):
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
@ -275,9 +267,7 @@ class QdrantVector(BaseVector):
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
@ -288,7 +278,6 @@ class QdrantVector(BaseVector):
raise e
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
@ -304,9 +293,7 @@ class QdrantVector(BaseVector):
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
@ -324,15 +311,13 @@ class QdrantVector(BaseVector):
all_collection_name.append(collection.name)
if self._collection_name not in all_collection_name:
return False
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[id]
)
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
@ -348,22 +333,22 @@ class QdrantVector(BaseVector):
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", .0)
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold:
metadata['score'] = result.score
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@ -372,6 +357,7 @@ class QdrantVector(BaseVector):
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = models.Filter(
must=[
models.FieldCondition(
@ -381,24 +367,21 @@ class QdrantVector(BaseVector):
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
),
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get('top_k', 2),
limit=kwargs.get("top_k", 2),
with_payload=True,
with_vectors=True
with_vectors=True,
)
results = response[0]
documents = []
for result in results:
if result:
document = self._document_from_scored_point(
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
)
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
documents.append(document)
return documents
@ -410,10 +393,10 @@ class QdrantVector(BaseVector):
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
@ -425,24 +408,25 @@ class QdrantVector(BaseVector):
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
raise ValueError("Dataset Collection Bindings is not exist!")
else:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
if not dataset.index_struct_dict:
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
config = current_app.config
return QdrantVector(
@ -454,6 +438,6 @@ class QdrantVectorFactory(AbstractVectorFactory):
root_path=config.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
)
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
),
)

View File

@ -33,28 +33,30 @@ class RelytConfig(BaseModel):
password: str
database: str
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config RELYT_HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config RELYT_PORT is required")
if not values['user']:
if not values["user"]:
raise ValueError("config RELYT_USER is required")
if not values['password']:
if not values["password"]:
raise ValueError("config RELYT_PASSWORD is required")
if not values['database']:
if not values["database"]:
raise ValueError("config RELYT_DATABASE is required")
return values
class RelytVector(BaseVector):
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
super().__init__(collection_name)
self.embedding_dimension = 1536
self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._url = (
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self.client = create_engine(self._url)
self._fields = []
self._group_id = group_id
@ -70,9 +72,9 @@ class RelytVector(BaseVector):
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
@ -110,7 +112,7 @@ class RelytVector(BaseVector):
ids = [str(uuid.uuid1()) for _ in documents]
metadatas = [d.metadata for d in documents]
for metadata in metadatas:
metadata['group_id'] = self._group_id
metadata["group_id"] = self._group_id
texts = [d.page_content for d in documents]
# Define the table schema
@ -125,29 +127,26 @@ class RelytVector(BaseVector):
)
chunks_table_data = []
with self.client.connect() as conn:
with conn.begin():
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)
with self.client.connect() as conn, conn.begin():
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))
return ids
@ -186,25 +185,22 @@ class RelytVector(BaseVector):
)
try:
with self.client.connect() as conn:
with conn.begin():
delete_condition = chunks_table.c.id.in_(ids)
conn.execute(chunks_table.delete().where(delete_condition))
return True
with self.client.connect() as conn, conn.begin():
delete_condition = chunks_table.c.id.in_(ids)
conn.execute(chunks_table.delete().where(delete_condition))
return True
except Exception as e:
print("Delete operation failed:", str(e))
return False
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_uuids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
with Session(self.client) as session:
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
)
@ -228,38 +224,34 @@ class RelytVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
results = self.similarity_search_with_score_by_vector(
k=int(kwargs.get('top_k')),
embedding=query_vector,
filter=kwargs.get('filter')
k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter")
)
# Organize results.
docs = []
for document, score in results:
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if 1 - score > score_threshold:
docs.append(document)
return docs
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError(
"Could not import Row from sqlalchemy.engine. "
"Please 'pip install sqlalchemy>=1.4'."
)
raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.")
filter_condition = ""
if filter is not None:
conditions = [
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
if len(value) > 1
else f"metadata->>{key!r} = {value[0]!r}"
for key, value in filter.items()
]
@ -305,13 +297,12 @@ class RelytVector(BaseVector):
class RelytVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name))
return RelytVector(
collection_name=collection_name,
@ -322,5 +313,5 @@ class RelytVectorFactory(AbstractVectorFactory):
password=dify_config.RELYT_PASSWORD,
database=dify_config.RELYT_DATABASE,
),
group_id=dataset.id
group_id=dataset.id,
)

View File

@ -25,16 +25,11 @@ class TencentConfig(BaseModel):
database: Optional[str]
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = 1,
replicas: int = 2,
shard: int = (1,)
replicas: int = (2,)
def to_tencent_params(self):
return {
'url': self.url,
'username': self.username,
'key': self.api_key,
'timeout': self.timeout
}
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
class TencentVector(BaseVector):
@ -61,25 +56,19 @@ class TencentVector(BaseVector):
return self._client.create_database(database_name=self._client_config.database)
def get_type(self) -> str:
return 'tencent'
return "tencent"
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def _has_collection(self) -> bool:
collections = self._db.list_collections()
for collection in collections:
if collection.collection_name == self._collection_name:
return True
return False
return any(collection.collection_name == self._collection_name for collection in collections)
def _create_collection(self, dimension: int) -> None:
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
@ -101,9 +90,7 @@ class TencentVector(BaseVector):
raise ValueError("unsupported metric_type")
params = vdb_index.HNSWParams(m=16, efconstruction=200)
index = vdb_index.Index(
vdb_index.FilterIndex(
self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
),
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
vdb_index.VectorIndex(
self.field_vector,
dimension,
@ -111,12 +98,8 @@ class TencentVector(BaseVector):
metric_type,
params,
),
vdb_index.FilterIndex(
self.field_text, enum.FieldType.String, enum.IndexType.FILTER
),
vdb_index.FilterIndex(
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
),
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER),
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER),
)
self._db.create_collection(
@ -163,15 +146,14 @@ class TencentVector(BaseVector):
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
res = self._db.collection(self._collection_name).search(vectors=[query_vector],
params=document.HNSWSearchParams(
ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get('top_k', 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
res = self._db.collection(self._collection_name).search(
vectors=[query_vector],
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),
timeout=self._client_config.timeout,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@ -200,15 +182,13 @@ class TencentVector(BaseVector):
class TencentVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
return TencentVector(
collection_name=collection_name,
@ -220,5 +200,5 @@ class TencentVectorFactory(AbstractVectorFactory):
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
)
),
)

View File

@ -28,47 +28,57 @@ class TiDBVectorConfig(BaseModel):
database: str
program_name: str
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values['host']:
if not values["host"]:
raise ValueError("config TIDB_VECTOR_HOST is required")
if not values['port']:
if not values["port"]:
raise ValueError("config TIDB_VECTOR_PORT is required")
if not values['user']:
if not values["user"]:
raise ValueError("config TIDB_VECTOR_USER is required")
if not values['password']:
if not values["password"]:
raise ValueError("config TIDB_VECTOR_PASSWORD is required")
if not values['database']:
if not values["database"]:
raise ValueError("config TIDB_VECTOR_DATABASE is required")
if not values['program_name']:
if not values["program_name"]:
raise ValueError("config APPLICATION_NAME is required")
return values
class TiDBVector(BaseVector):
def get_type(self) -> str:
return VectorType.TIDB_VECTOR
def _table(self, dim: int) -> Table:
from tidb_vector.sqlalchemy import VectorType
return Table(
self._collection_name,
self._orm_base.metadata,
Column('id', String(36), primary_key=True, nullable=False),
Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
Column("id", String(36), primary_key=True, nullable=False),
Column(
"vector",
VectorType(dim),
nullable=False,
comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
),
Column("text", TEXT, nullable=False),
Column("meta", JSON, nullable=False),
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
extend_existing=True
Column(
"update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
),
extend_existing=True,
)
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"):
super().__init__(collection_name)
self._client_config = config
self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}")
self._url = (
f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}"
)
self._distance_func = distance_func.lower()
self._engine = create_engine(self._url)
self._orm_base = declarative_base()
@ -83,9 +93,9 @@ class TiDBVector(BaseVector):
def _create_collection(self, dimension: int):
logger.info("_create_collection, collection_name " + self._collection_name)
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
with Session(self._engine) as session:
@ -114,31 +124,28 @@ class TiDBVector(BaseVector):
texts = [d.page_content for d in documents]
chunks_table_data = []
with self._engine.connect() as conn:
with conn.begin():
for id, text, meta, embedding in zip(
ids, texts, metas, embeddings
):
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
with self._engine.connect() as conn, conn.begin():
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(table).values(chunks_table_data))
return ids
def text_exists(self, id: str) -> bool:
result = self.get_ids_by_metadata_field('doc_id', id)
result = self.get_ids_by_metadata_field("doc_id", id)
return bool(result)
def delete_by_ids(self, ids: list[str]) -> None:
with Session(self._engine) as session:
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
)
@ -152,11 +159,10 @@ class TiDBVector(BaseVector):
raise ValueError("No ids provided to delete.")
table = self._table(self._dimension)
try:
with self._engine.connect() as conn:
with conn.begin():
delete_condition = table.c.id.in_(ids)
conn.execute(table.delete().where(delete_condition))
return True
with self._engine.connect() as conn, conn.begin():
delete_condition = table.c.id.in_(ids)
conn.execute(table.delete().where(delete_condition))
return True
except Exception as e:
print("Delete operation failed:", str(e))
return False
@ -179,21 +185,23 @@ class TiDBVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
filter = kwargs.get('filter')
score_threshold = float(kwargs.get("score_threshold") or 0.0)
filter = kwargs.get("filter")
distance = 1 - score_threshold
query_vector_str = ", ".join(format(x) for x in query_vector)
query_vector_str = "[" + query_vector_str + "]"
logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
logger.debug(
f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}"
)
docs = []
if self._distance_func == 'l2':
tidb_func = 'Vec_l2_distance'
elif self._distance_func == 'cosine':
tidb_func = 'Vec_Cosine_distance'
if self._distance_func == "l2":
tidb_func = "Vec_l2_distance"
elif self._distance_func == "cosine":
tidb_func = "Vec_Cosine_distance"
else:
tidb_func = 'Vec_Cosine_distance'
tidb_func = "Vec_Cosine_distance"
with Session(self._engine) as session:
select_statement = sql_text(
@ -208,7 +216,7 @@ class TiDBVector(BaseVector):
results = [(row[0], row[1], row[2]) for row in res]
for meta, text, distance in results:
metadata = json.loads(meta)
metadata['score'] = 1 - distance
metadata["score"] = 1 - distance
docs.append(Document(page_content=text, metadata=metadata))
return docs
@ -224,15 +232,13 @@ class TiDBVector(BaseVector):
class TiDBVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
return TiDBVector(
collection_name=collection_name,

View File

@ -7,7 +7,6 @@ from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name
@ -39,26 +38,19 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def search_by_vector(
self,
query_vector: list[float],
**kwargs: Any
) -> list[Document]:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
raise NotImplementedError
@abstractmethod
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
doc_id = text.metadata['doc_id']
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
@ -66,7 +58,7 @@ class BaseVector(ABC):
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
return [text.metadata["doc_id"] for text in texts]
@property
def collection_name(self):

View File

@ -20,17 +20,14 @@ class AbstractVectorFactory(ABC):
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
index_struct_dict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
return index_struct_dict
class Vector:
def __init__(self, dataset: Dataset, attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
@ -39,7 +36,7 @@ class Vector:
def _init_vector(self) -> BaseVector:
vector_type = dify_config.VECTOR_STORE
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
vector_type = self._dataset.index_struct_dict["type"]
if not vector_type:
raise ValueError("Vector store must be specified.")
@ -52,45 +49,59 @@ class Vector:
match vector_type:
case VectorType.CHROMA:
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
return ChromaVectorFactory
case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
return PGVectoRSFactory
case VectorType.QDRANT:
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
return QdrantVectorFactory
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.ORACLE:
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
return OracleVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
case VectorType.ANALYTICDB:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
@ -98,21 +109,14 @@ class Vector:
def create(self, texts: list = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(
texts=texts,
embeddings=embeddings,
**kwargs
)
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get('duplicate_check', False):
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.create(
texts=documents,
embeddings=embeddings,
**kwargs
)
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
@ -123,24 +127,18 @@ class Vector:
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self) -> None:
self._vector_processor.delete()
# delete collection redis cache
if self._vector_processor.collection_name:
collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name)
redis_client.delete(collection_exist_cache_key)
def _get_embeddings(self) -> Embeddings:
@ -150,14 +148,13 @@ class Vector:
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model
model=self._dataset.embedding_model,
)
return CacheEmbedding(embedding_model)
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
doc_id = text.metadata['doc_id']
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)

View File

@ -2,17 +2,17 @@ from enum import Enum
class VectorType(str, Enum):
ANALYTICDB = 'analyticdb'
CHROMA = 'chroma'
MILVUS = 'milvus'
MYSCALE = 'myscale'
PGVECTOR = 'pgvector'
PGVECTO_RS = 'pgvecto-rs'
QDRANT = 'qdrant'
RELYT = 'relyt'
TIDB_VECTOR = 'tidb_vector'
WEAVIATE = 'weaviate'
OPENSEARCH = 'opensearch'
TENCENT = 'tencent'
ORACLE = 'oracle'
ELASTICSEARCH = 'elasticsearch'
ANALYTICDB = "analyticdb"
CHROMA = "chroma"
MILVUS = "milvus"
MYSCALE = "myscale"
PGVECTOR = "pgvector"
PGVECTO_RS = "pgvecto-rs"
QDRANT = "qdrant"
RELYT = "relyt"
TIDB_VECTOR = "tidb_vector"
WEAVIATE = "weaviate"
OPENSEARCH = "opensearch"
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"

View File

@ -22,15 +22,15 @@ class WeaviateConfig(BaseModel):
api_key: Optional[str] = None
batch_size: int = 100
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
super().__init__(collection_name)
self._client = self._init_client(config)
@ -43,10 +43,7 @@ class WeaviateVector(BaseVector):
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
@ -68,10 +65,10 @@ class WeaviateVector(BaseVector):
def get_collection_name(self, dataset: Dataset) -> str:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
if not class_prefix.endswith("_Node"):
# original class_prefix
class_prefix += '_Node'
class_prefix += "_Node"
return class_prefix
@ -79,10 +76,7 @@ class WeaviateVector(BaseVector):
return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
# create collection
@ -91,9 +85,9 @@ class WeaviateVector(BaseVector):
self.add_texts(texts, embeddings)
def _create_collection(self):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
schema = self._default_schema(self._collection_name)
@ -129,17 +123,9 @@ class WeaviateVector(BaseVector):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal")
def delete(self):
# check whether the index already exists
@ -154,11 +140,19 @@ class WeaviateVector(BaseVector):
# check whether the index already exists
if not self._client.schema.contains(schema):
return False
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}).with_limit(1).do()
result = (
self._client.query.get(collection_name)
.with_additional(["id"])
.with_where(
{
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}
)
.with_limit(1)
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
@ -211,13 +205,13 @@ class WeaviateVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score > score_threshold:
doc.metadata['score'] = score
doc.metadata["score"] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@ -240,15 +234,15 @@ class WeaviateVector(BaseVector):
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
query_obj = query_obj.with_additional(["vector"])
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
additional = res.pop('_additional')
docs.append(Document(page_content=text, vector=additional['vector'], metadata=res))
additional = res.pop("_additional")
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
return docs
def _default_schema(self, index_name: str) -> dict:
@ -271,20 +265,19 @@ class WeaviateVector(BaseVector):
class WeaviateVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT,
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),
attributes=attributes
attributes=attributes,
)