Files
ragflow/rag/utils/es_conn.py
Wang Qi f45ce00347 Not allow to sort by id (#14526)
### What problem does this PR solve?

id as "text", not a "keyword", order by it will cause error.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2026-04-30 14:52:43 +08:00

624 lines
26 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import time
import copy
from elasticsearch_dsl import UpdateByQuery, Q, Search
from elastic_transport import ConnectionTimeout
from common.decorator import singleton
from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr
from common.doc_store.es_conn_base import ESConnectionBase
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
MAX_RESULT_WINDOW = 10000
SEARCH_AFTER_BATCH_SIZE = 1000
# Single-document atomic pagerank_fea adjust (chunk feedback). Clamps using params.min_w / max_w;
# removes field at zero for rank_feature compatibility.
_PAGERANK_FEA_ADJUST_SCRIPT = """
double cur = 0.0;
if (ctx._source.containsKey(params.pf)) {
Object v = ctx._source[params.pf];
if (v != null) {
if (v instanceof Number) {
cur = ((Number)v).doubleValue();
} else {
try { cur = Double.parseDouble(v.toString()); } catch (Exception e) { cur = 0.0; }
}
}
}
double nw = cur + params.delta;
if (nw < params.min_w) { nw = params.min_w; }
if (nw > params.max_w) { nw = params.max_w; }
if (nw <= 0.0) {
if (ctx._source.containsKey(params.pf)) {
ctx._source.remove(params.pf);
}
} else {
ctx._source[params.pf] = nw;
}
"""
@singleton
class ESConnection(ESConnectionBase):
"""
CRUD operations
"""
def _es_search_once(self, index_names: list[str], query: dict, track_total_hits: bool):
return self.es.search(
index=index_names,
body=query,
timeout="600s",
track_total_hits=track_total_hits,
_source=True,
)
def _search_with_search_after(self, index_names: list[str], query: dict, offset: int, limit: int):
q_base = copy.deepcopy(query)
q_base.pop("from", None)
q_base.pop("size", None)
search_after = None
template_res = None
collected_hits = []
remaining_skip = max(0, offset)
remaining_take = max(0, limit)
with_aggs = True
while remaining_skip > 0:
batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_skip)
q_iter = copy.deepcopy(q_base)
q_iter["size"] = batch
if search_after is not None:
q_iter["search_after"] = search_after
if not with_aggs:
q_iter.pop("aggs", None)
res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None)
if template_res is None:
template_res = res
hits = res.get("hits", {}).get("hits", [])
if not hits:
break
next_search_after = hits[-1].get("sort")
if not next_search_after or next_search_after == search_after:
break
search_after = next_search_after
remaining_skip -= len(hits)
with_aggs = False
if len(hits) < batch:
break
while remaining_skip <= 0 and remaining_take > 0:
batch = min(SEARCH_AFTER_BATCH_SIZE, remaining_take)
q_iter = copy.deepcopy(q_base)
q_iter["size"] = batch
if search_after is not None:
q_iter["search_after"] = search_after
if not with_aggs:
q_iter.pop("aggs", None)
res = self._es_search_once(index_names, q_iter, track_total_hits=template_res is None)
if template_res is None:
template_res = res
hits = res.get("hits", {}).get("hits", [])
if not hits:
break
collected_hits.extend(hits)
remaining_take -= len(hits)
next_search_after = hits[-1].get("sort")
if not next_search_after or next_search_after == search_after:
break
search_after = next_search_after
with_aggs = False
if len(hits) < batch:
break
if template_res is None:
q_count = copy.deepcopy(q_base)
q_count["size"] = 0
template_res = self._es_search_once(index_names, q_count, track_total_hits=True)
template_res["hits"]["hits"] = collected_hits
return template_res
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
assert "_id" not in condition
bool_query = Q("bool", must=[])
condition["kb_id"] = knowledgebase_ids
for k, v in condition.items():
if k == "available_int":
if v == 0:
bool_query.filter.append(Q("range", available_int={"lt": 1}))
else:
bool_query.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if k == "id":
if not v:
continue
if isinstance(v, list):
bool_query.filter.append(
Q("bool", should=[Q("terms", id=v), Q("terms", _id=v)], minimum_should_match=1))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(
Q("bool", should=[Q("term", id=v), Q("term", _id=v)], minimum_should_match=1))
continue
if not v:
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(
match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bool_query is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(m.vector_column_name,
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bool_query.to_dict(),
similarity=similarity,
)
if bool_query and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bool_query:
s = s.query(bool_query)
for field in highlight_fields:
s = s.highlight(field)
if order_by:
orders = list()
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
"mode": "avg", "numeric_type": "double"}
elif field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
elif field == "id":
continue # id as "text", not a "keyword", order by it will cause error
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
if agg_fields:
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
has_dense = any(isinstance(m, MatchDenseExpr) for m in match_expressions)
has_explicit_sort = bool(order_by and order_by.fields)
use_search_after = (
limit > 0
and (offset + limit > MAX_RESULT_WINDOW)
and has_explicit_sort
and not has_dense
)
if limit > 0 and not use_search_after:
s = s[offset:offset + limit]
# Filter _source to only requested fields for efficiency, and add vector
# fields to "fields" param so they appear in hit.fields when ES 9.x
# exclude_source_vectors is enabled (dense_vector not in _source).
if select_fields:
s = s.source(select_fields)
q = s.to_dict()
# ES 9.x: dense_vector fields excluded from _source; request them via fields.
# Note: knn does NOT have a "fields" parameter - adding it inside the knn
# object causes BadRequestError on ES 9.x. We add "fields" at top level.
vector_fields = [f for f in (select_fields or []) if f.endswith("_vec")]
if vector_fields:
q["fields"] = vector_fields
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
if use_search_after:
res = self._search_with_search_after(index_names, q, offset, limit)
else:
# print(json.dumps(q, ensure_ascii=False))
res = self._es_search_once(index_names, q, track_total_hits=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
# Only log debug for NotFoundError(accepted when metadata index doesn't exist)
if 'NotFound' in str(e):
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + str(q) + " - " + str(e))
else:
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy = copy.deepcopy(d)
d_copy["kb_id"] = knowledgebase_id
# Use id as _id for uniqueness, also keep "id" as a regular field for sorting
meta_id = d_copy.get("id", "")
operations.append(
{"index": {"_index": index_name, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=index_name, operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
self.logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
doc = copy.deepcopy(new_value)
doc.pop("id", None)
condition["kb_id"] = knowledgebase_id
if "id" in condition and isinstance(condition["id"], str):
# update specific single document
chunk_id = condition["id"]
for i in range(ATTEMPT_TIME):
doc_part = copy.deepcopy(doc)
remove_value = doc_part.pop("remove", None)
remove_field = remove_value if isinstance(remove_value, str) else None
remove_dict = remove_value if isinstance(remove_value, dict) else None
for k in doc_part.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
if remove_field is not None:
self.es.update(
index=index_name,
id=chunk_id,
script=f"ctx._source.remove('{remove_field}');",
)
if remove_dict is not None:
scripts = []
params = {}
for kk, vv in remove_dict.items():
scripts.append(
f"if (ctx._source.containsKey('{kk}') && ctx._source.{kk} != null) "
f"{{ int i = ctx._source.{kk}.indexOf(params.p_{kk}); "
f"if (i >= 0) {{ ctx._source.{kk}.remove(i); }} }}"
)
params[f"p_{kk}"] = vv
if scripts:
self.es.update(
index=index_name,
id=chunk_id,
script={"source": "".join(scripts), "params": params},
)
if doc_part:
self.es.update(index=index_name, id=chunk_id, doc=doc_part)
if remove_field is not None or remove_dict is not None or doc_part:
return True
except Exception as e:
self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(
e))
break
return False
# update unspecific maybe-multiple documents
bool_query = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bool_query.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in new_value.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
params[f"p_{kk}"] = vv
continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "available_int":
continue
if isinstance(v, str):
v = re.sub(r"(['\n\r]|\\.)", " ", v)
params[f"pp_{k}"] = v
scripts.append(f"ctx._source.{k}=params.pp_{k};")
elif isinstance(v, int) or isinstance(v, float):
scripts.append(f"ctx._source.{k}={v};")
elif isinstance(v, list):
scripts.append(f"ctx._source.{k}=params.pp_{k};")
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=index_name).using(
self.es).query(bool_query)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for _ in range(ATTEMPT_TIME):
try:
_ = ubq.execute()
return True
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def adjust_chunk_pagerank_fea(
self,
chunk_id: str,
index_name: str,
knowledgebase_id: str,
delta: float,
min_w: float = 0.0,
max_w: float = 100.0,
row_id: int | None = None,
) -> bool:
"""Atomically adjust pagerank_fea on one chunk (painless script)."""
_ = row_id
for _ in range(ATTEMPT_TIME):
try:
self.es.update(
index=index_name,
id=chunk_id,
retry_on_conflict=3,
script={
"source": _PAGERANK_FEA_ADJUST_SCRIPT.strip(),
"lang": "painless",
"params": {
"pf": PAGERANK_FLD,
"delta": float(delta),
"min_w": float(min_w),
"max_w": float(max_w),
},
},
)
self.logger.debug(
"ESConnection.adjust_chunk_pagerank_fea(index=%s, id=%s, delta=%s) succeeded",
index_name,
chunk_id,
delta,
)
return True
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(
"ESConnection.adjust_chunk_pagerank_fea(index=%s, id=%s): %s",
index_name,
chunk_id,
e,
)
if re.search(r"connection", str(e).lower()):
time.sleep(3)
self._connect()
continue
break
return False
def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int:
assert "_id" not in condition
condition["kb_id"] = knowledgebase_id
# Build a bool query that combines id filter with other conditions
bool_query = Q("bool")
# Handle chunk IDs if present
if "id" in condition:
chunk_ids = condition["id"]
if not isinstance(chunk_ids, list):
chunk_ids = [chunk_ids]
if chunk_ids:
# Filter by specific chunk IDs
bool_query.filter.append(Q("ids", values=chunk_ids))
# If chunk_ids is empty, we don't add an ids filter - rely on other conditions
# Add all other conditions as filters
for k, v in condition.items():
if k == "id":
continue # Already handled above
if k == "exists":
bool_query.filter.append(Q("exists", field=v))
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
bool_query.must_not.append(Q("exists", field=vv))
elif isinstance(v, list):
bool_query.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.must.append(Q("term", **{k: v}))
elif v is not None:
raise Exception("Condition value must be int, str or list.")
# If no filters were added, use match_all (for tenant-wide operations)
if not bool_query.filter and not bool_query.must and not bool_query.must_not:
qry = Q("match_all")
else:
qry = bool_query
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=index_name,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
"""
Helper functions for search result
"""
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
hits = res.get("hits", {}).get("hits", [])
for hit in hits:
doc_id = hit.get("_id")
d = hit.get("_source", {})
# Also extract fields from ES "fields" response (used by dense_vector in ES 9.x)
hit_fields = hit.get("fields", {})
m = {}
for n in fields:
# First check _source
if d.get(n) is not None:
m[n] = d.get(n)
# Then check fields (ES 9.x stores dense_vector here, not in _source)
elif n in hit_fields:
vals = hit_fields[n]
# ES fields response wraps dense_vector in 2 levels: [[v1,v2,...]] -> [v1,v2,...]
if isinstance(vals, list) and len(vals) == 1:
vals = vals[0]
m[n] = vals
for n, v in m.items():
if isinstance(v, list):
m[n] = v
continue
if n == "available_int" and isinstance(v, (int, float)):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(m[n])
# if n.find("tks") > 0:
# m[n] = remove_redundant_spaces(m[n])
if m:
res_fields[doc_id] = m
return res_fields