mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 08:28:03 +08:00
Merge branch main into feat/rag-2
This commit is contained in:
@ -94,11 +94,11 @@ class Jieba(BaseKeyword):
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment_query = db.session.query(DocumentSegment).filter(
|
||||
segment_query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
segment = segment_query.first()
|
||||
|
||||
if segment:
|
||||
@ -215,7 +215,7 @@ class Jieba(BaseKeyword):
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
|
||||
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
|
||||
.first()
|
||||
)
|
||||
if document_segment:
|
||||
|
||||
@ -127,7 +127,7 @@ class RetrievalService:
|
||||
external_retrieval_model: Optional[dict] = None,
|
||||
metadata_filtering_conditions: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
@ -145,7 +145,7 @@ class RetrievalService:
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
|
||||
with Session(db.engine) as session:
|
||||
return session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
return session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
@ -294,7 +294,7 @@ class RetrievalService:
|
||||
dataset_documents = {
|
||||
doc.id: doc
|
||||
for doc in db.session.query(DatasetDocument)
|
||||
.filter(DatasetDocument.id.in_(document_ids))
|
||||
.where(DatasetDocument.id.in_(document_ids))
|
||||
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
|
||||
.all()
|
||||
}
|
||||
@ -318,7 +318,7 @@ class RetrievalService:
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
)
|
||||
|
||||
if not child_chunk:
|
||||
@ -326,7 +326,7 @@ class RetrievalService:
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
@ -381,7 +381,7 @@ class RetrievalService:
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
|
||||
@ -6,7 +6,7 @@ from uuid import UUID, uuid4
|
||||
from numpy import ndarray
|
||||
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Float, String, create_engine, insert, select, text
|
||||
from sqlalchemy import Float, create_engine, insert, select, text
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
@ -67,7 +67,7 @@ class PGVectoRS(BaseVector):
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
)
|
||||
text: Mapped[str] = mapped_column(String)
|
||||
text: Mapped[str]
|
||||
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
|
||||
vector: Mapped[ndarray] = mapped_column(VECTOR(dim))
|
||||
|
||||
|
||||
@ -443,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_collection_binding:
|
||||
|
||||
@ -118,10 +118,21 @@ class TableStoreVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
return self._search_by_vector(query_vector, top_k)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filtered_list = None
|
||||
if document_ids_filter:
|
||||
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._search_by_full_text(query)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filtered_list = None
|
||||
if document_ids_filter:
|
||||
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||
|
||||
return self._search_by_full_text(query, filtered_list, top_k)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._delete_table_if_exist()
|
||||
@ -230,32 +241,51 @@ class TableStoreVector(BaseVector):
|
||||
primary_key = [("id", id)]
|
||||
row = tablestore.Row(primary_key)
|
||||
self._tablestore_client.delete_row(self._table_name, row, None)
|
||||
logging.info("Tablestore delete row successfully. id:%s", id)
|
||||
|
||||
def _search_by_metadata(self, key: str, value: str) -> list[str]:
|
||||
query = tablestore.SearchQuery(
|
||||
tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)),
|
||||
limit=100,
|
||||
limit=1000,
|
||||
get_total_count=False,
|
||||
)
|
||||
rows: list[str] = []
|
||||
next_token = None
|
||||
while True:
|
||||
if next_token is not None:
|
||||
query.next_token = next_token
|
||||
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
index_name=self._index_name,
|
||||
search_query=query,
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
index_name=self._index_name,
|
||||
search_query=query,
|
||||
columns_to_get=tablestore.ColumnsToGet(
|
||||
column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED
|
||||
),
|
||||
)
|
||||
|
||||
return [row[0][0][1] for row in search_response.rows]
|
||||
if search_response is not None:
|
||||
rows.extend([row[0][0][1] for row in search_response.rows])
|
||||
|
||||
def _search_by_vector(self, query_vector: list[float], top_k: int) -> list[Document]:
|
||||
ots_query = tablestore.KnnVectorQuery(
|
||||
if search_response is None or search_response.next_token == b"":
|
||||
break
|
||||
else:
|
||||
next_token = search_response.next_token
|
||||
|
||||
return rows
|
||||
|
||||
def _search_by_vector(
|
||||
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
knn_vector_query = tablestore.KnnVectorQuery(
|
||||
field_name=Field.VECTOR.value,
|
||||
top_k=top_k,
|
||||
float32_query_vector=query_vector,
|
||||
)
|
||||
if document_ids_filter:
|
||||
knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter)
|
||||
|
||||
sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)])
|
||||
search_query = tablestore.SearchQuery(ots_query, limit=top_k, get_total_count=False, sort=sort)
|
||||
search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort)
|
||||
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
@ -263,30 +293,42 @@ class TableStoreVector(BaseVector):
|
||||
search_query=search_query,
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
logging.info(
|
||||
"Tablestore search successfully. request_id:%s",
|
||||
search_response.request_id,
|
||||
)
|
||||
return self._to_query_result(search_response)
|
||||
|
||||
def _to_query_result(self, search_response: tablestore.SearchResponse) -> list[Document]:
|
||||
documents = []
|
||||
for row in search_response.rows:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=row[1][2][1],
|
||||
vector=json.loads(row[1][3][1]),
|
||||
metadata=json.loads(row[1][0][1]),
|
||||
)
|
||||
)
|
||||
for search_hit in search_response.search_hits:
|
||||
if search_hit.score > score_threshold:
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
|
||||
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
metadata["score"] = search_hit.score
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def _search_by_full_text(self, query: str) -> list[Document]:
|
||||
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
|
||||
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
|
||||
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
|
||||
|
||||
if document_ids_filter:
|
||||
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
|
||||
|
||||
search_query = tablestore.SearchQuery(
|
||||
query=tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value),
|
||||
query=bool_query,
|
||||
sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]),
|
||||
limit=100,
|
||||
limit=top_k,
|
||||
)
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
@ -295,7 +337,25 @@ class TableStoreVector(BaseVector):
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
|
||||
return self._to_query_result(search_response)
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return documents
|
||||
|
||||
|
||||
class TableStoreVectorFactory(AbstractVectorFactory):
|
||||
|
||||
@ -284,7 +284,8 @@ class TencentVector(BaseVector):
|
||||
# Compatible with version 1.1.3 and below.
|
||||
meta = json.loads(meta)
|
||||
score = 1 - result.get("score", 0.0)
|
||||
score = result.get("score", 0.0)
|
||||
else:
|
||||
score = result.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
||||
|
||||
@ -418,13 +418,13 @@ class TidbOnQdrantVector(BaseVector):
|
||||
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
if not tidb_auth_binding:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
else:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
@ -47,7 +47,7 @@ class Vector:
|
||||
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
|
||||
whitelist = (
|
||||
db.session.query(Whitelist)
|
||||
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
|
||||
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
|
||||
.one_or_none()
|
||||
)
|
||||
if whitelist:
|
||||
|
||||
Reference in New Issue
Block a user