Files
ragflow/agent/tools/retrieval.py
sxxtony 59c35100c5 Perf: push metadata filters down to Elasticsearch (#14576)
### What problem does this PR solve?

Fixes #14412.

`common.metadata_utils.meta_filter` evaluates user-defined metadata
conditions in Python after `DocMetadataService.get_flatted_meta_by_kbs`
loads the entire `meta_fields` table into memory. Past a few thousand
documents per knowledge base this becomes a memory bottleneck and a
wasted ES round-trip — every filter request currently fetches up to
10000 metadata rows even when the resulting `doc_ids` list is tiny.

This PR adds an ES push-down path that translates the same filter
language into a `bool` query and returns just the matching document IDs.

**Changes**

- `common/metadata_es_filter.py` *(new)*: pure-Python translator from
the RAGflow filter list to ES DSL. Covers every operator the in-memory
path supports (`=`, `≠`, `>`, `<`, `≥`, `≤`, `in`, `not in`, `contains`,
`not contains`, `start with`, `end with`, `empty`, `not empty`) with
`case_insensitive: true` on `prefix` and `wildcard` for parity with the
existing lower-cased Python comparisons. User wildcard metacharacters
are escaped before being injected into `wildcard` patterns. Negative
operators (`≠`, `not in`, `not contains`, ranges) are wrapped with an
`exists` guard so they do not accidentally match documents missing the
key, matching the legacy `if k not in metas` behaviour.
- `api/db/services/doc_metadata_service.py`: new
`DocMetadataService.filter_doc_ids_by_meta_pushdown(kb_ids, filters,
logic)` that returns the doc IDs ES matched, or `None` to signal the
caller should fall back to the in-memory path. Returns `None` when the
active doc store is Infinity (`meta_fields` is a JSON column, not a
dotted-object mapping), when any filter cannot be expressed in DSL
(`UnsupportedMetaFilter`), or when the ES request or metadata index
lookup errors.
- `common/metadata_utils.py`: `apply_meta_data_filter` accepts an
optional `kb_ids` argument. When supplied, conditions go through
push-down first via a new `_try_meta_pushdown` helper; on `None` the
function falls back to the original `meta_filter` call. Default
behaviour is unchanged for callers that don't pass `kb_ids`.
- Updated all four callers (`agent/tools/retrieval.py`,
`api/db/services/dialog_service.py` ×2,
`api/apps/services/dataset_api_service.py`, `api/apps/sdk/session.py`)
to forward `kb_ids` so the push-down path is exercised in production.
- `test/unit_test/common/test_metadata_es_filter.py` *(new)*: 35 unit
tests covering every operator's DSL shape, value coercion
(`ast.literal_eval`, lowercasing, ISO-date pass-through), wildcard
escaping, OR-logic wrapping that protects negative clauses, and the
doc-ID extractor.

**Behaviour preserved**

- The in-memory `meta_filter` is untouched and still services every
fallback case (Infinity backend, unknown operators, ES outages).
- The eligibility / credibility / issue-multiplier semantics described
in the LLM-driven `auto` and `semi_auto` modes still hand the LLM the
full in-memory `metas` dict to choose conditions from. Only the
*evaluation* of those generated conditions is pushed down.
- Existing tests in
`test/unit_test/common/test_metadata_filter_operators.py` continue to
pass (14/14).

**Test plan**

- `pytest test/unit_test/common/test_metadata_es_filter.py` — 35 passed.
- `pytest test/unit_test/common/test_metadata_filter_operators.py` — 14
passed.
- `ruff check` clean on every modified file.
- Reviewer please validate the ES query shapes against a live cluster —
particularly `case_insensitive` on `wildcard` and `prefix` (requires ES
7.10+) and the `exists` + `must_not` pairing for `≠`.

**Notes**

- The first cut caps each push-down request at 10000 results, matching
the existing `get_flatted_meta_by_kbs` limit, and logs a warning when
the cap is hit. A `search_after` follow-up would let us drop the cap
entirely once the push-down path is validated.
- Operator parity with the in-memory path is exact for the canonical
unicode operators (`≥`, `≤`, `≠`) used internally; the ASCII aliases
(`>=`, `<=`, `!=`) are normalised by `convert_conditions` before they
reach the translator.

### Type of change

- [x] Performance Improvement

---------

Co-authored-by: sxxtony <sxxtony@users.noreply.github.com>
2026-05-07 21:23:43 +08:00

337 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from functools import partial
import json
import os
import re
from abc import ABC
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from common.constants import LLMType
from api.db.services.doc_metadata_service import DocMetadataService
from common.metadata_utils import apply_meta_data_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.memory_service import MemoryService
from api.db.joint_services import memory_message_service
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type
from common import settings
from common.connection_utils import timeout
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt
class RetrievalParam(ToolParamBase):
"""
Define the Retrieval component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "search_my_dateset",
"description": "This tool can be utilized for relevant content searching in the datasets.",
"parameters": {
"query": {
"type": "string",
"description": "The keywords to search the dataset. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "",
"required": True
}
}
}
super().__init__()
self.function_name = "search_my_dateset"
self.description = "This tool can be utilized for relevant content searching in the datasets."
self.similarity_threshold = 0.2
self.keywords_similarity_weight = 0.5
self.top_n = 8
self.top_k = 1024
self.dataset_ids = []
self.kb_ids = [] # Deprecated: keep for backward compatibility
self.memory_ids = []
self.kb_vars = []
self.rerank_id = ""
self.empty_response = ""
self.use_kg = False
self.cross_languages = []
self.toc_enhance = False
self.meta_data_filter={}
def check(self):
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keyword similarity weight")
self.check_positive_number(self.top_n, "[Retrieval] Top N")
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
}
}
class Retrieval(ToolBase, ABC):
component_name = "Retrieval"
@property
def _dataset_ids(self):
"""Get dataset IDs with backward compatibility for kb_ids."""
return self._param.dataset_ids or getattr(self._param, "kb_ids", None) or []
async def _retrieve_kb(self, query_text: str):
kb_ids: list[str] = []
for id in self._dataset_ids:
if id.find("@") < 0:
kb_ids.append(id)
continue
kb_nm = self._canvas.get_variable_value(id)
# if kb_nm is a list
kb_nm_list = kb_nm if isinstance(kb_nm, list) else [kb_nm]
for nm_or_id in kb_nm_list:
e, kb = KnowledgebaseService.get_by_name(nm_or_id,
self._canvas._tenant_id)
if not e:
e, kb = KnowledgebaseService.get_by_id(nm_or_id)
if not e:
raise Exception(f"Dataset({nm_or_id}) does not exist.")
kb_ids.append(kb.id)
filtered_kb_ids: list[str] = list(set([kb_id for kb_id in kb_ids if kb_id]))
kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids)
if not kbs:
raise Exception("No dataset is selected.")
embd_nms = list(set([kb.embd_id for kb in kbs]))
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
embd_mdl = None
if embd_nms:
tenant_id = self._canvas.get_tenant_id()
embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_nms[0])
embd_mdl = LLMBundle(tenant_id, embd_model_config)
rerank_mdl = None
if self._param.rerank_id:
rerank_model_config = get_model_config_by_type_and_name(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
rerank_mdl = LLMBundle(kbs[0].tenant_id, rerank_model_config)
vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
doc_ids = []
if self._param.meta_data_filter != {}:
# Defer the (potentially expensive) metadata table load — manual
# filters served by ES push-down never need it. The loader is
# invoked at most once per request by ``apply_meta_data_filter``.
def _load_metas() -> dict:
return DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
def _resolve_manual_filter(flt: dict) -> dict:
pat = re.compile(self.variable_ref_patt)
s = flt.get("value", "")
out_parts = []
last = 0
for m in pat.finditer(s):
out_parts.append(s[last:m.start()])
key = m.group(1)
v = self._canvas.get_variable_value(key)
if v is None:
rep = ""
elif isinstance(v, partial):
buf = []
for chunk in v():
buf.append(chunk)
rep = "".join(buf)
elif isinstance(v, str):
rep = v
else:
rep = json.dumps(v, ensure_ascii=False)
out_parts.append(rep)
last = m.end()
out_parts.append(s[last:])
flt["value"] = "".join(out_parts)
return flt
chat_mdl = None
if self._param.meta_data_filter.get("method") in ["auto", "semi_auto"]:
tenant_id = self._canvas.get_tenant_id()
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
doc_ids = await apply_meta_data_filter(
self._param.meta_data_filter,
None,
query,
chat_mdl,
doc_ids,
_resolve_manual_filter if self._param.meta_data_filter.get("method") == "manual" else None,
kb_ids=kb_ids,
metas_loader=_load_metas,
)
if self._param.cross_languages:
query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = await settings.retriever.retrieval(
query,
embd_mdl,
[kb.tenant_id for kb in kbs],
filtered_kb_ids,
1,
self._param.top_n,
self._param.similarity_threshold,
1 - self._param.keywords_similarity_weight,
doc_ids=doc_ids,
aggs=True,
rerank_mdl=rerank_mdl,
rank_feature=label_question(query, kbs),
)
if self.check_if_canceled("Retrieval processing"):
return
if self._param.toc_enhance:
tenant_id = self._canvas._tenant_id
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"):
return
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
[kb.tenant_id for kb in kbs])
if self._param.use_kg:
tenant_id = self._canvas.get_tenant_id()
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(tenant_id, chat_model_config))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
else:
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
chat_model_config = get_tenant_default_model_by_type(kbs[0].tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
LLMBundle(kbs[0].tenant_id, chat_model_config))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
ck["content"] = ck["content_with_weight"]
del ck["content_with_weight"]
kbinfos["chunks"].insert(0, ck)
for ck in kbinfos["chunks"]:
if "vector" in ck:
del ck["vector"]
if "content_ltks" in ck:
del ck["content_ltks"]
if not kbinfos["chunks"]:
self.set_output("formalized_content", self._param.empty_response)
return
# Format the chunks for JSON output (similar to how other tools do it)
json_output = kbinfos["chunks"].copy()
self._canvas.add_reference(kbinfos["chunks"], kbinfos["doc_aggs"])
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
# Set both formalized content and JSON output
self.set_output("formalized_content", form_cnt)
self.set_output("json", json_output)
return form_cnt
async def _retrieve_memory(self, query_text: str):
memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids]
user_id: str = self._param.user_id if hasattr(self._param, "user_id") else None
memory_list = MemoryService.get_by_ids(memory_ids)
if not memory_list:
raise Exception("No memory is selected.")
embd_names = list({memory.embd_id for memory in memory_list})
assert len(embd_names) == 1, "Memory use different embedding models."
vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
# query message
filter_dict: dict = {"memory_id": memory_ids}
if user_id:
import re
# is variable
if re.match(r"^{.*}$", user_id):
user_id = self._canvas.get_variable_value(user_id)
filter_dict["user_id"] = user_id
message_list = memory_message_service.query_message(filter_dict, {
"query": query,
"similarity_threshold": self._param.similarity_threshold,
"keywords_similarity_weight": self._param.keywords_similarity_weight,
"top_n": self._param.top_n
})
if not message_list:
self.set_output("formalized_content", self._param.empty_response)
return ""
formated_content = "\n".join(memory_prompt(message_list, 200000))
# set formalized_content output
self.set_output("formalized_content", formated_content)
return formated_content
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
if hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "dataset":
return await self._retrieve_kb(kwargs["query"])
elif hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "memory":
return await self._retrieve_memory(kwargs["query"])
elif self._dataset_ids:
return await self._retrieve_kb(kwargs["query"])
elif hasattr(self._param, "memory_ids") and self._param.memory_ids:
return await self._retrieve_memory(kwargs["query"])
else:
self.set_output("formalized_content", self._param.empty_response)
return
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
def thoughts(self) -> str:
return """
Keywords: {}
Looking for the most relevant articles.
""".format(self.get_input().get("query", "-_-!"))