mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-04-28 06:17:49 +08:00
Fix: add soft limit for graph rag size (#13252)
### What problem does this PR solve? Fix: add soft limit for graph rag size #13258 Q2 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
@ -540,15 +540,18 @@ class Dealer:
|
||||
res = []
|
||||
bs = 128
|
||||
for p in range(offset, max_count, bs):
|
||||
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
|
||||
limit = min(bs, max_count - p)
|
||||
if limit <= 0:
|
||||
break
|
||||
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, limit, index_name(tenant_id),
|
||||
kb_ids)
|
||||
dict_chunks = self.dataStore.get_fields(es_res, fields)
|
||||
for id, doc in dict_chunks.items():
|
||||
doc["id"] = id
|
||||
if dict_chunks:
|
||||
res.extend(dict_chunks.values())
|
||||
# FIX: Solo terminar si no hay chunks, no si hay menos de bs
|
||||
if len(dict_chunks.values()) == 0:
|
||||
chunk_count = len(dict_chunks)
|
||||
if chunk_count == 0 or chunk_count < limit:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
@ -28,6 +28,8 @@ from common.float_utils import get_float
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
MAX_RESULT_WINDOW = 10000
|
||||
SEARCH_AFTER_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@singleton
|
||||
@ -36,6 +38,81 @@ class ESConnection(ESConnectionBase):
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def _es_search_once(self, index_names: list[str], query: dict, track_total_hits: bool):
|
||||
return self.es.search(
|
||||
index=index_names,
|
||||
body=query,
|
||||
timeout="600s",
|
||||
track_total_hits=track_total_hits,
|
||||
_source=True,
|
||||
)
|
||||
|
||||
def _search_with_search_after(self, index_names: list[str], query: dict, offset: int, limit: int):
|
||||
q_base = copy.deepcopy(query)
|
||||
q_base.pop("from", None)
|
||||
q_base.pop("size", None)
|
||||
|
||||
search_after = None
|
||||
template_res = None
|
||||
collected_hits = []
|
||||
remaining_skip = max(0, offset)
|
||||
remaining_take = max(0, limit)
|
||||
with_aggs = True
|
||||
|
||||
while remaining_skip > 0:
|
||||
batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_skip)
|
||||
q_iter = copy.deepcopy(q_base)
|
||||
q_iter["size"] = batch
|
||||
if search_after is not None:
|
||||
q_iter["search_after"] = search_after
|
||||
if not with_aggs:
|
||||
q_iter.pop("aggs", None)
|
||||
res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None)
|
||||
if template_res is None:
|
||||
template_res = res
|
||||
hits = res.get("hits", {}).get("hits", [])
|
||||
if not hits:
|
||||
break
|
||||
next_search_after = hits[-1].get("sort")
|
||||
if not next_search_after or next_search_after == search_after:
|
||||
break
|
||||
search_after = next_search_after
|
||||
remaining_skip -= len(hits)
|
||||
with_aggs = False
|
||||
if len(hits) < batch:
|
||||
break
|
||||
|
||||
while remaining_skip <= 0 and remaining_take > 0:
|
||||
batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_take)
|
||||
q_iter = copy.deepcopy(q_base)
|
||||
q_iter["size"] = batch
|
||||
if search_after is not None:
|
||||
q_iter["search_after"] = search_after
|
||||
if not with_aggs:
|
||||
q_iter.pop("aggs", None)
|
||||
res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None)
|
||||
if template_res is None:
|
||||
template_res = res
|
||||
hits = res.get("hits", {}).get("hits", [])
|
||||
if not hits:
|
||||
break
|
||||
collected_hits.extend(hits)
|
||||
remaining_take -= len(hits)
|
||||
next_search_after = hits[-1].get("sort")
|
||||
if not next_search_after or next_search_after == search_after:
|
||||
break
|
||||
search_after = next_search_after
|
||||
with_aggs = False
|
||||
if len(hits) < batch:
|
||||
break
|
||||
|
||||
if template_res is None:
|
||||
q_count = copy.deepcopy(q_base)
|
||||
q_count["size"] = 0
|
||||
template_res = self._es_search_once(index_names, q_count, track_total_hits=True)
|
||||
template_res["hits"]["hits"] = collected_hits
|
||||
return template_res
|
||||
|
||||
def search(
|
||||
self, select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
@ -139,20 +216,27 @@ class ESConnection(ESConnectionBase):
|
||||
for fld in agg_fields:
|
||||
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
|
||||
|
||||
if limit > 0:
|
||||
has_dense = any(isinstance(m, MatchDenseExpr) for m in match_expressions)
|
||||
has_explicit_sort = bool(order_by and order_by.fields)
|
||||
use_search_after = (
|
||||
limit > 0
|
||||
and (offset + limit > MAX_RESULT_WINDOW)
|
||||
and has_explicit_sort
|
||||
and not has_dense
|
||||
)
|
||||
|
||||
if limit > 0 and not use_search_after:
|
||||
s = s[offset:offset + limit]
|
||||
q = s.to_dict()
|
||||
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
# print(json.dumps(q, ensure_ascii=False))
|
||||
res = self.es.search(index=index_names,
|
||||
body=q,
|
||||
timeout="600s",
|
||||
# search_type="dfs_query_then_fetch",
|
||||
track_total_hits=True,
|
||||
_source=True)
|
||||
if use_search_after:
|
||||
res = self._search_with_search_after(index_names, q, offset, limit)
|
||||
else:
|
||||
# print(json.dumps(q, ensure_ascii=False))
|
||||
res = self._es_search_once(index_names, q, track_total_hits=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
|
||||
|
||||
Reference in New Issue
Block a user