mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 08:28:03 +08:00
Merge remote-tracking branch 'origin/main' into feat/collaboration
This commit is contained in:
@ -21,7 +21,7 @@ from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
@ -107,7 +107,7 @@ class RetrievalService:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
# Deduplicate documents for hybrid search to avoid duplicate chunks
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
|
||||
all_documents = cls._deduplicate_documents(all_documents)
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||
@ -134,7 +134,7 @@ class RetrievalService:
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
)
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id,
|
||||
@ -245,10 +245,10 @@ class RetrievalService:
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
@ -293,10 +293,10 @@ class RetrievalService:
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
|
||||
@ -0,0 +1,388 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import mysql.connector
|
||||
from mysql.connector import Error as MySQLError
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlibabaCloudMySQLVectorConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
max_connection: int
|
||||
charset: str = "utf8mb4"
|
||||
distance_function: Literal["cosine", "euclidean"] = "cosine"
|
||||
hnsw_m: int = 6
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
if not values.get("host"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required")
|
||||
if not values.get("port"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_PORT is required")
|
||||
if not values.get("user"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_USER is required")
|
||||
if values.get("password") is None:
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_PASSWORD is required")
|
||||
if not values.get("database"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_DATABASE is required")
|
||||
if not values.get("max_connection"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_MAX_CONNECTION is required")
|
||||
return values
|
||||
|
||||
|
||||
SQL_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
text LONGTEXT NOT NULL,
|
||||
meta JSON NOT NULL,
|
||||
embedding VECTOR({dimension}) NOT NULL,
|
||||
VECTOR INDEX (embedding) M={hnsw_m} DISTANCE={distance_function}
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
"""
|
||||
|
||||
SQL_CREATE_META_INDEX = """
|
||||
CREATE INDEX idx_{index_hash}_meta ON {table_name}
|
||||
((CAST(JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) AS CHAR(36))));
|
||||
"""
|
||||
|
||||
SQL_CREATE_FULLTEXT_INDEX = """
|
||||
CREATE FULLTEXT INDEX idx_{index_hash}_text ON {table_name} (text) WITH PARSER ngram;
|
||||
"""
|
||||
|
||||
|
||||
class AlibabaCloudMySQLVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: AlibabaCloudMySQLVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = collection_name.lower()
|
||||
self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8]
|
||||
self.distance_function = config.distance_function.lower()
|
||||
self.hnsw_m = config.hnsw_m
|
||||
self._check_vector_support()
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ALIBABACLOUD_MYSQL
|
||||
|
||||
def _create_connection_pool(self, config: AlibabaCloudMySQLVectorConfig):
|
||||
# Create connection pool using mysql-connector-python pooling
|
||||
pool_config: dict[str, Any] = {
|
||||
"host": config.host,
|
||||
"port": config.port,
|
||||
"user": config.user,
|
||||
"password": config.password,
|
||||
"database": config.database,
|
||||
"charset": config.charset,
|
||||
"autocommit": True,
|
||||
"pool_name": f"pool_{self.collection_name}",
|
||||
"pool_size": config.max_connection,
|
||||
"pool_reset_session": True,
|
||||
}
|
||||
return mysql.connector.pooling.MySQLConnectionPool(**pool_config)
|
||||
|
||||
def _check_vector_support(self):
|
||||
"""Check if the MySQL server supports vector operations."""
|
||||
try:
|
||||
with self._get_cursor() as cur:
|
||||
# Check MySQL version and vector support
|
||||
cur.execute("SELECT VERSION()")
|
||||
version = cur.fetchone()["VERSION()"]
|
||||
logger.debug("Connected to MySQL version: %s", version)
|
||||
# Try to execute a simple vector function to verify support
|
||||
cur.execute("SELECT VEC_FromText('[1,2,3]') IS NOT NULL as vector_support")
|
||||
result = cur.fetchone()
|
||||
if not result or not result.get("vector_support"):
|
||||
raise ValueError(
|
||||
"RDS MySQL Vector functions are not available."
|
||||
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if "FUNCTION" in str(e) and "VEC_FromText" in str(e):
|
||||
raise ValueError(
|
||||
"RDS MySQL Vector functions are not available."
|
||||
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
|
||||
) from e
|
||||
raise e
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.get_connection()
|
||||
cur = conn.cursor(dictionary=True)
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
pks = []
|
||||
for i, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
# Convert embedding list to Aliyun MySQL vector format
|
||||
vector_str = "[" + ",".join(map(str, embeddings[i])) + "]"
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
vector_str,
|
||||
)
|
||||
)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
insert_sql = (
|
||||
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (%s, %s, %s, VEC_FromText(%s))"
|
||||
)
|
||||
cur.executemany(insert_sql, values)
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
placeholders = ",".join(["%s"] * len(ids))
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata = record["meta"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
# Avoiding crashes caused by performing delete operations on empty lists
|
||||
if not ids:
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
placeholders = ",".join(["%s"] * len(ids))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||
except MySQLError as e:
|
||||
if e.errno == 1146: # Table doesn't exist
|
||||
logger.warning("Table %s not found, skipping delete operation.", self.table_name)
|
||||
return
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"DELETE FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", (f"$.{key}", value)
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search the nearest neighbors to a vector using RDS MySQL vector distance functions.
|
||||
|
||||
:param query_vector: The input vector to search for similar items.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
params = []
|
||||
|
||||
if document_ids_filter:
|
||||
placeholders = ",".join(["%s"] * len(document_ids_filter))
|
||||
where_clause = f" WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
|
||||
params.extend(document_ids_filter)
|
||||
|
||||
# Convert query vector to RDS MySQL vector format
|
||||
query_vector_str = "[" + ",".join(map(str, query_vector)) + "]"
|
||||
|
||||
# Use RSD MySQL's native vector distance functions
|
||||
with self._get_cursor() as cur:
|
||||
# Choose distance function based on configuration
|
||||
distance_func = "VEC_DISTANCE_COSINE" if self.distance_function == "cosine" else "VEC_DISTANCE_EUCLIDEAN"
|
||||
|
||||
# Note: RDS MySQL optimizer will use vector index when ORDER BY + LIMIT are present
|
||||
# Use column alias in ORDER BY to avoid calculating distance twice
|
||||
sql = f"""
|
||||
SELECT meta, text,
|
||||
{distance_func}(embedding, VEC_FromText(%s)) AS distance
|
||||
FROM {self.table_name}
|
||||
{where_clause}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
"""
|
||||
query_params = [query_vector_str] + params + [top_k]
|
||||
|
||||
cur.execute(sql, query_params)
|
||||
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
for record in cur:
|
||||
try:
|
||||
distance = float(record["distance"])
|
||||
# Convert distance to similarity score
|
||||
if self.distance_function == "cosine":
|
||||
# For cosine distance: similarity = 1 - distance
|
||||
similarity = 1.0 - distance
|
||||
else:
|
||||
# For euclidean distance: use inverse relationship
|
||||
# similarity = 1 / (1 + distance)
|
||||
similarity = 1.0 / (1.0 + distance)
|
||||
|
||||
metadata = record["meta"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
metadata["score"] = similarity
|
||||
metadata["distance"] = distance
|
||||
|
||||
if similarity >= score_threshold:
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning("Error processing search result: %s", e)
|
||||
continue
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
params = []
|
||||
|
||||
if document_ids_filter:
|
||||
placeholders = ",".join(["%s"] * len(document_ids_filter))
|
||||
where_clause = f" AND JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
|
||||
params.extend(document_ids_filter)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
# Build query parameters: query (twice for MATCH clauses), document_ids_filter (if any), top_k
|
||||
query_params = [query, query] + params + [top_k]
|
||||
cur.execute(
|
||||
f"""SELECT meta, text,
|
||||
MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE)
|
||||
{where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT %s""",
|
||||
query_params,
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata = record["meta"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
metadata["score"] = float(record["score"])
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{collection_exist_cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
# Create table with vector column and vector index
|
||||
cur.execute(
|
||||
SQL_CREATE_TABLE.format(
|
||||
table_name=self.table_name,
|
||||
dimension=dimension,
|
||||
distance_function=self.distance_function,
|
||||
hnsw_m=self.hnsw_m,
|
||||
)
|
||||
)
|
||||
# Create metadata index (check if exists first)
|
||||
try:
|
||||
cur.execute(SQL_CREATE_META_INDEX.format(table_name=self.table_name, index_hash=self.index_hash))
|
||||
except MySQLError as e:
|
||||
if e.errno != 1061: # Duplicate key name
|
||||
logger.warning("Could not create meta index: %s", e)
|
||||
|
||||
# Create full-text index for text search
|
||||
try:
|
||||
cur.execute(
|
||||
SQL_CREATE_FULLTEXT_INDEX.format(table_name=self.table_name, index_hash=self.index_hash)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if e.errno != 1061: # Duplicate key name
|
||||
logger.warning("Could not create fulltext index: %s", e)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class AlibabaCloudMySQLVectorFactory(AbstractVectorFactory):
|
||||
def _validate_distance_function(self, distance_function: str) -> Literal["cosine", "euclidean"]:
|
||||
"""Validate and return the distance function as a proper Literal type."""
|
||||
if distance_function not in ["cosine", "euclidean"]:
|
||||
raise ValueError(f"Invalid distance function: {distance_function}. Must be 'cosine' or 'euclidean'")
|
||||
return cast(Literal["cosine", "euclidean"], distance_function)
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AlibabaCloudMySQLVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ALIBABACLOUD_MYSQL, collection_name)
|
||||
)
|
||||
return AlibabaCloudMySQLVector(
|
||||
collection_name=collection_name,
|
||||
config=AlibabaCloudMySQLVectorConfig(
|
||||
host=dify_config.ALIBABACLOUD_MYSQL_HOST or "localhost",
|
||||
port=dify_config.ALIBABACLOUD_MYSQL_PORT,
|
||||
user=dify_config.ALIBABACLOUD_MYSQL_USER or "root",
|
||||
password=dify_config.ALIBABACLOUD_MYSQL_PASSWORD or "",
|
||||
database=dify_config.ALIBABACLOUD_MYSQL_DATABASE or "dify",
|
||||
max_connection=dify_config.ALIBABACLOUD_MYSQL_MAX_CONNECTION,
|
||||
charset=dify_config.ALIBABACLOUD_MYSQL_CHARSET or "utf8mb4",
|
||||
distance_function=self._validate_distance_function(
|
||||
dify_config.ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION or "cosine"
|
||||
),
|
||||
hnsw_m=dify_config.ALIBABACLOUD_MYSQL_HNSW_M or 6,
|
||||
),
|
||||
)
|
||||
@ -488,9 +488,9 @@ class ClickzettaVector(BaseVector):
|
||||
create_table_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} (
|
||||
id STRING NOT NULL COMMENT 'Unique document identifier',
|
||||
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
|
||||
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
|
||||
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
|
||||
{Field.CONTENT_KEY} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
|
||||
{Field.METADATA_KEY} JSON COMMENT 'Document metadata including source, type, and other attributes',
|
||||
{Field.VECTOR} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
|
||||
'High-dimensional embedding vector for semantic similarity search',
|
||||
PRIMARY KEY (id)
|
||||
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
|
||||
@ -519,15 +519,15 @@ class ClickzettaVector(BaseVector):
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
# Check if vector index already exists on the embedding column
|
||||
if Field.VECTOR.value in str(idx).lower():
|
||||
logger.info("Vector index already exists on column %s", Field.VECTOR.value)
|
||||
if Field.VECTOR in str(idx).lower():
|
||||
logger.info("Vector index already exists on column %s", Field.VECTOR)
|
||||
return
|
||||
except (RuntimeError, ValueError) as e:
|
||||
logger.warning("Failed to check existing indexes: %s", e)
|
||||
|
||||
index_sql = f"""
|
||||
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
|
||||
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
|
||||
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR})
|
||||
PROPERTIES (
|
||||
"distance.function" = "{self._config.vector_distance_function}",
|
||||
"scalar.type" = "f32",
|
||||
@ -560,17 +560,17 @@ class ClickzettaVector(BaseVector):
|
||||
# More precise check: look for inverted index specifically on the content column
|
||||
if (
|
||||
"inverted" in idx_str
|
||||
and Field.CONTENT_KEY.value.lower() in idx_str
|
||||
and Field.CONTENT_KEY.lower() in idx_str
|
||||
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)
|
||||
):
|
||||
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
|
||||
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY, idx)
|
||||
return
|
||||
except (RuntimeError, ValueError) as e:
|
||||
logger.warning("Failed to check existing indexes: %s", e)
|
||||
|
||||
index_sql = f"""
|
||||
CREATE INVERTED INDEX IF NOT EXISTS {index_name}
|
||||
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
|
||||
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY})
|
||||
PROPERTIES (
|
||||
"analyzer" = "{self._config.analyzer_type}",
|
||||
"mode" = "{self._config.analyzer_mode}"
|
||||
@ -588,13 +588,13 @@ class ClickzettaVector(BaseVector):
|
||||
or "with the same type" in error_msg
|
||||
or "cannot create inverted index" in error_msg
|
||||
) and "already has index" in error_msg:
|
||||
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
|
||||
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY)
|
||||
# Try to get the existing index name for logging
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower():
|
||||
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.lower() in str(idx).lower():
|
||||
logger.info("Found existing inverted index: %s", idx)
|
||||
break
|
||||
except (RuntimeError, ValueError):
|
||||
@ -669,7 +669,7 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
# Use parameterized INSERT with executemany for better performance and security
|
||||
# Cast JSON and VECTOR in SQL, pass raw data as parameters
|
||||
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
|
||||
columns = f"id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, {Field.VECTOR}"
|
||||
insert_sql = (
|
||||
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
|
||||
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
|
||||
@ -767,7 +767,7 @@ class ClickzettaVector(BaseVector):
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
sql = (
|
||||
f"DELETE FROM {self._config.schema_name}.{self._table_name} "
|
||||
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?"
|
||||
f"WHERE json_extract_string({Field.METADATA_KEY}, '$.{key}') = ?"
|
||||
)
|
||||
cursor.execute(sql, binding_params=[value])
|
||||
|
||||
@ -795,9 +795,7 @@ class ClickzettaVector(BaseVector):
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
@ -808,23 +806,21 @@ class ClickzettaVector(BaseVector):
|
||||
distance_func = "COSINE_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(
|
||||
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}"
|
||||
)
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {2 - score_threshold}")
|
||||
else:
|
||||
# For L2 distance, smaller is better
|
||||
distance_func = "L2_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}")
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {score_threshold}")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
|
||||
|
||||
# Execute vector search query
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value},
|
||||
{distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance
|
||||
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY},
|
||||
{distance_func}({Field.VECTOR}, {query_vector_str}) AS distance
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
ORDER BY distance
|
||||
@ -887,9 +883,7 @@ class ClickzettaVector(BaseVector):
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
@ -897,13 +891,13 @@ class ClickzettaVector(BaseVector):
|
||||
# match_all requires all terms to be present
|
||||
# Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')")
|
||||
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY}, '{escaped_query}')")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
# Execute full-text search query
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
|
||||
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
LIMIT {top_k}
|
||||
@ -986,19 +980,17 @@ class ClickzettaVector(BaseVector):
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use simple quote escaping for LIKE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
|
||||
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
LIMIT {top_k}
|
||||
|
||||
@ -57,18 +57,18 @@ class ElasticSearchJaVector(ElasticSearchVector):
|
||||
}
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {
|
||||
Field.CONTENT_KEY: {
|
||||
"type": "text",
|
||||
"analyzer": "ja_analyzer",
|
||||
"search_analyzer": "ja_analyzer",
|
||||
},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
Field.VECTOR: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
Field.METADATA_KEY: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
|
||||
@ -4,7 +4,7 @@ import math
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from elasticsearch import ConnectionError as ElasticsearchConnectionError
|
||||
from elasticsearch import Elasticsearch
|
||||
from flask import current_app
|
||||
from packaging.version import parse as parse_version
|
||||
@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
|
||||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
except requests.ConnectionError as e:
|
||||
except ElasticsearchConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
@ -163,9 +163,9 @@ class ElasticSearchVector(BaseVector):
|
||||
index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] or None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata or {},
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i] or None,
|
||||
Field.METADATA_KEY: documents[i].metadata or {},
|
||||
},
|
||||
)
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
@ -193,7 +193,7 @@ class ElasticSearchVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
num_candidates = math.ceil(top_k * 1.5)
|
||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||
knn = {"field": Field.VECTOR, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
@ -205,9 +205,9 @@ class ElasticSearchVector(BaseVector):
|
||||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"][Field.VECTOR],
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
@ -224,13 +224,13 @@ class ElasticSearchVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
if document_ids_filter:
|
||||
query_str = {
|
||||
"bool": {
|
||||
"must": {"match": {Field.CONTENT_KEY.value: query}},
|
||||
"must": {"match": {Field.CONTENT_KEY: query}},
|
||||
"filter": {"terms": {"metadata.document_id": document_ids_filter}},
|
||||
}
|
||||
}
|
||||
@ -240,9 +240,9 @@ class ElasticSearchVector(BaseVector):
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"][Field.VECTOR],
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
)
|
||||
)
|
||||
|
||||
@ -270,14 +270,14 @@ class ElasticSearchVector(BaseVector):
|
||||
dim = len(embeddings[0])
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
Field.CONTENT_KEY: {"type": "text"},
|
||||
Field.VECTOR: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
Field.METADATA_KEY: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
|
||||
|
||||
@ -67,9 +67,9 @@ class HuaweiCloudVector(BaseVector):
|
||||
index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] or None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata or {},
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i] or None,
|
||||
Field.METADATA_KEY: documents[i].metadata or {},
|
||||
},
|
||||
)
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
@ -101,7 +101,7 @@ class HuaweiCloudVector(BaseVector):
|
||||
"size": top_k,
|
||||
"query": {
|
||||
"vector": {
|
||||
Field.VECTOR.value: {
|
||||
Field.VECTOR: {
|
||||
"vector": query_vector,
|
||||
"topk": top_k,
|
||||
}
|
||||
@ -116,9 +116,9 @@ class HuaweiCloudVector(BaseVector):
|
||||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"][Field.VECTOR],
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
@ -135,15 +135,15 @@ class HuaweiCloudVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
query_str = {"match": {Field.CONTENT_KEY: query}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"][Field.VECTOR],
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
)
|
||||
)
|
||||
|
||||
@ -171,8 +171,8 @@ class HuaweiCloudVector(BaseVector):
|
||||
dim = len(embeddings[0])
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
Field.CONTENT_KEY: {"type": "text"},
|
||||
Field.VECTOR: { # Make sure the dimension is correct here
|
||||
"type": "vector",
|
||||
"dimension": dim,
|
||||
"indexing": True,
|
||||
@ -181,7 +181,7 @@ class HuaweiCloudVector(BaseVector):
|
||||
"neighbors": 32,
|
||||
"efc": 128,
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
Field.METADATA_KEY: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
|
||||
@ -125,9 +125,9 @@ class LindormVectorStore(BaseVector):
|
||||
}
|
||||
}
|
||||
action_values: dict[str, Any] = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i],
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
@ -149,7 +149,7 @@ class LindormVectorStore(BaseVector):
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query: dict[str, Any] = {
|
||||
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
|
||||
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
|
||||
}
|
||||
if self._using_ugc:
|
||||
query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
|
||||
@ -252,14 +252,14 @@ class LindormVectorStore(BaseVector):
|
||||
search_query: dict[str, Any] = {
|
||||
"size": top_k,
|
||||
"_source": True,
|
||||
"query": {"knn": {Field.VECTOR.value: {"vector": query_vector, "k": top_k}}},
|
||||
"query": {"knn": {Field.VECTOR: {"vector": query_vector, "k": top_k}}},
|
||||
}
|
||||
|
||||
final_ext: dict[str, Any] = {"lvector": {}}
|
||||
if filters is not None and len(filters) > 0:
|
||||
# when using filter, transform filter from List[Dict] to Dict as valid format
|
||||
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][Field.VECTOR.value]["filter"] = filter_dict # filter should be Dict
|
||||
search_query["query"]["knn"][Field.VECTOR]["filter"] = filter_dict # filter should be Dict
|
||||
final_ext["lvector"]["filter_type"] = "pre_filter"
|
||||
|
||||
if final_ext != {"lvector": {}}:
|
||||
@ -279,9 +279,9 @@ class LindormVectorStore(BaseVector):
|
||||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"][Field.VECTOR],
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
@ -318,9 +318,9 @@ class LindormVectorStore(BaseVector):
|
||||
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value)
|
||||
vector = hit["_source"].get(Field.VECTOR.value)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY)
|
||||
vector = hit["_source"].get(Field.VECTOR)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@ -342,8 +342,8 @@ class LindormVectorStore(BaseVector):
|
||||
"settings": {"index": {"knn": True, "knn_routing": self._using_ugc}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
Field.VECTOR.value: {
|
||||
Field.CONTENT_KEY: {"type": "text"},
|
||||
Field.VECTOR: {
|
||||
"type": "knn_vector",
|
||||
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
|
||||
"method": {
|
||||
|
||||
@ -85,7 +85,7 @@ class MilvusVector(BaseVector):
|
||||
collection_info = self._client.describe_collection(self._collection_name)
|
||||
fields = [field["name"] for field in collection_info["fields"]]
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self._fields = [f for f in fields if f != Field.PRIMARY_KEY.value]
|
||||
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
"""
|
||||
@ -130,9 +130,9 @@ class MilvusVector(BaseVector):
|
||||
insert_dict = {
|
||||
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
|
||||
# function will automatically convert the native text into a sparse vector for us.
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i],
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
}
|
||||
insert_dict_list.append(insert_dict)
|
||||
# Total insert count
|
||||
@ -243,15 +243,15 @@ class MilvusVector(BaseVector):
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
anns_field=Field.VECTOR.value,
|
||||
anns_field=Field.VECTOR,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
@ -264,7 +264,7 @@ class MilvusVector(BaseVector):
|
||||
"Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)."
|
||||
)
|
||||
return []
|
||||
if not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
if not self.field_exists(Field.SPARSE_VECTOR):
|
||||
logger.warning(
|
||||
"Full-text search unavailable: collection missing 'sparse_vector' field; "
|
||||
"recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index."
|
||||
@ -279,15 +279,15 @@ class MilvusVector(BaseVector):
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query],
|
||||
anns_field=Field.SPARSE_VECTOR.value,
|
||||
anns_field=Field.SPARSE_VECTOR,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
@ -311,7 +311,7 @@ class MilvusVector(BaseVector):
|
||||
dim = len(embeddings[0])
|
||||
fields = []
|
||||
if metadatas:
|
||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||
fields.append(FieldSchema(Field.METADATA_KEY, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
||||
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
||||
@ -326,15 +326,15 @@ class MilvusVector(BaseVector):
|
||||
):
|
||||
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
|
||||
|
||||
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs))
|
||||
fields.append(FieldSchema(Field.CONTENT_KEY, DataType.VARCHAR, **content_field_kwargs))
|
||||
|
||||
# Create the primary key field
|
||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||
fields.append(FieldSchema(Field.PRIMARY_KEY, DataType.INT64, is_primary=True, auto_id=True))
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||
fields.append(FieldSchema(Field.VECTOR, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
|
||||
fields.append(FieldSchema(Field.SPARSE_VECTOR, DataType.SPARSE_FLOAT_VECTOR))
|
||||
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
@ -342,8 +342,8 @@ class MilvusVector(BaseVector):
|
||||
if self._hybrid_search_enabled:
|
||||
bm25_function = Function(
|
||||
name="text_bm25_emb",
|
||||
input_field_names=[Field.CONTENT_KEY.value],
|
||||
output_field_names=[Field.SPARSE_VECTOR.value],
|
||||
input_field_names=[Field.CONTENT_KEY],
|
||||
output_field_names=[Field.SPARSE_VECTOR],
|
||||
function_type=FunctionType.BM25,
|
||||
)
|
||||
schema.add_function(bm25_function)
|
||||
@ -352,12 +352,12 @@ class MilvusVector(BaseVector):
|
||||
|
||||
# Create Index params for the collection
|
||||
index_params_obj = IndexParams()
|
||||
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
||||
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
|
||||
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
index_params_obj.add_index(
|
||||
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
|
||||
field_name=Field.SPARSE_VECTOR, index_type="AUTOINDEX", metric_type="BM25"
|
||||
)
|
||||
|
||||
# Create the collection
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
@ -8,6 +8,7 @@ from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from configs.middleware.vdb.opensearch_config import AuthMethod
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@ -25,7 +26,7 @@ class OpenSearchConfig(BaseModel):
|
||||
port: int
|
||||
secure: bool = False # use_ssl
|
||||
verify_certs: bool = True
|
||||
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
|
||||
auth_method: AuthMethod = AuthMethod.BASIC
|
||||
user: str | None = None
|
||||
password: str | None = None
|
||||
aws_region: str | None = None
|
||||
@ -98,9 +99,9 @@ class OpenSearchVector(BaseVector):
|
||||
"_op_type": "index",
|
||||
"_index": self._collection_name.lower(),
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
},
|
||||
}
|
||||
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
|
||||
@ -116,7 +117,7 @@ class OpenSearchVector(BaseVector):
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}}
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
if response["hits"]["hits"]:
|
||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||
@ -180,17 +181,17 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
query = {
|
||||
"size": kwargs.get("top_k", 4),
|
||||
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
"query": {"knn": {Field.VECTOR: {Field.VECTOR: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query["query"] = {
|
||||
"script_score": {
|
||||
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}},
|
||||
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID: document_ids_filter}}]}},
|
||||
"script": {
|
||||
"source": "knn_score",
|
||||
"lang": "knn",
|
||||
"params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"},
|
||||
"params": {"field": Field.VECTOR, "query_value": query_vector, "space_type": "l2"},
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -203,7 +204,7 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY, {})
|
||||
|
||||
# Make sure metadata is a dictionary
|
||||
if metadata is None:
|
||||
@ -212,7 +213,7 @@ class OpenSearchVector(BaseVector):
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if hit["_score"] >= score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@ -227,9 +228,9 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value)
|
||||
vector = hit["_source"].get(Field.VECTOR.value)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY)
|
||||
vector = hit["_source"].get(Field.VECTOR)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@ -250,8 +251,8 @@ class OpenSearchVector(BaseVector):
|
||||
"settings": {"index": {"knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
Field.VECTOR.value: {
|
||||
Field.CONTENT_KEY: {"type": "text"},
|
||||
Field.VECTOR: {
|
||||
"type": "knn_vector",
|
||||
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
|
||||
"method": {
|
||||
@ -261,7 +262,7 @@ class OpenSearchVector(BaseVector):
|
||||
"parameters": {"ef_construction": 64, "m": 8},
|
||||
},
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
Field.METADATA_KEY: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
|
||||
@ -293,7 +294,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
port=dify_config.OPENSEARCH_PORT,
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS,
|
||||
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
|
||||
auth_method=dify_config.OPENSEARCH_AUTH_METHOD,
|
||||
user=dify_config.OPENSEARCH_USER,
|
||||
password=dify_config.OPENSEARCH_PASSWORD,
|
||||
aws_region=dify_config.OPENSEARCH_AWS_REGION,
|
||||
|
||||
@ -147,15 +147,13 @@ class QdrantVector(BaseVector):
|
||||
|
||||
# create group_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create doc_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
|
||||
# create document_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create full text index
|
||||
text_index_params = TextIndexParams(
|
||||
@ -165,9 +163,7 @@ class QdrantVector(BaseVector):
|
||||
max_token_len=20,
|
||||
lowercase=True,
|
||||
)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -220,10 +216,10 @@ class QdrantVector(BaseVector):
|
||||
self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
Field.CONTENT_KEY,
|
||||
Field.METADATA_KEY,
|
||||
group_id or "", # Ensure group_id is never None
|
||||
Field.GROUP_KEY.value,
|
||||
Field.GROUP_KEY,
|
||||
),
|
||||
)
|
||||
]
|
||||
@ -381,12 +377,12 @@ class QdrantVector(BaseVector):
|
||||
for result in results:
|
||||
if result.payload is None:
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
metadata = result.payload.get(Field.METADATA_KEY) or {}
|
||||
# duplicate check score threshold
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
page_content=result.payload.get(Field.CONTENT_KEY, ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
@ -433,7 +429,7 @@ class QdrantVector(BaseVector):
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
|
||||
@ -55,7 +55,7 @@ class TableStoreVector(BaseVector):
|
||||
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
|
||||
self._table_name = f"{collection_name}"
|
||||
self._index_name = f"{collection_name}_idx"
|
||||
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
|
||||
self._tags_field = f"{Field.METADATA_KEY}_tags"
|
||||
|
||||
def create_collection(self, embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
@ -64,7 +64,7 @@ class TableStoreVector(BaseVector):
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
docs = []
|
||||
request = BatchGetRowRequest()
|
||||
columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value]
|
||||
columns_to_get = [Field.METADATA_KEY, Field.CONTENT_KEY]
|
||||
rows_to_get = [[("id", _id)] for _id in ids]
|
||||
request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
|
||||
|
||||
@ -73,11 +73,7 @@ class TableStoreVector(BaseVector):
|
||||
for item in table_result:
|
||||
if item.is_ok and item.row:
|
||||
kv = {k: v for k, v, _ in item.row.attribute_columns}
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])
|
||||
)
|
||||
)
|
||||
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
|
||||
return docs
|
||||
|
||||
def get_type(self) -> str:
|
||||
@ -95,9 +91,9 @@ class TableStoreVector(BaseVector):
|
||||
self._write_row(
|
||||
primary_key=uuids[i],
|
||||
attributes={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i],
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
},
|
||||
)
|
||||
return uuids
|
||||
@ -180,7 +176,7 @@ class TableStoreVector(BaseVector):
|
||||
|
||||
field_schemas = [
|
||||
tablestore.FieldSchema(
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.CONTENT_KEY,
|
||||
tablestore.FieldType.TEXT,
|
||||
analyzer=tablestore.AnalyzerType.MAXWORD,
|
||||
index=True,
|
||||
@ -188,7 +184,7 @@ class TableStoreVector(BaseVector):
|
||||
store=False,
|
||||
),
|
||||
tablestore.FieldSchema(
|
||||
Field.VECTOR.value,
|
||||
Field.VECTOR,
|
||||
tablestore.FieldType.VECTOR,
|
||||
vector_options=tablestore.VectorOptions(
|
||||
data_type=tablestore.VectorDataType.VD_FLOAT_32,
|
||||
@ -197,7 +193,7 @@ class TableStoreVector(BaseVector):
|
||||
),
|
||||
),
|
||||
tablestore.FieldSchema(
|
||||
Field.METADATA_KEY.value,
|
||||
Field.METADATA_KEY,
|
||||
tablestore.FieldType.KEYWORD,
|
||||
index=True,
|
||||
store=False,
|
||||
@ -233,15 +229,15 @@ class TableStoreVector(BaseVector):
|
||||
pk = [("id", primary_key)]
|
||||
|
||||
tags = []
|
||||
for key, value in attributes[Field.METADATA_KEY.value].items():
|
||||
for key, value in attributes[Field.METADATA_KEY].items():
|
||||
tags.append(str(key) + "=" + str(value))
|
||||
|
||||
attribute_columns = [
|
||||
(Field.CONTENT_KEY.value, attributes[Field.CONTENT_KEY.value]),
|
||||
(Field.VECTOR.value, json.dumps(attributes[Field.VECTOR.value])),
|
||||
(Field.CONTENT_KEY, attributes[Field.CONTENT_KEY]),
|
||||
(Field.VECTOR, json.dumps(attributes[Field.VECTOR])),
|
||||
(
|
||||
Field.METADATA_KEY.value,
|
||||
json.dumps(attributes[Field.METADATA_KEY.value]),
|
||||
Field.METADATA_KEY,
|
||||
json.dumps(attributes[Field.METADATA_KEY]),
|
||||
),
|
||||
(self._tags_field, json.dumps(tags)),
|
||||
]
|
||||
@ -270,7 +266,7 @@ class TableStoreVector(BaseVector):
|
||||
index_name=self._index_name,
|
||||
search_query=query,
|
||||
columns_to_get=tablestore.ColumnsToGet(
|
||||
column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED
|
||||
column_names=[Field.PRIMARY_KEY], return_type=tablestore.ColumnReturnType.SPECIFIED
|
||||
),
|
||||
)
|
||||
|
||||
@ -288,7 +284,7 @@ class TableStoreVector(BaseVector):
|
||||
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
knn_vector_query = tablestore.KnnVectorQuery(
|
||||
field_name=Field.VECTOR.value,
|
||||
field_name=Field.VECTOR,
|
||||
top_k=top_k,
|
||||
float32_query_vector=query_vector,
|
||||
)
|
||||
@ -311,8 +307,8 @@ class TableStoreVector(BaseVector):
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
|
||||
vector_str = ots_column_map.get(Field.VECTOR)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
@ -321,7 +317,7 @@ class TableStoreVector(BaseVector):
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
@ -343,7 +339,7 @@ class TableStoreVector(BaseVector):
|
||||
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
|
||||
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
|
||||
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY))
|
||||
|
||||
if document_ids_filter:
|
||||
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
|
||||
@ -374,10 +370,10 @@ class TableStoreVector(BaseVector):
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||
vector_str = ots_column_map.get(Field.VECTOR)
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
|
||||
if score:
|
||||
@ -385,7 +381,7 @@ class TableStoreVector(BaseVector):
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@ -5,9 +5,10 @@ from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import httpx
|
||||
import qdrant_client
|
||||
import requests
|
||||
from flask import current_app
|
||||
from httpx import DigestAuth
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.models import (
|
||||
@ -19,7 +20,6 @@ from qdrant_client.http.models import (
|
||||
TokenizerType,
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from requests.auth import HTTPDigestAuth
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
@ -141,15 +141,13 @@ class TidbOnQdrantVector(BaseVector):
|
||||
|
||||
# create group_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create doc_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
|
||||
# create document_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create full text index
|
||||
text_index_params = TextIndexParams(
|
||||
@ -159,9 +157,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
max_token_len=20,
|
||||
lowercase=True,
|
||||
)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -211,10 +207,10 @@ class TidbOnQdrantVector(BaseVector):
|
||||
self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
Field.CONTENT_KEY,
|
||||
Field.METADATA_KEY,
|
||||
group_id or "",
|
||||
Field.GROUP_KEY.value,
|
||||
Field.GROUP_KEY,
|
||||
),
|
||||
)
|
||||
]
|
||||
@ -349,13 +345,13 @@ class TidbOnQdrantVector(BaseVector):
|
||||
for result in results:
|
||||
if result.payload is None:
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
metadata = result.payload.get(Field.METADATA_KEY) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
page_content=result.payload.get(Field.CONTENT_KEY, ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
@ -392,7 +388,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
@ -504,10 +500,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
}
|
||||
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
|
||||
|
||||
response = requests.post(
|
||||
response = httpx.post(
|
||||
f"{tidb_config.api_url}/clusters",
|
||||
json=cluster_data,
|
||||
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@ -527,10 +523,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
|
||||
body = {"password": new_password}
|
||||
|
||||
response = requests.put(
|
||||
response = httpx.put(
|
||||
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
|
||||
json=body,
|
||||
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
|
||||
@ -2,8 +2,8 @@ import time
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPDigestAuth
|
||||
import httpx
|
||||
from httpx import DigestAuth
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
@ -49,7 +49,7 @@ class TidbService:
|
||||
"rootPassword": password,
|
||||
}
|
||||
|
||||
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key))
|
||||
response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
@ -83,7 +83,7 @@ class TidbService:
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
|
||||
response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
@ -102,7 +102,7 @@ class TidbService:
|
||||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
|
||||
response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
@ -127,10 +127,10 @@ class TidbService:
|
||||
|
||||
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
|
||||
|
||||
response = requests.patch(
|
||||
response = httpx.patch(
|
||||
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
|
||||
json=body,
|
||||
auth=HTTPDigestAuth(public_key, private_key),
|
||||
auth=DigestAuth(public_key, private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@ -161,9 +161,7 @@ class TidbService:
|
||||
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
|
||||
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
|
||||
params = {"clusterIds": cluster_ids, "view": "BASIC"}
|
||||
response = requests.get(
|
||||
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
|
||||
)
|
||||
response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
@ -224,8 +222,8 @@ class TidbService:
|
||||
clusters.append(cluster_data)
|
||||
|
||||
request_body = {"requests": clusters}
|
||||
response = requests.post(
|
||||
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key)
|
||||
response = httpx.post(
|
||||
f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
|
||||
@ -55,13 +55,13 @@ class TiDBVector(BaseVector):
|
||||
return Table(
|
||||
self._collection_name,
|
||||
self._orm_base.metadata,
|
||||
Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
|
||||
Column(Field.PRIMARY_KEY, String(36), primary_key=True, nullable=False),
|
||||
Column(
|
||||
Field.VECTOR.value,
|
||||
Field.VECTOR,
|
||||
VectorType(dim),
|
||||
nullable=False,
|
||||
),
|
||||
Column(Field.TEXT_KEY.value, TEXT, nullable=False),
|
||||
Column(Field.TEXT_KEY, TEXT, nullable=False),
|
||||
Column("meta", JSON, nullable=False),
|
||||
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
||||
Column(
|
||||
|
||||
@ -71,6 +71,12 @@ class Vector:
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||
|
||||
return MilvusVectorFactory
|
||||
case VectorType.ALIBABACLOUD_MYSQL:
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
AlibabaCloudMySQLVectorFactory,
|
||||
)
|
||||
|
||||
return AlibabaCloudMySQLVectorFactory
|
||||
case VectorType.MYSCALE:
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from enum import StrEnum
|
||||
|
||||
|
||||
class VectorType(StrEnum):
|
||||
ALIBABACLOUD_MYSQL = "alibabacloud_mysql"
|
||||
ANALYTICDB = "analyticdb"
|
||||
CHROMA = "chroma"
|
||||
MILVUS = "milvus"
|
||||
|
||||
@ -76,11 +76,11 @@ class VikingDBVector(BaseVector):
|
||||
|
||||
if not self._has_collection():
|
||||
fields = [
|
||||
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
|
||||
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
|
||||
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension),
|
||||
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
|
||||
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
|
||||
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=dimension),
|
||||
]
|
||||
|
||||
self._client.create_collection(
|
||||
@ -100,7 +100,7 @@ class VikingDBVector(BaseVector):
|
||||
collection_name=self._collection_name,
|
||||
index_name=self._index_name,
|
||||
vector_index=vector_index,
|
||||
partition_by=vdb_Field.GROUP_KEY.value,
|
||||
partition_by=vdb_Field.GROUP_KEY,
|
||||
description="Index For Dify",
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
@ -126,11 +126,11 @@ class VikingDBVector(BaseVector):
|
||||
# FIXME: fix the type of metadata later
|
||||
doc = Data(
|
||||
{
|
||||
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
|
||||
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
|
||||
vdb_Field.CONTENT_KEY.value: page_content,
|
||||
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
|
||||
vdb_Field.GROUP_KEY.value: self._group_id,
|
||||
vdb_Field.PRIMARY_KEY: metadatas[i]["doc_id"], # type: ignore
|
||||
vdb_Field.VECTOR: embeddings[i] if embeddings else None,
|
||||
vdb_Field.CONTENT_KEY: page_content,
|
||||
vdb_Field.METADATA_KEY: json.dumps(metadata),
|
||||
vdb_Field.GROUP_KEY: self._group_id,
|
||||
}
|
||||
)
|
||||
docs.append(doc)
|
||||
@ -151,7 +151,7 @@ class VikingDBVector(BaseVector):
|
||||
# Note: Metadata field value is an dict, but vikingdb field
|
||||
# not support json type
|
||||
results = self._client.get_index(self._collection_name, self._index_name).search(
|
||||
filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]},
|
||||
filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]},
|
||||
# max value is 5000
|
||||
limit=5000,
|
||||
)
|
||||
@ -161,7 +161,7 @@ class VikingDBVector(BaseVector):
|
||||
|
||||
ids = []
|
||||
for result in results:
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
if metadata.get(key) == value:
|
||||
@ -189,12 +189,12 @@ class VikingDBVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata)
|
||||
docs.append(doc)
|
||||
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
@ -2,7 +2,6 @@ import datetime
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import weaviate # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@ -45,8 +44,8 @@ class WeaviateVector(BaseVector):
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
|
||||
)
|
||||
except requests.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
except Exception as exc:
|
||||
raise ConnectionError("Vector database connection error") from exc
|
||||
|
||||
client.batch.configure(
|
||||
# `batch_size` takes an `int` value to enable auto-batching
|
||||
@ -105,7 +104,7 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
with self._client.batch as batch:
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {Field.TEXT_KEY.value: text}
|
||||
data_properties = {Field.TEXT_KEY: text}
|
||||
if metadatas is not None:
|
||||
# metadata maybe None
|
||||
for key, val in (metadatas[i] or {}).items():
|
||||
@ -183,7 +182,7 @@ class WeaviateVector(BaseVector):
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
collection_name = self._collection_name
|
||||
properties = self._attributes
|
||||
properties.append(Field.TEXT_KEY.value)
|
||||
properties.append(Field.TEXT_KEY)
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
|
||||
vector = {"vector": query_vector}
|
||||
@ -205,7 +204,7 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
docs_and_scores = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
text = res.pop(Field.TEXT_KEY)
|
||||
score = 1 - res["_additional"]["distance"]
|
||||
docs_and_scores.append((Document(page_content=text, metadata=res), score))
|
||||
|
||||
@ -233,7 +232,7 @@ class WeaviateVector(BaseVector):
|
||||
collection_name = self._collection_name
|
||||
content: dict[str, Any] = {"concepts": [query]}
|
||||
properties = self._attributes
|
||||
properties.append(Field.TEXT_KEY.value)
|
||||
properties.append(Field.TEXT_KEY)
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
@ -251,7 +250,7 @@ class WeaviateVector(BaseVector):
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
text = res.pop(Field.TEXT_KEY)
|
||||
additional = res.pop("_additional")
|
||||
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
|
||||
return docs
|
||||
|
||||
Reference in New Issue
Block a user