Refa: async retrieval process. (#12629)

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Kevin Hu
2026-01-15 12:28:49 +08:00
committed by GitHub
parent f82628c40c
commit 9a10558f80
11 changed files with 52 additions and 57 deletions

View File

@ -174,7 +174,7 @@ class Retrieval(ToolBase, ABC):
if kbs: if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE) query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = settings.retriever.retrieval( kbinfos = await settings.retriever.retrieval(
query, query,
embd_mdl, embd_mdl,
[kb.tenant_id for kb in kbs], [kb.tenant_id for kb in kbs],

View File

@ -61,7 +61,7 @@ async def list_chunk():
} }
if "available_int" in req: if "available_int" in req:
query["available_int"] = int(req["available_int"]) query["available_int"] = int(req["available_int"])
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) sres = await settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -371,7 +371,7 @@ async def retrieval_test():
_question += await keyword_extraction(chat_mdl, _question) _question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb]) labels = label_question(_question, [kb])
ranks = await asyncio.to_thread(settings.retriever.retrieval, ranks = await settings.retriever.retrieval(
_question, _question,
embd_mdl, embd_mdl,
tenant_ids, tenant_ids,
@ -383,7 +383,7 @@ async def retrieval_test():
doc_ids=local_doc_ids, doc_ids=local_doc_ids,
top=top, top=top,
rerank_mdl=rerank_mdl, rerank_mdl=rerank_mdl,
rank_feature=labels, rank_feature=labels
) )
if use_kg: if use_kg:
@ -413,7 +413,7 @@ async def retrieval_test():
@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821 @manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
@login_required @login_required
def knowledge_graph(): async def knowledge_graph():
doc_id = request.args["doc_id"] doc_id = request.args["doc_id"]
tenant_id = DocumentService.get_tenant_id(doc_id) tenant_id = DocumentService.get_tenant_id(doc_id)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
@ -421,7 +421,7 @@ def knowledge_graph():
"doc_ids": [doc_id], "doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"] "knowledge_graph_kwd": ["graph", "mind_map"]
} }
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids) sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]: for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"] ty = sres.field[id]["knowledge_graph_kwd"]

View File

@ -373,7 +373,7 @@ async def rename_tags(kb_id):
@manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821 @manager.route('/<kb_id>/knowledge_graph', methods=['GET']) # noqa: F821
@login_required @login_required
def knowledge_graph(kb_id): async def knowledge_graph(kb_id):
if not KnowledgebaseService.accessible(kb_id, current_user.id): if not KnowledgebaseService.accessible(kb_id, current_user.id):
return get_json_result( return get_json_result(
data=False, data=False,
@ -389,7 +389,7 @@ def knowledge_graph(kb_id):
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id): if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj) return get_json_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids): if not len(sres.ids):
return get_json_result(data=obj) return get_json_result(data=obj)

View File

@ -481,7 +481,7 @@ def list_datasets(tenant_id):
@manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821 @manager.route('/datasets/<dataset_id>/knowledge_graph', methods=['GET']) # noqa: F821
@token_required @token_required
def knowledge_graph(tenant_id, dataset_id): async def knowledge_graph(tenant_id, dataset_id):
if not KnowledgebaseService.accessible(dataset_id, tenant_id): if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return get_result( return get_result(
data=False, data=False,
@ -497,7 +497,7 @@ def knowledge_graph(tenant_id, dataset_id):
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj) return get_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids): if not len(sres.ids):
return get_result(data=obj) return get_result(data=obj)

View File

@ -135,7 +135,7 @@ async def retrieval(tenant_id):
doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
if not doc_ids and metadata_condition: if not doc_ids and metadata_condition:
doc_ids = ["-999"] doc_ids = ["-999"]
ranks = settings.retriever.retrieval( ranks = await settings.retriever.retrieval(
question, question,
embd_mdl, embd_mdl,
kb.tenant_id, kb.tenant_id,

View File

@ -935,7 +935,7 @@ async def stop_parsing(tenant_id, dataset_id):
@manager.route("/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["GET"]) # noqa: F821 @manager.route("/datasets/<dataset_id>/documents/<document_id>/chunks", methods=["GET"]) # noqa: F821
@token_required @token_required
def list_chunks(tenant_id, dataset_id, document_id): async def list_chunks(tenant_id, dataset_id, document_id):
""" """
List chunks of a document. List chunks of a document.
--- ---
@ -1081,7 +1081,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
_ = Chunk(**final_chunk) _ = Chunk(**final_chunk)
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id): elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) sres = await settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total res["total"] = sres.total
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -1559,7 +1559,7 @@ async def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += await keyword_extraction(chat_mdl, question) question += await keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval( ranks = await settings.retriever.retrieval(
question, question,
embd_mdl, embd_mdl,
tenant_ids, tenant_ids,

View File

@ -1098,7 +1098,7 @@ async def retrieval_test_embedded():
_question += await keyword_extraction(chat_mdl, _question) _question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval( ranks = await settings.retriever.retrieval(
_question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, _question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
) )

View File

@ -403,17 +403,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False} yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False}
await task await task
'''
async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
if isinstance(think, str):
thought = think
knowledges = [t for t in think.split("\n") if t]
elif stream:
yield think
'''
else: else:
if embd_mdl: if embd_mdl:
kbinfos = await asyncio.to_thread(retriever.retrieval, kbinfos = await retriever.retrieval(
" ".join(questions), " ".join(questions),
embd_mdl, embd_mdl,
tenant_ids, tenant_ids,
@ -853,7 +846,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
kbinfos = retriever.retrieval( kbinfos = await retriever.retrieval(
question=question, question=question,
embd_mdl=embd_mdl, embd_mdl=embd_mdl,
tenant_ids=tenant_ids, tenant_ids=tenant_ids,
@ -929,7 +922,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
ranks = settings.retriever.retrieval( ranks = await settings.retriever.retrieval(
question=question, question=question,
embd_mdl=embd_mdl, embd_mdl=embd_mdl,
tenant_ids=tenant_ids, tenant_ids=tenant_ids,

View File

@ -36,12 +36,12 @@ class TreeStructuredQueryDecompositionRetrieval:
self._kg_retrieve = kg_retrieve self._kg_retrieve = kg_retrieve
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
def _retrieve_information(self, search_query): async def _retrieve_information(self, search_query):
"""Retrieve information from different sources""" """Retrieve information from different sources"""
# 1. Knowledge base retrieval # 1. Knowledge base retrieval
kbinfos = [] kbinfos = []
try: try:
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []} kbinfos = await self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
except Exception as e: except Exception as e:
logging.error(f"Knowledge base retrieval error: {e}") logging.error(f"Knowledge base retrieval error: {e}")
@ -58,7 +58,7 @@ class TreeStructuredQueryDecompositionRetrieval:
# 3. Knowledge graph retrieval (if configured) # 3. Knowledge graph retrieval (if configured)
try: try:
if self.prompt_config.get("use_kg") and self._kg_retrieve: if self.prompt_config.get("use_kg") and self._kg_retrieve:
ck = self._kg_retrieve(question=search_query) ck = await self._kg_retrieve(question=search_query)
if ck["content_with_weight"]: if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck) kbinfos["chunks"].insert(0, ck)
except Exception as e: except Exception as e:
@ -100,9 +100,9 @@ class TreeStructuredQueryDecompositionRetrieval:
if callback: if callback:
await callback(f"Searching by `{query}`...") await callback(f"Searching by `{query}`...")
st = timer() st = timer()
ret = self._retrieve_information(query) ret = await self._retrieve_information(query)
if callback: if callback:
await callback("Retrieval %d results by %.1fms"%(len(ret["chunks"]), (timer()-st)*1000)) await callback("Retrieval %d results in %.1fms"%(len(ret["chunks"]), (timer()-st)*1000))
await self._async_update_chunk_info(chunk_info, ret) await self._async_update_chunk_info(chunk_info, ret)
ret = kb_prompt(ret, self.chat_mdl.max_length*0.5) ret = kb_prompt(ret, self.chat_mdl.max_length*0.5)
@ -111,14 +111,14 @@ class TreeStructuredQueryDecompositionRetrieval:
suff = await sufficiency_check(self.chat_mdl, question, ret) suff = await sufficiency_check(self.chat_mdl, question, ret)
if suff["is_sufficient"]: if suff["is_sufficient"]:
if callback: if callback:
await callback("Yes, it's sufficient.") await callback(f"Yes, the retrieved information is sufficient for '{question}'.")
return ret return ret
#if callback: #if callback:
# await callback("The retrieved information is not sufficient. Planing next steps...") # await callback("The retrieved information is not sufficient. Planing next steps...")
succ_question_info = await multi_queries_gen(self.chat_mdl, question, query, suff["missing_information"], ret) succ_question_info = await multi_queries_gen(self.chat_mdl, question, query, suff["missing_information"], ret)
if callback: if callback:
await callback("Next step is to search for the following questions:\n" + "\n - ".join(step["question"] for step in succ_question_info["questions"])) await callback("Next step is to search for the following questions:</br> - " + "</br> - ".join(step["question"] for step in succ_question_info["questions"]))
steps = [] steps = []
for step in succ_question_info["questions"]: for step in succ_question_info["questions"]:
steps.append(asyncio.create_task(self._research(chunk_info, step["question"], step["query"], depth-1, callback))) steps.append(asyncio.create_task(self._research(chunk_info, step["question"], step["query"], depth-1, callback)))

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import os import os
import sys import sys
@ -52,8 +53,8 @@ class Benchmark:
run = defaultdict(dict) run = defaultdict(dict)
query_list = list(qrels.keys()) query_list = list(qrels.keys())
for query in query_list: for query in query_list:
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, ranks = asyncio.run(settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight) 0.0, self.vector_similarity_weight))
if len(ranks["chunks"]) == 0: if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}") print(f"deleted query: {query}")
del qrels[query] del qrels[query]

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import re import re
@ -49,8 +50,8 @@ class Dealer:
keywords: list[str] | None = None keywords: list[str] | None = None
group_docs: list[list] | None = None group_docs: list[list] | None = None
def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1): async def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1):
qv, _ = emb_mdl.encode_queries(txt) qv, _ = await asyncio.to_thread(emb_mdl.encode_queries, txt)
shape = np.array(qv).shape shape = np.array(qv).shape
if len(shape) > 1: if len(shape) > 1:
raise Exception( raise Exception(
@ -71,7 +72,7 @@ class Dealer:
condition[key] = req[key] condition[key] = req[key]
return condition return condition
def search(self, req, idx_names: str | list[str], async def search(self, req, idx_names: str | list[str],
kb_ids: list[str], kb_ids: list[str],
emb_mdl=None, emb_mdl=None,
highlight: bool | list | None = None, highlight: bool | list | None = None,
@ -114,12 +115,12 @@ class Dealer:
matchText, keywords = self.qryr.question(qst, min_match=0.3) matchText, keywords = self.qryr.question(qst, min_match=0.3)
if emb_mdl is None: if emb_mdl is None:
matchExprs = [matchText] matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature) idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.get_total(res) total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
else: else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) matchDense = await self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data q_vec = matchDense.embedding_data
if not settings.DOC_ENGINE_INFINITY: if not settings.DOC_ENGINE_INFINITY:
src.append(f"q_{len(q_vec)}_vec") src.append(f"q_{len(q_vec)}_vec")
@ -127,7 +128,7 @@ class Dealer:
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"}) fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
matchExprs = [matchText, matchDense, fusionExpr] matchExprs = [matchText, matchDense, fusionExpr]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature) idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.get_total(res) total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
@ -135,12 +136,12 @@ class Dealer:
# If result is empty, try again with lower min_match # If result is empty, try again with lower min_match
if total == 0: if total == 0:
if filters.get("doc_id"): if filters.get("doc_id"):
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) res = await asyncio.to_thread(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.get_total(res) total = self.dataStore.get_total(res)
else: else:
matchText, _ = self.qryr.question(qst, min_match=0.1) matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17 matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, orderBy, offset, limit, idx_names, kb_ids,
rank_feature=rank_feature) rank_feature=rank_feature)
total = self.dataStore.get_total(res) total = self.dataStore.get_total(res)
@ -359,7 +360,7 @@ class Dealer:
rag_tokenizer.tokenize(ans).split(), rag_tokenizer.tokenize(ans).split(),
rag_tokenizer.tokenize(inst).split()) rag_tokenizer.tokenize(inst).split())
def retrieval( async def retrieval(
self, self,
question, question,
embd_mdl, embd_mdl,
@ -398,7 +399,7 @@ class Dealer:
if isinstance(tenant_ids, str): if isinstance(tenant_ids, str):
tenant_ids = tenant_ids.split(",") tenant_ids = tenant_ids.split(",")
sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, sres = await self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight,
rank_feature=rank_feature) rank_feature=rank_feature)
if rerank_mdl and sres.total > 0: if rerank_mdl and sres.total > 0: