Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly
2024-11-01 16:23:04 +08:00
396 changed files with 15138 additions and 2610 deletions

View File

@ -34,6 +34,8 @@ class RetrievalService:
reranking_mode: Optional[str] = "reranking_model",
weights: Optional[dict] = None,
):
if not query:
return []
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []

View File

@ -3,11 +3,13 @@ import time
import uuid
from typing import Any
import numpy as np
from pydantic import BaseModel, model_validator
from pymochow import MochowClient
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.configuration import Configuration
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
from pymochow.exception import ServerError
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
@ -116,6 +118,7 @@ class BaiduVector(BaseVector):
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
@ -149,7 +152,13 @@ class BaiduVector(BaseVector):
return docs
def delete(self) -> None:
self._db.drop_table(table_name=self._collection_name)
try:
self._db.drop_table(table_name=self._collection_name)
except ServerError as e:
if e.code == ServerErrCode.TABLE_NOT_EXIST:
pass
else:
raise
def _init_client(self, config) -> MochowClient:
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
@ -166,7 +175,14 @@ class BaiduVector(BaseVector):
if exists:
return self._client.database(self._client_config.database)
else:
return self._client.create_database(database_name=self._client_config.database)
try:
self._client.create_database(database_name=self._client_config.database)
except ServerError as e:
if e.code == ServerErrCode.DB_ALREADY_EXIST:
pass
else:
raise
return
def _table_existed(self) -> bool:
tables = self._db.list_table()
@ -175,7 +191,7 @@ class BaiduVector(BaseVector):
def _create_table(self, dimension: int) -> None:
# Try to grab distributed lock and create table
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
with redis_client.lock(lock_name, timeout=60):
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(table_exist_cache_key):
return
@ -238,15 +254,14 @@ class BaiduVector(BaseVector):
description="Table for Dify",
)
# Wait for table created
while True:
time.sleep(1)
table = self._db.describe_table(self._collection_name)
if table.state == TableState.NORMAL:
break
redis_client.set(table_exist_cache_key, 1, ex=3600)
# Wait for table created
while True:
time.sleep(1)
table = self._db.describe_table(self._collection_name)
if table.state == TableState.NORMAL:
break
class BaiduVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:

View File

@ -0,0 +1,378 @@
import json
import logging
import time
import uuid
from datetime import timedelta
from typing import Any
from couchbase import search
from couchbase.auth import PasswordAuthenticator
from couchbase.cluster import Cluster
from couchbase.management.search import SearchIndex
# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc.
from couchbase.options import ClusterOptions, SearchOptions
from couchbase.vector_search import VectorQuery, VectorSearch
from flask import current_app
from pydantic import BaseModel, model_validator
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class CouchbaseConfig(BaseModel):
connection_string: str
user: str
password: str
bucket_name: str
scope_name: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values.get("connection_string"):
raise ValueError("config COUCHBASE_CONNECTION_STRING is required")
if not values.get("user"):
raise ValueError("config COUCHBASE_USER is required")
if not values.get("password"):
raise ValueError("config COUCHBASE_PASSWORD is required")
if not values.get("bucket_name"):
raise ValueError("config COUCHBASE_PASSWORD is required")
if not values.get("scope_name"):
raise ValueError("config COUCHBASE_SCOPE_NAME is required")
return values
class CouchbaseVector(BaseVector):
def __init__(self, collection_name: str, config: CouchbaseConfig):
super().__init__(collection_name)
self._client_config = config
"""Connect to couchbase"""
auth = PasswordAuthenticator(config.user, config.password)
options = ClusterOptions(auth)
self._cluster = Cluster(config.connection_string, options)
self._bucket = self._cluster.bucket(config.bucket_name)
self._scope = self._bucket.scope(config.scope_name)
self._bucket_name = config.bucket_name
self._scope_name = config.scope_name
# Wait until the cluster is ready for use.
self._cluster.wait_until_ready(timedelta(seconds=5))
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_id = str(uuid.uuid4()).replace("-", "")
self._create_collection(uuid=index_id, vector_length=len(embeddings[0]))
self.add_texts(texts, embeddings)
def _create_collection(self, vector_length: int, uuid: str):
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
if self._collection_exists(self._collection_name):
return
manager = self._bucket.collections()
manager.create_collection(self._client_config.scope_name, self._collection_name)
index_manager = self._scope.search_indexes()
index_definition = json.loads("""
{
"type": "fulltext-index",
"name": "Embeddings._default.Vector_Search",
"uuid": "26d4db528e78b716",
"sourceType": "gocbcore",
"sourceName": "Embeddings",
"sourceUUID": "2242e4a25b4decd6650c9c7b3afa1dbf",
"planParams": {
"maxPartitionsPerPIndex": 1024,
"indexPartitions": 1
},
"params": {
"doc_config": {
"docid_prefix_delim": "",
"docid_regexp": "",
"mode": "scope.collection.type_field",
"type_field": "type"
},
"mapping": {
"analysis": { },
"default_analyzer": "standard",
"default_datetime_parser": "dateTimeOptional",
"default_field": "_all",
"default_mapping": {
"dynamic": true,
"enabled": true
},
"default_type": "_default",
"docvalues_dynamic": false,
"index_dynamic": true,
"store_dynamic": true,
"type_field": "_type",
"types": {
"collection_name": {
"dynamic": true,
"enabled": true,
"properties": {
"embedding": {
"dynamic": false,
"enabled": true,
"fields": [
{
"dims": 1536,
"index": true,
"name": "embedding",
"similarity": "dot_product",
"type": "vector",
"vector_index_optimized_for": "recall"
}
]
},
"metadata": {
"dynamic": true,
"enabled": true
},
"text": {
"dynamic": false,
"enabled": true,
"fields": [
{
"index": true,
"name": "text",
"store": true,
"type": "text"
}
]
}
}
}
}
},
"store": {
"indexType": "scorch",
"segmentVersion": 16
}
},
"sourceParams": { }
}
""")
index_definition["name"] = self._collection_name + "_search"
index_definition["uuid"] = uuid
index_definition["params"]["mapping"]["types"]["collection_name"]["properties"]["embedding"]["fields"][0][
"dims"
] = vector_length
index_definition["params"]["mapping"]["types"][self._scope_name + "." + self._collection_name] = (
index_definition["params"]["mapping"]["types"].pop("collection_name")
)
time.sleep(2)
index_manager.upsert_index(
SearchIndex(
index_definition["name"],
params=index_definition["params"],
source_name=self._bucket_name,
),
)
time.sleep(1)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _collection_exists(self, name: str):
scope_collection_map: dict[str, Any] = {}
# Get a list of all scopes in the bucket
for scope in self._bucket.collections().get_all_scopes():
scope_collection_map[scope.name] = []
# Get a list of all the collections in the scope
for collection in scope.collections:
scope_collection_map[scope.name].append(collection.name)
# Check if the collection exists in the scope
return self._collection_name in scope_collection_map[self._scope_name]
def get_type(self) -> str:
return VectorType.COUCHBASE
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
doc_ids = []
documents_to_insert = [
{"text": text, "embedding": vector, "metadata": metadata}
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
]
for doc, id in zip(documents_to_insert, uuids):
result = self._scope.collection(self._collection_name).upsert(id, doc)
doc_ids.extend(uuids)
return doc_ids
def text_exists(self, id: str) -> bool:
# Use a parameterized query for safety and correctness
query = f"""
SELECT COUNT(1) AS count FROM
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id = $doc_id
"""
# Pass the id as a parameter to the query
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
for row in result:
return row["count"] > 0
return False # Return False if no rows are returned
def delete_by_ids(self, ids: list[str]) -> None:
query = f"""
DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id IN $doc_ids;
"""
try:
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
except Exception as e:
logger.error(e)
def delete_by_document_id(self, document_id: str):
query = f"""
DELETE FROM
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id = $doc_id;
"""
self._cluster.query(query, named_parameters={"doc_id": document_id}).execute()
# def get_ids_by_metadata_field(self, key: str, value: str):
# query = f"""
# SELECT id FROM
# `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
# WHERE `metadata.{key}` = $value;
# """
# result = self._cluster.query(query, named_parameters={'value':value})
# return [row['id'] for row in result.rows()]
def delete_by_metadata_field(self, key: str, value: str) -> None:
query = f"""
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE metadata.{key} = $value;
"""
self._cluster.query(query, named_parameters={"value": value}).execute()
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") or 0.0
search_req = search.SearchRequest.create(
VectorSearch.from_vector_query(
VectorQuery(
"embedding",
query_vector,
top_k,
)
)
)
try:
search_iter = self._scope.search(
self._collection_name + "_search",
search_req,
SearchOptions(limit=top_k, collections=[self._collection_name], fields=["*"]),
)
docs = []
# Parse the results
for row in search_iter.rows():
text = row.fields.pop("text")
metadata = self._format_metadata(row.fields)
score = row.score
metadata["score"] = score
doc = Document(page_content=text, metadata=metadata)
if score >= score_threshold:
docs.append(doc)
except Exception as e:
raise ValueError(f"Search failed with error: {e}")
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 2)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)
docs = []
for row in search_iter.rows():
text = row.fields.pop("text")
metadata = self._format_metadata(row.fields)
score = row.score
metadata["score"] = score
doc = Document(page_content=text, metadata=metadata)
docs.append(doc)
except Exception as e:
raise ValueError(f"Search failed with error: {e}")
return docs
def delete(self):
manager = self._bucket.collections()
scopes = manager.get_all_scopes()
for scope in scopes:
for collection in scope.collections:
if collection.name == self._collection_name:
manager.drop_collection("_default", self._collection_name)
def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]:
"""Helper method to format the metadata from the Couchbase Search API.
Args:
row_fields (Dict[str, Any]): The fields to format.
Returns:
Dict[str, Any]: The formatted metadata.
"""
metadata = {}
for key, value in row_fields.items():
# Couchbase Search returns the metadata key with a prefix
# `metadata.` We remove it to get the original metadata key
if key.startswith("metadata"):
new_key = key.split("metadata" + ".")[-1]
metadata[new_key] = value
else:
metadata[key] = value
return metadata
class CouchbaseVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.COUCHBASE, collection_name))
config = current_app.config
return CouchbaseVector(
collection_name=collection_name,
config=CouchbaseConfig(
connection_string=config.get("COUCHBASE_CONNECTION_STRING"),
user=config.get("COUCHBASE_USER"),
password=config.get("COUCHBASE_PASSWORD"),
bucket_name=config.get("COUCHBASE_BUCKET_NAME"),
scope_name=config.get("COUCHBASE_SCOPE_NAME"),
),
)

View File

@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY.value: query}}
results = self._client.search(index=self._collection_name, query=query_str)
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:
docs.append(

View File

@ -0,0 +1,209 @@
import json
import logging
import math
from typing import Any
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient
from sqlalchemy import JSON, Column, String, func
from sqlalchemy.dialects.mysql import LONGTEXT
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256}
DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64}
OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW"
DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"
class OceanBaseVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config OCEANBASE_VECTOR_HOST is required")
if not values["port"]:
raise ValueError("config OCEANBASE_VECTOR_PORT is required")
if not values["user"]:
raise ValueError("config OCEANBASE_VECTOR_USER is required")
if not values["database"]:
raise ValueError("config OCEANBASE_VECTOR_DATABASE is required")
return values
class OceanBaseVector(BaseVector):
def __init__(self, collection_name: str, config: OceanBaseVectorConfig):
super().__init__(collection_name)
self._config = config
self._hnsw_ef_search = -1
self._client = ObVecClient(
uri=f"{self._config.host}:{self._config.port}",
user=self._config.user,
password=self._config.password,
db_name=self._config.database,
)
def get_type(self) -> str:
return VectorType.OCEANBASE
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._vec_dim = len(embeddings[0])
self._create_collection()
self.add_texts(texts, embeddings)
def _create_collection(self) -> None:
lock_name = "vector_indexing_lock_" + self._collection_name
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_" + self._collection_name
if redis_client.get(collection_exist_cache_key):
return
if self._client.check_table_exists(self._collection_name):
return
self.delete()
cols = [
Column("id", String(36), primary_key=True, autoincrement=False),
Column("vector", VECTOR(self._vec_dim)),
Column("text", LONGTEXT),
Column("metadata", JSON),
]
vidx_params = self._client.prepare_index_params()
vidx_params.add_index(
field_name="vector",
index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE,
index_name="vector_index",
metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM,
)
self._client.create_table_with_index_params(
table_name=self._collection_name,
columns=cols,
vidxs=vidx_params,
)
vals = []
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'")
for row in params:
val = int(row[6])
vals.append(val)
if len(vals) == 0:
print("ob_vector_memory_limit_percentage not found in parameters.")
exit(1)
if any(val == 0 for val in vals):
try:
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
except Exception as e:
raise Exception(
"Failed to set ob_vector_memory_limit_percentage. "
+ "Maybe the database user has insufficient privilege.",
e,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents)
for id, doc, emb in zip(ids, documents, embeddings):
self._client.insert(
table_name=self._collection_name,
data={
"id": id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
},
)
def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, id=id)
return cur.rowcount != 0
def delete_by_ids(self, ids: list[str]) -> None:
self._client.delete(table_name=self._collection_name, ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
cur = self._client.get(
table_name=self._collection_name,
where_clause=f"metadata->>'$.{key}' = '{value}'",
output_column_name=["id"],
)
return [row[0] for row in cur]
def delete_by_metadata_field(self, key: str, value: str) -> None:
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search)
self._hnsw_ef_search = ef_search
topk = kwargs.get("top_k", 10)
cur = self._client.ann_search(
table_name=self._collection_name,
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=func.l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
)
docs = []
for text, metadata, distance in cur:
metadata = json.loads(metadata)
metadata["score"] = 1 - distance / math.sqrt(2)
docs.append(
Document(
page_content=text,
metadata=metadata,
)
)
return docs
def delete(self) -> None:
self._client.drop_table_if_exist(self._collection_name)
class OceanBaseVectorFactory(AbstractVectorFactory):
def init_vector(
self,
dataset: Dataset,
attributes: list,
embeddings: Embeddings,
) -> BaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OCEANBASE, collection_name))
return OceanBaseVector(
collection_name,
OceanBaseVectorConfig(
host=dify_config.OCEANBASE_VECTOR_HOST,
port=dify_config.OCEANBASE_VECTOR_PORT,
user=dify_config.OCEANBASE_VECTOR_USER,
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
database=dify_config.OCEANBASE_VECTOR_DATABASE,
),
)

View File

@ -0,0 +1,17 @@
from typing import Optional
from pydantic import BaseModel
class ClusterEntity(BaseModel):
"""
Model Config Entity.
"""
name: str
cluster_id: str
displayName: str
region: str
spendingLimit: Optional[int] = 1000
version: str
createdBy: str

View File

@ -0,0 +1,526 @@
import json
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import qdrant_client
import requests
from flask import current_app
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
FilterSelector,
HnswConfigDiff,
PayloadSchemaType,
TextIndexParams,
TextIndexType,
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, TidbAuthBinding
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
class TidbOnQdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str] = None
timeout: float = 20
root_path: Optional[str] = None
grpc_port: int = 6334
prefer_grpc: bool = False
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {"path": path}
else:
return {
"url": self.endpoint,
"api_key": self.api_key,
"timeout": self.timeout,
"verify": False,
"grpc_port": self.grpc_port,
"prefer_grpc": self.prefer_grpc,
}
class TidbConfig(BaseModel):
api_url: str
public_key: str
private_key: str
class TidbOnQdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
self._distance_func = distance_func.upper()
self._group_id = group_id
def get_type(self) -> str:
return VectorType.TIDB_ON_QDRANT
def to_index_struct(self) -> dict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
# get embedding vector size
vector_size = len(embeddings[0])
# get collection name
collection_name = self._collection_name
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
lock_name = "vector_indexing_lock_{}".format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
collection_name = collection_name or uuid.uuid4().hex
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(
m=0,
payload_m=16,
ef_construct=100,
full_scan_threshold=10000,
max_indexing_threads=0,
on_disk=False,
)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
)
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = list(islice(embeddings_iterator, batch_size))
points = [
rest.PointStruct(
id=point_id,
vector=vector,
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
Field.GROUP_KEY.value,
),
)
]
yield batch_ids, points
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str,
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete(self):
from qdrant_client.http.exceptions import UnexpectedResponse
try:
self._client.delete_collection(collection_name=self._collection_name)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if self._collection_name not in all_collection_name:
return False
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
query_filter=filter,
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", 0.0),
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar by bm25.
Returns:
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = models.Filter(
must=[
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get("top_k", 2),
with_payload=True,
with_vectors=True,
)
results = response[0]
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document.metadata["vector"] = result.vector
documents.append(document)
return documents
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID,
dify_config.TIDB_API_URL,
dify_config.TIDB_IAM_API_URL,
dify_config.TIDB_PUBLIC_KEY,
dify_config.TIDB_PRIVATE_KEY,
dify_config.TIDB_REGION,
)
new_tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
tenant_id=dataset.tenant_id,
active=True,
status="ACTIVE",
)
db.session.add(new_tidb_auth_binding)
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name))
config = current_app.config
return TidbOnQdrantVector(
collection_name=collection_name,
group_id=dataset.id,
config=TidbOnQdrantConfig(
endpoint=dify_config.TIDB_ON_QDRANT_URL,
api_key=TIDB_ON_QDRANT_API_KEY,
root_path=config.root_path,
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
),
)
def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str):
"""
Creates a new TiDB Serverless cluster.
:param tidb_config: The configuration for the TiDB Cloud API.
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": "1372813089454548012",
}
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
response = requests.post(
f"{tidb_config.api_url}/clusters",
json=cluster_data,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str):
"""
Changes the root password of a specific TiDB Serverless cluster.
:param tidb_config: The configuration for the TiDB Cloud API.
:param cluster_id: The ID of the cluster for which the password is to be changed (required).
:param new_password: The new password for the root user (required).
:return: The response from the API.
"""
body = {"password": new_password}
response = requests.put(
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
json=body,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()

View File

@ -0,0 +1,251 @@
import time
import uuid
import requests
from requests.auth import HTTPDigestAuth
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
class TidbService:
@staticmethod
def create_tidb_serverless_cluster(
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
):
"""
Creates a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": project_id,
}
spending_limit = {
"monthly": dify_config.TIDB_SPEND_LIMIT,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")[:16]
cluster_data = {
"displayName": display_name,
"region": region_object,
"labels": labels,
"spendingLimit": spending_limit,
"rootPassword": password,
}
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
cluster_id = response_data["clusterId"]
retry_count = 0
max_retries = 30
while retry_count < max_retries:
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if cluster_response["state"] == "ACTIVE":
user_prefix = cluster_response["userPrefix"]
return {
"cluster_id": cluster_id,
"cluster_name": display_name,
"account": f"{user_prefix}.root",
"password": password,
}
time.sleep(30) # wait 30 seconds before retrying
retry_count += 1
else:
response.raise_for_status()
@staticmethod
def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
"""
Deletes a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster to be deleted (required).
:return: The response from the API.
"""
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
"""
Deletes a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster to be deleted (required).
:return: The response from the API.
"""
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def change_tidb_serverless_root_password(
api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str
):
"""
Changes the root password of a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster for which the password is to be changed (required).+
:param account: The account for which the password is to be changed (required).
:param new_password: The new password for the root user (required).
:return: The response from the API.
"""
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
response = requests.patch(
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
json=body,
auth=HTTPDigestAuth(public_key, private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def batch_update_tidb_serverless_cluster_status(
tidb_serverless_list: list[TidbAuthBinding],
project_id: str,
api_url: str,
iam_url: str,
public_key: str,
private_key: str,
) -> list[dict]:
"""
Update the status of a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
clusters = []
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "FULL"}
response = requests.get(
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
)
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
for item in response_data["clusters"]:
state = item["state"]
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = "ACTIVE"
cluster_info.account = f"{userPrefix}.root"
db.session.add(cluster_info)
db.session.commit()
else:
response.raise_for_status()
@staticmethod
def batch_create_tidb_serverless_cluster(
batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
) -> list[dict]:
"""
Creates a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
clusters = []
for _ in range(batch_size):
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": project_id,
}
spending_limit = {
"monthly": dify_config.TIDB_SPEND_LIMIT,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")
cluster_data = {
"cluster": {
"displayName": display_name,
"region": region_object,
"labels": labels,
"spendingLimit": spending_limit,
"rootPassword": password,
}
}
cache_key = f"tidb_serverless_cluster_password:{display_name}"
redis_client.setex(cache_key, 3600, password)
clusters.append(cluster_data)
request_body = {"requests": clusters}
response = requests.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key)
)
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
for item in response_data["clusters"]:
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
password = redis_client.get(cache_key)
if not password:
continue
cluster_info = {
"cluster_id": item["clusterId"],
"cluster_name": item["displayName"],
"account": "root",
"password": password.decode("utf-8"),
}
cluster_infos.append(cluster_info)
return cluster_infos
else:
response.raise_for_status()

View File

@ -0,0 +1,129 @@
import json
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, model_validator
from upstash_vector import Index, Vector
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from models.dataset import Dataset
class UpstashVectorConfig(BaseModel):
url: str
token: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["url"]:
raise ValueError("Upstash URL is required")
if not values["token"]:
raise ValueError("Upstash Token is required")
return values
class UpstashVector(BaseVector):
def __init__(self, collection_name: str, config: UpstashVectorConfig):
super().__init__(collection_name)
self._table_name = collection_name
self.index = Index(url=config.url, token=config.token)
def _get_index_dimension(self) -> int:
index_info = self.index.info()
if index_info and index_info.dimension:
return index_info.dimension
else:
return 1536
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
vectors = [
Vector(
id=str(uuid4()),
vector=embedding,
metadata=doc.metadata,
data=doc.page_content,
)
for doc, embedding in zip(documents, embeddings)
]
self.index.upsert(vectors=vectors)
def text_exists(self, id: str) -> bool:
response = self.get_ids_by_metadata_field("doc_id", id)
return len(response) > 0
def delete_by_ids(self, ids: list[str]) -> None:
item_ids = []
for doc_id in ids:
ids = self.get_ids_by_metadata_field("doc_id", doc_id)
if id:
item_ids += ids
self._delete_by_ids(ids=item_ids)
def _delete_by_ids(self, ids: list[str]) -> None:
if ids:
self.index.delete(ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
query_result = self.index.query(
vector=[1.001 * i for i in range(self._get_index_dimension())],
include_metadata=True,
top_k=1000,
filter=f"{key} = '{value}'",
)
return [result.id for result in query_result]
def delete_by_metadata_field(self, key: str, value: str) -> None:
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._delete_by_ids(ids)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in result:
metadata = record.metadata
text = record.data
score = record.score
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def delete(self) -> None:
self.index.reset()
def get_type(self) -> str:
return VectorType.UPSTASH
class UpstashVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> UpstashVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.UPSTASH, collection_name))
return UpstashVector(
collection_name=collection_name,
config=UpstashVectorConfig(
url=dify_config.UPSTASH_VECTOR_URL,
token=dify_config.UPSTASH_VECTOR_TOKEN,
),
)

View File

@ -9,8 +9,9 @@ from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.dataset import Dataset, Whitelist
class AbstractVectorFactory(ABC):
@ -35,8 +36,18 @@ class Vector:
def _init_vector(self) -> BaseVector:
vector_type = dify_config.VECTOR_STORE
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict["type"]
else:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
)
if whitelist:
vector_type = VectorType.TIDB_ON_QDRANT
if not vector_type:
raise ValueError("Vector store must be specified.")
@ -103,6 +114,10 @@ class Vector:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case VectorType.COUCHBASE:
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
return CouchbaseVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
@ -111,6 +126,18 @@ class Vector:
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
return VikingDBVectorFactory
case VectorType.UPSTASH:
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
return UpstashVectorFactory
case VectorType.TIDB_ON_QDRANT:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case VectorType.OCEANBASE:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -16,5 +16,9 @@ class VectorType(str, Enum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
COUCHBASE = "couchbase"
BAIDU = "baidu"
VIKINGDB = "vikingdb"
UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"
OCEANBASE = "oceanbase"

View File

@ -105,7 +105,7 @@ class ExtractProcessor:
extractor = PdfExtractor(file_path)
elif file_extension in {".md", ".markdown"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
if is_automatic
else MarkdownExtractor(file_path, autodetect_encoding=True)
)
@ -116,17 +116,19 @@ class ExtractProcessor:
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == ".msg":
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url, unstructured_api_key)
elif file_extension == ".eml":
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url, unstructured_api_key)
elif file_extension == ".ppt":
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key)
# You must first specify the API key
# because unstructured_api_key is necessary to parse .ppt documents
elif file_extension == ".pptx":
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url, unstructured_api_key)
elif file_extension == ".xml":
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url, unstructured_api_key)
elif file_extension == ".epub":
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url)
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key)
else:
# txt
extractor = (

View File

@ -10,24 +10,26 @@ logger = logging.getLogger(__name__)
class UnstructuredEmailExtractor(BaseExtractor):
"""Load msg files.
"""Load eml files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
def __init__(self, file_path: str, api_url: str, api_key: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
self._api_key = api_key
def extract(self) -> list[Document]:
from unstructured.partition.email import partition_email
if self._api_url:
from unstructured.partition.api import partition_via_api
elements = partition_email(filename=self._file_path)
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
from unstructured.partition.email import partition_email
elements = partition_email(filename=self._file_path)
# noinspection PyBroadException
try:

View File

@ -19,15 +19,23 @@ class UnstructuredEpubExtractor(BaseExtractor):
self,
file_path: str,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
self._api_key = api_key
def extract(self) -> list[Document]:
from unstructured.partition.epub import partition_epub
if self._api_url:
from unstructured.partition.api import partition_via_api
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
from unstructured.partition.epub import partition_epub
elements = partition_epub(filename=self._file_path, xml_keep_tags=True)
elements = partition_epub(filename=self._file_path, xml_keep_tags=True)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)

View File

@ -24,19 +24,21 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
if the specified encoding fails.
"""
def __init__(
self,
file_path: str,
api_url: str,
):
def __init__(self, file_path: str, api_url: str, api_key: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
self._api_key = api_key
def extract(self) -> list[Document]:
from unstructured.partition.md import partition_md
if self._api_url:
from unstructured.partition.api import partition_via_api
elements = partition_md(filename=self._file_path)
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
from unstructured.partition.md import partition_md
elements = partition_md(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)

View File

@ -14,15 +14,21 @@ class UnstructuredMsgExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(self, file_path: str, api_url: str):
def __init__(self, file_path: str, api_url: str, api_key: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
self._api_key = api_key
def extract(self) -> list[Document]:
from unstructured.partition.msg import partition_msg
if self._api_url:
from unstructured.partition.api import partition_via_api
elements = partition_msg(filename=self._file_path)
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
from unstructured.partition.msg import partition_msg
elements = partition_msg(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)

View File

@ -0,0 +1,47 @@
import logging
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class UnstructuredPDFExtractor(BaseExtractor):
"""Load pdf files.
Args:
file_path: Path to the file to load.
api_url: Unstructured API URL
api_key: Unstructured API Key
"""
def __init__(self, file_path: str, api_url: str, api_key: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
self._api_key = api_key
def extract(self) -> list[Document]:
if self._api_url:
from unstructured.partition.api import partition_via_api
elements = partition_via_api(
filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto"
)
else:
from unstructured.partition.pdf import partition_pdf
elements = partition_pdf(filename=self._file_path, strategy="auto")
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
documents = []
for chunk in chunks:
text = chunk.text.strip()
documents.append(Document(page_content=text))
return documents

View File

@ -7,7 +7,7 @@ logger = logging.getLogger(__name__)
class UnstructuredPPTExtractor(BaseExtractor):
"""Load msg files.
"""Load ppt files.
Args:
@ -21,9 +21,12 @@ class UnstructuredPPTExtractor(BaseExtractor):
self._api_key = api_key
def extract(self) -> list[Document]:
from unstructured.partition.api import partition_via_api
if self._api_url:
from unstructured.partition.api import partition_via_api
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
raise NotImplementedError("Unstructured API Url is not configured")
text_by_page = {}
for element in elements:
page = element.metadata.page_number

View File

@ -7,22 +7,28 @@ logger = logging.getLogger(__name__)
class UnstructuredPPTXExtractor(BaseExtractor):
"""Load msg files.
"""Load pptx files.
Args:
file_path: Path to the file to load.
"""
def __init__(self, file_path: str, api_url: str):
def __init__(self, file_path: str, api_url: str, api_key: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
self._api_key = api_key
def extract(self) -> list[Document]:
from unstructured.partition.pptx import partition_pptx
if self._api_url:
from unstructured.partition.api import partition_via_api
elements = partition_pptx(filename=self._file_path)
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
from unstructured.partition.pptx import partition_pptx
elements = partition_pptx(filename=self._file_path)
text_by_page = {}
for element in elements:
page = element.metadata.page_number

View File

@ -7,22 +7,29 @@ logger = logging.getLogger(__name__)
class UnstructuredXmlExtractor(BaseExtractor):
"""Load msg files.
"""Load xml files.
Args:
file_path: Path to the file to load.
"""
def __init__(self, file_path: str, api_url: str):
def __init__(self, file_path: str, api_url: str, api_key: str):
"""Initialize with file path."""
self._file_path = file_path
self._api_url = api_url
self._api_key = api_key
def extract(self) -> list[Document]:
from unstructured.partition.xml import partition_xml
if self._api_url:
from unstructured.partition.api import partition_via_api
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
from unstructured.partition.xml import partition_xml
elements = partition_xml(filename=self._file_path, xml_keep_tags=True)
elements = partition_xml(filename=self._file_path, xml_keep_tags=True)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)

View File

@ -234,7 +234,7 @@ class WordExtractor(BaseExtractor):
def parse_paragraph(paragraph):
paragraph_content = []
for run in paragraph.runs:
if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"):
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
drawing_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
)

View File

@ -27,18 +27,17 @@ class RerankModelRunner(BaseRerankRunner):
:return:
"""
docs = []
doc_id = []
doc_id = set()
unique_documents = []
dify_documents = [item for item in documents if item.provider == "dify"]
external_documents = [item for item in documents if item.provider == "external"]
for document in dify_documents:
if document.metadata["doc_id"] not in doc_id:
doc_id.append(document.metadata["doc_id"])
for document in documents:
if document.provider == "dify" and document.metadata["doc_id"] not in doc_id:
doc_id.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
for document in external_documents:
docs.append(document.page_content)
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents

View File

@ -1,6 +1,6 @@
from enum import Enum
class RerankMode(Enum):
class RerankMode(str, Enum):
RERANKING_MODEL = "reranking_model"
WEIGHTED_SCORE = "weighted_score"

View File

@ -22,6 +22,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
@ -359,10 +360,39 @@ class DatasetRetrieval:
reranking_enable: bool = True,
message_id: Optional[str] = None,
):
if not available_datasets:
return []
threads = []
all_documents = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type = None
index_type_check = all(
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
)
if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
raise ValueError(
"The configured knowledge base list have different indexing technique, please set reranking model."
)
index_type = available_datasets[0].indexing_technique
if index_type == "high_quality":
embedding_model_check = all(
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
)
embedding_model_provider_check = all(
item.embedding_model_provider == available_datasets[0].embedding_model_provider
for item in available_datasets
)
if (
reranking_enable
and reranking_mode == "weighted_score"
and (not embedding_model_check or not embedding_model_provider_check)
):
raise ValueError(
"The configured knowledge base list have different embedding model, please set reranking model."
)
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
for dataset in available_datasets:
index_type = dataset.indexing_technique
retrieval_thread = threading.Thread(