diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 77a39b731..2a19b74ef 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -174,7 +174,7 @@ class Retrieval(ToolBase, ABC): if kbs: query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE) - kbinfos = settings.retriever.retrieval( + kbinfos = await settings.retriever.retrieval( query, embd_mdl, [kb.tenant_id for kb in kbs], diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 731506404..20891033a 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -61,7 +61,7 @@ async def list_chunk(): } if "available_int" in req: query["available_int"] = int(req["available_int"]) - sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) + sres = await settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} for id in sres.ids: d = { @@ -371,20 +371,20 @@ async def retrieval_test(): _question += await keyword_extraction(chat_mdl, _question) labels = label_question(_question, [kb]) - ranks = await asyncio.to_thread(settings.retriever.retrieval, - _question, - embd_mdl, - tenant_ids, - kb_ids, - page, - size, - float(req.get("similarity_threshold", 0.0)), - float(req.get("vector_similarity_weight", 0.3)), - doc_ids=local_doc_ids, - top=top, - rerank_mdl=rerank_mdl, - rank_feature=labels, - ) + ranks = await settings.retriever.retrieval( + _question, + embd_mdl, + tenant_ids, + kb_ids, + page, + size, + float(req.get("similarity_threshold", 0.0)), + float(req.get("vector_similarity_weight", 0.3)), + doc_ids=local_doc_ids, + top=top, + rerank_mdl=rerank_mdl, + rank_feature=labels + ) if use_kg: ck = await settings.kg_retriever.retrieval(_question, @@ -413,7 +413,7 @@ async def retrieval_test(): @manager.route('/knowledge_graph', methods=['GET']) # noqa: F821 @login_required -def knowledge_graph(): +async def knowledge_graph(): doc_id = request.args["doc_id"] tenant_id = DocumentService.get_tenant_id(doc_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) @@ -421,7 +421,7 @@ def knowledge_graph(): "doc_ids": [doc_id], "knowledge_graph_kwd": ["graph", "mind_map"] } - sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids) + sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids) obj = {"graph": {}, "mind_map": {}} for id in sres.ids[:2]: ty = sres.field[id]["knowledge_graph_kwd"] diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 26ea12f96..5ffc3040e 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -373,7 +373,7 @@ async def rename_tags(kb_id): @manager.route('//knowledge_graph', methods=['GET']) # noqa: F821 @login_required -def knowledge_graph(kb_id): +async def knowledge_graph(kb_id): if not KnowledgebaseService.accessible(kb_id, current_user.id): return get_json_result( data=False, @@ -389,7 +389,7 @@ def knowledge_graph(kb_id): obj = {"graph": {}, "mind_map": {}} if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id): return get_json_result(data=obj) - sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) + sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) if not len(sres.ids): return get_json_result(data=obj) diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 7d52c3fec..f98705de0 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -481,7 +481,7 @@ def list_datasets(tenant_id): @manager.route('/datasets//knowledge_graph', methods=['GET']) # noqa: F821 @token_required -def knowledge_graph(tenant_id, dataset_id): +async def knowledge_graph(tenant_id, dataset_id): if not KnowledgebaseService.accessible(dataset_id, tenant_id): return get_result( data=False, @@ -497,7 +497,7 @@ def knowledge_graph(tenant_id, dataset_id): obj = {"graph": {}, "mind_map": {}} if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): return get_result(data=obj) - sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) + sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) if not len(sres.ids): return get_result(data=obj) diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index 91f1c9a8f..0841bf7bd 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -135,7 +135,7 @@ async def retrieval(tenant_id): doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) if not doc_ids and metadata_condition: doc_ids = ["-999"] - ranks = settings.retriever.retrieval( + ranks = await settings.retriever.retrieval( question, embd_mdl, kb.tenant_id, diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index d8afe5f27..b27f972b9 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -935,7 +935,7 @@ async def stop_parsing(tenant_id, dataset_id): @manager.route("/datasets//documents//chunks", methods=["GET"]) # noqa: F821 @token_required -def list_chunks(tenant_id, dataset_id, document_id): +async def list_chunks(tenant_id, dataset_id, document_id): """ List chunks of a document. --- @@ -1081,7 +1081,7 @@ def list_chunks(tenant_id, dataset_id, document_id): _ = Chunk(**final_chunk) elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id): - sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) + sres = await settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) res["total"] = sres.total for id in sres.ids: d = { @@ -1559,7 +1559,7 @@ async def retrieval_test(tenant_id): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) question += await keyword_extraction(chat_mdl, question) - ranks = settings.retriever.retrieval( + ranks = await settings.retriever.retrieval( question, embd_mdl, tenant_ids, diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 03140b60b..80f8229be 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -1098,7 +1098,7 @@ async def retrieval_test_embedded(): _question += await keyword_extraction(chat_mdl, _question) labels = label_question(_question, [kb]) - ranks = settings.retriever.retrieval( + ranks = await settings.retriever.retrieval( _question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 9935827be..ccf8474b6 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -403,17 +403,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs): yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False} await task - ''' - async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)): - if isinstance(think, str): - thought = think - knowledges = [t for t in think.split("\n") if t] - elif stream: - yield think - ''' + else: if embd_mdl: - kbinfos = await asyncio.to_thread(retriever.retrieval, + kbinfos = await retriever.retrieval( " ".join(questions), embd_mdl, tenant_ids, @@ -853,7 +846,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf metas = DocumentService.get_meta_by_kbs(kb_ids) doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) - kbinfos = retriever.retrieval( + kbinfos = await retriever.retrieval( question=question, embd_mdl=embd_mdl, tenant_ids=tenant_ids, @@ -929,7 +922,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): metas = DocumentService.get_meta_by_kbs(kb_ids) doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) - ranks = settings.retriever.retrieval( + ranks = await settings.retriever.retrieval( question=question, embd_mdl=embd_mdl, tenant_ids=tenant_ids, diff --git a/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py b/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py index 77689cab0..214485c3b 100644 --- a/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py +++ b/rag/advanced_rag/tree_structured_query_decomposition_retrieval.py @@ -36,12 +36,12 @@ class TreeStructuredQueryDecompositionRetrieval: self._kg_retrieve = kg_retrieve self._lock = asyncio.Lock() - def _retrieve_information(self, search_query): + async def _retrieve_information(self, search_query): """Retrieve information from different sources""" # 1. Knowledge base retrieval kbinfos = [] try: - kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []} + kbinfos = await self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []} except Exception as e: logging.error(f"Knowledge base retrieval error: {e}") @@ -58,7 +58,7 @@ class TreeStructuredQueryDecompositionRetrieval: # 3. Knowledge graph retrieval (if configured) try: if self.prompt_config.get("use_kg") and self._kg_retrieve: - ck = self._kg_retrieve(question=search_query) + ck = await self._kg_retrieve(question=search_query) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) except Exception as e: @@ -100,9 +100,9 @@ class TreeStructuredQueryDecompositionRetrieval: if callback: await callback(f"Searching by `{query}`...") st = timer() - ret = self._retrieve_information(query) + ret = await self._retrieve_information(query) if callback: - await callback("Retrieval %d results by %.1fms"%(len(ret["chunks"]), (timer()-st)*1000)) + await callback("Retrieval %d results in %.1fms"%(len(ret["chunks"]), (timer()-st)*1000)) await self._async_update_chunk_info(chunk_info, ret) ret = kb_prompt(ret, self.chat_mdl.max_length*0.5) @@ -111,14 +111,14 @@ class TreeStructuredQueryDecompositionRetrieval: suff = await sufficiency_check(self.chat_mdl, question, ret) if suff["is_sufficient"]: if callback: - await callback("Yes, it's sufficient.") + await callback(f"Yes, the retrieved information is sufficient for '{question}'.") return ret #if callback: # await callback("The retrieved information is not sufficient. Planing next steps...") succ_question_info = await multi_queries_gen(self.chat_mdl, question, query, suff["missing_information"], ret) if callback: - await callback("Next step is to search for the following questions:\n" + "\n - ".join(step["question"] for step in succ_question_info["questions"])) + await callback("Next step is to search for the following questions:
- " + "
- ".join(step["question"] for step in succ_question_info["questions"])) steps = [] for step in succ_question_info["questions"]: steps.append(asyncio.create_task(self._research(chunk_info, step["question"], step["query"], depth-1, callback))) diff --git a/rag/benchmark.py b/rag/benchmark.py index c19785db3..93b93adcf 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import os import sys @@ -52,8 +53,8 @@ class Benchmark: run = defaultdict(dict) query_list = list(qrels.keys()) for query in query_list: - ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, - 0.0, self.vector_similarity_weight) + ranks = asyncio.run(settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, + 0.0, self.vector_similarity_weight)) if len(ranks["chunks"]) == 0: print(f"deleted query: {query}") del qrels[query] diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 46b8b5b0a..54d46b9c8 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import re @@ -49,8 +50,8 @@ class Dealer: keywords: list[str] | None = None group_docs: list[list] | None = None - def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1): - qv, _ = emb_mdl.encode_queries(txt) + async def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1): + qv, _ = await asyncio.to_thread(emb_mdl.encode_queries, txt) shape = np.array(qv).shape if len(shape) > 1: raise Exception( @@ -71,7 +72,7 @@ class Dealer: condition[key] = req[key] return condition - def search(self, req, idx_names: str | list[str], + async def search(self, req, idx_names: str | list[str], kb_ids: list[str], emb_mdl=None, highlight: bool | list | None = None, @@ -114,12 +115,12 @@ class Dealer: matchText, keywords = self.qryr.question(qst, min_match=0.3) if emb_mdl is None: matchExprs = [matchText] - res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, + res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) total = self.dataStore.get_total(res) logging.debug("Dealer.search TOTAL: {}".format(total)) else: - matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) + matchDense = await self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) q_vec = matchDense.embedding_data if not settings.DOC_ENGINE_INFINITY: src.append(f"q_{len(q_vec)}_vec") @@ -127,7 +128,7 @@ class Dealer: fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"}) matchExprs = [matchText, matchDense, fusionExpr] - res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, + res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) total = self.dataStore.get_total(res) logging.debug("Dealer.search TOTAL: {}".format(total)) @@ -135,12 +136,12 @@ class Dealer: # If result is empty, try again with lower min_match if total == 0: if filters.get("doc_id"): - res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) + res = await asyncio.to_thread(self.dataStore.search, src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) total = self.dataStore.get_total(res) else: matchText, _ = self.qryr.question(qst, min_match=0.1) matchDense.extra_options["similarity"] = 0.17 - res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], + res = await asyncio.to_thread(self.dataStore.search, src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) total = self.dataStore.get_total(res) @@ -359,7 +360,7 @@ class Dealer: rag_tokenizer.tokenize(ans).split(), rag_tokenizer.tokenize(inst).split()) - def retrieval( + async def retrieval( self, question, embd_mdl, @@ -398,7 +399,7 @@ class Dealer: if isinstance(tenant_ids, str): tenant_ids = tenant_ids.split(",") - sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, + sres = await self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, rank_feature=rank_feature) if rerank_mdl and sres.total > 0: