mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-19 03:35:11 +08:00
Refa: async retrieval process. (#12629)
### Type of change - [x] Refactoring - [x] Performance Improvement
This commit is contained in:
@ -174,7 +174,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
if kbs:
|
if kbs:
|
||||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||||
kbinfos = settings.retriever.retrieval(
|
kbinfos = await settings.retriever.retrieval(
|
||||||
query,
|
query,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
|
|||||||
@ -61,7 +61,7 @@ async def list_chunk():
|
|||||||
}
|
}
|
||||||
if "available_int" in req:
|
if "available_int" in req:
|
||||||
query["available_int"] = int(req["available_int"])
|
query["available_int"] = int(req["available_int"])
|
||||||
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
sres = await settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
@ -371,20 +371,20 @@ async def retrieval_test():
|
|||||||
_question += await keyword_extraction(chat_mdl, _question)
|
_question += await keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(_question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = await asyncio.to_thread(settings.retriever.retrieval,
|
ranks = await settings.retriever.retrieval(
|
||||||
_question,
|
_question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
kb_ids,
|
kb_ids,
|
||||||
page,
|
page,
|
||||||
size,
|
size,
|
||||||
float(req.get("similarity_threshold", 0.0)),
|
float(req.get("similarity_threshold", 0.0)),
|
||||||
float(req.get("vector_similarity_weight", 0.3)),
|
float(req.get("vector_similarity_weight", 0.3)),
|
||||||
doc_ids=local_doc_ids,
|
doc_ids=local_doc_ids,
|
||||||
top=top,
|
top=top,
|
||||||
rerank_mdl=rerank_mdl,
|
rerank_mdl=rerank_mdl,
|
||||||
rank_feature=labels,
|
rank_feature=labels
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = await settings.kg_retriever.retrieval(_question,
|
ck = await settings.kg_retriever.retrieval(_question,
|
||||||
@ -413,7 +413,7 @@ async def retrieval_test():
|
|||||||
|
|
||||||
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
|
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def knowledge_graph():
|
async def knowledge_graph():
|
||||||
doc_id = request.args["doc_id"]
|
doc_id = request.args["doc_id"]
|
||||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
@ -421,7 +421,7 @@ def knowledge_graph():
|
|||||||
"doc_ids": [doc_id],
|
"doc_ids": [doc_id],
|
||||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||||
}
|
}
|
||||||
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
for id in sres.ids[:2]:
|
for id in sres.ids[:2]:
|
||||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||||
|
|||||||
@ -373,7 +373,7 @@ async def rename_tags(kb_id):
|
|||||||
|
|
||||||
@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def knowledge_graph(kb_id):
|
async def knowledge_graph(kb_id):
|
||||||
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
@ -389,7 +389,7 @@ def knowledge_graph(kb_id):
|
|||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
|
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
|
||||||
return get_json_result(data=obj)
|
return get_json_result(data=obj)
|
||||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||||
if not len(sres.ids):
|
if not len(sres.ids):
|
||||||
return get_json_result(data=obj)
|
return get_json_result(data=obj)
|
||||||
|
|
||||||
|
|||||||
@ -481,7 +481,7 @@ def list_datasets(tenant_id):
|
|||||||
|
|
||||||
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def knowledge_graph(tenant_id, dataset_id):
|
async def knowledge_graph(tenant_id, dataset_id):
|
||||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||||
return get_result(
|
return get_result(
|
||||||
data=False,
|
data=False,
|
||||||
@ -497,7 +497,7 @@ def knowledge_graph(tenant_id, dataset_id):
|
|||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
|
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
|
||||||
return get_result(data=obj)
|
return get_result(data=obj)
|
||||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||||
if not len(sres.ids):
|
if not len(sres.ids):
|
||||||
return get_result(data=obj)
|
return get_result(data=obj)
|
||||||
|
|
||||||
|
|||||||
@ -135,7 +135,7 @@ async def retrieval(tenant_id):
|
|||||||
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||||
if not doc_ids and metadata_condition:
|
if not doc_ids and metadata_condition:
|
||||||
doc_ids = ["-999"]
|
doc_ids = ["-999"]
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = await settings.retriever.retrieval(
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
kb.tenant_id,
|
kb.tenant_id,
|
||||||
|
|||||||
@ -935,7 +935,7 @@ async def stop_parsing(tenant_id, dataset_id):
|
|||||||
|
|
||||||
@manager.route("/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["GET"]) # noqa: F821
|
@manager.route("/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["GET"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def list_chunks(tenant_id, dataset_id, document_id):
|
async def list_chunks(tenant_id, dataset_id, document_id):
|
||||||
"""
|
"""
|
||||||
List chunks of a document.
|
List chunks of a document.
|
||||||
---
|
---
|
||||||
@ -1081,7 +1081,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
_ = Chunk(**final_chunk)
|
_ = Chunk(**final_chunk)
|
||||||
|
|
||||||
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
|
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
|
||||||
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
sres = await settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||||
res["total"] = sres.total
|
res["total"] = sres.total
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
@ -1559,7 +1559,7 @@ async def retrieval_test(tenant_id):
|
|||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += await keyword_extraction(chat_mdl, question)
|
question += await keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = await settings.retriever.retrieval(
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
|
|||||||
@ -1098,7 +1098,7 @@ async def retrieval_test_embedded():
|
|||||||
_question += await keyword_extraction(chat_mdl, _question)
|
_question += await keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(_question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = await settings.retriever.retrieval(
|
||||||
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||||
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
)
|
)
|
||||||
|
|||||||
@ -403,17 +403,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False}
|
yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False}
|
||||||
|
|
||||||
await task
|
await task
|
||||||
'''
|
|
||||||
async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
|
||||||
if isinstance(think, str):
|
|
||||||
thought = think
|
|
||||||
knowledges = [t for t in think.split("\n") if t]
|
|
||||||
elif stream:
|
|
||||||
yield think
|
|
||||||
'''
|
|
||||||
else:
|
else:
|
||||||
if embd_mdl:
|
if embd_mdl:
|
||||||
kbinfos = await asyncio.to_thread(retriever.retrieval,
|
kbinfos = await retriever.retrieval(
|
||||||
" ".join(questions),
|
" ".join(questions),
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
@ -853,7 +846,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
|||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||||
|
|
||||||
kbinfos = retriever.retrieval(
|
kbinfos = await retriever.retrieval(
|
||||||
question=question,
|
question=question,
|
||||||
embd_mdl=embd_mdl,
|
embd_mdl=embd_mdl,
|
||||||
tenant_ids=tenant_ids,
|
tenant_ids=tenant_ids,
|
||||||
@ -929,7 +922,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
|||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||||
|
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = await settings.retriever.retrieval(
|
||||||
question=question,
|
question=question,
|
||||||
embd_mdl=embd_mdl,
|
embd_mdl=embd_mdl,
|
||||||
tenant_ids=tenant_ids,
|
tenant_ids=tenant_ids,
|
||||||
|
|||||||
@ -36,12 +36,12 @@ class TreeStructuredQueryDecompositionRetrieval:
|
|||||||
self._kg_retrieve = kg_retrieve
|
self._kg_retrieve = kg_retrieve
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
def _retrieve_information(self, search_query):
|
async def _retrieve_information(self, search_query):
|
||||||
"""Retrieve information from different sources"""
|
"""Retrieve information from different sources"""
|
||||||
# 1. Knowledge base retrieval
|
# 1. Knowledge base retrieval
|
||||||
kbinfos = []
|
kbinfos = []
|
||||||
try:
|
try:
|
||||||
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
|
kbinfos = await self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Knowledge base retrieval error: {e}")
|
logging.error(f"Knowledge base retrieval error: {e}")
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ class TreeStructuredQueryDecompositionRetrieval:
|
|||||||
# 3. Knowledge graph retrieval (if configured)
|
# 3. Knowledge graph retrieval (if configured)
|
||||||
try:
|
try:
|
||||||
if self.prompt_config.get("use_kg") and self._kg_retrieve:
|
if self.prompt_config.get("use_kg") and self._kg_retrieve:
|
||||||
ck = self._kg_retrieve(question=search_query)
|
ck = await self._kg_retrieve(question=search_query)
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
kbinfos["chunks"].insert(0, ck)
|
kbinfos["chunks"].insert(0, ck)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -100,9 +100,9 @@ class TreeStructuredQueryDecompositionRetrieval:
|
|||||||
if callback:
|
if callback:
|
||||||
await callback(f"Searching by `{query}`...")
|
await callback(f"Searching by `{query}`...")
|
||||||
st = timer()
|
st = timer()
|
||||||
ret = self._retrieve_information(query)
|
ret = await self._retrieve_information(query)
|
||||||
if callback:
|
if callback:
|
||||||
await callback("Retrieval %d results by %.1fms"%(len(ret["chunks"]), (timer()-st)*1000))
|
await callback("Retrieval %d results in %.1fms"%(len(ret["chunks"]), (timer()-st)*1000))
|
||||||
await self._async_update_chunk_info(chunk_info, ret)
|
await self._async_update_chunk_info(chunk_info, ret)
|
||||||
ret = kb_prompt(ret, self.chat_mdl.max_length*0.5)
|
ret = kb_prompt(ret, self.chat_mdl.max_length*0.5)
|
||||||
|
|
||||||
@ -111,14 +111,14 @@ class TreeStructuredQueryDecompositionRetrieval:
|
|||||||
suff = await sufficiency_check(self.chat_mdl, question, ret)
|
suff = await sufficiency_check(self.chat_mdl, question, ret)
|
||||||
if suff["is_sufficient"]:
|
if suff["is_sufficient"]:
|
||||||
if callback:
|
if callback:
|
||||||
await callback("Yes, it's sufficient.")
|
await callback(f"Yes, the retrieved information is sufficient for '{question}'.")
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
#if callback:
|
#if callback:
|
||||||
# await callback("The retrieved information is not sufficient. Planing next steps...")
|
# await callback("The retrieved information is not sufficient. Planing next steps...")
|
||||||
succ_question_info = await multi_queries_gen(self.chat_mdl, question, query, suff["missing_information"], ret)
|
succ_question_info = await multi_queries_gen(self.chat_mdl, question, query, suff["missing_information"], ret)
|
||||||
if callback:
|
if callback:
|
||||||
await callback("Next step is to search for the following questions:\n" + "\n - ".join(step["question"] for step in succ_question_info["questions"]))
|
await callback("Next step is to search for the following questions:</br> - " + "</br> - ".join(step["question"] for step in succ_question_info["questions"]))
|
||||||
steps = []
|
steps = []
|
||||||
for step in succ_question_info["questions"]:
|
for step in succ_question_info["questions"]:
|
||||||
steps.append(asyncio.create_task(self._research(chunk_info, step["question"], step["query"], depth-1, callback)))
|
steps.append(asyncio.create_task(self._research(chunk_info, step["question"], step["query"], depth-1, callback)))
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -52,8 +53,8 @@ class Benchmark:
|
|||||||
run = defaultdict(dict)
|
run = defaultdict(dict)
|
||||||
query_list = list(qrels.keys())
|
query_list = list(qrels.keys())
|
||||||
for query in query_list:
|
for query in query_list:
|
||||||
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
ranks = asyncio.run(settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||||
0.0, self.vector_similarity_weight)
|
0.0, self.vector_similarity_weight))
|
||||||
if len(ranks["chunks"]) == 0:
|
if len(ranks["chunks"]) == 0:
|
||||||
print(f"deleted query: {query}")
|
print(f"deleted query: {query}")
|
||||||
del qrels[query]
|
del qrels[query]
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -49,8 +50,8 @@ class Dealer:
|
|||||||
keywords: list[str] | None = None
|
keywords: list[str] | None = None
|
||||||
group_docs: list[list] | None = None
|
group_docs: list[list] | None = None
|
||||||
|
|
||||||
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
async def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
|
||||||
qv, _ = emb_mdl.encode_queries(txt)
|
qv, _ = await asyncio.to_thread(emb_mdl.encode_queries, txt)
|
||||||
shape = np.array(qv).shape
|
shape = np.array(qv).shape
|
||||||
if len(shape) > 1:
|
if len(shape) > 1:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@ -71,7 +72,7 @@ class Dealer:
|
|||||||
condition[key] = req[key]
|
condition[key] = req[key]
|
||||||
return condition
|
return condition
|
||||||
|
|
||||||
def search(self, req, idx_names: str | list[str],
|
async def search(self, req, idx_names: str | list[str],
|
||||||
kb_ids: list[str],
|
kb_ids: list[str],
|
||||||
emb_mdl=None,
|
emb_mdl=None,
|
||||||
highlight: bool | list | None = None,
|
highlight: bool | list | None = None,
|
||||||
@ -114,12 +115,12 @@ class Dealer:
|
|||||||
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
||||||
if emb_mdl is None:
|
if emb_mdl is None:
|
||||||
matchExprs = [matchText]
|
matchExprs = [matchText]
|
||||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||||
idx_names, kb_ids, rank_feature=rank_feature)
|
idx_names, kb_ids, rank_feature=rank_feature)
|
||||||
total = self.dataStore.get_total(res)
|
total = self.dataStore.get_total(res)
|
||||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||||
else:
|
else:
|
||||||
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
matchDense = await self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||||
q_vec = matchDense.embedding_data
|
q_vec = matchDense.embedding_data
|
||||||
if not settings.DOC_ENGINE_INFINITY:
|
if not settings.DOC_ENGINE_INFINITY:
|
||||||
src.append(f"q_{len(q_vec)}_vec")
|
src.append(f"q_{len(q_vec)}_vec")
|
||||||
@ -127,7 +128,7 @@ class Dealer:
|
|||||||
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
|
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
|
||||||
matchExprs = [matchText, matchDense, fusionExpr]
|
matchExprs = [matchText, matchDense, fusionExpr]
|
||||||
|
|
||||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||||
idx_names, kb_ids, rank_feature=rank_feature)
|
idx_names, kb_ids, rank_feature=rank_feature)
|
||||||
total = self.dataStore.get_total(res)
|
total = self.dataStore.get_total(res)
|
||||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||||
@ -135,12 +136,12 @@ class Dealer:
|
|||||||
# If result is empty, try again with lower min_match
|
# If result is empty, try again with lower min_match
|
||||||
if total == 0:
|
if total == 0:
|
||||||
if filters.get("doc_id"):
|
if filters.get("doc_id"):
|
||||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
res = await asyncio.to_thread(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||||
total = self.dataStore.get_total(res)
|
total = self.dataStore.get_total(res)
|
||||||
else:
|
else:
|
||||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||||
matchDense.extra_options["similarity"] = 0.17
|
matchDense.extra_options["similarity"] = 0.17
|
||||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||||
orderBy, offset, limit, idx_names, kb_ids,
|
orderBy, offset, limit, idx_names, kb_ids,
|
||||||
rank_feature=rank_feature)
|
rank_feature=rank_feature)
|
||||||
total = self.dataStore.get_total(res)
|
total = self.dataStore.get_total(res)
|
||||||
@ -359,7 +360,7 @@ class Dealer:
|
|||||||
rag_tokenizer.tokenize(ans).split(),
|
rag_tokenizer.tokenize(ans).split(),
|
||||||
rag_tokenizer.tokenize(inst).split())
|
rag_tokenizer.tokenize(inst).split())
|
||||||
|
|
||||||
def retrieval(
|
async def retrieval(
|
||||||
self,
|
self,
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
@ -398,7 +399,7 @@ class Dealer:
|
|||||||
if isinstance(tenant_ids, str):
|
if isinstance(tenant_ids, str):
|
||||||
tenant_ids = tenant_ids.split(",")
|
tenant_ids = tenant_ids.split(",")
|
||||||
|
|
||||||
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
|
sres = await self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
|
||||||
rank_feature=rank_feature)
|
rank_feature=rank_feature)
|
||||||
|
|
||||||
if rerank_mdl and sres.total > 0:
|
if rerank_mdl and sres.total > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user