From 24af0875e532ec2f9d7edd5c9a0b50a27f379bfa Mon Sep 17 00:00:00 2001 From: Attili-sys Date: Thu, 30 Apr 2026 18:13:27 +0300 Subject: [PATCH] Feat/configurable metadata display (#13464) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Currently, RAGFlow's Search and Chat interfaces display only raw vectorized text chunks during retrieval, without contextual information about their source documents. Users cannot see document titles, page numbers, upload dates, or custom metadata fields that would help them understand and trust the retrieved results. This PR introduces an **optional metadata display feature** that enriches retrieved chunks with document-level metadata in both the Search tab and Chatbot interface. **Key improvements:** - **Search results**: Display document metadata as styled badges beneath chunk snippets - **Chat citations**: Show metadata in citation popovers and reference lists for better source context - **LLM context**: Metadata is injected into the LLM prompt to enable more accurate, citation-aware responses - **External API support**: Applications using RAGFlow's SDK retrieval endpoints (`/v1/retrieval`, `/v1/searchbots/retrieval_test`) can opt-in via request parameters - **User control**: Multi-select dropdown UI allows users to choose which metadata fields to display **Implementation approach:** - ✅ Reuses existing `DocMetadataService` infrastructure (no new database tables or indices) - ✅ Settings stored in existing JSON configuration fields (`search_config.reference_metadata`, `prompt_config.reference_metadata`) - ✅ No database migrations required - ✅ Disabled by default (fully opt-in and backward-compatible) - ✅ Dynamic metadata field selection populated from actual document metadata keys - ✅ Fixed critical bug where Python's builtin `set()` was shadowed by a route handler function **Modified endpoints (all backward-compatible):** - `POST /v1/retrieval` (Public SDK) - `POST /v1/searchbots/retrieval_test` (Searchbots) - `POST /v1/chunk/retrieval_test` (UI/Internal) - Chat completions endpoints (via `extra_body.reference_metadata` or `prompt_config`) ### Type of change - [x] New Feature (non-breaking change which adds functionality) ###Images - image



image



image --------- Co-authored-by: Cursor Agent Co-authored-by: Attili-sys Co-authored-by: Ahmad Intisar --- api/apps/restful_apis/openai_api.py | 51 ++--- api/apps/sdk/doc.py | 23 +++ api/apps/sdk/session.py | 20 +- api/db/services/dialog_service.py | 89 +++++++- api/db/services/doc_metadata_service.py | 30 +++ api/utils/reference_metadata_utils.py | 125 ++++++++++++ rag/prompts/generator.py | 17 +- run_tests.py | 28 ++- sdk/python/ragflow_sdk/modules/dataset.py | 6 +- .../test_doc_sdk_routes_unit.py | 148 +++++++++++++- .../test_session_sdk_routes_unit.py | 191 +++++++++++++++++- .../components/fallback-component/index.tsx | 31 ++- web/src/components/markdown-content/index.tsx | 22 ++ web/src/hooks/use-knowledge-request.ts | 19 ++ web/src/interfaces/database/chat.ts | 5 + .../form/doc-generator-form/use-values.ts | 2 +- .../chat/app-settings/chat-basic-settings.tsx | 84 +++++++- .../chat/app-settings/chat-settings.tsx | 24 +++ .../app-settings/use-chat-setting-schema.tsx | 6 + web/src/pages/next-search/search-setting.tsx | 115 +++++++++++ web/src/pages/next-search/search-view.tsx | 26 +++ web/src/pages/next-searches/hooks.ts | 4 + web/src/services/knowledge-service.ts | 5 + 23 files changed, 1004 insertions(+), 67 deletions(-) create mode 100644 api/utils/reference_metadata_utils.py diff --git a/api/apps/restful_apis/openai_api.py b/api/apps/restful_apis/openai_api.py index 320ecd09d..baa011f32 100644 --- a/api/apps/restful_apis/openai_api.py +++ b/api/apps/restful_apis/openai_api.py @@ -48,44 +48,35 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None): return None +import logging +from api.utils.reference_metadata_utils import enrich_chunks_with_document_metadata + def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None): chunks = chunks_format(reference) if not include_metadata: + logging.debug("Skipping document metadata enrichment (include_metadata=False)") return chunks - doc_ids_by_kb = {} - for chunk in chunks: - kb_id = chunk.get("dataset_id") - doc_id = chunk.get("document_id") - if not kb_id or not doc_id: - continue - doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id) - - if not doc_ids_by_kb: - return chunks - - meta_by_doc = {} - for kb_id, doc_ids in doc_ids_by_kb.items(): - meta_map = DocMetadataService.get_metadata_for_documents(list(doc_ids), kb_id) - if meta_map: - meta_by_doc.update(meta_map) - + normalized_fields = None if metadata_fields is not None: - metadata_fields = {f for f in metadata_fields if isinstance(f, str)} - if not metadata_fields: + if not isinstance(metadata_fields, list): + return chunks + normalized_fields = {f for f in metadata_fields if isinstance(f, str)} + if not normalized_fields: return chunks - for chunk in chunks: - doc_id = chunk.get("document_id") - if not doc_id: - continue - meta = meta_by_doc.get(doc_id) - if not meta: - continue - if metadata_fields is not None: - meta = {k: v for k, v in meta.items() if k in metadata_fields} - if meta: - chunk["document_metadata"] = meta + logging.debug( + "Enriching %d chunks with document metadata (fields: %s)", + len(chunks), + "ALL" if normalized_fields is None else list(normalized_fields), + ) + + enrich_chunks_with_document_metadata( + chunks, + normalized_fields, + kb_field="dataset_id", + doc_field="document_id", + ) return chunks diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index dbb8f9203..9aa641ccf 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging from io import BytesIO from quart import request, send_file @@ -37,6 +38,18 @@ from rag.prompts.generator import cross_languages, keyword_extraction MAXIMUM_OF_UPLOADING_FILES = 256 +from api.utils.reference_metadata_utils import ( + enrich_chunks_with_document_metadata, + resolve_reference_metadata_preferences, +) + +def _resolve_reference_metadata(req: dict, search_config: dict | None = None): + return resolve_reference_metadata_preferences(req, search_config) + +def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=None) -> None: + enrich_chunks_with_document_metadata(chunks, metadata_fields) + + @manager.route("/datasets//documents/", methods=["GET"]) # noqa: F821 @token_required async def download(tenant_id, dataset_id, document_id): @@ -450,6 +463,7 @@ async def retrieval_test(tenant_id): return get_error_data_result("`highlight` should be a boolean") else: return get_error_data_result("`highlight` should be a boolean") + include_metadata, metadata_fields = _resolve_reference_metadata(req) try: tenant_ids = list(set([kb.tenant_id for kb in kbs])) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) @@ -508,6 +522,15 @@ async def retrieval_test(tenant_id): for c in ranks["chunks"]: c.pop("vector", None) + if include_metadata: + logging.info( + "sdk.retrieval reference_metadata enabled dataset_ids=%s fields=%s chunks=%s", + kb_ids, + sorted(metadata_fields) if metadata_fields else None, + len(ranks["chunks"]), + ) + enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields) + ##rename keys renamed_chunks = [] for chunk in ranks["chunks"]: diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 8b6a777ba..2cb431299 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -44,6 +44,10 @@ from rag.prompts.template import load_prompt from rag.prompts.generator import cross_languages, keyword_extraction from common.constants import RetCode, LLMType from common import settings +from api.utils.reference_metadata_utils import ( + enrich_chunks_with_document_metadata, + resolve_reference_metadata_preferences, +) @token_required @@ -327,6 +331,7 @@ async def retrieval_test_embedded(): tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") + search_config = {} async def _retrieval(): nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id @@ -337,8 +342,11 @@ async def retrieval_test_embedded(): meta_data_filter = {} chat_mdl = None if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) - meta_data_filter = search_config.get("meta_data_filter", {}) + nonlocal search_config + detail = SearchService.get_detail(req.get("search_id", "")) + if detail: + search_config = detail.get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_id = search_config.get("chat_id", "") if chat_id: @@ -414,6 +422,11 @@ async def retrieval_test_embedded(): for c in ranks["chunks"]: c.pop("vector", None) + + include_metadata, metadata_fields = _resolve_reference_metadata(req, search_config) + if include_metadata: + enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields) + ranks["labels"] = labels return get_json_result(data=ranks) @@ -529,3 +542,6 @@ async def mindmap(): return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) + +def _resolve_reference_metadata(req, search_config=None): + return resolve_reference_metadata_preferences(req, search_config) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 608391405..09ca70c43 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -33,6 +33,10 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.llm_service import LLMBundle from common.metadata_utils import apply_meta_data_filter +from api.utils.reference_metadata_utils import ( + enrich_chunks_with_document_metadata, + resolve_reference_metadata_preferences, +) from api.db.services.tenant_llm_service import TenantLLMService from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type from common.time_utils import current_timestamp, datetime_format @@ -48,6 +52,16 @@ from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces from common import settings +def _resolve_reference_metadata(request_payload=None, config=None): + return resolve_reference_metadata_preferences(request_payload or {}, config) + +def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None): + enrich_chunks_with_document_metadata(chunks, metadata_fields) + +def _chunk_kb_id_for_doc(row_dict, kb_ids, doc_id): + if len(kb_ids or []) == 1: + return kb_ids[0] + return row_dict.get("kb_id") or row_dict.get("kb_id_kwd") def _normalize_internet_flag(value): if isinstance(value, bool): @@ -70,6 +84,15 @@ def _should_use_web_search(prompt_config, internet=None): return normalized is True +def _resolve_reference_metadata(config, request_payload=None): + return resolve_reference_metadata_preferences(request_payload or {}, config) + + +def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None): + enrich_chunks_with_document_metadata(chunks, metadata_fields) + + + class DialogService(CommonService): model = Dialog @@ -547,6 +570,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): attachments_ = "\n\n".join(text_attachments) prompt_config = dialog.prompt_config + include_reference_metadata, metadata_fields = _resolve_reference_metadata(prompt_config, request_payload=kwargs) field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) logging.debug(f"field_map retrieved: {field_map}") # try to use sql if field mapping is good to go @@ -555,6 +579,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs): ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) # For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")): + if include_reference_metadata and ans.get("reference", {}).get("chunks"): + if len(dialog.kb_ids) != 1 and any(not c.get("kb_id") for c in ans["reference"]["chunks"]): + logging.warning( + "Skipping some _enrich_chunks_with_document_metadata results because " + "dialog.kb_ids has %d entries and use_sql returned chunks without kb_id.", + len(dialog.kb_ids), + ) + _enrich_chunks_with_document_metadata(ans["reference"]["chunks"], metadata_fields) yield ans return else: @@ -675,6 +707,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs): if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) + if include_reference_metadata: + logging.debug( + "reference_metadata enrichment enabled for async_chat: chunk_count=%d metadata_fields=%s", + len(kbinfos.get("chunks", [])), + metadata_fields, + ) + _enrich_chunks_with_document_metadata(kbinfos.get("chunks", []), metadata_fields) + knowledges = kb_prompt(kbinfos, max_tokens) logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) @@ -1121,11 +1161,12 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"]) doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]]) + kb_id_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["kb_id", "kb_id_kwd"]]) logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}") - logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}") + logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}, kb_id_idx={kb_id_idx}") - column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] + column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx | kb_id_idx)] logging.debug(f"use_sql: column_idx={column_idx}") logging.debug(f"use_sql: field_map={field_map}") @@ -1221,8 +1262,11 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE) if where_match: where_clause = where_match.group(1).strip() - # Build a query to get doc_id and docnm_kwd with the same WHERE clause - chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}" + # Build a query to get source fields with the same WHERE clause. + # Single-KB queries can derive kb_id from the dialog, while multi-KB + # ES/OS queries need the row value for metadata enrichment. + chunks_kb_column = ", kb_id" if not (kb_ids and len(kb_ids) == 1) else "" + chunks_sql = f"select doc_id, {expected_doc_name_column}{chunks_kb_column} from {table_name} where {where_clause}" # Add LIMIT to avoid fetching too many chunks if "limit" not in chunks_sql.lower(): chunks_sql += " limit 20" @@ -1233,8 +1277,18 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat # Build chunks reference - use case-insensitive matching chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None) chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None) + chunks_kb_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["kb_id", "kb_id_kwd"]), None) if chunks_did_idx is not None and chunks_dn_idx is not None: - chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]] + chunks = [] + for r in chunks_tbl["rows"]: + chunk = {"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} + row_dict = {chunks_tbl["columns"][i]["name"]: r[i] for i in range(len(chunks_tbl["columns"])) if i < len(r)} + kb_id = _chunk_kb_id_for_doc(row_dict, kb_ids, chunk["doc_id"]) + if kb_id: + chunk["kb_id"] = kb_id + elif chunks_kb_idx is not None: + chunk["kb_id"] = r[chunks_kb_idx] + chunks.append(chunk) # Build doc_aggs doc_aggs = {} for r in chunks_tbl["rows"]: @@ -1264,7 +1318,22 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat result = { "answer": "\n".join([columns, line, rows]), "reference": { - "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], + "chunks": [ + { + key: value + for key, value in { + "doc_id": r[docid_idx], + "docnm_kwd": r[doc_name_idx], + "kb_id": _chunk_kb_id_for_doc( + {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}, + kb_ids, + r[docid_idx], + ), + }.items() + if value + } + for r in tbl["rows"] + ], "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()], }, "prompt": sys_prompt, @@ -1414,6 +1483,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf chat_llm_name = search_config.get("chat_id", chat_llm_name) rerank_id = search_config.get("rerank_id", "") meta_data_filter = search_config.get("meta_data_filter") + include_reference_metadata, metadata_fields = _resolve_reference_metadata(search_config) kbs = KnowledgebaseService.get_by_ids(kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) @@ -1450,6 +1520,13 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf rerank_mdl=rerank_mdl, rank_feature=label_question(question, kbs) ) + if include_reference_metadata: + logging.debug( + "reference_metadata enrichment enabled for async_ask: chunk_count=%d metadata_fields=%s", + len(kbinfos.get("chunks", [])), + metadata_fields, + ) + _enrich_chunks_with_document_metadata(kbinfos.get("chunks", []), metadata_fields) knowledges = kb_prompt(kbinfos, max_tokens) sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges)) diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py index 2e4b93056..db05f4bb2 100644 --- a/api/db/services/doc_metadata_service.py +++ b/api/db/services/doc_metadata_service.py @@ -772,6 +772,36 @@ class DocMetadataService: logging.error(f"Error getting flattened metadata for KBs {kb_ids}: {e}") return {} + @classmethod + def get_metadata_keys_by_kbs(cls, kb_ids: List[str]) -> List[str]: + """ + Get unique metadata field names across multiple knowledge bases. + + Args: + kb_ids: List of knowledge base IDs + + Returns: + Sorted list of unique metadata field names + """ + if not kb_ids: + return [] + + logging.debug(f"get_metadata_keys_by_kbs start: n_kbs={len(kb_ids)}") + keys: set[str] = set() + try: + for kb_id in kb_ids: + results = cls._search_metadata(kb_id, condition={"kb_id": kb_id}) + for _doc_id, doc in cls._iter_search_results(results): + doc_meta = cls._extract_metadata(doc) + if not isinstance(doc_meta, dict): + continue + keys.update(str(k) for k in doc_meta.keys()) + logging.debug(f"get_metadata_keys_by_kbs end: n_keys={len(keys)}, kb_ids={kb_ids}") + return sorted(keys) + except Exception as e: + logging.error(f"Error getting metadata keys for KBs {kb_ids}: {e}") + return [] + @classmethod def get_metadata_for_documents(cls, doc_ids: Optional[List[str]], kb_id: str) -> Dict[str, Dict]: """ diff --git a/api/utils/reference_metadata_utils.py b/api/utils/reference_metadata_utils.py new file mode 100644 index 000000000..58d5beffb --- /dev/null +++ b/api/utils/reference_metadata_utils.py @@ -0,0 +1,125 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging + +logger = logging.getLogger(__name__) + + +def resolve_reference_metadata_preferences( + request_payload: dict | None = None, + config_payload: dict | None = None, +) -> tuple[bool, set[str] | None]: + """ + Resolve metadata include/fields from request and optional config. + Request values take precedence over config values. + Supports legacy request keys: include_metadata / metadata_fields. + """ + request_payload = request_payload or {} + config_payload = config_payload or {} + + config_ref = config_payload.get("reference_metadata", {}) + request_ref = request_payload.get("reference_metadata", {}) + + resolved: dict = {} + if isinstance(config_ref, dict): + resolved.update(config_ref) + if isinstance(request_ref, dict): + resolved.update(request_ref) + + if "include_metadata" in request_payload: + resolved["include"] = bool(request_payload.get("include_metadata")) + if "metadata_fields" in request_payload: + resolved["fields"] = request_payload.get("metadata_fields") + + include_metadata = bool(resolved.get("include", False)) + fields = resolved.get("fields") + if fields is None: + return include_metadata, None + if not isinstance(fields, list): + logger.warning( + "reference_metadata.fields is not a list; include_metadata=%s fields=%r type=%s resolved=%r. " + "enrich_chunks_with_document_metadata will skip enrichment.", + include_metadata, + fields, + type(fields).__name__, + resolved, + ) + return include_metadata, set() + return include_metadata, {f for f in fields if isinstance(f, str)} + + +def enrich_chunks_with_document_metadata( + chunks: list[dict], + metadata_fields: set[str] | None = None, + *, + kb_field: str = "kb_id", + doc_field: str = "doc_id", + output_field: str = "document_metadata", +) -> None: + """ + Mutates chunk payloads in-place by attaching `document_metadata`. + Field names can be customized for different chunk schemas. + """ + if metadata_fields is not None and not metadata_fields: + return + + doc_ids_by_kb: dict[str, set[str]] = {} + for chunk in chunks: + kb_ids = chunk.get(kb_field) + doc_id = chunk.get(doc_field) + if not kb_ids or not doc_id: + continue + if isinstance(kb_ids, (list, tuple)): + for kid in kb_ids: + if kid: + doc_ids_by_kb.setdefault(kid, set()).add(doc_id) + else: + doc_ids_by_kb.setdefault(kb_ids, set()).add(doc_id) + + if not doc_ids_by_kb: + return + + # Resolve service lazily so callers/tests that swap service modules at runtime + # (e.g. via monkeypatch) don't get stuck with a stale class reference. + from api.db.services.doc_metadata_service import DocMetadataService + metadata_getter = getattr(DocMetadataService, "get_metadata_for_documents", None) + if not callable(metadata_getter): + logging.warning( + "DocMetadataService.get_metadata_for_documents is unavailable; " + "skipping metadata enrichment." + ) + return + + meta_by_doc: dict[str, dict] = {} + for kb_id, doc_ids in doc_ids_by_kb.items(): + meta_map = metadata_getter(list(doc_ids), kb_id) + if meta_map: + meta_by_doc.update(meta_map) + logging.debug("Fetched metadata for %d docs in kb_id=%s", len(meta_map), kb_id) + + for chunk in chunks: + doc_id = chunk.get(doc_field) + if not doc_id: + continue + meta = meta_by_doc.get(doc_id) + if not meta: + continue + if metadata_fields is not None: + meta = {k: v for k, v in meta.items() if k in metadata_fields} + if meta: + chunk[output_field] = meta + logging.debug("Enriched chunk for doc_id=%s with %d metadata fields: %s", doc_id, len(meta), list(meta.keys())) diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 47c0b9f2b..2ef8b8f8c 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -58,6 +58,7 @@ def chunks_format(reference): "term_similarity": chunk.get("term_similarity"), "row_id": chunk.get("row_id"), "doc_type": get_value(chunk, "doc_type_kwd", "doc_type"), + "document_metadata": chunk.get("document_metadata"), } for chunk in raw_chunks if isinstance(chunk, dict) @@ -102,9 +103,6 @@ def message_fit_in(msg, max_length=4000): def kb_prompt(kbinfos, max_tokens, hash_id=False): - from api.db.services.document_service import DocumentService - from api.db.services.doc_metadata_service import DocMetadataService - knowledges = [get_value(ck, "content", "content_with_weight") for ck in kbinfos["chunks"]] kwlg_len = len(knowledges) used_token_count = 0 @@ -119,14 +117,6 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False): logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}") break - docs = DocumentService.get_by_ids([get_value(ck, "doc_id", "document_id") for ck in kbinfos["chunks"][:chunks_num]]) - - docs_with_meta = {} - for d in docs: - meta = DocMetadataService.get_document_metadata(d.id) - docs_with_meta[d.id] = meta if meta else {} - docs = docs_with_meta - def draw_node(k, line): if line is not None and not isinstance(line, str): line = str(line) @@ -138,8 +128,9 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False): for i, ck in enumerate(kbinfos["chunks"][:chunks_num]): cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500)) cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name")) - cnt += draw_node("URL", ck['url']) if "url" in ck else "" - for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items(): + cnt += draw_node("URL", ck.get('url', '')) + meta = ck.get("document_metadata", {}) + for k, v in meta.items(): cnt += draw_node(k, v) cnt += "\n└── Content:\n" cnt += get_value(ck, "content", "content_with_weight") diff --git a/run_tests.py b/run_tests.py index aee34a833..48b039187 100755 --- a/run_tests.py +++ b/run_tests.py @@ -43,6 +43,8 @@ class TestRunner: self.verbose = False self.ignore_syntax_warning = False self.markers = "" + self.test_path = "" + self.keyword = "" # Python interpreter path self.python = sys.executable @@ -100,13 +102,20 @@ EXAMPLES: def build_pytest_command(self) -> List[str]: """Build the pytest command arguments""" - cmd = ["pytest", str(self.ut_dir)] - - # Add test path + cmd = ["pytest"] + if self.test_path: + test_target = Path(self.test_path) + if not test_target.is_absolute(): + test_target = self.project_root / test_target + cmd.append(str(test_target)) + else: + cmd.append(str(self.ut_dir)) # Add markers if self.markers: cmd.extend(["-m", self.markers]) + if self.keyword: + cmd.extend(["-k", self.keyword]) # Add verbose flag if self.verbose: @@ -161,9 +170,13 @@ EXAMPLES: self.print_info(f"Coverage: {self.coverage}") self.print_info(f"Parallel: {self.parallel}") self.print_info(f"Verbose: {self.verbose}") + if self.test_path: + self.print_info(f"Test target: {self.test_path}") if self.markers: self.print_info(f"Markers: {self.markers}") + if self.keyword: + self.print_info(f"Keyword: {self.keyword}") print(f"\n{Colors.BLUE}[EXECUTING]{Colors.NC} {' '.join(cmd)}\n") @@ -244,6 +257,13 @@ Examples: help="Run specific test file or directory" ) + parser.add_argument( + "-k", "--keyword", + type=str, + default="", + help="Run tests matching keyword expression (pytest -k)" + ) + parser.add_argument( "-m", "--markers", type=str, @@ -260,6 +280,8 @@ Examples: self.verbose = args.verbose self.markers = args.markers self.ignore_syntax_warning = args.ignore + self.test_path = args.test + self.keyword = args.keyword return True diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index fd65e6116..de520f3fe 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -14,6 +14,7 @@ # limitations under the License. # from typing import Any + from .base import Base from .document import Document @@ -79,7 +80,7 @@ class DataSet(Base): # Validate that id and ids are not used together if id and ids: raise ValueError("Cannot use both 'id' and 'ids' parameters at the same time.") - + params = { "id": id, "name": name, @@ -109,8 +110,7 @@ class DataSet(Base): res = res.json() if res.get("code") != 0: raise Exception(res["message"]) - - + def _get_documents_status(self, document_ids): import time terminal_states = {"DONE", "FAIL", "CANCEL"} diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index 0d3ee68d1..4a6d022c6 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -17,6 +17,7 @@ import asyncio import inspect import importlib.util import sys +from functools import wraps from pathlib import Path from types import ModuleType, SimpleNamespace @@ -26,6 +27,16 @@ import pytest from api.db import FileType +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + class _DummyManager: def route(self, *_args, **_kwargs): def decorator(func): @@ -126,6 +137,127 @@ def _load_doc_module(monkeypatch): common_pkg.__path__ = [str(repo_root / "common")] monkeypatch.setitem(sys.modules, "common", common_pkg) + common_settings_mod = ModuleType("common.settings") + common_settings_mod.retriever = SimpleNamespace() + common_settings_mod.kg_retriever = SimpleNamespace() + common_settings_mod.STORAGE_IMPL = SimpleNamespace(get=lambda *_args, **_kwargs: b"", rm=lambda *_args, **_kwargs: None) + monkeypatch.setitem(sys.modules, "common.settings", common_settings_mod) + + class _FakeExpr: + def __or__(self, other): + return self + + def __and__(self, other): + return self + + class _FakeField: + def __eq__(self, other): + return _FakeExpr() + + def __ne__(self, other): + return _FakeExpr() + + def is_null(self, value=True): + return _FakeExpr() + + class _StubDocumentModel: + id = _FakeField() + run = _FakeField() + + class _StubTaskModel: + doc_id = _FakeField() + + db_models_mod = ModuleType("api.db.db_models") + db_models_mod.APIToken = SimpleNamespace(query=lambda **_kwargs: []) + db_models_mod.Document = _StubDocumentModel + db_models_mod.Task = _StubTaskModel + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [str(repo_root / "api" / "db" / "services")] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + doc_metadata_service_mod = ModuleType("api.db.services.doc_metadata_service") + doc_metadata_service_mod.DocMetadataService = SimpleNamespace( + get_flatted_meta_by_kbs=lambda *_args, **_kwargs: [], + get_metadata_for_documents=lambda *_args, **_kwargs: {}, + ) + monkeypatch.setitem(sys.modules, "api.db.services.doc_metadata_service", doc_metadata_service_mod) + + document_service_mod = ModuleType("api.db.services.document_service") + document_service_mod.DocumentService = SimpleNamespace( + query=lambda **_kwargs: [], + filter_update=lambda *_args, **_kwargs: 0, + get_by_id=lambda *_args, **_kwargs: (False, None), + update_by_id=lambda *_args, **_kwargs: True, + decrement_chunk_num=lambda *_args, **_kwargs: None, + get_embd_id=lambda *_args, **_kwargs: "", + get_tenant_embd_id=lambda *_args, **_kwargs: None, + ) + monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod) + + file2document_service_mod = ModuleType("api.db.services.file2document_service") + file2document_service_mod.File2DocumentService = SimpleNamespace( + get_storage_address=lambda **_kwargs: ("", ""), + ) + monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_service_mod) + + knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service") + knowledgebase_service_mod.KnowledgebaseService = SimpleNamespace( + accessible=lambda **_kwargs: False, + get_by_id=lambda *_args, **_kwargs: (False, None), + get_by_ids=lambda *_args, **_kwargs: [], + list_documents_by_ids=lambda *_args, **_kwargs: [], + query=lambda **_kwargs: [], + ) + monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod) + + task_service_mod = ModuleType("api.db.services.task_service") + task_service_mod.TaskService = SimpleNamespace(filter_delete=lambda *_args, **_kwargs: None) + task_service_mod.cancel_all_task_of = lambda *_args, **_kwargs: None + task_service_mod.queue_tasks = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.check_duplicate_ids = lambda ids, _kind="item": (ids, []) + api_utils_mod.construct_json_result = lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data} + api_utils_mod.get_error_data_result = lambda message="Sorry! Data missing!", code=102: {"code": code, "message": message} + api_utils_mod.get_request_json = lambda: _AwaitableValue({}) + api_utils_mod.get_result = lambda code=0, message="", data=None, total=None: { + key: value + for key, value in {"code": code, "message": message, "data": data, "total": total}.items() + if value is not None + } + api_utils_mod.server_error_response = lambda e: {"code": 500, "message": str(e)} + def _token_required(func): + @wraps(func) + async def wrapper(*args, **kwargs): + return await func(*args, **kwargs) + + return wrapper + + api_utils_mod.token_required = _token_required + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + common_metadata_utils_mod = ModuleType("common.metadata_utils") + common_metadata_utils_mod.convert_conditions = lambda conditions: conditions + common_metadata_utils_mod.meta_filter = lambda *_args, **_kwargs: [] + monkeypatch.setitem(sys.modules, "common.metadata_utils", common_metadata_utils_mod) + + rag_app_tag_mod = ModuleType("rag.app.tag") + rag_app_tag_mod.label_question = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "rag.app.tag", rag_app_tag_mod) + + rag_prompts_generator_mod = ModuleType("rag.prompts.generator") + rag_prompts_generator_mod.cross_languages = lambda *_args, **_kwargs: "" + rag_prompts_generator_mod.keyword_extraction = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_prompts_generator_mod) + + rag_nlp_mod = ModuleType("rag.nlp") + rag_nlp_mod.search = SimpleNamespace(index_name=lambda tenant_id: f"idx_{tenant_id}") + monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp_mod) + monkeypatch.setitem(sys.modules, "rag.nlp.search", rag_nlp_mod.search) + deepdoc_pkg = ModuleType("deepdoc") deepdoc_parser_pkg = ModuleType("deepdoc.parser") deepdoc_parser_pkg.__path__ = [] @@ -344,7 +476,7 @@ def _patch_docstore(monkeypatch, module, **kwargs): "index_exist": lambda *_args, **_kwargs: False, } defaults.update(kwargs) - monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(**defaults)) + monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(**defaults), raising=False) @pytest.mark.p2 @@ -643,7 +775,7 @@ class TestDocRoutesUnit: res = _run(_route_core(module.update_chunk)("tenant-1", "ds-1", "doc-1", "chunk-1")) assert res["code"] == 0 - def test_retrieval_validation_matrix(self, monkeypatch): + def test_retrieval_metadata_validation_matrix(self, monkeypatch): module = _load_doc_module(monkeypatch) monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"dataset_ids": "bad"})) res = _run(module.retrieval_test.__wrapped__("tenant-1")) @@ -825,6 +957,7 @@ class TestDocRoutesUnit: "keyword": True, "toc_enhance": True, "use_kg": True, + "reference_metadata": {"include": True, "fields": ["author"]}, } ), ) @@ -835,6 +968,16 @@ class TestDocRoutesUnit: monkeypatch.setattr(module.settings, "kg_retriever", _FeatureKgRetriever()) monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: {}) monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: SimpleNamespace()) + monkeypatch.setattr( + module.DocMetadataService, + "get_metadata_for_documents", + lambda _doc_ids, _kb_id: { + "doc-1": {"author": "alice", "year": "2025"}, + "doc-toc": {"author": "bob"}, + "doc-child": {"author": "carol"}, + "doc-kg": {"author": "kg-author"}, + }, + ) res = _run(module.retrieval_test.__wrapped__("tenant-1")) assert res["code"] == 0, res["message"] assert feature_calls["cross"] == ("fr",) @@ -842,6 +985,7 @@ class TestDocRoutesUnit: assert feature_calls["retrieval_question"] == "q-xl-kw" assert res["data"]["chunks"][0]["id"] == "kg-1" assert res["data"]["chunks"][0]["content"] == "kg content" + assert res["data"]["chunks"][0]["document_metadata"]["author"] == "kg-author" assert any(chunk["id"] == "toc-1" for chunk in res["data"]["chunks"]) assert any(chunk["id"] == "child-1" for chunk in res["data"]["chunks"]) diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index f442db519..6d2dcbf3a 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -251,6 +251,53 @@ def _load_session_module(monkeypatch): common_constants_mod.MAXIMUM_TASK_PAGE_NUMBER = _MTPN monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod) + common_metadata_utils_mod = ModuleType("common.metadata_utils") + common_metadata_utils_mod.apply_meta_data_filter = lambda *_args, **_kwargs: [] + common_metadata_utils_mod.convert_conditions = lambda conditions: conditions + common_metadata_utils_mod.meta_filter = lambda *_args, **_kwargs: True + monkeypatch.setitem(sys.modules, "common.metadata_utils", common_metadata_utils_mod) + + common_settings_mod = ModuleType("common.settings") + common_settings_mod.retriever = SimpleNamespace() + common_settings_mod.kg_retriever = SimpleNamespace() + monkeypatch.setitem(sys.modules, "common.settings", common_settings_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.add_tenant_id_to_kwargs = lambda func: func + api_utils_mod.check_duplicate_ids = lambda ids, _kind="item": (ids, []) + api_utils_mod.get_data_error_result = lambda message="Sorry! Data missing!", code=_StubRetCode.DATA_ERROR: {"code": code, "message": message} + api_utils_mod.get_error_data_result = lambda message="Sorry! Data missing!", code=_StubRetCode.DATA_ERROR: {"code": code, "message": message} + api_utils_mod.get_json_result = lambda code=_StubRetCode.SUCCESS, message="success", data=None: {"code": code, "message": message, "data": data} + api_utils_mod.get_result = lambda code=_StubRetCode.SUCCESS, message="", data=None, total=None: { + key: value + for key, value in {"code": code, "message": message, "data": data, "total": total}.items() + if value is not None + } + api_utils_mod.get_request_json = lambda: _AwaitableValue({}) + api_utils_mod.server_error_response = lambda e: {"code": _StubRetCode.SERVER_ERROR, "message": str(e)} + api_utils_mod.token_required = lambda func: func + api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + rag_app_tag_mod = ModuleType("rag.app.tag") + rag_app_tag_mod.label_question = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "rag.app.tag", rag_app_tag_mod) + + rag_prompts_generator_mod = ModuleType("rag.prompts.generator") + rag_prompts_generator_mod.cross_languages = lambda *_args, **_kwargs: "" + rag_prompts_generator_mod.keyword_extraction = lambda *_args, **_kwargs: "" + rag_prompts_generator_mod.chunks_format = lambda chunks: chunks + monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_prompts_generator_mod) + + rag_prompts_template_mod = ModuleType("rag.prompts.template") + rag_prompts_template_mod.load_prompt = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "rag.prompts.template", rag_prompts_template_mod) + + rag_nlp_mod = ModuleType("rag.nlp") + rag_nlp_mod.search = SimpleNamespace(index_name=lambda tenant_id: f"idx_{tenant_id}") + monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp_mod) + monkeypatch.setitem(sys.modules, "rag.nlp.search", rag_nlp_mod.search) + deepdoc_pkg = ModuleType("deepdoc") deepdoc_parser_pkg = ModuleType("deepdoc.parser") deepdoc_parser_pkg.__path__ = [] @@ -508,8 +555,128 @@ def _load_session_module(monkeypatch): quart_mod.jsonify = lambda payload: payload quart_mod.current_app = SimpleNamespace() quart_mod.has_app_context = lambda: False + quart_mod.has_request_context = lambda: False + quart_mod.has_websocket_context = lambda: False + quart_mod.websocket = SimpleNamespace() monkeypatch.setitem(sys.modules, "quart", quart_mod) + quart_auth_mod = ModuleType("quart_auth") + + class _StubAuthUser: + pass + + quart_auth_mod.AuthUser = _StubAuthUser + monkeypatch.setitem(sys.modules, "quart_auth", quart_auth_mod) + + class _FakeExpr: + def __or__(self, other): + return self + + def __and__(self, other): + return self + + class _FakeField: + def __eq__(self, other): + return _FakeExpr() + + def __ne__(self, other): + return _FakeExpr() + + def is_null(self, value=True): + return _FakeExpr() + + class _StubTaskModel: + id = _FakeField() + doc_id = _FakeField() + + db_models_mod = ModuleType("api.db.db_models") + db_models_mod.APIToken = SimpleNamespace(query=lambda **_kwargs: []) + db_models_mod.Task = _StubTaskModel + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [str(repo_root / "api" / "db" / "services")] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + api_service_mod = ModuleType("api.db.services.api_service") + api_service_mod.API4ConversationService = SimpleNamespace( + get_names=lambda *_args, **_kwargs: [], + get_list=lambda *_args, **_kwargs: (0, []), + save=lambda **_kwargs: True, + get_by_id=lambda _session_id: (True, SimpleNamespace(to_dict=lambda: {"id": _session_id})), + delete_by_id=lambda *_args, **_kwargs: True, + query=lambda **_kwargs: [], + ) + monkeypatch.setitem(sys.modules, "api.db.services.api_service", api_service_mod) + + canvas_service_mod = ModuleType("api.db.services.canvas_service") + canvas_service_mod.CanvasTemplateService = SimpleNamespace(get_all=lambda *_args, **_kwargs: []) + canvas_service_mod.UserCanvasService = SimpleNamespace( + query=lambda **_kwargs: [], + get_by_id=lambda *_args, **_kwargs: (False, None), + accessible=lambda *_args, **_kwargs: False, + get_agent_dsl_with_release=lambda *_args, **_kwargs: (SimpleNamespace(id="agent-1"), "{}"), + ) + + async def _empty_agent_completion(*_args, **_kwargs): + if False: + yield None + + canvas_service_mod.completion = _empty_agent_completion + canvas_service_mod.completion_openai = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "api.db.services.canvas_service", canvas_service_mod) + + conversation_service_mod = ModuleType("api.db.services.conversation_service") + conversation_service_mod.ConversationService = SimpleNamespace(query=lambda **_kwargs: []) + conversation_service_mod.async_iframe_completion = lambda *_args, **_kwargs: None + conversation_service_mod.async_completion = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.conversation_service", conversation_service_mod) + + dialog_service_mod = ModuleType("api.db.services.dialog_service") + dialog_service_mod.DialogService = SimpleNamespace( + query=lambda **_kwargs: [], + get_by_id=lambda *_args, **_kwargs: (False, None), + ) + dialog_service_mod.async_ask = lambda *_args, **_kwargs: None + dialog_service_mod.async_chat = lambda *_args, **_kwargs: None + dialog_service_mod.gen_mindmap = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.dialog_service", dialog_service_mod) + + doc_metadata_service_mod = ModuleType("api.db.services.doc_metadata_service") + doc_metadata_service_mod.DocMetadataService = SimpleNamespace( + get_flatted_meta_by_kbs=lambda *_args, **_kwargs: [], + get_metadata_for_documents=lambda *_args, **_kwargs: {}, + ) + monkeypatch.setitem(sys.modules, "api.db.services.doc_metadata_service", doc_metadata_service_mod) + + knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service") + knowledgebase_service_mod.KnowledgebaseService = SimpleNamespace( + query=lambda **_kwargs: [], + get_by_id=lambda *_args, **_kwargs: (False, None), + ) + monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod) + + search_service_mod = ModuleType("api.db.services.search_service") + search_service_mod.SearchService = SimpleNamespace( + query=lambda **_kwargs: [], + get_detail=lambda *_args, **_kwargs: None, + ) + monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + user_service_mod.UserTenantService = SimpleNamespace(query=lambda **_kwargs: []) + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + user_canvas_version_mod = ModuleType("api.db.services.user_canvas_version") + user_canvas_version_mod.UserCanvasVersionService = SimpleNamespace( + list_by_canvas_id=lambda *_args, **_kwargs: [], + get_by_id=lambda *_args, **_kwargs: (False, None), + get_latest_version_title=lambda *_args, **_kwargs: "", + save_or_replace_latest=lambda **_kwargs: True, + build_version_title=lambda *_args, **_kwargs: "v1", + ) + monkeypatch.setitem(sys.modules, "api.db.services.user_canvas_version", user_canvas_version_mod) + module_path = repo_root / "api" / "apps" / "sdk" / "session.py" spec = importlib.util.spec_from_file_location("test_session_sdk_routes_unit_module", module_path) module = importlib.util.module_from_spec(spec) @@ -612,7 +779,10 @@ def _load_agent_api_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod) knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service") - knowledgebase_service_mod.KnowledgebaseService = SimpleNamespace(query=lambda **_kwargs: []) + knowledgebase_service_mod.KnowledgebaseService = SimpleNamespace( + query=lambda **_kwargs: [], + get_by_id=lambda *_args, **_kwargs: (False, None), + ) monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod) task_service_mod = ModuleType("api.db.services.task_service") @@ -1352,7 +1522,7 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch): "rank_feature": rank_feature, } ) - return {"chunks": [{"id": "chunk-1", "vector": [0.1]}]} + return {"chunks": [{"id": "chunk-1", "doc_id": "doc-1", "kb_id": "kb-1", "vector": [0.1]}]} async def _translate(_tenant_id, _chat_id, question, _langs): return question + "-translated" @@ -1384,10 +1554,16 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch): "vector_similarity_weight": 0.8, "top_k": 7, "rerank_id": "reranker-model", + "reference_metadata": {"include": True, "fields": ["author"]}, } }, ) monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: [{"id": "doc-2"}]) + monkeypatch.setattr( + module.DocMetadataService, + "get_metadata_for_documents", + lambda _doc_ids, _kb_id: {"doc-1": {"author": "alice", "year": "2025"}}, + ) monkeypatch.setattr(module, "apply_meta_data_filter", _apply_filter) monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-a")]) monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) @@ -1409,6 +1585,8 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch): assert retrieval_capture["local_doc_ids"] == ["doc-filtered"] assert retrieval_capture["rank_feature"] == ["label-1"] assert retrieval_capture["rerank_mdl"] is not None + assert res["data"]["chunks"][0]["document_metadata"]["author"] == "alice" + assert "year" not in res["data"]["chunks"][0]["document_metadata"] assert any(call[1] == module.LLMType.EMBEDDING.value and call[2] == "embd-model" for call in llm_calls) llm_calls.clear() @@ -1621,9 +1799,18 @@ def test_build_reference_chunks_metadata_matrix_unit(monkeypatch): monkeypatch.setattr(module, "chunks_format", lambda _reference: [{"dataset_id": "kb-1", "document_id": "doc-1"}]) monkeypatch.setattr(module.DocMetadataService, "get_metadata_for_documents", lambda _doc_ids, _kb_id: {"doc-1": {"author": "alice"}}) + res = module._build_reference_chunks([], include_metadata=True, metadata_fields=None) + assert res[0]["document_metadata"] == {"author": "alice"} + + res = module._build_reference_chunks([], include_metadata=True, metadata_fields=[]) + assert "document_metadata" not in res[0] + res = module._build_reference_chunks([], include_metadata=True, metadata_fields=[1, None]) assert "document_metadata" not in res[0] + res = module._build_reference_chunks([], include_metadata=True, metadata_fields="author") + assert "document_metadata" not in res[0] + source_chunks = [ {"dataset_id": "kb-1", "document_id": "doc-1"}, {"dataset_id": "kb-2", "document_id": "doc-2"}, diff --git a/web/src/components/fallback-component/index.tsx b/web/src/components/fallback-component/index.tsx index 131051827..7f4f1b795 100644 --- a/web/src/components/fallback-component/index.tsx +++ b/web/src/components/fallback-component/index.tsx @@ -1,5 +1,6 @@ import React from 'react'; import { useTranslation } from 'react-i18next'; +import { isRouteErrorResponse, useRouteError } from 'react-router'; interface FallbackComponentProps { error?: Error; @@ -7,10 +8,32 @@ interface FallbackComponentProps { } const FallbackComponent: React.FC = ({ - error, + error: errorProp, reset, }) => { const { t } = useTranslation(); + const routeError = useRouteError(); + const error = + errorProp ?? (routeError instanceof Error ? routeError : undefined); + + let routeErrorDataStr = ''; + if (isRouteErrorResponse(routeError)) { + if (typeof routeError.data === 'string') { + routeErrorDataStr = routeError.data; + } else if (routeError.data == null) { + routeErrorDataStr = 'no body'; + } else { + try { + routeErrorDataStr = JSON.stringify(routeError.data); + } catch { + routeErrorDataStr = String(routeError.data); + } + } + } + + const errorMessage = isRouteErrorResponse(routeError) + ? `${routeError.status} ${routeError.statusText}${routeErrorDataStr ? `: ${routeErrorDataStr}` : ''}` + : (error?.toString() ?? (routeError ? String(routeError) : undefined)); return (
@@ -21,10 +44,10 @@ const FallbackComponent: React.FC = ({ 'Sorry, an error occurred while loading the page.', )}

- {error && ( -
+ {errorMessage && ( +
{t('error_boundary.details', 'Error details')} - {error.toString()} + {errorMessage}
)}
diff --git a/web/src/components/markdown-content/index.tsx b/web/src/components/markdown-content/index.tsx index 72247674f..846d25d5a 100644 --- a/web/src/components/markdown-content/index.tsx +++ b/web/src/components/markdown-content/index.tsx @@ -40,6 +40,13 @@ import styles from './index.module.less'; const getChunkIndex = (match: string) => parseCitationIndex(match); +const formatMetadataValue = (value: unknown) => { + if (Array.isArray(value)) return value.join(', '); + if (value === null || value === undefined) return ''; + if (typeof value === 'object') return JSON.stringify(value); + return String(value); +}; + // TODO: The display of the table is inconsistent with the display previously placed in the MessageItem. const MarkdownContent = ({ reference, @@ -174,6 +181,21 @@ const MarkdownContent = ({ className={classNames(styles.chunkContentText)} dir="auto" >
+ {chunkItem?.document_metadata && + Object.keys(chunkItem.document_metadata).length > 0 && ( +
+ {Object.entries(chunkItem.document_metadata).map( + ([key, value]) => ( +
+ {key}:{' '} + + {formatMetadataValue(value)} + +
+ ), + )} +
+ )} {documentId && (
{fileThumbnail ? ( diff --git a/web/src/hooks/use-knowledge-request.ts b/web/src/hooks/use-knowledge-request.ts index 782b1282f..5bd5a796b 100644 --- a/web/src/hooks/use-knowledge-request.ts +++ b/web/src/hooks/use-knowledge-request.ts @@ -48,6 +48,7 @@ export const enum KnowledgeApiAction { FetchKnowledgeDetail = 'fetchKnowledgeDetail', FetchKnowledgeGraph = 'fetchKnowledgeGraph', FetchMetadata = 'fetchMetadata', + FetchMetadataKeys = 'fetchMetadataKeys', FetchKnowledgeList = 'fetchKnowledgeList', RemoveKnowledgeGraph = 'removeKnowledgeGraph', } @@ -378,6 +379,24 @@ export function useFetchKnowledgeMetadata(kbIds: string[] = []) { return { data, loading }; } +export function useFetchKnowledgeMetadataKeys(kbIds: string[] = []) { + const sortedKbIds = useMemo(() => [...kbIds].sort(), [kbIds]); + const { data, isFetching: loading } = useQuery({ + queryKey: [KnowledgeApiAction.FetchMetadataKeys, sortedKbIds], + initialData: [], + enabled: sortedKbIds.length > 0, + gcTime: 0, + queryFn: async () => { + const { data } = await kbService.getMetaKeys({ + kb_ids: sortedKbIds.join(','), + }); + return data?.data ?? []; + }, + }); + + return { data, loading }; +} + export const useRemoveKnowledgeGraph = () => { const knowledgeBaseId = useKnowledgeBaseId(); diff --git a/web/src/interfaces/database/chat.ts b/web/src/interfaces/database/chat.ts index 5cce383f5..447409bcf 100644 --- a/web/src/interfaces/database/chat.ts +++ b/web/src/interfaces/database/chat.ts @@ -22,6 +22,10 @@ export interface PromptConfig { cross_languages?: Array; tavily_api_key?: string; toc_enhance?: boolean; + reference_metadata?: { + include?: boolean; + fields?: string[]; + }; } export interface Parameter { @@ -126,6 +130,7 @@ export interface IReferenceChunk { term_similarity: number; positions: number[]; doc_type?: string; + document_metadata?: Record; } export interface IReference { diff --git a/web/src/pages/agent/form/doc-generator-form/use-values.ts b/web/src/pages/agent/form/doc-generator-form/use-values.ts index e4426ae8a..b4df1809a 100644 --- a/web/src/pages/agent/form/doc-generator-form/use-values.ts +++ b/web/src/pages/agent/form/doc-generator-form/use-values.ts @@ -1,5 +1,5 @@ +import { type Node } from '@xyflow/react'; import { useMemo } from 'react'; -import { Node } from 'reactflow'; import { initialDocGeneratorValues } from '../../constant'; export const useValues = (node?: Node) => { diff --git a/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx b/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx index 367748cef..5794787d9 100644 --- a/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx +++ b/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx @@ -13,16 +13,45 @@ import { FormLabel, FormMessage, } from '@/components/ui/form'; +import { MultiSelect } from '@/components/ui/multi-select'; +import { Switch } from '@/components/ui/switch'; import { Textarea } from '@/components/ui/textarea'; import { useTranslate } from '@/hooks/common-hooks'; +import { useFetchKnowledgeMetadataKeys } from '@/hooks/use-knowledge-request'; import { getDirAttribute } from '@/utils/text-direction'; -import { useFormContext } from 'react-hook-form'; +import { useEffect, useMemo } from 'react'; +import { useFormContext, useWatch } from 'react-hook-form'; export default function ChatBasicSetting() { const { t } = useTranslate('chat'); const form = useFormContext(); const emptyResponseValue = form.watch('prompt_config.empty_response'); const prologueValue = form.watch('prompt_config.prologue'); + const kbIds = (useWatch({ control: form.control, name: 'dataset_ids' }) || + []) as string[]; + const metadataInclude = useWatch({ + control: form.control, + name: 'prompt_config.reference_metadata.include', + }); + const { data: metadataKeys } = useFetchKnowledgeMetadataKeys(kbIds); + const metadataFieldOptions = useMemo(() => { + return (metadataKeys || []).map((key) => ({ + label: key, + value: key, + })); + }, [metadataKeys]); + + useEffect(() => { + const currentFields = form.getValues('prompt_config.reference_metadata.fields'); + if (metadataInclude && Array.isArray(currentFields) && currentFields.length > 0 && metadataKeys) { + const validFields = currentFields.filter((field) => metadataKeys.includes(field)); + if (validFields.length !== currentFields.length) { + form.setValue('prompt_config.reference_metadata.fields', validFields); + } + } else if (!metadataInclude) { + form.setValue('prompt_config.reference_metadata.fields', undefined); + } + }, [kbIds, metadataKeys, metadataInclude, form]); return (
@@ -83,6 +112,59 @@ export default function ChatBasicSetting() { + ( + + + { + field.onChange(value); + if (!value) { + form.setValue( + 'prompt_config.reference_metadata.fields', + undefined, + ); + } + }} + /> + + + Show chunk metadata + + + )} + /> + {metadataInclude && ( + ( + + + {t('metadataKeys')} + + + + + + + )} + /> + )}
); } diff --git a/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx b/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx index 028cc0147..f3079a314 100644 --- a/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx +++ b/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx @@ -57,6 +57,10 @@ export function ChatSettings({ hasSingleChatBox }: ChatSettingsProps) { reasoning: false, cross_languages: [], toc_enhance: false, + reference_metadata: { + include: false, + fields: undefined, + }, }, top_n: 8, similarity_threshold: 0.2, @@ -74,6 +78,14 @@ export function ChatSettings({ hasSingleChatBox }: ChatSettingsProps) { values, 'llm_setting.', ); + const referenceMetadata = nextValues?.prompt_config?.reference_metadata; + if ( + referenceMetadata && + Array.isArray(referenceMetadata.fields) && + referenceMetadata.fields.length === 0 + ) { + referenceMetadata.fields = undefined; + } updateChat({ chatId: id!, @@ -101,8 +113,20 @@ export function ChatSettings({ hasSingleChatBox }: ChatSettingsProps) { const llmSettingEnabledValues = setLLMSettingEnabledValues( data.llm_setting, ); + const referenceMetadata = data?.prompt_config?.reference_metadata; + const normalizedReferenceMetadata = + referenceMetadata && + Array.isArray(referenceMetadata.fields) && + referenceMetadata.fields.length === 0 + ? { ...referenceMetadata, fields: undefined } + : referenceMetadata; + const nextData = { ...data, + prompt_config: { + ...data.prompt_config, + reference_metadata: normalizedReferenceMetadata, + }, ...llmSettingEnabledValues, }; diff --git a/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx b/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx index ba29383f9..f80ab79b7 100644 --- a/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx +++ b/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx @@ -36,6 +36,12 @@ export function useChatSettingSchema() { reasoning: z.boolean().optional(), cross_languages: z.array(z.string()).optional(), toc_enhance: z.boolean().optional(), + reference_metadata: z + .object({ + include: z.boolean().optional(), + fields: z.array(z.string()).optional(), + }) + .optional(), }); const formSchema = z.object({ diff --git a/web/src/pages/next-search/search-setting.tsx b/web/src/pages/next-search/search-setting.tsx index 7b8203bf0..d9a381782 100644 --- a/web/src/pages/next-search/search-setting.tsx +++ b/web/src/pages/next-search/search-setting.tsx @@ -22,9 +22,15 @@ import { FormMessage, } from '@/components/ui/form'; import { Input } from '@/components/ui/input'; +import { MultiSelect } from '@/components/ui/multi-select'; import { RAGFlowSelect } from '@/components/ui/select'; import { Spin } from '@/components/ui/spin'; import { Switch } from '@/components/ui/switch'; +import { Textarea } from '@/components/ui/textarea'; +import { + useFetchKnowledgeList, + useFetchKnowledgeMetadataKeys, +} from '@/hooks/use-knowledge-request'; import { useComposeLlmOptionsByModelTypes, useSelectLlmOptionsByModelType, @@ -79,6 +85,12 @@ const SearchSettingFormSchema = z highlight: z.boolean(), keyword: z.boolean(), chat_settingcross_languages: z.array(z.string()), + reference_metadata: z + .object({ + include: z.boolean().optional(), + fields: z.array(z.string()).optional(), + }) + .optional(), ...MetadataFilterSchema, }), }) @@ -156,6 +168,14 @@ const SearchSetting: React.FC = ({ related_search: search_config?.related_search || false, query_mindmap: search_config?.query_mindmap || false, meta_data_filter: search_config?.meta_data_filter, + reference_metadata: { + include: search_config?.reference_metadata?.include || false, + fields: + search_config?.reference_metadata?.fields && + search_config.reference_metadata.fields.length > 0 + ? search_config.reference_metadata.fields + : undefined, + }, }, }); }, [data, search_config, llm_setting, formMethods, descriptionDefaultValue]); @@ -193,6 +213,35 @@ const SearchSetting: React.FC = ({ control: formMethods.control, name: 'search_config.summary', }); + const selectedKbIds = useWatch({ + control: formMethods.control, + name: 'search_config.kb_ids', + }); + const referenceMetadataEnabled = useWatch({ + control: formMethods.control, + name: 'search_config.reference_metadata.include', + }); + const { data: metadataKeys } = useFetchKnowledgeMetadataKeys( + selectedKbIds || [], + ); + const metadataFieldOptions = useMemo(() => { + return (metadataKeys || []).map((key) => ({ + label: key, + value: key, + })); + }, [metadataKeys]); + + useEffect(() => { + const currentFields = formMethods.getValues('search_config.reference_metadata.fields'); + if (referenceMetadataEnabled && Array.isArray(currentFields) && currentFields.length > 0 && metadataKeys) { + const validFields = currentFields.filter((field) => metadataKeys.includes(field)); + if (validFields.length !== currentFields.length) { + formMethods.setValue('search_config.reference_metadata.fields', validFields); + } + } else if (!referenceMetadataEnabled) { + formMethods.setValue('search_config.reference_metadata.fields', undefined); + } + }, [selectedKbIds, metadataKeys, referenceMetadataEnabled, formMethods]); // Reset top_k to 1024 only when user actively disables rerank (from true to false) const prevRerankEnabled = useRef(undefined); @@ -227,11 +276,22 @@ const SearchSetting: React.FC = ({ frequency_penalty: llm_setting.frequency_penalty, presence_penalty: llm_setting.presence_penalty, } as IllmSettingProps; + const referenceMetadata = other_config.reference_metadata; + const normalizedReferenceMetadata = referenceMetadata + ? { + ...referenceMetadata, + ...(Array.isArray(referenceMetadata.fields) && + referenceMetadata.fields.length === 0 + ? { fields: undefined } + : {}), + } + : referenceMetadata; await updateSearch({ ...other_formdata, search_config: { ...other_config, + reference_metadata: normalizedReferenceMetadata, chat_id: llm_setting.llm_id, vector_similarity_weight: 1 - vector_similarity_weight, rerank_id: use_rerank ? rerank_id : '', @@ -288,6 +348,61 @@ const SearchSetting: React.FC = ({ required > + ( + + + { + field.onChange(value); + if (!value) { + formMethods.setValue( + 'search_config.reference_metadata.fields', + undefined, + ); + } + }} + /> + + + Show chunk metadata + + + )} + /> + {referenceMetadataEnabled && ( + ( + + + Metadata fields + + + + + + + )} + /> + )} { + if (Array.isArray(value)) return value.join(', '); + if (value === null || value === undefined) return ''; + if (typeof value === 'object') return JSON.stringify(value); + return String(value); +}; export default function SearchingView({ setIsSearching, searchData, @@ -208,6 +214,26 @@ export default function SearchingView({ {chunk.content_with_weight}
+ {chunk.document_metadata && + Object.keys(chunk.document_metadata).length > 0 && ( +
+ {Object.entries(chunk.document_metadata).map( + ([key, value]) => ( +
+ + {key}: + {' '} + + {formatMetadataValue(value)} + +
+ ), + )} +
+ )}
diff --git a/web/src/pages/next-searches/hooks.ts b/web/src/pages/next-searches/hooks.ts index 89bdd88c5..e8358e20e 100644 --- a/web/src/pages/next-searches/hooks.ts +++ b/web/src/pages/next-searches/hooks.ts @@ -185,6 +185,10 @@ export interface ISearchAppDetailProps { method: string; manual: { key: string; op: string; value: string }[]; }; + reference_metadata?: { + include?: boolean; + fields?: string[]; + }; }; tenant_id: string; update_time: number; diff --git a/web/src/services/knowledge-service.ts b/web/src/services/knowledge-service.ts index 47e674e45..dfd31c341 100644 --- a/web/src/services/knowledge-service.ts +++ b/web/src/services/knowledge-service.ts @@ -24,6 +24,7 @@ const { listTagByKnowledgeIds, setMeta, getMeta, + getMetaKeys, retrievalTestShare, } = api; @@ -81,6 +82,10 @@ const methods = { url: getMeta, method: 'get', }, + getMetaKeys: { + url: getMetaKeys, + method: 'get', + }, retrievalTestShare: { url: retrievalTestShare, method: 'post',