Merge branch 'main' into feat/rag-2

This commit is contained in:
twwu
2025-08-27 11:16:27 +08:00
438 changed files with 17986 additions and 7846 deletions

View File

@ -259,8 +259,16 @@ class MilvusVector(BaseVector):
"""
Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
if not self._hybrid_search_enabled:
logger.warning(
"Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)."
)
return []
if not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning(
"Full-text search unavailable: collection missing 'sparse_vector' field; "
"recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index."
)
return []
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""

View File

@ -15,6 +15,8 @@ from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class MyScaleConfig(BaseModel):
host: str
@ -53,7 +55,7 @@ class MyScaleVector(BaseVector):
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
def _create_collection(self, dimension: int):
logging.info("create MyScale collection %s with dimension %s", self._collection_name, dimension)
logger.info("create MyScale collection %s with dimension %s", self._collection_name, dimension)
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
sql = f"""
@ -151,7 +153,7 @@ class MyScaleVector(BaseVector):
for r in self._client.query(sql).named_results()
]
except Exception as e:
logging.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401
logger.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401
return []
def delete(self) -> None:

View File

@ -188,14 +188,17 @@ class OracleVector(BaseVector):
def text_exists(self, id: str) -> bool:
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,))
return cur.fetchone() is not None
conn.close()
def get_by_ids(self, ids: list[str]) -> list[Document]:
if not ids:
return []
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
@ -208,14 +211,15 @@ class OracleVector(BaseVector):
return
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
conn.commit()
conn.close()
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,))
conn.commit()
conn.close()
@ -227,12 +231,20 @@ class OracleVector(BaseVector):
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 4 # Use default if invalid
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params = [numpy.array(query_vector)]
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter)))
where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
params.extend(document_ids_filter)
with self._get_connection() as conn:
conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler
@ -241,7 +253,7 @@ class OracleVector(BaseVector):
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
AS distance FROM {self.table_name}
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
[numpy.array(query_vector)],
params,
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
@ -259,7 +271,10 @@ class OracleVector(BaseVector):
import nltk # type: ignore
from nltk.corpus import stopwords # type: ignore
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 5 # Use default if invalid
# just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0:
@ -297,14 +312,21 @@ class OracleVector(BaseVector):
with conn.cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params: dict[str, Any] = {"kk": " ACCUM ".join(entities)}
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
placeholders = []
for i, doc_id in enumerate(document_ids_filter):
param_name = f"doc_id_{i}"
placeholders.append(f":{param_name}")
params[param_name] = doc_id
where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) "
cur.execute(
f"""select meta, text, embedding FROM {self.table_name}
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
order by score(1) desc fetch first {top_k} rows only""",
kk=" ACCUM ".join(entities),
params,
)
docs = []
for record in cur:

View File

@ -19,6 +19,8 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class PGVectorConfig(BaseModel):
host: str
@ -155,7 +157,7 @@ class PGVector(BaseVector):
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
except psycopg2.errors.UndefinedTable:
# table not exists
logging.warning("Table %s not found, skipping delete operation.", self.table_name)
logger.warning("Table %s not found, skipping delete operation.", self.table_name)
return
except Exception as e:
raise e

View File

@ -17,6 +17,8 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models import Dataset
logger = logging.getLogger(__name__)
class TableStoreConfig(BaseModel):
access_key_id: Optional[str] = None
@ -145,7 +147,7 @@ class TableStoreVector(BaseVector):
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logging.info("Collection %s already exists.", self._collection_name)
logger.info("Collection %s already exists.", self._collection_name)
return
self._create_table_if_not_exist()
@ -155,7 +157,7 @@ class TableStoreVector(BaseVector):
def _create_table_if_not_exist(self) -> None:
table_list = self._tablestore_client.list_table()
if self._table_name in table_list:
logging.info("Tablestore system table[%s] already exists", self._table_name)
logger.info("Tablestore system table[%s] already exists", self._table_name)
return None
schema_of_primary_key = [("id", "STRING")]
@ -163,12 +165,12 @@ class TableStoreVector(BaseVector):
table_options = tablestore.TableOptions()
reserved_throughput = tablestore.ReservedThroughput(tablestore.CapacityUnit(0, 0))
self._tablestore_client.create_table(table_meta, table_options, reserved_throughput)
logging.info("Tablestore create table[%s] successfully.", self._table_name)
logger.info("Tablestore create table[%s] successfully.", self._table_name)
def _create_search_index_if_not_exist(self, dimension: int) -> None:
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
if self._index_name in [t[1] for t in search_index_list]:
logging.info("Tablestore system index[%s] already exists", self._index_name)
logger.info("Tablestore system index[%s] already exists", self._index_name)
return None
field_schemas = [
@ -206,20 +208,20 @@ class TableStoreVector(BaseVector):
index_meta = tablestore.SearchIndexMeta(field_schemas)
self._tablestore_client.create_search_index(self._table_name, self._index_name, index_meta)
logging.info("Tablestore create system index[%s] successfully.", self._index_name)
logger.info("Tablestore create system index[%s] successfully.", self._index_name)
def _delete_table_if_exist(self):
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
for resp_tuple in search_index_list:
self._tablestore_client.delete_search_index(resp_tuple[0], resp_tuple[1])
logging.info("Tablestore delete index[%s] successfully.", self._index_name)
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
self._tablestore_client.delete_table(self._table_name)
logging.info("Tablestore delete system table[%s] successfully.", self._index_name)
logger.info("Tablestore delete system table[%s] successfully.", self._index_name)
def _delete_search_index(self) -> None:
self._tablestore_client.delete_search_index(self._table_name, self._index_name)
logging.info("Tablestore delete index[%s] successfully.", self._index_name)
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
def _write_row(self, primary_key: str, attributes: dict[str, Any]) -> None:
pk = [("id", primary_key)]

View File

@ -83,14 +83,14 @@ class TiDBVector(BaseVector):
self._dimension = 1536
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
logger.info("create collection and add texts, collection_name: " + self._collection_name)
logger.info("create collection and add texts, collection_name: %s", self._collection_name)
self._create_collection(len(embeddings[0]))
self.add_texts(texts, embeddings)
self._dimension = len(embeddings[0])
pass
def _create_collection(self, dimension: int):
logger.info("_create_collection, collection_name " + self._collection_name)
logger.info("_create_collection, collection_name %s", self._collection_name)
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"