mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 08:28:03 +08:00
Merge branch 'main' into fix/chore-fix
This commit is contained in:
@ -1,8 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||
from core.rag.datasource.keyword.keyword_type import KeyWordType
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
@ -13,16 +13,19 @@ class Keyword:
|
||||
self._keyword_processor = self._init_keyword()
|
||||
|
||||
def _init_keyword(self) -> BaseKeyword:
|
||||
config = dify_config
|
||||
keyword_type = config.KEYWORD_STORE
|
||||
keyword_type = dify_config.KEYWORD_STORE
|
||||
keyword_factory = self.get_keyword_factory(keyword_type)
|
||||
return keyword_factory(self._dataset)
|
||||
|
||||
if not keyword_type:
|
||||
raise ValueError("Keyword store must be specified.")
|
||||
@staticmethod
|
||||
def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]:
|
||||
match keyword_type:
|
||||
case KeyWordType.JIEBA:
|
||||
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
||||
|
||||
if keyword_type == "jieba":
|
||||
return Jieba(dataset=self._dataset)
|
||||
else:
|
||||
raise ValueError(f"Keyword store {keyword_type} is not supported.")
|
||||
return Jieba
|
||||
case _:
|
||||
raise ValueError(f"Keyword store {keyword_type} is not supported.")
|
||||
|
||||
def create(self, texts: list[Document], **kwargs):
|
||||
self._keyword_processor.create(texts, **kwargs)
|
||||
|
||||
5
api/core/rag/datasource/keyword/keyword_type.py
Normal file
5
api/core/rag/datasource/keyword/keyword_type.py
Normal file
@ -0,0 +1,5 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class KeyWordType(str, Enum):
|
||||
JIEBA = "jieba"
|
||||
@ -112,7 +112,7 @@ class ElasticSearchVector(BaseVector):
|
||||
self._client.indices.delete(index=self._collection_name)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
num_candidates = math.ceil(top_k * 1.5)
|
||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||
|
||||
|
||||
@ -121,7 +121,7 @@ class MyScaleVector(BaseVector):
|
||||
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
|
||||
@ -168,14 +168,6 @@ class OracleVector(BaseVector):
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
return docs
|
||||
|
||||
# def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
# with self._get_cursor() as cur:
|
||||
# cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" )
|
||||
# idss = []
|
||||
# for record in cur:
|
||||
# idss.append(record[0])
|
||||
# return idss
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
@ -192,7 +184,7 @@ class OracleVector(BaseVector):
|
||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||
|
||||
@ -186,7 +186,7 @@ class PGVectoRS(BaseVector):
|
||||
query_vector,
|
||||
).label("distance"),
|
||||
)
|
||||
.limit(kwargs.get("top_k", 2))
|
||||
.limit(kwargs.get("top_k", 4))
|
||||
.order_by("distance")
|
||||
)
|
||||
res = session.execute(stmt)
|
||||
@ -205,18 +205,6 @@ class PGVectoRS(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# with Session(self._client) as session:
|
||||
# select_statement = sql_text(
|
||||
# f"SELECT text, meta FROM {self._collection_name} WHERE to_tsvector(text) @@ '{query}'::tsquery"
|
||||
# )
|
||||
# results = session.execute(select_statement).fetchall()
|
||||
# if results:
|
||||
# docs = []
|
||||
# for result in results:
|
||||
# doc = Document(page_content=result[0],
|
||||
# metadata=result[1])
|
||||
# docs.append(doc)
|
||||
# return docs
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@ -143,7 +143,7 @@ class PGVector(BaseVector):
|
||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
|
||||
@ -224,7 +224,7 @@ class RelytVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
results = self.similarity_search_with_score_by_vector(
|
||||
k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter")
|
||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
|
||||
)
|
||||
|
||||
# Organize results.
|
||||
|
||||
@ -184,7 +184,7 @@ class TiDBVector(BaseVector):
|
||||
self._delete_by_ids(ids)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
filter = kwargs.get("filter")
|
||||
distance = 1 - score_threshold
|
||||
|
||||
@ -173,7 +173,7 @@ class VikingDBVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
results = self._client.get_index(self._collection_name, self._index_name).search_by_vector(
|
||||
query_vector, limit=kwargs.get("top_k", 50)
|
||||
query_vector, limit=kwargs.get("top_k", 4)
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(results, score_threshold)
|
||||
|
||||
@ -235,7 +235,7 @@ class WeaviateVector(BaseVector):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
query_obj = query_obj.with_additional(["vector"])
|
||||
properties = ["text"]
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do()
|
||||
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
|
||||
@ -215,7 +215,7 @@ class DatasetRetrieval:
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": invoke_from.to_source(),
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
"score": document_score_list.get(segment.index_node_id, 0.0),
|
||||
}
|
||||
|
||||
if invoke_from.to_source() == "dev":
|
||||
@ -229,12 +229,12 @@ class DatasetRetrieval:
|
||||
source["content"] = segment.content
|
||||
retrieval_resource_list.append(source)
|
||||
if hit_callback and retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score"), reverse=True)
|
||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["position"] = position
|
||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||
if document_context_list:
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score, reverse=True)
|
||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||
return str("\n".join([document_context.content for document_context in document_context_list]))
|
||||
return ""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user