Fix: add soft limit for graph rag size (#13252)

### What problem does this PR solve?

Fix: add soft limit for graph rag size #13258 Q2

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
This commit is contained in:
Magicbook1108
2026-03-02 14:02:36 +08:00
committed by GitHub
parent 8a6b5ced6b
commit daec36e935
2 changed files with 98 additions and 11 deletions

View File

@ -540,15 +540,18 @@ class Dealer:
res = []
bs = 128
for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
limit = min(bs, max_count - p)
if limit <= 0:
break
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, limit, index_name(tenant_id),
kb_ids)
dict_chunks = self.dataStore.get_fields(es_res, fields)
for id, doc in dict_chunks.items():
doc["id"] = id
if dict_chunks:
res.extend(dict_chunks.values())
# FIX: Solo terminar si no hay chunks, no si hay menos de bs
if len(dict_chunks.values()) == 0:
chunk_count = len(dict_chunks)
if chunk_count == 0 or chunk_count < limit:
break
return res

View File

@ -28,6 +28,8 @@ from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
MAX_RESULT_WINDOW = 10000
SEARCH_AFTER_BATCH_SIZE = 1000
@singleton
@ -36,6 +38,81 @@ class ESConnection(ESConnectionBase):
CRUD operations
"""
def _es_search_once(self, index_names: list[str], query: dict, track_total_hits: bool):
return self.es.search(
index=index_names,
body=query,
timeout="600s",
track_total_hits=track_total_hits,
_source=True,
)
def _search_with_search_after(self, index_names: list[str], query: dict, offset: int, limit: int):
q_base = copy.deepcopy(query)
q_base.pop("from", None)
q_base.pop("size", None)
search_after = None
template_res = None
collected_hits = []
remaining_skip = max(0, offset)
remaining_take = max(0, limit)
with_aggs = True
while remaining_skip > 0:
batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_skip)
q_iter = copy.deepcopy(q_base)
q_iter["size"] = batch
if search_after is not None:
q_iter["search_after"] = search_after
if not with_aggs:
q_iter.pop("aggs", None)
res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None)
if template_res is None:
template_res = res
hits = res.get("hits", {}).get("hits", [])
if not hits:
break
next_search_after = hits[-1].get("sort")
if not next_search_after or next_search_after == search_after:
break
search_after = next_search_after
remaining_skip -= len(hits)
with_aggs = False
if len(hits) < batch:
break
while remaining_skip <= 0 and remaining_take > 0:
batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_take)
q_iter = copy.deepcopy(q_base)
q_iter["size"] = batch
if search_after is not None:
q_iter["search_after"] = search_after
if not with_aggs:
q_iter.pop("aggs", None)
res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None)
if template_res is None:
template_res = res
hits = res.get("hits", {}).get("hits", [])
if not hits:
break
collected_hits.extend(hits)
remaining_take -= len(hits)
next_search_after = hits[-1].get("sort")
if not next_search_after or next_search_after == search_after:
break
search_after = next_search_after
with_aggs = False
if len(hits) < batch:
break
if template_res is None:
q_count = copy.deepcopy(q_base)
q_count["size"] = 0
template_res = self._es_search_once(index_names, q_count, track_total_hits=True)
template_res["hits"]["hits"] = collected_hits
return template_res
def search(
self, select_fields: list[str],
highlight_fields: list[str],
@ -139,20 +216,27 @@ class ESConnection(ESConnectionBase):
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
has_dense = any(isinstance(m, MatchDenseExpr) for m in match_expressions)
has_explicit_sort = bool(order_by and order_by.fields)
use_search_after = (
limit > 0
and (offset + limit > MAX_RESULT_WINDOW)
and has_explicit_sort
and not has_dense
)
if limit > 0 and not use_search_after:
s = s[offset:offset + limit]
q = s.to_dict()
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
# print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=index_names,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=True)
if use_search_after:
res = self._search_with_search_after(index_names, q, offset, limit)
else:
# print(json.dumps(q, ensure_ascii=False))
res = self._es_search_once(index_names, q, track_total_hits=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))