Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly
2024-12-16 14:29:05 +08:00
371 changed files with 10899 additions and 6959 deletions

View File

@ -1,13 +1,10 @@
import copy
import json
import logging
from collections.abc import Iterable
from typing import Any, Optional
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_fixed
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -20,15 +17,18 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("lindorm").setLevel(logging.WARN)
ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index"
class LindormVectorStoreConfig(BaseModel):
hosts: str
username: Optional[str] = None
password: Optional[str] = None
using_ugc: Optional[bool] = False
@model_validator(mode="before")
@classmethod
@ -42,9 +42,7 @@ class LindormVectorStoreConfig(BaseModel):
return values
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts,
}
params = {"hosts": self.hosts}
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params
@ -52,9 +50,21 @@ class LindormVectorStoreConfig(BaseModel):
class LindormVectorStore(BaseVector):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
super().__init__(collection_name.lower())
self._routing = None
self._routing_field = None
if config.using_ugc:
routing_value: str = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
self._routing = routing_value.lower()
self._routing_field = ROUTING_FIELD
ugc_index_name = collection_name
super().__init__(ugc_index_name.lower())
else:
super().__init__(collection_name.lower())
self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params())
self._using_ugc = config.using_ugc
self.kwargs = kwargs
def get_type(self) -> str:
@ -67,94 +77,37 @@ class LindormVectorStore(BaseVector):
def refresh(self):
self._client.indices.refresh(index=self._collection_name)
def __filter_existed_ids(
self,
texts: list[str],
metadatas: list[dict],
ids: list[str],
bulk_size: int = 1024,
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(index=self._collection_name, body={
"ids": batch_ids}, _source=False)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch {batch_ids}")
return set()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(
body={
"docs": [
{"_index": self._collection_name,
"_id": id, "routing": routing}
for id, routing in zip(batch_ids, route_ids)
]
},
_source=False,
)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch ids: {batch_ids}")
return set()
if ids is None:
return texts, metadatas, ids
if len(texts) != len(ids):
raise RuntimeError(f"texts {len(texts)} != {ids}")
filtered_texts = []
filtered_metadatas = []
filtered_ids = []
def batch(iterable, n):
length = len(iterable)
for idx in range(0, length, n):
yield iterable[idx: min(idx + n, length)]
for ids_batch, texts_batch, metadatas_batch in zip(
batch(ids, bulk_size),
batch(texts, bulk_size),
batch(metadatas, bulk_size) if metadatas is not None else batch(
[None] * len(ids), bulk_size),
):
existing_ids_set = __fetch_existing_ids(ids_batch)
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
if doc_id not in existing_ids_set:
filtered_texts.append(text)
filtered_ids.append(doc_id)
if metadatas is not None:
filtered_metadatas.append(metadata)
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
uuids = self._get_uuids(documents)
for i in range(len(documents)):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuids[i],
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
# Make sure you pass an array here
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
},
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
}
}
actions.append(action)
bulk(self._client, actions)
self.refresh()
action_values = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]:
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}")
else:
self.refresh()
def get_ids_by_metadata_field(self, key: str, value: str):
query = {
"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
response = self._client.search(index=self._collection_name, body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
@ -162,57 +115,63 @@ class LindormVectorStore(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(
index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
params = {}
if self._using_ugc:
params["routing"] = self._routing
for id in ids:
if self._client.exists(index=self._collection_name, id=id):
self._client.delete(index=self._collection_name, id=id)
if self._client.exists(index=self._collection_name, id=id, params=params):
params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else:
logger.warning(
f"DELETE BY ID: ID {id} does not exist in the index.")
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
def delete(self) -> None:
try:
if self._using_ugc:
routing_filter_query = {
"query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
}
self._client.delete_by_query(self._collection_name, body=routing_filter_query)
self.refresh()
else:
if self._client.indices.exists(index=self._collection_name):
self._client.indices.delete(
index=self._collection_name, params={"timeout": 60})
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
logger.info("Delete index success")
else:
logger.warning(
f"Index '{self._collection_name}' does not exist. No deletion performed.")
except Exception as e:
logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
raise e
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
def text_exists(self, id: str) -> bool:
try:
self._client.get(index=self._collection_name, id=id)
params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.get(index=self._collection_name, id=id, params=params)
return True
except:
return False
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Make sure query_vector is a list
if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats")
# Check whether query_vector is a floating-point number list
if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats")
top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(
query_vector=query_vector, k=top_k, **kwargs)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
try:
response = self._client.search(
index=self._collection_name, body=query)
except Exception as e:
params = {}
if self._using_ugc:
params["routing"] = self._routing
response = self._client.search(index=self._collection_name, body=query, params=params)
except Exception:
logger.exception(f"Error executing vector search, query: {query}")
raise
@ -244,7 +203,7 @@ class LindormVectorStore(BaseVector):
minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 10)
filters = kwargs.get("filter")
routing = kwargs.get("routing")
routing = self._routing
full_text_query = default_text_search_query(
query_text=query,
k=top_k,
@ -255,9 +214,9 @@ class LindormVectorStore(BaseVector):
minimum_should_match=minimum_should_match,
filters=filters,
routing=routing,
routing_field=self._routing_field,
)
response = self._client.search(
index=self._collection_name, body=full_text_query)
response = self._client.search(index=self._collection_name, body=full_text_query)
docs = []
for hit in response["hits"]["hits"]:
docs.append(
@ -275,40 +234,37 @@ class LindormVectorStore(BaseVector):
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(
f"Collection {self._collection_name} already exists.")
logger.info(f"Collection {self._collection_name} already exists.")
return
if self._client.indices.exists(index=self._collection_name):
logger.info("{self._collection_name.lower()} already exists.")
logger.info(f"{self._collection_name.lower()} already exists.")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
return
if len(self.kwargs) == 0 and len(kwargs) != 0:
self.kwargs = copy.deepcopy(kwargs)
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
shards = kwargs.pop("shards", 2)
shards = kwargs.pop("shards", 4)
engine = kwargs.pop("engine", "lvector")
method_name = kwargs.pop("method_name", "hnsw")
method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE)
space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE)
data_type = kwargs.pop("data_type", "float")
space_type = kwargs.pop("space_type", "cosinesimil")
hnsw_m = kwargs.pop("hnsw_m", 24)
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
nlist = kwargs.pop("nlist", 1000)
centroids_use_hnsw = kwargs.pop(
"centroids_use_hnsw", True if nlist >= 5000 else False)
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
centroids_hnsw_ef_construct = kwargs.pop(
"centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop(
"centroids_hnsw_ef_search", 100)
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
mapping = default_text_mapping(
dimension,
method_name,
space_type=space_type,
shards=shards,
engine=engine,
data_type=data_type,
space_type=space_type,
vector_field=vector_field,
hnsw_m=hnsw_m,
hnsw_ef_construction=hnsw_ef_construction,
@ -318,24 +274,29 @@ class LindormVectorStore(BaseVector):
centroids_hnsw_m=centroids_hnsw_m,
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
using_ugc=self._using_ugc,
**kwargs,
)
self._client.indices.create(
index=self._collection_name.lower(), body=mapping)
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
# logger.info(f"create index success: {self._collection_name}")
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
routing_field = kwargs.get("routing_field")
excludes_from_source = kwargs.get("excludes_from_source")
analyzer = kwargs.get("analyzer", "ik_max_word")
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
engine = kwargs["engine"]
shard = kwargs["shards"]
space_type = kwargs["space_type"]
space_type = kwargs.get("space_type")
if space_type is None:
if method_name == "hnsw":
space_type = "l2"
else:
space_type = "cosine"
data_type = kwargs["data_type"]
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
using_ugc = kwargs.get("using_ugc", False)
if method_name == "ivfpq":
ivfpq_m = kwargs["ivfpq_m"]
@ -385,13 +346,11 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
# e.g. {"excludes": ["vector_field"]}
mapping["mappings"]["_source"] = {"excludes": excludes_from_source}
if method_name == "ivfpq" and routing_field is not None:
if using_ugc and method_name == "ivfpq":
mapping["settings"]["index"]["knn_routing"] = True
mapping["settings"]["index"]["knn.offline.construction"] = True
if method_name == "flat" and routing_field is not None:
elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat":
mapping["settings"]["index"]["knn_routing"] = True
return mapping
@ -405,14 +364,12 @@ def default_text_search_query(
minimum_should_match: int = 0,
filters: Optional[list[dict]] = None,
routing: Optional[str] = None,
routing_field: Optional[str] = None,
**kwargs,
) -> dict:
if routing is not None:
routing_field = kwargs.get("routing_field", "routing_field")
query_clause = {
"bool": {
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
}
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
}
else:
query_clause = {"match": {text_field: query_text}}
@ -424,8 +381,7 @@ def default_text_search_query(
# build complex search_query when either of must/must_not/should/filter is specified
if must:
if not isinstance(must, list):
raise RuntimeError(
f"unexpected [must] clause with {type(filters)}")
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
if query_clause not in must:
must.append(query_clause)
else:
@ -435,22 +391,19 @@ def default_text_search_query(
if must_not:
if not isinstance(must_not, list):
raise RuntimeError(
f"unexpected [must_not] clause with {type(filters)}")
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
boolean_query["must_not"] = must_not
if should:
if not isinstance(should, list):
raise RuntimeError(
f"unexpected [should] clause with {type(filters)}")
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
boolean_query["should"] = should
if minimum_should_match != 0:
boolean_query["minimum_should_match"] = minimum_should_match
if filters:
if not isinstance(filters, list):
raise RuntimeError(
f"unexpected [filter] clause with {type(filters)}")
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
boolean_query["filter"] = filters
search_query = {"size": k, "query": {"bool": boolean_query}}
@ -472,7 +425,7 @@ def default_vector_search_query(
) -> dict:
if filters is not None:
filter_type = "post_filter" if filter_type is None else filter_type
if not isinstance(filter, list):
if not isinstance(filters, list):
raise RuntimeError(f"unexpected filter with {type(filters)}")
final_ext = {"lvector": {}}
if min_score != "0.0":
@ -494,8 +447,7 @@ def default_vector_search_query(
if filters is not None:
# when using filter, transform filter from List[Dict] to Dict as valid format
filters = {"bool": {"must": filters}} if len(
filters) > 1 else filters[0]
filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
# filter should be Dict
search_query["query"]["knn"][vector_field]["filter"] = filters
if filter_type:
@ -508,17 +460,40 @@ def default_vector_search_query(
class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL,
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
)
return LindormVectorStore(collection_name, lindorm_config)
using_ugc = dify_config.USING_UGC_INDEX
routing_value = None
if dataset.index_struct:
if using_ugc:
dimension = dataset.index_struct_dict["dimension"]
index_type = dataset.index_struct_dict["index_type"]
distance_type = dataset.index_struct_dict["distance_type"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
embedding_vector = embeddings.embed_query("hello word")
dimension = len(embedding_vector)
index_type = dify_config.DEFAULT_INDEX_TYPE
distance_type = dify_config.DEFAULT_DISTANCE_TYPE
class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
index_struct_dict = {
"type": VectorType.LINDORM,
"vector_store": {"class_prefix": class_prefix},
"index_type": index_type,
"dimension": dimension,
"distance_type": distance_type,
}
dataset.index_struct = json.dumps(index_struct_dict)
if using_ugc:
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = class_prefix
else:
index_name = class_prefix
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value)

View File

@ -37,8 +37,6 @@ class TiDBVectorConfig(BaseModel):
raise ValueError("config TIDB_VECTOR_PORT is required")
if not values["user"]:
raise ValueError("config TIDB_VECTOR_USER is required")
if not values["password"]:
raise ValueError("config TIDB_VECTOR_PASSWORD is required")
if not values["database"]:
raise ValueError("config TIDB_VECTOR_DATABASE is required")
if not values["program_name"]: