mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
merge main
This commit is contained in:
@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI:
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
filter=where_clause,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI:
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
filter=where_clause,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
|
||||
@ -147,10 +147,17 @@ class ElasticSearchVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
if document_ids_filter:
|
||||
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore
|
||||
query_str = {
|
||||
"bool": {
|
||||
"must": {"match": {Field.CONTENT_KEY.value: query}},
|
||||
"filter": {"terms": {"metadata.document_id": document_ids_filter}},
|
||||
}
|
||||
}
|
||||
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
|
||||
@ -102,6 +102,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
splits = text.split()
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
|
||||
else:
|
||||
splits = list(text)
|
||||
splits = [s for s in splits if (s not in {"", "\n"})]
|
||||
|
||||
Reference in New Issue
Block a user