Files
ragflow/rag/utils/es_conn.py
as-ondewo 6fb8c31c22 Fix: Document parse status set to DONE before chunks are retrievable (#13352)
### What problem does this PR solve?

The document parse status was set to DONE before the document chunks
were actually retrievable from Elasticsearch/Opensearch because it did
not wait for the index refresh. This meant that it was possible that the
document parse status returned by the API was DONE but when trying to
retrieve chunks there were none. Since the index refreshes every 1
second this was quite likely to happen when wait for document parsing by
polling with a short interval and then immediately trying to retrieve
chunks once the status was DONE.

I fixed this bug and added a test case that would have caught it.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-05-11 16:04:08 +08:00

624 lines
26 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import time
import copy
from elasticsearch_dsl import UpdateByQuery, Q, Search
from elastic_transport import ConnectionTimeout
from common.decorator import singleton
from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr
from common.doc_store.es_conn_base import ESConnectionBase
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
MAX_RESULT_WINDOW = 10000
SEARCH_AFTER_BATCH_SIZE = 1000
# Single-document atomic pagerank_fea adjust (chunk feedback). Clamps using params.min_w / max_w;
# removes field at zero for rank_feature compatibility.
_PAGERANK_FEA_ADJUST_SCRIPT = """
double cur = 0.0;
if (ctx._source.containsKey(params.pf)) {
Object v = ctx._source[params.pf];
if (v != null) {
if (v instanceof Number) {
cur = ((Number)v).doubleValue();
} else {
try { cur = Double.parseDouble(v.toString()); } catch (Exception e) { cur = 0.0; }
}
}
}
double nw = cur + params.delta;
if (nw < params.min_w) { nw = params.min_w; }
if (nw > params.max_w) { nw = params.max_w; }
if (nw <= 0.0) {
if (ctx._source.containsKey(params.pf)) {
ctx._source.remove(params.pf);
}
} else {
ctx._source[params.pf] = nw;
}
"""
@singleton
class ESConnection(ESConnectionBase):
"""
CRUD operations
"""
def _es_search_once(self, index_names: list[str], query: dict, track_total_hits: bool):
return self.es.search(
index=index_names,
body=query,
timeout="600s",
track_total_hits=track_total_hits,
_source=True,
)
def _search_with_search_after(self, index_names: list[str], query: dict, offset: int, limit: int):
q_base = copy.deepcopy(query)
q_base.pop("from", None)
q_base.pop("size", None)
search_after = None
template_res = None
collected_hits = []
remaining_skip = max(0, offset)
remaining_take = max(0, limit)
with_aggs = True
while remaining_skip > 0:
batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_skip)
q_iter = copy.deepcopy(q_base)
q_iter["size"] = batch
if search_after is not None:
q_iter["search_after"] = search_after
if not with_aggs:
q_iter.pop("aggs", None)
res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None)
if template_res is None:
template_res = res
hits = res.get("hits", {}).get("hits", [])
if not hits:
break
next_search_after = hits[-1].get("sort")
if not next_search_after or next_search_after == search_after:
break
search_after = next_search_after
remaining_skip -= len(hits)
with_aggs = False
if len(hits) < batch:
break
while remaining_skip <= 0 and remaining_take > 0:
batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_take)
q_iter = copy.deepcopy(q_base)
q_iter["size"] = batch
if search_after is not None:
q_iter["search_after"] = search_after
if not with_aggs:
q_iter.pop("aggs", None)
res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None)
if template_res is None:
template_res = res
hits = res.get("hits", {}).get("hits", [])
if not hits:
break
collected_hits.extend(hits)
remaining_take -= len(hits)
next_search_after = hits[-1].get("sort")
if not next_search_after or next_search_after == search_after:
break
search_after = next_search_after
with_aggs = False
if len(hits) < batch:
break
if template_res is None:
q_count = copy.deepcopy(q_base)
q_count["size"] = 0
template_res = self._es_search_once(index_names, q_count, track_total_hits=True)
template_res["hits"]["hits"] = collected_hits
return template_res
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
assert "_id" not in condition
bool_query = Q("bool", must=[])
condition["kb_id"] = knowledgebase_ids
for k, v in condition.items():
if k == "available_int":
if v == 0:
bool_query.filter.append(Q("range", available_int={"lt": 1}))
else:
bool_query.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if k == "id":
if not v:
continue
if isinstance(v, list):
bool_query.filter.append(
Q("bool", should=[Q("terms", id=v), Q("terms", _id=v)], minimum_should_match=1))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(
Q("bool", should=[Q("term", id=v), Q("term", _id=v)], minimum_should_match=1))
continue
if not v:
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(
match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bool_query is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(m.vector_column_name,
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bool_query.to_dict(),
similarity=similarity,
)
if bool_query and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bool_query:
s = s.query(bool_query)
for field in highlight_fields:
s = s.highlight(field)
if order_by:
orders = list()
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}
elif field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
elif field == "id":
continue # id as "text", not a "keyword", order by it will cause error
else:
order_info = {"order": order, "unmapped_type": "keyword"}
orders.append({field: order_info})
s = s.sort(*orders)
if agg_fields:
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
has_dense = any(isinstance(m, MatchDenseExpr) for m in match_expressions)
has_explicit_sort = bool(order_by and order_by.fields)
use_search_after = (
limit > 0
and (offset + limit > MAX_RESULT_WINDOW)
and has_explicit_sort
and not has_dense
)
if limit > 0 and not use_search_after:
s = s[offset:offset + limit]
# Filter _source to only requested fields for efficiency, and add vector
# fields to "fields" param so they appear in hit.fields when ES 9.x
# exclude_source_vectors is enabled (dense_vector not in _source).
if select_fields:
s = s.source(select_fields)
q = s.to_dict()
# ES 9.x: dense_vector fields excluded from _source; request them via fields.
# Note: knn does NOT have a "fields" parameter - adding it inside the knn
# object causes BadRequestError on ES 9.x. We add "fields" at top level.
vector_fields = [f for f in (select_fields or []) if f.endswith("_vec")]
if vector_fields:
q["fields"] = vector_fields
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
if use_search_after:
res = self._search_with_search_after(index_names, q, offset, limit)
else:
# print(json.dumps(q, ensure_ascii=False))
res = self._es_search_once(index_names, q, track_total_hits=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
# Only log debug for NotFoundError(accepted when metadata index doesn't exist)
if 'NotFound' in str(e):
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + str(q) + " - " + str(e))
else:
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy = copy.deepcopy(d)
d_copy["kb_id"] = knowledgebase_id
# Use id as _id for uniqueness, also keep "id" as a regular field for sorting
meta_id = d_copy.get("id", "")
operations.append(
{"index": {"_index": index_name, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=index_name, operations=operations,
refresh="wait_for", timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
self.logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
doc = copy.deepcopy(new_value)
doc.pop("id", None)
condition["kb_id"] = knowledgebase_id
if "id" in condition and isinstance(condition["id"], str):
# update specific single document
chunk_id = condition["id"]
for i in range(ATTEMPT_TIME):
doc_part = copy.deepcopy(doc)
remove_value = doc_part.pop("remove", None)
remove_field = remove_value if isinstance(remove_value, str) else None
remove_dict = remove_value if isinstance(remove_value, dict) else None
for k in doc_part.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
if remove_field is not None:
self.es.update(
index=index_name,
id=chunk_id,
script=f"ctx._source.remove('{remove_field}');",
)
if remove_dict is not None:
scripts = []
params = {}
for kk, vv in remove_dict.items():
scripts.append(
f"if (ctx._source.containsKey('{kk}') && ctx._source.{kk} != null) "
f"{{ int i = ctx._source.{kk}.indexOf(params.p_{kk}); "
f"if (i >= 0) {{ ctx._source.{kk}.remove(i); }} }}"
)
params[f"p_{kk}"] = vv
if scripts:
self.es.update(
index=index_name,
id=chunk_id,
script={"source": "".join(scripts), "params": params},
)
if doc_part:
self.es.update(index=index_name, id=chunk_id, doc=doc_part)
if remove_field is not None or remove_dict is not None or doc_part:
return True
except Exception as e:
self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(
e))
break
return False
# update unspecific maybe-multiple documents
bool_query = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bool_query.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in new_value.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
params[f"p_{kk}"] = vv
continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str):
v = re.sub(r"(['\n\r]|\\.)", " ", v)
params[f"pp_{k}"] = v
scripts.append(f"ctx._source.{k}=params.pp_{k};")
elif isinstance(v, int) or isinstance(v, float):
scripts.append(f"ctx._source.{k}={v};")
elif isinstance(v, list):
scripts.append(f"ctx._source.{k}=params.pp_{k};")
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=index_name).using(
self.es).query(bool_query)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for _ in range(ATTEMPT_TIME):
try:
_ = ubq.execute()
return True
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def adjust_chunk_pagerank_fea(
self,
chunk_id: str,
index_name: str,
knowledgebase_id: str,
delta: float,
min_w: float = 0.0,
max_w: float = 100.0,
row_id: int | None = None,
) -> bool:
"""Atomically adjust pagerank_fea on one chunk (painless script)."""
_ = row_id
for _ in range(ATTEMPT_TIME):
try:
self.es.update(
index=index_name,
id=chunk_id,
retry_on_conflict=3,
script={
"source": _PAGERANK_FEA_ADJUST_SCRIPT.strip(),
"lang": "painless",
"params": {
"pf": PAGERANK_FLD,
"delta": float(delta),
"min_w": float(min_w),
"max_w": float(max_w),
},
},
)
self.logger.debug(
"ESConnection.adjust_chunk_pagerank_fea(index=%s, id=%s, delta=%s) succeeded",
index_name,
chunk_id,
delta,
)
return True
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(
"ESConnection.adjust_chunk_pagerank_fea(index=%s, id=%s): %s",
index_name,
chunk_id,
e,
)
if re.search(r"connection", str(e).lower()):
time.sleep(3)
self._connect()
continue
break
return False
def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int:
assert "_id" not in condition
condition["kb_id"] = knowledgebase_id
# Build a bool query that combines id filter with other conditions
bool_query = Q("bool")
# Handle chunk IDs if present
if "id" in condition:
chunk_ids = condition["id"]
if not isinstance(chunk_ids, list):
chunk_ids = [chunk_ids]
if chunk_ids:
# Filter by specific chunk IDs
bool_query.filter.append(Q("ids", values=chunk_ids))
# If chunk_ids is empty, we don't add an ids filter - rely on other conditions
# Add all other conditions as filters
for k, v in condition.items():
if k == "id":
continue # Already handled above
if k == "exists":
bool_query.filter.append(Q("exists", field=v))
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
bool_query.must_not.append(Q("exists", field=vv))
elif isinstance(v, list):
bool_query.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.must.append(Q("term", **{k: v}))
elif v is not None:
raise Exception("Condition value must be int, str or list.")
# If no filters were added, use match_all (for tenant-wide operations)
if not bool_query.filter and not bool_query.must and not bool_query.must_not:
qry = Q("match_all")
else:
qry = bool_query
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=index_name,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
"""
Helper functions for search result
"""
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
hits = res.get("hits", {}).get("hits", [])
for hit in hits:
doc_id = hit.get("_id")
d = hit.get("_source", {})
# Also extract fields from ES "fields" response (used by dense_vector in ES 9.x)
hit_fields = hit.get("fields", {})
m = {}
for n in fields:
# First check _source
if d.get(n) is not None:
m[n] = d.get(n)
# Then check fields (ES 9.x stores dense_vector here, not in _source)
elif n in hit_fields:
vals = hit_fields[n]
# ES fields response wraps dense_vector in 2 levels: [[v1,v2,...]] -> [v1,v2,...]
if isinstance(vals, list) and len(vals) == 1:
vals = vals[0]
m[n] = vals
for n, v in m.items():
if isinstance(v, list):
m[n] = v
continue
if n == "available_int" and isinstance(v, (int, float)):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(m[n])
# if n.find("tks") > 0:
# m[n] = remove_redundant_spaces(m[n])
if m:
res_fields[doc_id] = m
return res_fields