mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
Merge branch 'main' into fix/chore-fix
This commit is contained in:
@ -34,6 +34,8 @@ class RetrievalService:
|
||||
reranking_mode: Optional[str] = "reranking_model",
|
||||
weights: Optional[dict] = None,
|
||||
):
|
||||
if not query:
|
||||
return []
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
@ -3,11 +3,13 @@ import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymochow import MochowClient
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
|
||||
from pymochow.exception import ServerError
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
|
||||
|
||||
@ -116,6 +118,7 @@ class BaiduVector(BaseVector):
|
||||
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
@ -149,7 +152,13 @@ class BaiduVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
self._db.drop_table(table_name=self._collection_name)
|
||||
try:
|
||||
self._db.drop_table(table_name=self._collection_name)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
def _init_client(self, config) -> MochowClient:
|
||||
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
|
||||
@ -166,7 +175,14 @@ class BaiduVector(BaseVector):
|
||||
if exists:
|
||||
return self._client.database(self._client_config.database)
|
||||
else:
|
||||
return self._client.create_database(database_name=self._client_config.database)
|
||||
try:
|
||||
self._client.create_database(database_name=self._client_config.database)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.DB_ALREADY_EXIST:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
return
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
@ -175,7 +191,7 @@ class BaiduVector(BaseVector):
|
||||
def _create_table(self, dimension: int) -> None:
|
||||
# Try to grab distributed lock and create table
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
with redis_client.lock(lock_name, timeout=60):
|
||||
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(table_exist_cache_key):
|
||||
return
|
||||
@ -238,15 +254,14 @@ class BaiduVector(BaseVector):
|
||||
description="Table for Dify",
|
||||
)
|
||||
|
||||
# Wait for table created
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
|
||||
# Wait for table created
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
|
||||
|
||||
class BaiduVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
|
||||
|
||||
0
api/core/rag/datasource/vdb/couchbase/__init__.py
Normal file
0
api/core/rag/datasource/vdb/couchbase/__init__.py
Normal file
378
api/core/rag/datasource/vdb/couchbase/couchbase_vector.py
Normal file
378
api/core/rag/datasource/vdb/couchbase/couchbase_vector.py
Normal file
@ -0,0 +1,378 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from couchbase import search
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
from couchbase.cluster import Cluster
|
||||
from couchbase.management.search import SearchIndex
|
||||
|
||||
# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc.
|
||||
from couchbase.options import ClusterOptions, SearchOptions
|
||||
from couchbase.vector_search import VectorQuery, VectorSearch
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CouchbaseConfig(BaseModel):
|
||||
connection_string: str
|
||||
user: str
|
||||
password: str
|
||||
bucket_name: str
|
||||
scope_name: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get("connection_string"):
|
||||
raise ValueError("config COUCHBASE_CONNECTION_STRING is required")
|
||||
if not values.get("user"):
|
||||
raise ValueError("config COUCHBASE_USER is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("config COUCHBASE_PASSWORD is required")
|
||||
if not values.get("bucket_name"):
|
||||
raise ValueError("config COUCHBASE_PASSWORD is required")
|
||||
if not values.get("scope_name"):
|
||||
raise ValueError("config COUCHBASE_SCOPE_NAME is required")
|
||||
return values
|
||||
|
||||
|
||||
class CouchbaseVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: CouchbaseConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
|
||||
"""Connect to couchbase"""
|
||||
|
||||
auth = PasswordAuthenticator(config.user, config.password)
|
||||
options = ClusterOptions(auth)
|
||||
self._cluster = Cluster(config.connection_string, options)
|
||||
self._bucket = self._cluster.bucket(config.bucket_name)
|
||||
self._scope = self._bucket.scope(config.scope_name)
|
||||
self._bucket_name = config.bucket_name
|
||||
self._scope_name = config.scope_name
|
||||
|
||||
# Wait until the cluster is ready for use.
|
||||
self._cluster.wait_until_ready(timedelta(seconds=5))
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_id = str(uuid.uuid4()).replace("-", "")
|
||||
self._create_collection(uuid=index_id, vector_length=len(embeddings[0]))
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def _create_collection(self, vector_length: int, uuid: str):
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
if self._collection_exists(self._collection_name):
|
||||
return
|
||||
manager = self._bucket.collections()
|
||||
manager.create_collection(self._client_config.scope_name, self._collection_name)
|
||||
|
||||
index_manager = self._scope.search_indexes()
|
||||
|
||||
index_definition = json.loads("""
|
||||
{
|
||||
"type": "fulltext-index",
|
||||
"name": "Embeddings._default.Vector_Search",
|
||||
"uuid": "26d4db528e78b716",
|
||||
"sourceType": "gocbcore",
|
||||
"sourceName": "Embeddings",
|
||||
"sourceUUID": "2242e4a25b4decd6650c9c7b3afa1dbf",
|
||||
"planParams": {
|
||||
"maxPartitionsPerPIndex": 1024,
|
||||
"indexPartitions": 1
|
||||
},
|
||||
"params": {
|
||||
"doc_config": {
|
||||
"docid_prefix_delim": "",
|
||||
"docid_regexp": "",
|
||||
"mode": "scope.collection.type_field",
|
||||
"type_field": "type"
|
||||
},
|
||||
"mapping": {
|
||||
"analysis": { },
|
||||
"default_analyzer": "standard",
|
||||
"default_datetime_parser": "dateTimeOptional",
|
||||
"default_field": "_all",
|
||||
"default_mapping": {
|
||||
"dynamic": true,
|
||||
"enabled": true
|
||||
},
|
||||
"default_type": "_default",
|
||||
"docvalues_dynamic": false,
|
||||
"index_dynamic": true,
|
||||
"store_dynamic": true,
|
||||
"type_field": "_type",
|
||||
"types": {
|
||||
"collection_name": {
|
||||
"dynamic": true,
|
||||
"enabled": true,
|
||||
"properties": {
|
||||
"embedding": {
|
||||
"dynamic": false,
|
||||
"enabled": true,
|
||||
"fields": [
|
||||
{
|
||||
"dims": 1536,
|
||||
"index": true,
|
||||
"name": "embedding",
|
||||
"similarity": "dot_product",
|
||||
"type": "vector",
|
||||
"vector_index_optimized_for": "recall"
|
||||
}
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"dynamic": true,
|
||||
"enabled": true
|
||||
},
|
||||
"text": {
|
||||
"dynamic": false,
|
||||
"enabled": true,
|
||||
"fields": [
|
||||
{
|
||||
"index": true,
|
||||
"name": "text",
|
||||
"store": true,
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"store": {
|
||||
"indexType": "scorch",
|
||||
"segmentVersion": 16
|
||||
}
|
||||
},
|
||||
"sourceParams": { }
|
||||
}
|
||||
""")
|
||||
index_definition["name"] = self._collection_name + "_search"
|
||||
index_definition["uuid"] = uuid
|
||||
index_definition["params"]["mapping"]["types"]["collection_name"]["properties"]["embedding"]["fields"][0][
|
||||
"dims"
|
||||
] = vector_length
|
||||
index_definition["params"]["mapping"]["types"][self._scope_name + "." + self._collection_name] = (
|
||||
index_definition["params"]["mapping"]["types"].pop("collection_name")
|
||||
)
|
||||
time.sleep(2)
|
||||
index_manager.upsert_index(
|
||||
SearchIndex(
|
||||
index_definition["name"],
|
||||
params=index_definition["params"],
|
||||
source_name=self._bucket_name,
|
||||
),
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _collection_exists(self, name: str):
|
||||
scope_collection_map: dict[str, Any] = {}
|
||||
|
||||
# Get a list of all scopes in the bucket
|
||||
for scope in self._bucket.collections().get_all_scopes():
|
||||
scope_collection_map[scope.name] = []
|
||||
|
||||
# Get a list of all the collections in the scope
|
||||
for collection in scope.collections:
|
||||
scope_collection_map[scope.name].append(collection.name)
|
||||
|
||||
# Check if the collection exists in the scope
|
||||
return self._collection_name in scope_collection_map[self._scope_name]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.COUCHBASE
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
doc_ids = []
|
||||
|
||||
documents_to_insert = [
|
||||
{"text": text, "embedding": vector, "metadata": metadata}
|
||||
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
|
||||
]
|
||||
for doc, id in zip(documents_to_insert, uuids):
|
||||
result = self._scope.collection(self._collection_name).upsert(id, doc)
|
||||
|
||||
doc_ids.extend(uuids)
|
||||
|
||||
return doc_ids
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
# Use a parameterized query for safety and correctness
|
||||
query = f"""
|
||||
SELECT COUNT(1) AS count FROM
|
||||
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
|
||||
WHERE META().id = $doc_id
|
||||
"""
|
||||
# Pass the id as a parameter to the query
|
||||
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
|
||||
for row in result:
|
||||
return row["count"] > 0
|
||||
return False # Return False if no rows are returned
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
query = f"""
|
||||
DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
|
||||
WHERE META().id IN $doc_ids;
|
||||
"""
|
||||
try:
|
||||
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
query = f"""
|
||||
DELETE FROM
|
||||
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
|
||||
WHERE META().id = $doc_id;
|
||||
"""
|
||||
self._cluster.query(query, named_parameters={"doc_id": document_id}).execute()
|
||||
|
||||
# def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
# query = f"""
|
||||
# SELECT id FROM
|
||||
# `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
|
||||
# WHERE `metadata.{key}` = $value;
|
||||
# """
|
||||
# result = self._cluster.query(query, named_parameters={'value':value})
|
||||
# return [row['id'] for row in result.rows()]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
query = f"""
|
||||
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
|
||||
WHERE metadata.{key} = $value;
|
||||
"""
|
||||
self._cluster.query(query, named_parameters={"value": value}).execute()
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
|
||||
search_req = search.SearchRequest.create(
|
||||
VectorSearch.from_vector_query(
|
||||
VectorQuery(
|
||||
"embedding",
|
||||
query_vector,
|
||||
top_k,
|
||||
)
|
||||
)
|
||||
)
|
||||
try:
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search",
|
||||
search_req,
|
||||
SearchOptions(limit=top_k, collections=[self._collection_name], fields=["*"]),
|
||||
)
|
||||
|
||||
docs = []
|
||||
# Parse the results
|
||||
for row in search_iter.rows():
|
||||
text = row.fields.pop("text")
|
||||
metadata = self._format_metadata(row.fields)
|
||||
score = row.score
|
||||
metadata["score"] = score
|
||||
doc = Document(page_content=text, metadata=metadata)
|
||||
if score >= score_threshold:
|
||||
docs.append(doc)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Search failed with error: {e}")
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 2)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||
)
|
||||
|
||||
docs = []
|
||||
for row in search_iter.rows():
|
||||
text = row.fields.pop("text")
|
||||
metadata = self._format_metadata(row.fields)
|
||||
score = row.score
|
||||
metadata["score"] = score
|
||||
doc = Document(page_content=text, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Search failed with error: {e}")
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
manager = self._bucket.collections()
|
||||
scopes = manager.get_all_scopes()
|
||||
|
||||
for scope in scopes:
|
||||
for collection in scope.collections:
|
||||
if collection.name == self._collection_name:
|
||||
manager.drop_collection("_default", self._collection_name)
|
||||
|
||||
def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Helper method to format the metadata from the Couchbase Search API.
|
||||
Args:
|
||||
row_fields (Dict[str, Any]): The fields to format.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The formatted metadata.
|
||||
"""
|
||||
metadata = {}
|
||||
for key, value in row_fields.items():
|
||||
# Couchbase Search returns the metadata key with a prefix
|
||||
# `metadata.` We remove it to get the original metadata key
|
||||
if key.startswith("metadata"):
|
||||
new_key = key.split("metadata" + ".")[-1]
|
||||
metadata[new_key] = value
|
||||
else:
|
||||
metadata[key] = value
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
class CouchbaseVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.COUCHBASE, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return CouchbaseVector(
|
||||
collection_name=collection_name,
|
||||
config=CouchbaseConfig(
|
||||
connection_string=config.get("COUCHBASE_CONNECTION_STRING"),
|
||||
user=config.get("COUCHBASE_USER"),
|
||||
password=config.get("COUCHBASE_PASSWORD"),
|
||||
bucket_name=config.get("COUCHBASE_BUCKET_NAME"),
|
||||
scope_name=config.get("COUCHBASE_SCOPE_NAME"),
|
||||
),
|
||||
)
|
||||
@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str)
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
|
||||
0
api/core/rag/datasource/vdb/oceanbase/__init__.py
Normal file
0
api/core/rag/datasource/vdb/oceanbase/__init__.py
Normal file
209
api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
Normal file
209
api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
Normal file
@ -0,0 +1,209 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pyobvector import VECTOR, ObVecClient
|
||||
from sqlalchemy import JSON, Column, String, func
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256}
|
||||
DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64}
|
||||
OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW"
|
||||
DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"
|
||||
|
||||
|
||||
class OceanBaseVectorConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config OCEANBASE_VECTOR_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config OCEANBASE_VECTOR_PORT is required")
|
||||
if not values["user"]:
|
||||
raise ValueError("config OCEANBASE_VECTOR_USER is required")
|
||||
if not values["database"]:
|
||||
raise ValueError("config OCEANBASE_VECTOR_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
class OceanBaseVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: OceanBaseVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._hnsw_ef_search = -1
|
||||
self._client = ObVecClient(
|
||||
uri=f"{self._config.host}:{self._config.port}",
|
||||
user=self._config.user,
|
||||
password=self._config.password,
|
||||
db_name=self._config.database,
|
||||
)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.OCEANBASE
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self._vec_dim = len(embeddings[0])
|
||||
self._create_collection()
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def _create_collection(self) -> None:
|
||||
lock_name = "vector_indexing_lock_" + self._collection_name
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = "vector_indexing_" + self._collection_name
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
if self._client.check_table_exists(self._collection_name):
|
||||
return
|
||||
|
||||
self.delete()
|
||||
|
||||
cols = [
|
||||
Column("id", String(36), primary_key=True, autoincrement=False),
|
||||
Column("vector", VECTOR(self._vec_dim)),
|
||||
Column("text", LONGTEXT),
|
||||
Column("metadata", JSON),
|
||||
]
|
||||
vidx_params = self._client.prepare_index_params()
|
||||
vidx_params.add_index(
|
||||
field_name="vector",
|
||||
index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE,
|
||||
index_name="vector_index",
|
||||
metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
|
||||
params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM,
|
||||
)
|
||||
|
||||
self._client.create_table_with_index_params(
|
||||
table_name=self._collection_name,
|
||||
columns=cols,
|
||||
vidxs=vidx_params,
|
||||
)
|
||||
vals = []
|
||||
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'")
|
||||
for row in params:
|
||||
val = int(row[6])
|
||||
vals.append(val)
|
||||
if len(vals) == 0:
|
||||
print("ob_vector_memory_limit_percentage not found in parameters.")
|
||||
exit(1)
|
||||
if any(val == 0 for val in vals):
|
||||
try:
|
||||
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Failed to set ob_vector_memory_limit_percentage. "
|
||||
+ "Maybe the database user has insufficient privilege.",
|
||||
e,
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = self._get_uuids(documents)
|
||||
for id, doc, emb in zip(ids, documents, embeddings):
|
||||
self._client.insert(
|
||||
table_name=self._collection_name,
|
||||
data={
|
||||
"id": id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
cur = self._client.get(table_name=self._collection_name, id=id)
|
||||
return cur.rowcount != 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
cur = self._client.get(
|
||||
table_name=self._collection_name,
|
||||
where_clause=f"metadata->>'$.{key}' = '{value}'",
|
||||
output_column_name=["id"],
|
||||
)
|
||||
return [row[0] for row in cur]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||
if ef_search != self._hnsw_ef_search:
|
||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||
self._hnsw_ef_search = ef_search
|
||||
topk = kwargs.get("top_k", 10)
|
||||
cur = self._client.ann_search(
|
||||
table_name=self._collection_name,
|
||||
vec_column_name="vector",
|
||||
vec_data=query_vector,
|
||||
topk=topk,
|
||||
distance_func=func.l2_distance,
|
||||
output_column_names=["text", "metadata"],
|
||||
with_dist=True,
|
||||
)
|
||||
docs = []
|
||||
for text, metadata, distance in cur:
|
||||
metadata = json.loads(metadata)
|
||||
metadata["score"] = 1 - distance / math.sqrt(2)
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=text,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
|
||||
|
||||
class OceanBaseVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
attributes: list,
|
||||
embeddings: Embeddings,
|
||||
) -> BaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OCEANBASE, collection_name))
|
||||
return OceanBaseVector(
|
||||
collection_name,
|
||||
OceanBaseVectorConfig(
|
||||
host=dify_config.OCEANBASE_VECTOR_HOST,
|
||||
port=dify_config.OCEANBASE_VECTOR_PORT,
|
||||
user=dify_config.OCEANBASE_VECTOR_USER,
|
||||
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
|
||||
database=dify_config.OCEANBASE_VECTOR_DATABASE,
|
||||
),
|
||||
)
|
||||
17
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py
Normal file
17
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py
Normal file
@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ClusterEntity(BaseModel):
|
||||
"""
|
||||
Model Config Entity.
|
||||
"""
|
||||
|
||||
name: str
|
||||
cluster_id: str
|
||||
displayName: str
|
||||
region: str
|
||||
spendingLimit: Optional[int] = 1000
|
||||
version: str
|
||||
createdBy: str
|
||||
@ -0,0 +1,526 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import qdrant_client
|
||||
import requests
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.models import (
|
||||
FilterSelector,
|
||||
HnswConfigDiff,
|
||||
PayloadSchemaType,
|
||||
TextIndexParams,
|
||||
TextIndexType,
|
||||
TokenizerType,
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from requests.auth import HTTPDigestAuth
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, TidbAuthBinding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
from qdrant_client.conversions import common_types
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
DictFilter = dict[str, Union[str, int, bool, dict, list]]
|
||||
MetadataFilter = Union[DictFilter, common_types.Filter]
|
||||
|
||||
|
||||
class TidbOnQdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str] = None
|
||||
timeout: float = 20
|
||||
root_path: Optional[str] = None
|
||||
grpc_port: int = 6334
|
||||
prefer_grpc: bool = False
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {"path": path}
|
||||
else:
|
||||
return {
|
||||
"url": self.endpoint,
|
||||
"api_key": self.api_key,
|
||||
"timeout": self.timeout,
|
||||
"verify": False,
|
||||
"grpc_port": self.grpc_port,
|
||||
"prefer_grpc": self.prefer_grpc,
|
||||
}
|
||||
|
||||
|
||||
class TidbConfig(BaseModel):
|
||||
api_url: str
|
||||
public_key: str
|
||||
private_key: str
|
||||
|
||||
|
||||
class TidbOnQdrantVector(BaseVector):
|
||||
def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
||||
self._distance_func = distance_func.upper()
|
||||
self._group_id = group_id
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.TIDB_ON_QDRANT
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
if texts:
|
||||
# get embedding vector size
|
||||
vector_size = len(embeddings[0])
|
||||
# get collection name
|
||||
collection_name = self._collection_name
|
||||
# create collection
|
||||
self.create_collection(collection_name, vector_size)
|
||||
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(self, collection_name: str, vector_size: int):
|
||||
lock_name = "vector_indexing_lock_{}".format(collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
collection_name = collection_name or uuid.uuid4().hex
|
||||
all_collection_name = []
|
||||
collections_response = self._client.get_collections()
|
||||
collection_list = collections_response.collections
|
||||
for collection in collection_list:
|
||||
all_collection_name.append(collection.name)
|
||||
if collection_name not in all_collection_name:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[self._distance_func],
|
||||
)
|
||||
hnsw_config = HnswConfigDiff(
|
||||
m=0,
|
||||
payload_m=16,
|
||||
ef_construct=100,
|
||||
full_scan_threshold=10000,
|
||||
max_indexing_threads=0,
|
||||
on_disk=False,
|
||||
)
|
||||
self._client.recreate_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
hnsw_config=hnsw_config,
|
||||
timeout=int(self._client_config.timeout),
|
||||
)
|
||||
|
||||
# create group_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create doc_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create full text index
|
||||
text_index_params = TextIndexParams(
|
||||
type=TextIndexType.TEXT,
|
||||
tokenizer=TokenizerType.MULTILINGUAL,
|
||||
min_token_len=2,
|
||||
max_token_len=20,
|
||||
lowercase=True,
|
||||
)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
|
||||
def _generate_rest_batches(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
group_id: Optional[str] = None,
|
||||
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
texts_iterator = iter(texts)
|
||||
embeddings_iterator = iter(embeddings)
|
||||
metadatas_iterator = iter(metadatas or [])
|
||||
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
||||
while batch_texts := list(islice(texts_iterator, batch_size)):
|
||||
# Take the corresponding metadata and id for each text in a batch
|
||||
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||
batch_ids = list(islice(ids_iterator, batch_size))
|
||||
|
||||
# Generate the embeddings for all the texts in a batch
|
||||
batch_embeddings = list(islice(embeddings_iterator, batch_size))
|
||||
|
||||
points = [
|
||||
rest.PointStruct(
|
||||
id=point_id,
|
||||
vector=vector,
|
||||
payload=payload,
|
||||
)
|
||||
for point_id, vector, payload in zip(
|
||||
batch_ids,
|
||||
batch_embeddings,
|
||||
self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
group_id,
|
||||
Field.GROUP_KEY.value,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
yield batch_ids, points
|
||||
|
||||
@classmethod
|
||||
def _build_payloads(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]],
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
group_id: str,
|
||||
group_payload_key: str,
|
||||
) -> list[dict]:
|
||||
payloads = []
|
||||
for i, text in enumerate(texts):
|
||||
if text is None:
|
||||
raise ValueError(
|
||||
"At least one of the texts is None. Please remove it before "
|
||||
"calling .from_texts or .add_texts on Qdrant instance."
|
||||
)
|
||||
metadata = metadatas[i] if metadatas is not None else None
|
||||
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
|
||||
|
||||
return payloads
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete(self):
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
try:
|
||||
self._client.delete_collection(collection_name=self._collection_name)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
collections_response = self._client.get_collections()
|
||||
collection_list = collections_response.collections
|
||||
for collection in collection_list:
|
||||
all_collection_name.append(collection.name)
|
||||
if self._collection_name not in all_collection_name:
|
||||
return False
|
||||
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from qdrant_client.http import models
|
||||
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=filter,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
score_threshold=kwargs.get("score_threshold", 0.0),
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Return docs most similar by bm25.
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance for each.
|
||||
"""
|
||||
from qdrant_client.http import models
|
||||
|
||||
scroll_filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="page_content",
|
||||
match=models.MatchText(text=query),
|
||||
)
|
||||
]
|
||||
)
|
||||
response = self._client.scroll(
|
||||
collection_name=self._collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=kwargs.get("top_k", 2),
|
||||
with_payload=True,
|
||||
with_vectors=True,
|
||||
)
|
||||
results = response[0]
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
||||
document.metadata["vector"] = result.vector
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client = cast(QdrantLocal, self._client)
|
||||
self._client._load()
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
) -> Document:
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||
)
|
||||
|
||||
|
||||
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
if not tidb_auth_binding:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID,
|
||||
dify_config.TIDB_API_URL,
|
||||
dify_config.TIDB_IAM_API_URL,
|
||||
dify_config.TIDB_PUBLIC_KEY,
|
||||
dify_config.TIDB_PRIVATE_KEY,
|
||||
dify_config.TIDB_REGION,
|
||||
)
|
||||
new_tidb_auth_binding = TidbAuthBinding(
|
||||
cluster_id=new_cluster["cluster_id"],
|
||||
cluster_name=new_cluster["cluster_name"],
|
||||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
tenant_id=dataset.tenant_id,
|
||||
active=True,
|
||||
status="ACTIVE",
|
||||
)
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
|
||||
return TidbOnQdrantVector(
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=TidbOnQdrantConfig(
|
||||
endpoint=dify_config.TIDB_ON_QDRANT_URL,
|
||||
api_key=TIDB_ON_QDRANT_API_KEY,
|
||||
root_path=config.root_path,
|
||||
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
|
||||
),
|
||||
)
|
||||
|
||||
def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str):
|
||||
"""
|
||||
Creates a new TiDB Serverless cluster.
|
||||
:param tidb_config: The configuration for the TiDB Cloud API.
|
||||
:param display_name: The user-friendly display name of the cluster (required).
|
||||
:param region: The region where the cluster will be created (required).
|
||||
|
||||
:return: The response from the API.
|
||||
"""
|
||||
region_object = {
|
||||
"name": region,
|
||||
}
|
||||
|
||||
labels = {
|
||||
"tidb.cloud/project": "1372813089454548012",
|
||||
}
|
||||
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
|
||||
|
||||
response = requests.post(
|
||||
f"{tidb_config.api_url}/clusters",
|
||||
json=cluster_data,
|
||||
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str):
|
||||
"""
|
||||
Changes the root password of a specific TiDB Serverless cluster.
|
||||
|
||||
:param tidb_config: The configuration for the TiDB Cloud API.
|
||||
:param cluster_id: The ID of the cluster for which the password is to be changed (required).
|
||||
:param new_password: The new password for the root user (required).
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
body = {"password": new_password}
|
||||
|
||||
response = requests.put(
|
||||
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
|
||||
json=body,
|
||||
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
251
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
Normal file
251
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
Normal file
@ -0,0 +1,251 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPDigestAuth
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
|
||||
|
||||
class TidbService:
|
||||
@staticmethod
|
||||
def create_tidb_serverless_cluster(
|
||||
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
|
||||
):
|
||||
"""
|
||||
Creates a new TiDB Serverless cluster.
|
||||
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param iam_url: The URL of the TiDB Cloud IAM API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param display_name: The user-friendly display name of the cluster (required).
|
||||
:param region: The region where the cluster will be created (required).
|
||||
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
region_object = {
|
||||
"name": region,
|
||||
}
|
||||
|
||||
labels = {
|
||||
"tidb.cloud/project": project_id,
|
||||
}
|
||||
|
||||
spending_limit = {
|
||||
"monthly": dify_config.TIDB_SPEND_LIMIT,
|
||||
}
|
||||
password = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
display_name = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
cluster_data = {
|
||||
"displayName": display_name,
|
||||
"region": region_object,
|
||||
"labels": labels,
|
||||
"spendingLimit": spending_limit,
|
||||
"rootPassword": password,
|
||||
}
|
||||
|
||||
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_id = response_data["clusterId"]
|
||||
retry_count = 0
|
||||
max_retries = 30
|
||||
while retry_count < max_retries:
|
||||
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
|
||||
if cluster_response["state"] == "ACTIVE":
|
||||
user_prefix = cluster_response["userPrefix"]
|
||||
return {
|
||||
"cluster_id": cluster_id,
|
||||
"cluster_name": display_name,
|
||||
"account": f"{user_prefix}.root",
|
||||
"password": password,
|
||||
}
|
||||
time.sleep(30) # wait 30 seconds before retrying
|
||||
retry_count += 1
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
|
||||
"""
|
||||
Deletes a specific TiDB Serverless cluster.
|
||||
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param cluster_id: The ID of the cluster to be deleted (required).
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
|
||||
"""
|
||||
Deletes a specific TiDB Serverless cluster.
|
||||
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param cluster_id: The ID of the cluster to be deleted (required).
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def change_tidb_serverless_root_password(
|
||||
api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str
|
||||
):
|
||||
"""
|
||||
Changes the root password of a specific TiDB Serverless cluster.
|
||||
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param cluster_id: The ID of the cluster for which the password is to be changed (required).+
|
||||
:param account: The account for which the password is to be changed (required).
|
||||
:param new_password: The new password for the root user (required).
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
|
||||
|
||||
response = requests.patch(
|
||||
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
|
||||
json=body,
|
||||
auth=HTTPDigestAuth(public_key, private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def batch_update_tidb_serverless_cluster_status(
|
||||
tidb_serverless_list: list[TidbAuthBinding],
|
||||
project_id: str,
|
||||
api_url: str,
|
||||
iam_url: str,
|
||||
public_key: str,
|
||||
private_key: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Update the status of a new TiDB Serverless cluster.
|
||||
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param iam_url: The URL of the TiDB Cloud IAM API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param display_name: The user-friendly display name of the cluster (required).
|
||||
:param region: The region where the cluster will be created (required).
|
||||
|
||||
:return: The response from the API.
|
||||
"""
|
||||
clusters = []
|
||||
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
|
||||
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
|
||||
params = {"clusterIds": cluster_ids, "view": "FULL"}
|
||||
response = requests.get(
|
||||
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_infos = []
|
||||
for item in response_data["clusters"]:
|
||||
state = item["state"]
|
||||
userPrefix = item["userPrefix"]
|
||||
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||
cluster_info.status = "ACTIVE"
|
||||
cluster_info.account = f"{userPrefix}.root"
|
||||
db.session.add(cluster_info)
|
||||
db.session.commit()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def batch_create_tidb_serverless_cluster(
|
||||
batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Creates a new TiDB Serverless cluster.
|
||||
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||
:param api_url: The URL of the TiDB Cloud API (required).
|
||||
:param iam_url: The URL of the TiDB Cloud IAM API (required).
|
||||
:param public_key: The public key for the API (required).
|
||||
:param private_key: The private key for the API (required).
|
||||
:param display_name: The user-friendly display name of the cluster (required).
|
||||
:param region: The region where the cluster will be created (required).
|
||||
|
||||
:return: The response from the API.
|
||||
"""
|
||||
clusters = []
|
||||
for _ in range(batch_size):
|
||||
region_object = {
|
||||
"name": region,
|
||||
}
|
||||
|
||||
labels = {
|
||||
"tidb.cloud/project": project_id,
|
||||
}
|
||||
|
||||
spending_limit = {
|
||||
"monthly": dify_config.TIDB_SPEND_LIMIT,
|
||||
}
|
||||
password = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
display_name = str(uuid.uuid4()).replace("-", "")
|
||||
cluster_data = {
|
||||
"cluster": {
|
||||
"displayName": display_name,
|
||||
"region": region_object,
|
||||
"labels": labels,
|
||||
"spendingLimit": spending_limit,
|
||||
"rootPassword": password,
|
||||
}
|
||||
}
|
||||
cache_key = f"tidb_serverless_cluster_password:{display_name}"
|
||||
redis_client.setex(cache_key, 3600, password)
|
||||
clusters.append(cluster_data)
|
||||
|
||||
request_body = {"requests": clusters}
|
||||
response = requests.post(
|
||||
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_infos = []
|
||||
for item in response_data["clusters"]:
|
||||
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
|
||||
password = redis_client.get(cache_key)
|
||||
if not password:
|
||||
continue
|
||||
cluster_info = {
|
||||
"cluster_id": item["clusterId"],
|
||||
"cluster_name": item["displayName"],
|
||||
"account": "root",
|
||||
"password": password.decode("utf-8"),
|
||||
}
|
||||
cluster_infos.append(cluster_info)
|
||||
return cluster_infos
|
||||
else:
|
||||
response.raise_for_status()
|
||||
0
api/core/rag/datasource/vdb/upstash/__init__.py
Normal file
0
api/core/rag/datasource/vdb/upstash/__init__.py
Normal file
129
api/core/rag/datasource/vdb/upstash/upstash_vector.py
Normal file
129
api/core/rag/datasource/vdb/upstash/upstash_vector.py
Normal file
@ -0,0 +1,129 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from upstash_vector import Index, Vector
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class UpstashVectorConfig(BaseModel):
|
||||
url: str
|
||||
token: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["url"]:
|
||||
raise ValueError("Upstash URL is required")
|
||||
if not values["token"]:
|
||||
raise ValueError("Upstash Token is required")
|
||||
return values
|
||||
|
||||
|
||||
class UpstashVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: UpstashVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self._table_name = collection_name
|
||||
self.index = Index(url=config.url, token=config.token)
|
||||
|
||||
def _get_index_dimension(self) -> int:
|
||||
index_info = self.index.info()
|
||||
if index_info and index_info.dimension:
|
||||
return index_info.dimension
|
||||
else:
|
||||
return 1536
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
vectors = [
|
||||
Vector(
|
||||
id=str(uuid4()),
|
||||
vector=embedding,
|
||||
metadata=doc.metadata,
|
||||
data=doc.page_content,
|
||||
)
|
||||
for doc, embedding in zip(documents, embeddings)
|
||||
]
|
||||
self.index.upsert(vectors=vectors)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
response = self.get_ids_by_metadata_field("doc_id", id)
|
||||
return len(response) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
item_ids = []
|
||||
for doc_id in ids:
|
||||
ids = self.get_ids_by_metadata_field("doc_id", doc_id)
|
||||
if id:
|
||||
item_ids += ids
|
||||
self._delete_by_ids(ids=item_ids)
|
||||
|
||||
def _delete_by_ids(self, ids: list[str]) -> None:
|
||||
if ids:
|
||||
self.index.delete(ids=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
query_result = self.index.query(
|
||||
vector=[1.001 * i for i in range(self._get_index_dimension())],
|
||||
include_metadata=True,
|
||||
top_k=1000,
|
||||
filter=f"{key} = '{value}'",
|
||||
)
|
||||
return [result.id for result in query_result]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._delete_by_ids(ids)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in result:
|
||||
metadata = record.metadata
|
||||
text = record.data
|
||||
score = record.score
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
self.index.reset()
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.UPSTASH
|
||||
|
||||
|
||||
class UpstashVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> UpstashVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.UPSTASH, collection_name))
|
||||
|
||||
return UpstashVector(
|
||||
collection_name=collection_name,
|
||||
config=UpstashVectorConfig(
|
||||
url=dify_config.UPSTASH_VECTOR_URL,
|
||||
token=dify_config.UPSTASH_VECTOR_TOKEN,
|
||||
),
|
||||
)
|
||||
@ -9,8 +9,9 @@ from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, Whitelist
|
||||
|
||||
|
||||
class AbstractVectorFactory(ABC):
|
||||
@ -35,8 +36,18 @@ class Vector:
|
||||
|
||||
def _init_vector(self) -> BaseVector:
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict["type"]
|
||||
else:
|
||||
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
|
||||
whitelist = (
|
||||
db.session.query(Whitelist)
|
||||
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
|
||||
.one_or_none()
|
||||
)
|
||||
if whitelist:
|
||||
vector_type = VectorType.TIDB_ON_QDRANT
|
||||
|
||||
if not vector_type:
|
||||
raise ValueError("Vector store must be specified.")
|
||||
@ -103,6 +114,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
|
||||
return AnalyticdbVectorFactory
|
||||
case VectorType.COUCHBASE:
|
||||
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
|
||||
|
||||
return CouchbaseVectorFactory
|
||||
case VectorType.BAIDU:
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
|
||||
|
||||
@ -111,6 +126,18 @@ class Vector:
|
||||
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
|
||||
|
||||
return VikingDBVectorFactory
|
||||
case VectorType.UPSTASH:
|
||||
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
|
||||
|
||||
return UpstashVectorFactory
|
||||
case VectorType.TIDB_ON_QDRANT:
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
|
||||
|
||||
return TidbOnQdrantVectorFactory
|
||||
case VectorType.OCEANBASE:
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
|
||||
|
||||
return OceanBaseVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@ -16,5 +16,9 @@ class VectorType(str, Enum):
|
||||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
COUCHBASE = "couchbase"
|
||||
BAIDU = "baidu"
|
||||
VIKINGDB = "vikingdb"
|
||||
UPSTASH = "upstash"
|
||||
TIDB_ON_QDRANT = "tidb_on_qdrant"
|
||||
OCEANBASE = "oceanbase"
|
||||
|
||||
@ -105,7 +105,7 @@ class ExtractProcessor:
|
||||
extractor = PdfExtractor(file_path)
|
||||
elif file_extension in {".md", ".markdown"}:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
if is_automatic
|
||||
else MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
)
|
||||
@ -116,17 +116,19 @@ class ExtractProcessor:
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == ".msg":
|
||||
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
|
||||
extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".eml":
|
||||
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
|
||||
extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".ppt":
|
||||
extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
# You must first specify the API key
|
||||
# because unstructured_api_key is necessary to parse .ppt documents
|
||||
elif file_extension == ".pptx":
|
||||
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
|
||||
extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".xml":
|
||||
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
|
||||
extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".epub":
|
||||
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url)
|
||||
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
else:
|
||||
# txt
|
||||
extractor = (
|
||||
|
||||
@ -10,24 +10,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredEmailExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
"""Load eml files.
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_email(filename=self._file_path)
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
elements = partition_email(filename=self._file_path)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
|
||||
@ -19,15 +19,23 @@ class UnstructuredEpubExtractor(BaseExtractor):
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.epub import partition_epub
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
elements = partition_epub(filename=self._file_path, xml_keep_tags=True)
|
||||
|
||||
elements = partition_epub(filename=self._file_path, xml_keep_tags=True)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
|
||||
@ -24,19 +24,21 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
|
||||
if the specified encoding fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_url: str,
|
||||
):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.md import partition_md
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_md(filename=self._file_path)
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.md import partition_md
|
||||
|
||||
elements = partition_md(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
|
||||
@ -14,15 +14,21 @@ class UnstructuredMsgExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_msg(filename=self._file_path)
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
elements = partition_msg(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
|
||||
@ -0,0 +1,47 @@
|
||||
import logging
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPDFExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
|
||||
api_url: Unstructured API URL
|
||||
|
||||
api_key: Unstructured API Key
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_via_api(
|
||||
filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto"
|
||||
)
|
||||
else:
|
||||
from unstructured.partition.pdf import partition_pdf
|
||||
|
||||
elements = partition_pdf(filename=self._file_path, strategy="auto")
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
documents.append(Document(page_content=text))
|
||||
|
||||
return documents
|
||||
@ -7,7 +7,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
"""Load ppt files.
|
||||
|
||||
|
||||
Args:
|
||||
@ -21,9 +21,12 @@ class UnstructuredPPTExtractor(BaseExtractor):
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
raise NotImplementedError("Unstructured API Url is not configured")
|
||||
text_by_page = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
|
||||
@ -7,22 +7,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredPPTXExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
"""Load pptx files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_pptx(filename=self._file_path)
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path)
|
||||
text_by_page = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
|
||||
@ -7,22 +7,29 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnstructuredXmlExtractor(BaseExtractor):
|
||||
"""Load msg files.
|
||||
"""Load xml files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, api_url: str):
|
||||
def __init__(self, file_path: str, api_url: str, api_key: str):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
self._api_key = api_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True)
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
|
||||
@ -234,7 +234,7 @@ class WordExtractor(BaseExtractor):
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if hasattr(run.element, "tag") and isinstance(element.tag, str) and run.element.tag.endswith("r"):
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
drawing_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
|
||||
)
|
||||
|
||||
@ -27,18 +27,17 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:return:
|
||||
"""
|
||||
docs = []
|
||||
doc_id = []
|
||||
doc_id = set()
|
||||
unique_documents = []
|
||||
dify_documents = [item for item in documents if item.provider == "dify"]
|
||||
external_documents = [item for item in documents if item.provider == "external"]
|
||||
for document in dify_documents:
|
||||
if document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.append(document.metadata["doc_id"])
|
||||
for document in documents:
|
||||
if document.provider == "dify" and document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.add(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
for document in external_documents:
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
elif document.provider == "external":
|
||||
if document not in unique_documents:
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RerankMode(Enum):
|
||||
class RerankMode(str, Enum):
|
||||
RERANKING_MODEL = "reranking_model"
|
||||
WEIGHTED_SCORE = "weighted_score"
|
||||
|
||||
@ -22,6 +22,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
@ -359,10 +360,39 @@ class DatasetRetrieval:
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
return []
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
index_type = None
|
||||
index_type_check = all(
|
||||
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
|
||||
)
|
||||
if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
|
||||
raise ValueError(
|
||||
"The configured knowledge base list have different indexing technique, please set reranking model."
|
||||
)
|
||||
index_type = available_datasets[0].indexing_technique
|
||||
if index_type == "high_quality":
|
||||
embedding_model_check = all(
|
||||
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
|
||||
)
|
||||
embedding_model_provider_check = all(
|
||||
item.embedding_model_provider == available_datasets[0].embedding_model_provider
|
||||
for item in available_datasets
|
||||
)
|
||||
if (
|
||||
reranking_enable
|
||||
and reranking_mode == "weighted_score"
|
||||
and (not embedding_model_check or not embedding_model_provider_check)
|
||||
):
|
||||
raise ValueError(
|
||||
"The configured knowledge base list have different embedding model, please set reranking model."
|
||||
)
|
||||
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
|
||||
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
retrieval_thread = threading.Thread(
|
||||
|
||||
Reference in New Issue
Block a user