Merge remote-tracking branch 'origin/main' into feat/collaboration

This commit is contained in:
lyzno1
2025-10-11 14:43:20 +08:00
463 changed files with 9705 additions and 5680 deletions

View File

@ -21,7 +21,7 @@ from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
@ -107,7 +107,7 @@ class RetrievalService:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
all_documents = cls._deduplicate_documents(all_documents)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
@ -134,7 +134,7 @@ class RetrievalService:
if not dataset:
return []
metadata_condition = (
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,
@ -245,10 +245,10 @@ class RetrievalService:
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
@ -293,10 +293,10 @@ class RetrievalService:
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(

View File

@ -0,0 +1,388 @@
import hashlib
import json
import logging
import uuid
from contextlib import contextmanager
from typing import Any, Literal, cast
import mysql.connector
from mysql.connector import Error as MySQLError
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class AlibabaCloudMySQLVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
max_connection: int
charset: str = "utf8mb4"
distance_function: Literal["cosine", "euclidean"] = "cosine"
hnsw_m: int = 6
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values.get("host"):
raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required")
if not values.get("port"):
raise ValueError("config ALIBABACLOUD_MYSQL_PORT is required")
if not values.get("user"):
raise ValueError("config ALIBABACLOUD_MYSQL_USER is required")
if values.get("password") is None:
raise ValueError("config ALIBABACLOUD_MYSQL_PASSWORD is required")
if not values.get("database"):
raise ValueError("config ALIBABACLOUD_MYSQL_DATABASE is required")
if not values.get("max_connection"):
raise ValueError("config ALIBABACLOUD_MYSQL_MAX_CONNECTION is required")
return values
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id VARCHAR(36) PRIMARY KEY,
text LONGTEXT NOT NULL,
meta JSON NOT NULL,
embedding VECTOR({dimension}) NOT NULL,
VECTOR INDEX (embedding) M={hnsw_m} DISTANCE={distance_function}
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
"""
SQL_CREATE_META_INDEX = """
CREATE INDEX idx_{index_hash}_meta ON {table_name}
((CAST(JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) AS CHAR(36))));
"""
SQL_CREATE_FULLTEXT_INDEX = """
CREATE FULLTEXT INDEX idx_{index_hash}_text ON {table_name} (text) WITH PARSER ngram;
"""
class AlibabaCloudMySQLVector(BaseVector):
def __init__(self, collection_name: str, config: AlibabaCloudMySQLVectorConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = collection_name.lower()
self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8]
self.distance_function = config.distance_function.lower()
self.hnsw_m = config.hnsw_m
self._check_vector_support()
def get_type(self) -> str:
return VectorType.ALIBABACLOUD_MYSQL
def _create_connection_pool(self, config: AlibabaCloudMySQLVectorConfig):
# Create connection pool using mysql-connector-python pooling
pool_config: dict[str, Any] = {
"host": config.host,
"port": config.port,
"user": config.user,
"password": config.password,
"database": config.database,
"charset": config.charset,
"autocommit": True,
"pool_name": f"pool_{self.collection_name}",
"pool_size": config.max_connection,
"pool_reset_session": True,
}
return mysql.connector.pooling.MySQLConnectionPool(**pool_config)
def _check_vector_support(self):
"""Check if the MySQL server supports vector operations."""
try:
with self._get_cursor() as cur:
# Check MySQL version and vector support
cur.execute("SELECT VERSION()")
version = cur.fetchone()["VERSION()"]
logger.debug("Connected to MySQL version: %s", version)
# Try to execute a simple vector function to verify support
cur.execute("SELECT VEC_FromText('[1,2,3]') IS NOT NULL as vector_support")
result = cur.fetchone()
if not result or not result.get("vector_support"):
raise ValueError(
"RDS MySQL Vector functions are not available."
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
)
except MySQLError as e:
if "FUNCTION" in str(e) and "VEC_FromText" in str(e):
raise ValueError(
"RDS MySQL Vector functions are not available."
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
) from e
raise e
@contextmanager
def _get_cursor(self):
conn = self.pool.get_connection()
cur = conn.cursor(dictionary=True)
try:
yield cur
finally:
cur.close()
conn.close()
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
# Convert embedding list to Aliyun MySQL vector format
vector_str = "[" + ",".join(map(str, embeddings[i])) + "]"
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
vector_str,
)
)
with self._get_cursor() as cur:
insert_sql = (
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (%s, %s, %s, VEC_FromText(%s))"
)
cur.executemany(insert_sql, values)
return pks
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
return cur.fetchone() is not None
def get_by_ids(self, ids: list[str]) -> list[Document]:
if not ids:
return []
with self._get_cursor() as cur:
placeholders = ",".join(["%s"] * len(ids))
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = []
for record in cur:
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
def delete_by_ids(self, ids: list[str]):
# Avoiding crashes caused by performing delete operations on empty lists
if not ids:
return
with self._get_cursor() as cur:
try:
placeholders = ",".join(["%s"] * len(ids))
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
except MySQLError as e:
if e.errno == 1146: # Table doesn't exist
logger.warning("Table %s not found, skipping delete operation.", self.table_name)
return
else:
raise e
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
cur.execute(
f"DELETE FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", (f"$.{key}", value)
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector using RDS MySQL vector distance functions.
:param query_vector: The input vector to search for similar items.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params = []
if document_ids_filter:
placeholders = ",".join(["%s"] * len(document_ids_filter))
where_clause = f" WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
params.extend(document_ids_filter)
# Convert query vector to RDS MySQL vector format
query_vector_str = "[" + ",".join(map(str, query_vector)) + "]"
# Use RSD MySQL's native vector distance functions
with self._get_cursor() as cur:
# Choose distance function based on configuration
distance_func = "VEC_DISTANCE_COSINE" if self.distance_function == "cosine" else "VEC_DISTANCE_EUCLIDEAN"
# Note: RDS MySQL optimizer will use vector index when ORDER BY + LIMIT are present
# Use column alias in ORDER BY to avoid calculating distance twice
sql = f"""
SELECT meta, text,
{distance_func}(embedding, VEC_FromText(%s)) AS distance
FROM {self.table_name}
{where_clause}
ORDER BY distance
LIMIT %s
"""
query_params = [query_vector_str] + params + [top_k]
cur.execute(sql, query_params)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
try:
distance = float(record["distance"])
# Convert distance to similarity score
if self.distance_function == "cosine":
# For cosine distance: similarity = 1 - distance
similarity = 1.0 - distance
else:
# For euclidean distance: use inverse relationship
# similarity = 1 / (1 + distance)
similarity = 1.0 / (1.0 + distance)
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata["score"] = similarity
metadata["distance"] = distance
if similarity >= score_threshold:
docs.append(Document(page_content=record["text"], metadata=metadata))
except (ValueError, json.JSONDecodeError) as e:
logger.warning("Error processing search result: %s", e)
continue
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params = []
if document_ids_filter:
placeholders = ",".join(["%s"] * len(document_ids_filter))
where_clause = f" AND JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
params.extend(document_ids_filter)
with self._get_cursor() as cur:
# Build query parameters: query (twice for MATCH clauses), document_ids_filter (if any), top_k
query_params = [query, query] + params + [top_k]
cur.execute(
f"""SELECT meta, text,
MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE) AS score
FROM {self.table_name}
WHERE MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE)
{where_clause}
ORDER BY score DESC
LIMIT %s""",
query_params,
)
docs = []
for record in cur:
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata["score"] = float(record["score"])
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
def _create_collection(self, dimension: int):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{collection_exist_cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
# Create table with vector column and vector index
cur.execute(
SQL_CREATE_TABLE.format(
table_name=self.table_name,
dimension=dimension,
distance_function=self.distance_function,
hnsw_m=self.hnsw_m,
)
)
# Create metadata index (check if exists first)
try:
cur.execute(SQL_CREATE_META_INDEX.format(table_name=self.table_name, index_hash=self.index_hash))
except MySQLError as e:
if e.errno != 1061: # Duplicate key name
logger.warning("Could not create meta index: %s", e)
# Create full-text index for text search
try:
cur.execute(
SQL_CREATE_FULLTEXT_INDEX.format(table_name=self.table_name, index_hash=self.index_hash)
)
except MySQLError as e:
if e.errno != 1061: # Duplicate key name
logger.warning("Could not create fulltext index: %s", e)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class AlibabaCloudMySQLVectorFactory(AbstractVectorFactory):
def _validate_distance_function(self, distance_function: str) -> Literal["cosine", "euclidean"]:
"""Validate and return the distance function as a proper Literal type."""
if distance_function not in ["cosine", "euclidean"]:
raise ValueError(f"Invalid distance function: {distance_function}. Must be 'cosine' or 'euclidean'")
return cast(Literal["cosine", "euclidean"], distance_function)
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AlibabaCloudMySQLVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ALIBABACLOUD_MYSQL, collection_name)
)
return AlibabaCloudMySQLVector(
collection_name=collection_name,
config=AlibabaCloudMySQLVectorConfig(
host=dify_config.ALIBABACLOUD_MYSQL_HOST or "localhost",
port=dify_config.ALIBABACLOUD_MYSQL_PORT,
user=dify_config.ALIBABACLOUD_MYSQL_USER or "root",
password=dify_config.ALIBABACLOUD_MYSQL_PASSWORD or "",
database=dify_config.ALIBABACLOUD_MYSQL_DATABASE or "dify",
max_connection=dify_config.ALIBABACLOUD_MYSQL_MAX_CONNECTION,
charset=dify_config.ALIBABACLOUD_MYSQL_CHARSET or "utf8mb4",
distance_function=self._validate_distance_function(
dify_config.ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION or "cosine"
),
hnsw_m=dify_config.ALIBABACLOUD_MYSQL_HNSW_M or 6,
),
)

View File

@ -488,9 +488,9 @@ class ClickzettaVector(BaseVector):
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} (
id STRING NOT NULL COMMENT 'Unique document identifier',
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
{Field.CONTENT_KEY} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
{Field.METADATA_KEY} JSON COMMENT 'Document metadata including source, type, and other attributes',
{Field.VECTOR} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
'High-dimensional embedding vector for semantic similarity search',
PRIMARY KEY (id)
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
@ -519,15 +519,15 @@ class ClickzettaVector(BaseVector):
existing_indexes = cursor.fetchall()
for idx in existing_indexes:
# Check if vector index already exists on the embedding column
if Field.VECTOR.value in str(idx).lower():
logger.info("Vector index already exists on column %s", Field.VECTOR.value)
if Field.VECTOR in str(idx).lower():
logger.info("Vector index already exists on column %s", Field.VECTOR)
return
except (RuntimeError, ValueError) as e:
logger.warning("Failed to check existing indexes: %s", e)
index_sql = f"""
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR})
PROPERTIES (
"distance.function" = "{self._config.vector_distance_function}",
"scalar.type" = "f32",
@ -560,17 +560,17 @@ class ClickzettaVector(BaseVector):
# More precise check: look for inverted index specifically on the content column
if (
"inverted" in idx_str
and Field.CONTENT_KEY.value.lower() in idx_str
and Field.CONTENT_KEY.lower() in idx_str
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)
):
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY, idx)
return
except (RuntimeError, ValueError) as e:
logger.warning("Failed to check existing indexes: %s", e)
index_sql = f"""
CREATE INVERTED INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY})
PROPERTIES (
"analyzer" = "{self._config.analyzer_type}",
"mode" = "{self._config.analyzer_mode}"
@ -588,13 +588,13 @@ class ClickzettaVector(BaseVector):
or "with the same type" in error_msg
or "cannot create inverted index" in error_msg
) and "already has index" in error_msg:
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY)
# Try to get the existing index name for logging
try:
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
existing_indexes = cursor.fetchall()
for idx in existing_indexes:
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower():
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.lower() in str(idx).lower():
logger.info("Found existing inverted index: %s", idx)
break
except (RuntimeError, ValueError):
@ -669,7 +669,7 @@ class ClickzettaVector(BaseVector):
# Use parameterized INSERT with executemany for better performance and security
# Cast JSON and VECTOR in SQL, pass raw data as parameters
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
columns = f"id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, {Field.VECTOR}"
insert_sql = (
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
@ -767,7 +767,7 @@ class ClickzettaVector(BaseVector):
# Use json_extract_string function for ClickZetta compatibility
sql = (
f"DELETE FROM {self._config.schema_name}.{self._table_name} "
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?"
f"WHERE json_extract_string({Field.METADATA_KEY}, '$.{key}') = ?"
)
cursor.execute(sql, binding_params=[value])
@ -795,9 +795,7 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
# No need for dataset_id filter since each dataset has its own table
@ -808,23 +806,21 @@ class ClickzettaVector(BaseVector):
distance_func = "COSINE_DISTANCE"
if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}"
)
filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {2 - score_threshold}")
else:
# For L2 distance, smaller is better
distance_func = "L2_DISTANCE"
if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}")
filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {score_threshold}")
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
# Execute vector search query
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value},
{distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY},
{distance_func}({Field.VECTOR}, {query_vector_str}) AS distance
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
ORDER BY distance
@ -887,9 +883,7 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
# No need for dataset_id filter since each dataset has its own table
@ -897,13 +891,13 @@ class ClickzettaVector(BaseVector):
# match_all requires all terms to be present
# Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')")
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY}, '{escaped_query}')")
where_clause = " AND ".join(filter_clauses)
# Execute full-text search query
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
LIMIT {top_k}
@ -986,19 +980,17 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
# No need for dataset_id filter since each dataset has its own table
# Use simple quote escaping for LIKE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
where_clause = " AND ".join(filter_clauses)
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
LIMIT {top_k}

View File

@ -57,18 +57,18 @@ class ElasticSearchJaVector(ElasticSearchVector):
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
Field.CONTENT_KEY: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
Field.VECTOR: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type

View File

@ -4,7 +4,7 @@ import math
from typing import Any, cast
from urllib.parse import urlparse
import requests
from elasticsearch import ConnectionError as ElasticsearchConnectionError
from elasticsearch import Elasticsearch
from flask import current_app
from packaging.version import parse as parse_version
@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
if not client.ping():
raise ConnectionError("Failed to connect to Elasticsearch")
except requests.ConnectionError as e:
except ElasticsearchConnectionError as e:
raise ConnectionError(f"Vector database connection error: {str(e)}")
except Exception as e:
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
@ -163,9 +163,9 @@ class ElasticSearchVector(BaseVector):
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] or None,
Field.METADATA_KEY.value: documents[i].metadata or {},
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i] or None,
Field.METADATA_KEY: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
@ -193,7 +193,7 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
num_candidates = math.ceil(top_k * 1.5)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
knn = {"field": Field.VECTOR, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
@ -205,9 +205,9 @@ class ElasticSearchVector(BaseVector):
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
@ -224,13 +224,13 @@ class ElasticSearchVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}}
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query_str = {
"bool": {
"must": {"match": {Field.CONTENT_KEY.value: query}},
"must": {"match": {Field.CONTENT_KEY: query}},
"filter": {"terms": {"metadata.document_id": document_ids_filter}},
}
}
@ -240,9 +240,9 @@ class ElasticSearchVector(BaseVector):
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
)
)
@ -270,14 +270,14 @@ class ElasticSearchVector(BaseVector):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: { # Make sure the dimension is correct here
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type

View File

@ -67,9 +67,9 @@ class HuaweiCloudVector(BaseVector):
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] or None,
Field.METADATA_KEY.value: documents[i].metadata or {},
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i] or None,
Field.METADATA_KEY: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
@ -101,7 +101,7 @@ class HuaweiCloudVector(BaseVector):
"size": top_k,
"query": {
"vector": {
Field.VECTOR.value: {
Field.VECTOR: {
"vector": query_vector,
"topk": top_k,
}
@ -116,9 +116,9 @@ class HuaweiCloudVector(BaseVector):
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
@ -135,15 +135,15 @@ class HuaweiCloudVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY.value: query}}
query_str = {"match": {Field.CONTENT_KEY: query}}
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
)
)
@ -171,8 +171,8 @@ class HuaweiCloudVector(BaseVector):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: { # Make sure the dimension is correct here
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: { # Make sure the dimension is correct here
"type": "vector",
"dimension": dim,
"indexing": True,
@ -181,7 +181,7 @@ class HuaweiCloudVector(BaseVector):
"neighbors": 32,
"efc": 128,
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type

View File

@ -125,9 +125,9 @@ class LindormVectorStore(BaseVector):
}
}
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
@ -149,7 +149,7 @@ class LindormVectorStore(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = {
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
@ -252,14 +252,14 @@ class LindormVectorStore(BaseVector):
search_query: dict[str, Any] = {
"size": top_k,
"_source": True,
"query": {"knn": {Field.VECTOR.value: {"vector": query_vector, "k": top_k}}},
"query": {"knn": {Field.VECTOR: {"vector": query_vector, "k": top_k}}},
}
final_ext: dict[str, Any] = {"lvector": {}}
if filters is not None and len(filters) > 0:
# when using filter, transform filter from List[Dict] to Dict as valid format
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][Field.VECTOR.value]["filter"] = filter_dict # filter should be Dict
search_query["query"]["knn"][Field.VECTOR]["filter"] = filter_dict # filter should be Dict
final_ext["lvector"]["filter_type"] = "pre_filter"
if final_ext != {"lvector": {}}:
@ -279,9 +279,9 @@ class LindormVectorStore(BaseVector):
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
@ -318,9 +318,9 @@ class LindormVectorStore(BaseVector):
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value)
vector = hit["_source"].get(Field.VECTOR.value)
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
metadata = hit["_source"].get(Field.METADATA_KEY)
vector = hit["_source"].get(Field.VECTOR)
page_content = hit["_source"].get(Field.CONTENT_KEY)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
@ -342,8 +342,8 @@ class LindormVectorStore(BaseVector):
"settings": {"index": {"knn": True, "knn_routing": self._using_ugc}},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: {
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: {
"type": "knn_vector",
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
"method": {

View File

@ -85,7 +85,7 @@ class MilvusVector(BaseVector):
collection_info = self._client.describe_collection(self._collection_name)
fields = [field["name"] for field in collection_info["fields"]]
# Since primary field is auto-id, no need to track it
self._fields = [f for f in fields if f != Field.PRIMARY_KEY.value]
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
def _check_hybrid_search_support(self) -> bool:
"""
@ -130,9 +130,9 @@ class MilvusVector(BaseVector):
insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
}
insert_dict_list.append(insert_dict)
# Total insert count
@ -243,15 +243,15 @@ class MilvusVector(BaseVector):
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
anns_field=Field.VECTOR,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
filter=filter,
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
@ -264,7 +264,7 @@ class MilvusVector(BaseVector):
"Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)."
)
return []
if not self.field_exists(Field.SPARSE_VECTOR.value):
if not self.field_exists(Field.SPARSE_VECTOR):
logger.warning(
"Full-text search unavailable: collection missing 'sparse_vector' field; "
"recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index."
@ -279,15 +279,15 @@ class MilvusVector(BaseVector):
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
anns_field=Field.SPARSE_VECTOR,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
filter=filter,
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
@ -311,7 +311,7 @@ class MilvusVector(BaseVector):
dim = len(embeddings[0])
fields = []
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
fields.append(FieldSchema(Field.METADATA_KEY, DataType.JSON, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
@ -326,15 +326,15 @@ class MilvusVector(BaseVector):
):
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs))
fields.append(FieldSchema(Field.CONTENT_KEY, DataType.VARCHAR, **content_field_kwargs))
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
fields.append(FieldSchema(Field.PRIMARY_KEY, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
fields.append(FieldSchema(Field.VECTOR, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
fields.append(FieldSchema(Field.SPARSE_VECTOR, DataType.SPARSE_FLOAT_VECTOR))
schema = CollectionSchema(fields)
@ -342,8 +342,8 @@ class MilvusVector(BaseVector):
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
input_field_names=[Field.CONTENT_KEY],
output_field_names=[Field.SPARSE_VECTOR],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
@ -352,12 +352,12 @@ class MilvusVector(BaseVector):
# Create Index params for the collection
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
field_name=Field.SPARSE_VECTOR, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Literal
from typing import Any
from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
@ -8,6 +8,7 @@ from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from configs import dify_config
from configs.middleware.vdb.opensearch_config import AuthMethod
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@ -25,7 +26,7 @@ class OpenSearchConfig(BaseModel):
port: int
secure: bool = False # use_ssl
verify_certs: bool = True
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
auth_method: AuthMethod = AuthMethod.BASIC
user: str | None = None
password: str | None = None
aws_region: str | None = None
@ -98,9 +99,9 @@ class OpenSearchVector(BaseVector):
"_op_type": "index",
"_index": self._collection_name.lower(),
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY: documents[i].metadata,
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
@ -116,7 +117,7 @@ class OpenSearchVector(BaseVector):
)
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}}
response = self._client.search(index=self._collection_name.lower(), body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
@ -180,17 +181,17 @@ class OpenSearchVector(BaseVector):
query = {
"size": kwargs.get("top_k", 4),
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
"query": {"knn": {Field.VECTOR: {Field.VECTOR: query_vector, "k": kwargs.get("top_k", 4)}}},
}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query["query"] = {
"script_score": {
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}},
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID: document_ids_filter}}]}},
"script": {
"source": "knn_score",
"lang": "knn",
"params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"},
"params": {"field": Field.VECTOR, "query_value": query_vector, "space_type": "l2"},
},
}
}
@ -203,7 +204,7 @@ class OpenSearchVector(BaseVector):
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
metadata = hit["_source"].get(Field.METADATA_KEY, {})
# Make sure metadata is a dictionary
if metadata is None:
@ -212,7 +213,7 @@ class OpenSearchVector(BaseVector):
metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] >= score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY), metadata=metadata)
docs.append(doc)
return docs
@ -227,9 +228,9 @@ class OpenSearchVector(BaseVector):
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value)
vector = hit["_source"].get(Field.VECTOR.value)
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
metadata = hit["_source"].get(Field.METADATA_KEY)
vector = hit["_source"].get(Field.VECTOR)
page_content = hit["_source"].get(Field.CONTENT_KEY)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
@ -250,8 +251,8 @@ class OpenSearchVector(BaseVector):
"settings": {"index": {"knn": True}},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: {
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: {
"type": "knn_vector",
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
"method": {
@ -261,7 +262,7 @@ class OpenSearchVector(BaseVector):
"parameters": {"ef_construction": 64, "m": 8},
},
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
@ -293,7 +294,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
aws_region=dify_config.OPENSEARCH_AWS_REGION,

View File

@ -147,15 +147,13 @@ class QdrantVector(BaseVector):
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
# create document_id payload index
self._client.create_payload_index(
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
@ -165,9 +163,7 @@ class QdrantVector(BaseVector):
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -220,10 +216,10 @@ class QdrantVector(BaseVector):
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
Field.CONTENT_KEY,
Field.METADATA_KEY,
group_id or "", # Ensure group_id is never None
Field.GROUP_KEY.value,
Field.GROUP_KEY,
),
)
]
@ -381,12 +377,12 @@ class QdrantVector(BaseVector):
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
metadata = result.payload.get(Field.METADATA_KEY) or {}
# duplicate check score threshold
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
page_content=result.payload.get(Field.CONTENT_KEY, ""),
metadata=metadata,
)
docs.append(doc)
@ -433,7 +429,7 @@ class QdrantVector(BaseVector):
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(document)
return documents

View File

@ -55,7 +55,7 @@ class TableStoreVector(BaseVector):
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
self._table_name = f"{collection_name}"
self._index_name = f"{collection_name}_idx"
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
self._tags_field = f"{Field.METADATA_KEY}_tags"
def create_collection(self, embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
@ -64,7 +64,7 @@ class TableStoreVector(BaseVector):
def get_by_ids(self, ids: list[str]) -> list[Document]:
docs = []
request = BatchGetRowRequest()
columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value]
columns_to_get = [Field.METADATA_KEY, Field.CONTENT_KEY]
rows_to_get = [[("id", _id)] for _id in ids]
request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
@ -73,11 +73,7 @@ class TableStoreVector(BaseVector):
for item in table_result:
if item.is_ok and item.row:
kv = {k: v for k, v, _ in item.row.attribute_columns}
docs.append(
Document(
page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])
)
)
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
return docs
def get_type(self) -> str:
@ -95,9 +91,9 @@ class TableStoreVector(BaseVector):
self._write_row(
primary_key=uuids[i],
attributes={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
},
)
return uuids
@ -180,7 +176,7 @@ class TableStoreVector(BaseVector):
field_schemas = [
tablestore.FieldSchema(
Field.CONTENT_KEY.value,
Field.CONTENT_KEY,
tablestore.FieldType.TEXT,
analyzer=tablestore.AnalyzerType.MAXWORD,
index=True,
@ -188,7 +184,7 @@ class TableStoreVector(BaseVector):
store=False,
),
tablestore.FieldSchema(
Field.VECTOR.value,
Field.VECTOR,
tablestore.FieldType.VECTOR,
vector_options=tablestore.VectorOptions(
data_type=tablestore.VectorDataType.VD_FLOAT_32,
@ -197,7 +193,7 @@ class TableStoreVector(BaseVector):
),
),
tablestore.FieldSchema(
Field.METADATA_KEY.value,
Field.METADATA_KEY,
tablestore.FieldType.KEYWORD,
index=True,
store=False,
@ -233,15 +229,15 @@ class TableStoreVector(BaseVector):
pk = [("id", primary_key)]
tags = []
for key, value in attributes[Field.METADATA_KEY.value].items():
for key, value in attributes[Field.METADATA_KEY].items():
tags.append(str(key) + "=" + str(value))
attribute_columns = [
(Field.CONTENT_KEY.value, attributes[Field.CONTENT_KEY.value]),
(Field.VECTOR.value, json.dumps(attributes[Field.VECTOR.value])),
(Field.CONTENT_KEY, attributes[Field.CONTENT_KEY]),
(Field.VECTOR, json.dumps(attributes[Field.VECTOR])),
(
Field.METADATA_KEY.value,
json.dumps(attributes[Field.METADATA_KEY.value]),
Field.METADATA_KEY,
json.dumps(attributes[Field.METADATA_KEY]),
),
(self._tags_field, json.dumps(tags)),
]
@ -270,7 +266,7 @@ class TableStoreVector(BaseVector):
index_name=self._index_name,
search_query=query,
columns_to_get=tablestore.ColumnsToGet(
column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED
column_names=[Field.PRIMARY_KEY], return_type=tablestore.ColumnReturnType.SPECIFIED
),
)
@ -288,7 +284,7 @@ class TableStoreVector(BaseVector):
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
knn_vector_query = tablestore.KnnVectorQuery(
field_name=Field.VECTOR.value,
field_name=Field.VECTOR,
top_k=top_k,
float32_query_vector=query_vector,
)
@ -311,8 +307,8 @@ class TableStoreVector(BaseVector):
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
vector_str = ots_column_map.get(Field.VECTOR)
metadata_str = ots_column_map.get(Field.METADATA_KEY)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
@ -321,7 +317,7 @@ class TableStoreVector(BaseVector):
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
vector=vector,
metadata=metadata,
)
@ -343,7 +339,7 @@ class TableStoreVector(BaseVector):
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY))
if document_ids_filter:
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
@ -374,10 +370,10 @@ class TableStoreVector(BaseVector):
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY)
metadata = json.loads(metadata_str) if metadata_str else {}
vector_str = ots_column_map.get(Field.VECTOR.value)
vector_str = ots_column_map.get(Field.VECTOR)
vector = json.loads(vector_str) if vector_str else None
if score:
@ -385,7 +381,7 @@ class TableStoreVector(BaseVector):
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
vector=vector,
metadata=metadata,
)

View File

@ -5,9 +5,10 @@ from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Union
import httpx
import qdrant_client
import requests
from flask import current_app
from httpx import DigestAuth
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
@ -19,7 +20,6 @@ from qdrant_client.http.models import (
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
from sqlalchemy import select
from configs import dify_config
@ -141,15 +141,13 @@ class TidbOnQdrantVector(BaseVector):
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
# create document_id payload index
self._client.create_payload_index(
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
@ -159,9 +157,7 @@ class TidbOnQdrantVector(BaseVector):
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -211,10 +207,10 @@ class TidbOnQdrantVector(BaseVector):
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
Field.CONTENT_KEY,
Field.METADATA_KEY,
group_id or "",
Field.GROUP_KEY.value,
Field.GROUP_KEY,
),
)
]
@ -349,13 +345,13 @@ class TidbOnQdrantVector(BaseVector):
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
metadata = result.payload.get(Field.METADATA_KEY) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
page_content=result.payload.get(Field.CONTENT_KEY, ""),
metadata=metadata,
)
docs.append(doc)
@ -392,7 +388,7 @@ class TidbOnQdrantVector(BaseVector):
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(document)
return documents
@ -504,10 +500,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
}
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
response = requests.post(
response = httpx.post(
f"{tidb_config.api_url}/clusters",
json=cluster_data,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:
@ -527,10 +523,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
body = {"password": new_password}
response = requests.put(
response = httpx.put(
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
json=body,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:

View File

@ -2,8 +2,8 @@ import time
import uuid
from collections.abc import Sequence
import requests
from requests.auth import HTTPDigestAuth
import httpx
from httpx import DigestAuth
from configs import dify_config
from extensions.ext_database import db
@ -49,7 +49,7 @@ class TidbService:
"rootPassword": password,
}
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key))
response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
@ -83,7 +83,7 @@ class TidbService:
:return: The response from the API.
"""
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
@ -102,7 +102,7 @@ class TidbService:
:return: The response from the API.
"""
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
@ -127,10 +127,10 @@ class TidbService:
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
response = requests.patch(
response = httpx.patch(
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
json=body,
auth=HTTPDigestAuth(public_key, private_key),
auth=DigestAuth(public_key, private_key),
)
if response.status_code == 200:
@ -161,9 +161,7 @@ class TidbService:
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "BASIC"}
response = requests.get(
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
)
response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
@ -224,8 +222,8 @@ class TidbService:
clusters.append(cluster_data)
request_body = {"requests": clusters}
response = requests.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key)
response = httpx.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key)
)
if response.status_code == 200:

View File

@ -55,13 +55,13 @@ class TiDBVector(BaseVector):
return Table(
self._collection_name,
self._orm_base.metadata,
Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
Column(Field.PRIMARY_KEY, String(36), primary_key=True, nullable=False),
Column(
Field.VECTOR.value,
Field.VECTOR,
VectorType(dim),
nullable=False,
),
Column(Field.TEXT_KEY.value, TEXT, nullable=False),
Column(Field.TEXT_KEY, TEXT, nullable=False),
Column("meta", JSON, nullable=False),
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
Column(

View File

@ -71,6 +71,12 @@ class Vector:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.ALIBABACLOUD_MYSQL:
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVectorFactory,
)
return AlibabaCloudMySQLVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory

View File

@ -2,6 +2,7 @@ from enum import StrEnum
class VectorType(StrEnum):
ALIBABACLOUD_MYSQL = "alibabacloud_mysql"
ANALYTICDB = "analyticdb"
CHROMA = "chroma"
MILVUS = "milvus"

View File

@ -76,11 +76,11 @@ class VikingDBVector(BaseVector):
if not self._has_collection():
fields = [
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension),
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=dimension),
]
self._client.create_collection(
@ -100,7 +100,7 @@ class VikingDBVector(BaseVector):
collection_name=self._collection_name,
index_name=self._index_name,
vector_index=vector_index,
partition_by=vdb_Field.GROUP_KEY.value,
partition_by=vdb_Field.GROUP_KEY,
description="Index For Dify",
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -126,11 +126,11 @@ class VikingDBVector(BaseVector):
# FIXME: fix the type of metadata later
doc = Data(
{
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY.value: page_content,
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
vdb_Field.GROUP_KEY.value: self._group_id,
vdb_Field.PRIMARY_KEY: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY: page_content,
vdb_Field.METADATA_KEY: json.dumps(metadata),
vdb_Field.GROUP_KEY: self._group_id,
}
)
docs.append(doc)
@ -151,7 +151,7 @@ class VikingDBVector(BaseVector):
# Note: Metadata field value is an dict, but vikingdb field
# not support json type
results = self._client.get_index(self._collection_name, self._index_name).search(
filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]},
filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]},
# max value is 5000
limit=5000,
)
@ -161,7 +161,7 @@ class VikingDBVector(BaseVector):
ids = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
if metadata.get(key) == value:
@ -189,12 +189,12 @@ class VikingDBVector(BaseVector):
docs = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata)
docs.append(doc)
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs

View File

@ -2,7 +2,6 @@ import datetime
import json
from typing import Any
import requests
import weaviate # type: ignore
from pydantic import BaseModel, model_validator
@ -45,8 +44,8 @@ class WeaviateVector(BaseVector):
client = weaviate.Client(
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
)
except requests.ConnectionError:
raise ConnectionError("Vector database connection error")
except Exception as exc:
raise ConnectionError("Vector database connection error") from exc
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
@ -105,7 +104,7 @@ class WeaviateVector(BaseVector):
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY.value: text}
data_properties = {Field.TEXT_KEY: text}
if metadatas is not None:
# metadata maybe None
for key, val in (metadatas[i] or {}).items():
@ -183,7 +182,7 @@ class WeaviateVector(BaseVector):
"""Look up similar documents by embedding vector in Weaviate."""
collection_name = self._collection_name
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
properties.append(Field.TEXT_KEY)
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector}
@ -205,7 +204,7 @@ class WeaviateVector(BaseVector):
docs_and_scores = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
text = res.pop(Field.TEXT_KEY)
score = 1 - res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
@ -233,7 +232,7 @@ class WeaviateVector(BaseVector):
collection_name = self._collection_name
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
properties.append(Field.TEXT_KEY)
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
@ -251,7 +250,7 @@ class WeaviateVector(BaseVector):
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
text = res.pop(Field.TEXT_KEY)
additional = res.pop("_additional")
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
return docs