diff --git a/api/apps/restful_apis/openai_api.py b/api/apps/restful_apis/openai_api.py index 320ecd09d..baa011f32 100644 --- a/api/apps/restful_apis/openai_api.py +++ b/api/apps/restful_apis/openai_api.py @@ -48,44 +48,35 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None): return None +import logging +from api.utils.reference_metadata_utils import enrich_chunks_with_document_metadata + def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None): chunks = chunks_format(reference) if not include_metadata: + logging.debug("Skipping document metadata enrichment (include_metadata=False)") return chunks - doc_ids_by_kb = {} - for chunk in chunks: - kb_id = chunk.get("dataset_id") - doc_id = chunk.get("document_id") - if not kb_id or not doc_id: - continue - doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id) - - if not doc_ids_by_kb: - return chunks - - meta_by_doc = {} - for kb_id, doc_ids in doc_ids_by_kb.items(): - meta_map = DocMetadataService.get_metadata_for_documents(list(doc_ids), kb_id) - if meta_map: - meta_by_doc.update(meta_map) - + normalized_fields = None if metadata_fields is not None: - metadata_fields = {f for f in metadata_fields if isinstance(f, str)} - if not metadata_fields: + if not isinstance(metadata_fields, list): + return chunks + normalized_fields = {f for f in metadata_fields if isinstance(f, str)} + if not normalized_fields: return chunks - for chunk in chunks: - doc_id = chunk.get("document_id") - if not doc_id: - continue - meta = meta_by_doc.get(doc_id) - if not meta: - continue - if metadata_fields is not None: - meta = {k: v for k, v in meta.items() if k in metadata_fields} - if meta: - chunk["document_metadata"] = meta + logging.debug( + "Enriching %d chunks with document metadata (fields: %s)", + len(chunks), + "ALL" if normalized_fields is None else list(normalized_fields), + ) + + enrich_chunks_with_document_metadata( + chunks, + normalized_fields, + kb_field="dataset_id", + doc_field="document_id", + ) return chunks diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index dbb8f9203..9aa641ccf 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging from io import BytesIO from quart import request, send_file @@ -37,6 +38,18 @@ from rag.prompts.generator import cross_languages, keyword_extraction MAXIMUM_OF_UPLOADING_FILES = 256 +from api.utils.reference_metadata_utils import ( + enrich_chunks_with_document_metadata, + resolve_reference_metadata_preferences, +) + +def _resolve_reference_metadata(req: dict, search_config: dict | None = None): + return resolve_reference_metadata_preferences(req, search_config) + +def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=None) -> None: + enrich_chunks_with_document_metadata(chunks, metadata_fields) + + @manager.route("/datasets//documents/", methods=["GET"]) # noqa: F821 @token_required async def download(tenant_id, dataset_id, document_id): @@ -450,6 +463,7 @@ async def retrieval_test(tenant_id): return get_error_data_result("`highlight` should be a boolean") else: return get_error_data_result("`highlight` should be a boolean") + include_metadata, metadata_fields = _resolve_reference_metadata(req) try: tenant_ids = list(set([kb.tenant_id for kb in kbs])) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) @@ -508,6 +522,15 @@ async def retrieval_test(tenant_id): for c in ranks["chunks"]: c.pop("vector", None) + if include_metadata: + logging.info( + "sdk.retrieval reference_metadata enabled dataset_ids=%s fields=%s chunks=%s", + kb_ids, + sorted(metadata_fields) if metadata_fields else None, + len(ranks["chunks"]), + ) + enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields) + ##rename keys renamed_chunks = [] for chunk in ranks["chunks"]: diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 8b6a777ba..2cb431299 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -44,6 +44,10 @@ from rag.prompts.template import load_prompt from rag.prompts.generator import cross_languages, keyword_extraction from common.constants import RetCode, LLMType from common import settings +from api.utils.reference_metadata_utils import ( + enrich_chunks_with_document_metadata, + resolve_reference_metadata_preferences, +) @token_required @@ -327,6 +331,7 @@ async def retrieval_test_embedded(): tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") + search_config = {} async def _retrieval(): nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id @@ -337,8 +342,11 @@ async def retrieval_test_embedded(): meta_data_filter = {} chat_mdl = None if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) - meta_data_filter = search_config.get("meta_data_filter", {}) + nonlocal search_config + detail = SearchService.get_detail(req.get("search_id", "")) + if detail: + search_config = detail.get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_id = search_config.get("chat_id", "") if chat_id: @@ -414,6 +422,11 @@ async def retrieval_test_embedded(): for c in ranks["chunks"]: c.pop("vector", None) + + include_metadata, metadata_fields = _resolve_reference_metadata(req, search_config) + if include_metadata: + enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields) + ranks["labels"] = labels return get_json_result(data=ranks) @@ -529,3 +542,6 @@ async def mindmap(): return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) + +def _resolve_reference_metadata(req, search_config=None): + return resolve_reference_metadata_preferences(req, search_config) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 608391405..09ca70c43 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -33,6 +33,10 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.llm_service import LLMBundle from common.metadata_utils import apply_meta_data_filter +from api.utils.reference_metadata_utils import ( + enrich_chunks_with_document_metadata, + resolve_reference_metadata_preferences, +) from api.db.services.tenant_llm_service import TenantLLMService from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type from common.time_utils import current_timestamp, datetime_format @@ -48,6 +52,16 @@ from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces from common import settings +def _resolve_reference_metadata(request_payload=None, config=None): + return resolve_reference_metadata_preferences(request_payload or {}, config) + +def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None): + enrich_chunks_with_document_metadata(chunks, metadata_fields) + +def _chunk_kb_id_for_doc(row_dict, kb_ids, doc_id): + if len(kb_ids or []) == 1: + return kb_ids[0] + return row_dict.get("kb_id") or row_dict.get("kb_id_kwd") def _normalize_internet_flag(value): if isinstance(value, bool): @@ -70,6 +84,15 @@ def _should_use_web_search(prompt_config, internet=None): return normalized is True +def _resolve_reference_metadata(config, request_payload=None): + return resolve_reference_metadata_preferences(request_payload or {}, config) + + +def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None): + enrich_chunks_with_document_metadata(chunks, metadata_fields) + + + class DialogService(CommonService): model = Dialog @@ -547,6 +570,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): attachments_ = "\n\n".join(text_attachments) prompt_config = dialog.prompt_config + include_reference_metadata, metadata_fields = _resolve_reference_metadata(prompt_config, request_payload=kwargs) field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) logging.debug(f"field_map retrieved: {field_map}") # try to use sql if field mapping is good to go @@ -555,6 +579,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs): ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) # For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")): + if include_reference_metadata and ans.get("reference", {}).get("chunks"): + if len(dialog.kb_ids) != 1 and any(not c.get("kb_id") for c in ans["reference"]["chunks"]): + logging.warning( + "Skipping some _enrich_chunks_with_document_metadata results because " + "dialog.kb_ids has %d entries and use_sql returned chunks without kb_id.", + len(dialog.kb_ids), + ) + _enrich_chunks_with_document_metadata(ans["reference"]["chunks"], metadata_fields) yield ans return else: @@ -675,6 +707,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs): if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) + if include_reference_metadata: + logging.debug( + "reference_metadata enrichment enabled for async_chat: chunk_count=%d metadata_fields=%s", + len(kbinfos.get("chunks", [])), + metadata_fields, + ) + _enrich_chunks_with_document_metadata(kbinfos.get("chunks", []), metadata_fields) + knowledges = kb_prompt(kbinfos, max_tokens) logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) @@ -1121,11 +1161,12 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"]) doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]]) + kb_id_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["kb_id", "kb_id_kwd"]]) logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}") - logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}") + logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}, kb_id_idx={kb_id_idx}") - column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] + column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx | kb_id_idx)] logging.debug(f"use_sql: column_idx={column_idx}") logging.debug(f"use_sql: field_map={field_map}") @@ -1221,8 +1262,11 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE) if where_match: where_clause = where_match.group(1).strip() - # Build a query to get doc_id and docnm_kwd with the same WHERE clause - chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}" + # Build a query to get source fields with the same WHERE clause. + # Single-KB queries can derive kb_id from the dialog, while multi-KB + # ES/OS queries need the row value for metadata enrichment. + chunks_kb_column = ", kb_id" if not (kb_ids and len(kb_ids) == 1) else "" + chunks_sql = f"select doc_id, {expected_doc_name_column}{chunks_kb_column} from {table_name} where {where_clause}" # Add LIMIT to avoid fetching too many chunks if "limit" not in chunks_sql.lower(): chunks_sql += " limit 20" @@ -1233,8 +1277,18 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat # Build chunks reference - use case-insensitive matching chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None) chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None) + chunks_kb_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["kb_id", "kb_id_kwd"]), None) if chunks_did_idx is not None and chunks_dn_idx is not None: - chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]] + chunks = [] + for r in chunks_tbl["rows"]: + chunk = {"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} + row_dict = {chunks_tbl["columns"][i]["name"]: r[i] for i in range(len(chunks_tbl["columns"])) if i < len(r)} + kb_id = _chunk_kb_id_for_doc(row_dict, kb_ids, chunk["doc_id"]) + if kb_id: + chunk["kb_id"] = kb_id + elif chunks_kb_idx is not None: + chunk["kb_id"] = r[chunks_kb_idx] + chunks.append(chunk) # Build doc_aggs doc_aggs = {} for r in chunks_tbl["rows"]: @@ -1264,7 +1318,22 @@ Please correct the error and write SQL again using json_extract_string(chunk_dat result = { "answer": "\n".join([columns, line, rows]), "reference": { - "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], + "chunks": [ + { + key: value + for key, value in { + "doc_id": r[docid_idx], + "docnm_kwd": r[doc_name_idx], + "kb_id": _chunk_kb_id_for_doc( + {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}, + kb_ids, + r[docid_idx], + ), + }.items() + if value + } + for r in tbl["rows"] + ], "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()], }, "prompt": sys_prompt, @@ -1414,6 +1483,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf chat_llm_name = search_config.get("chat_id", chat_llm_name) rerank_id = search_config.get("rerank_id", "") meta_data_filter = search_config.get("meta_data_filter") + include_reference_metadata, metadata_fields = _resolve_reference_metadata(search_config) kbs = KnowledgebaseService.get_by_ids(kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) @@ -1450,6 +1520,13 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf rerank_mdl=rerank_mdl, rank_feature=label_question(question, kbs) ) + if include_reference_metadata: + logging.debug( + "reference_metadata enrichment enabled for async_ask: chunk_count=%d metadata_fields=%s", + len(kbinfos.get("chunks", [])), + metadata_fields, + ) + _enrich_chunks_with_document_metadata(kbinfos.get("chunks", []), metadata_fields) knowledges = kb_prompt(kbinfos, max_tokens) sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges)) diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py index 2e4b93056..db05f4bb2 100644 --- a/api/db/services/doc_metadata_service.py +++ b/api/db/services/doc_metadata_service.py @@ -772,6 +772,36 @@ class DocMetadataService: logging.error(f"Error getting flattened metadata for KBs {kb_ids}: {e}") return {} + @classmethod + def get_metadata_keys_by_kbs(cls, kb_ids: List[str]) -> List[str]: + """ + Get unique metadata field names across multiple knowledge bases. + + Args: + kb_ids: List of knowledge base IDs + + Returns: + Sorted list of unique metadata field names + """ + if not kb_ids: + return [] + + logging.debug(f"get_metadata_keys_by_kbs start: n_kbs={len(kb_ids)}") + keys: set[str] = set() + try: + for kb_id in kb_ids: + results = cls._search_metadata(kb_id, condition={"kb_id": kb_id}) + for _doc_id, doc in cls._iter_search_results(results): + doc_meta = cls._extract_metadata(doc) + if not isinstance(doc_meta, dict): + continue + keys.update(str(k) for k in doc_meta.keys()) + logging.debug(f"get_metadata_keys_by_kbs end: n_keys={len(keys)}, kb_ids={kb_ids}") + return sorted(keys) + except Exception as e: + logging.error(f"Error getting metadata keys for KBs {kb_ids}: {e}") + return [] + @classmethod def get_metadata_for_documents(cls, doc_ids: Optional[List[str]], kb_id: str) -> Dict[str, Dict]: """ diff --git a/api/utils/reference_metadata_utils.py b/api/utils/reference_metadata_utils.py new file mode 100644 index 000000000..58d5beffb --- /dev/null +++ b/api/utils/reference_metadata_utils.py @@ -0,0 +1,125 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging + +logger = logging.getLogger(__name__) + + +def resolve_reference_metadata_preferences( + request_payload: dict | None = None, + config_payload: dict | None = None, +) -> tuple[bool, set[str] | None]: + """ + Resolve metadata include/fields from request and optional config. + Request values take precedence over config values. + Supports legacy request keys: include_metadata / metadata_fields. + """ + request_payload = request_payload or {} + config_payload = config_payload or {} + + config_ref = config_payload.get("reference_metadata", {}) + request_ref = request_payload.get("reference_metadata", {}) + + resolved: dict = {} + if isinstance(config_ref, dict): + resolved.update(config_ref) + if isinstance(request_ref, dict): + resolved.update(request_ref) + + if "include_metadata" in request_payload: + resolved["include"] = bool(request_payload.get("include_metadata")) + if "metadata_fields" in request_payload: + resolved["fields"] = request_payload.get("metadata_fields") + + include_metadata = bool(resolved.get("include", False)) + fields = resolved.get("fields") + if fields is None: + return include_metadata, None + if not isinstance(fields, list): + logger.warning( + "reference_metadata.fields is not a list; include_metadata=%s fields=%r type=%s resolved=%r. " + "enrich_chunks_with_document_metadata will skip enrichment.", + include_metadata, + fields, + type(fields).__name__, + resolved, + ) + return include_metadata, set() + return include_metadata, {f for f in fields if isinstance(f, str)} + + +def enrich_chunks_with_document_metadata( + chunks: list[dict], + metadata_fields: set[str] | None = None, + *, + kb_field: str = "kb_id", + doc_field: str = "doc_id", + output_field: str = "document_metadata", +) -> None: + """ + Mutates chunk payloads in-place by attaching `document_metadata`. + Field names can be customized for different chunk schemas. + """ + if metadata_fields is not None and not metadata_fields: + return + + doc_ids_by_kb: dict[str, set[str]] = {} + for chunk in chunks: + kb_ids = chunk.get(kb_field) + doc_id = chunk.get(doc_field) + if not kb_ids or not doc_id: + continue + if isinstance(kb_ids, (list, tuple)): + for kid in kb_ids: + if kid: + doc_ids_by_kb.setdefault(kid, set()).add(doc_id) + else: + doc_ids_by_kb.setdefault(kb_ids, set()).add(doc_id) + + if not doc_ids_by_kb: + return + + # Resolve service lazily so callers/tests that swap service modules at runtime + # (e.g. via monkeypatch) don't get stuck with a stale class reference. + from api.db.services.doc_metadata_service import DocMetadataService + metadata_getter = getattr(DocMetadataService, "get_metadata_for_documents", None) + if not callable(metadata_getter): + logging.warning( + "DocMetadataService.get_metadata_for_documents is unavailable; " + "skipping metadata enrichment." + ) + return + + meta_by_doc: dict[str, dict] = {} + for kb_id, doc_ids in doc_ids_by_kb.items(): + meta_map = metadata_getter(list(doc_ids), kb_id) + if meta_map: + meta_by_doc.update(meta_map) + logging.debug("Fetched metadata for %d docs in kb_id=%s", len(meta_map), kb_id) + + for chunk in chunks: + doc_id = chunk.get(doc_field) + if not doc_id: + continue + meta = meta_by_doc.get(doc_id) + if not meta: + continue + if metadata_fields is not None: + meta = {k: v for k, v in meta.items() if k in metadata_fields} + if meta: + chunk[output_field] = meta + logging.debug("Enriched chunk for doc_id=%s with %d metadata fields: %s", doc_id, len(meta), list(meta.keys())) diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 47c0b9f2b..2ef8b8f8c 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -58,6 +58,7 @@ def chunks_format(reference): "term_similarity": chunk.get("term_similarity"), "row_id": chunk.get("row_id"), "doc_type": get_value(chunk, "doc_type_kwd", "doc_type"), + "document_metadata": chunk.get("document_metadata"), } for chunk in raw_chunks if isinstance(chunk, dict) @@ -102,9 +103,6 @@ def message_fit_in(msg, max_length=4000): def kb_prompt(kbinfos, max_tokens, hash_id=False): - from api.db.services.document_service import DocumentService - from api.db.services.doc_metadata_service import DocMetadataService - knowledges = [get_value(ck, "content", "content_with_weight") for ck in kbinfos["chunks"]] kwlg_len = len(knowledges) used_token_count = 0 @@ -119,14 +117,6 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False): logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}") break - docs = DocumentService.get_by_ids([get_value(ck, "doc_id", "document_id") for ck in kbinfos["chunks"][:chunks_num]]) - - docs_with_meta = {} - for d in docs: - meta = DocMetadataService.get_document_metadata(d.id) - docs_with_meta[d.id] = meta if meta else {} - docs = docs_with_meta - def draw_node(k, line): if line is not None and not isinstance(line, str): line = str(line) @@ -138,8 +128,9 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False): for i, ck in enumerate(kbinfos["chunks"][:chunks_num]): cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500)) cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name")) - cnt += draw_node("URL", ck['url']) if "url" in ck else "" - for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items(): + cnt += draw_node("URL", ck.get('url', '')) + meta = ck.get("document_metadata", {}) + for k, v in meta.items(): cnt += draw_node(k, v) cnt += "\n└── Content:\n" cnt += get_value(ck, "content", "content_with_weight") diff --git a/run_tests.py b/run_tests.py index aee34a833..48b039187 100755 --- a/run_tests.py +++ b/run_tests.py @@ -43,6 +43,8 @@ class TestRunner: self.verbose = False self.ignore_syntax_warning = False self.markers = "" + self.test_path = "" + self.keyword = "" # Python interpreter path self.python = sys.executable @@ -100,13 +102,20 @@ EXAMPLES: def build_pytest_command(self) -> List[str]: """Build the pytest command arguments""" - cmd = ["pytest", str(self.ut_dir)] - - # Add test path + cmd = ["pytest"] + if self.test_path: + test_target = Path(self.test_path) + if not test_target.is_absolute(): + test_target = self.project_root / test_target + cmd.append(str(test_target)) + else: + cmd.append(str(self.ut_dir)) # Add markers if self.markers: cmd.extend(["-m", self.markers]) + if self.keyword: + cmd.extend(["-k", self.keyword]) # Add verbose flag if self.verbose: @@ -161,9 +170,13 @@ EXAMPLES: self.print_info(f"Coverage: {self.coverage}") self.print_info(f"Parallel: {self.parallel}") self.print_info(f"Verbose: {self.verbose}") + if self.test_path: + self.print_info(f"Test target: {self.test_path}") if self.markers: self.print_info(f"Markers: {self.markers}") + if self.keyword: + self.print_info(f"Keyword: {self.keyword}") print(f"\n{Colors.BLUE}[EXECUTING]{Colors.NC} {' '.join(cmd)}\n") @@ -244,6 +257,13 @@ Examples: help="Run specific test file or directory" ) + parser.add_argument( + "-k", "--keyword", + type=str, + default="", + help="Run tests matching keyword expression (pytest -k)" + ) + parser.add_argument( "-m", "--markers", type=str, @@ -260,6 +280,8 @@ Examples: self.verbose = args.verbose self.markers = args.markers self.ignore_syntax_warning = args.ignore + self.test_path = args.test + self.keyword = args.keyword return True diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index fd65e6116..de520f3fe 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -14,6 +14,7 @@ # limitations under the License. # from typing import Any + from .base import Base from .document import Document @@ -79,7 +80,7 @@ class DataSet(Base): # Validate that id and ids are not used together if id and ids: raise ValueError("Cannot use both 'id' and 'ids' parameters at the same time.") - + params = { "id": id, "name": name, @@ -109,8 +110,7 @@ class DataSet(Base): res = res.json() if res.get("code") != 0: raise Exception(res["message"]) - - + def _get_documents_status(self, document_ids): import time terminal_states = {"DONE", "FAIL", "CANCEL"} diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index 0d3ee68d1..4a6d022c6 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -17,6 +17,7 @@ import asyncio import inspect import importlib.util import sys +from functools import wraps from pathlib import Path from types import ModuleType, SimpleNamespace @@ -26,6 +27,16 @@ import pytest from api.db import FileType +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + class _DummyManager: def route(self, *_args, **_kwargs): def decorator(func): @@ -126,6 +137,127 @@ def _load_doc_module(monkeypatch): common_pkg.__path__ = [str(repo_root / "common")] monkeypatch.setitem(sys.modules, "common", common_pkg) + common_settings_mod = ModuleType("common.settings") + common_settings_mod.retriever = SimpleNamespace() + common_settings_mod.kg_retriever = SimpleNamespace() + common_settings_mod.STORAGE_IMPL = SimpleNamespace(get=lambda *_args, **_kwargs: b"", rm=lambda *_args, **_kwargs: None) + monkeypatch.setitem(sys.modules, "common.settings", common_settings_mod) + + class _FakeExpr: + def __or__(self, other): + return self + + def __and__(self, other): + return self + + class _FakeField: + def __eq__(self, other): + return _FakeExpr() + + def __ne__(self, other): + return _FakeExpr() + + def is_null(self, value=True): + return _FakeExpr() + + class _StubDocumentModel: + id = _FakeField() + run = _FakeField() + + class _StubTaskModel: + doc_id = _FakeField() + + db_models_mod = ModuleType("api.db.db_models") + db_models_mod.APIToken = SimpleNamespace(query=lambda **_kwargs: []) + db_models_mod.Document = _StubDocumentModel + db_models_mod.Task = _StubTaskModel + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [str(repo_root / "api" / "db" / "services")] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + doc_metadata_service_mod = ModuleType("api.db.services.doc_metadata_service") + doc_metadata_service_mod.DocMetadataService = SimpleNamespace( + get_flatted_meta_by_kbs=lambda *_args, **_kwargs: [], + get_metadata_for_documents=lambda *_args, **_kwargs: {}, + ) + monkeypatch.setitem(sys.modules, "api.db.services.doc_metadata_service", doc_metadata_service_mod) + + document_service_mod = ModuleType("api.db.services.document_service") + document_service_mod.DocumentService = SimpleNamespace( + query=lambda **_kwargs: [], + filter_update=lambda *_args, **_kwargs: 0, + get_by_id=lambda *_args, **_kwargs: (False, None), + update_by_id=lambda *_args, **_kwargs: True, + decrement_chunk_num=lambda *_args, **_kwargs: None, + get_embd_id=lambda *_args, **_kwargs: "", + get_tenant_embd_id=lambda *_args, **_kwargs: None, + ) + monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod) + + file2document_service_mod = ModuleType("api.db.services.file2document_service") + file2document_service_mod.File2DocumentService = SimpleNamespace( + get_storage_address=lambda **_kwargs: ("", ""), + ) + monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_service_mod) + + knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service") + knowledgebase_service_mod.KnowledgebaseService = SimpleNamespace( + accessible=lambda **_kwargs: False, + get_by_id=lambda *_args, **_kwargs: (False, None), + get_by_ids=lambda *_args, **_kwargs: [], + list_documents_by_ids=lambda *_args, **_kwargs: [], + query=lambda **_kwargs: [], + ) + monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod) + + task_service_mod = ModuleType("api.db.services.task_service") + task_service_mod.TaskService = SimpleNamespace(filter_delete=lambda *_args, **_kwargs: None) + task_service_mod.cancel_all_task_of = lambda *_args, **_kwargs: None + task_service_mod.queue_tasks = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.check_duplicate_ids = lambda ids, _kind="item": (ids, []) + api_utils_mod.construct_json_result = lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data} + api_utils_mod.get_error_data_result = lambda message="Sorry! Data missing!", code=102: {"code": code, "message": message} + api_utils_mod.get_request_json = lambda: _AwaitableValue({}) + api_utils_mod.get_result = lambda code=0, message="", data=None, total=None: { + key: value + for key, value in {"code": code, "message": message, "data": data, "total": total}.items() + if value is not None + } + api_utils_mod.server_error_response = lambda e: {"code": 500, "message": str(e)} + def _token_required(func): + @wraps(func) + async def wrapper(*args, **kwargs): + return await func(*args, **kwargs) + + return wrapper + + api_utils_mod.token_required = _token_required + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + common_metadata_utils_mod = ModuleType("common.metadata_utils") + common_metadata_utils_mod.convert_conditions = lambda conditions: conditions + common_metadata_utils_mod.meta_filter = lambda *_args, **_kwargs: [] + monkeypatch.setitem(sys.modules, "common.metadata_utils", common_metadata_utils_mod) + + rag_app_tag_mod = ModuleType("rag.app.tag") + rag_app_tag_mod.label_question = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "rag.app.tag", rag_app_tag_mod) + + rag_prompts_generator_mod = ModuleType("rag.prompts.generator") + rag_prompts_generator_mod.cross_languages = lambda *_args, **_kwargs: "" + rag_prompts_generator_mod.keyword_extraction = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_prompts_generator_mod) + + rag_nlp_mod = ModuleType("rag.nlp") + rag_nlp_mod.search = SimpleNamespace(index_name=lambda tenant_id: f"idx_{tenant_id}") + monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp_mod) + monkeypatch.setitem(sys.modules, "rag.nlp.search", rag_nlp_mod.search) + deepdoc_pkg = ModuleType("deepdoc") deepdoc_parser_pkg = ModuleType("deepdoc.parser") deepdoc_parser_pkg.__path__ = [] @@ -344,7 +476,7 @@ def _patch_docstore(monkeypatch, module, **kwargs): "index_exist": lambda *_args, **_kwargs: False, } defaults.update(kwargs) - monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(**defaults)) + monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(**defaults), raising=False) @pytest.mark.p2 @@ -643,7 +775,7 @@ class TestDocRoutesUnit: res = _run(_route_core(module.update_chunk)("tenant-1", "ds-1", "doc-1", "chunk-1")) assert res["code"] == 0 - def test_retrieval_validation_matrix(self, monkeypatch): + def test_retrieval_metadata_validation_matrix(self, monkeypatch): module = _load_doc_module(monkeypatch) monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"dataset_ids": "bad"})) res = _run(module.retrieval_test.__wrapped__("tenant-1")) @@ -825,6 +957,7 @@ class TestDocRoutesUnit: "keyword": True, "toc_enhance": True, "use_kg": True, + "reference_metadata": {"include": True, "fields": ["author"]}, } ), ) @@ -835,6 +968,16 @@ class TestDocRoutesUnit: monkeypatch.setattr(module.settings, "kg_retriever", _FeatureKgRetriever()) monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: {}) monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: SimpleNamespace()) + monkeypatch.setattr( + module.DocMetadataService, + "get_metadata_for_documents", + lambda _doc_ids, _kb_id: { + "doc-1": {"author": "alice", "year": "2025"}, + "doc-toc": {"author": "bob"}, + "doc-child": {"author": "carol"}, + "doc-kg": {"author": "kg-author"}, + }, + ) res = _run(module.retrieval_test.__wrapped__("tenant-1")) assert res["code"] == 0, res["message"] assert feature_calls["cross"] == ("fr",) @@ -842,6 +985,7 @@ class TestDocRoutesUnit: assert feature_calls["retrieval_question"] == "q-xl-kw" assert res["data"]["chunks"][0]["id"] == "kg-1" assert res["data"]["chunks"][0]["content"] == "kg content" + assert res["data"]["chunks"][0]["document_metadata"]["author"] == "kg-author" assert any(chunk["id"] == "toc-1" for chunk in res["data"]["chunks"]) assert any(chunk["id"] == "child-1" for chunk in res["data"]["chunks"]) diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index f442db519..6d2dcbf3a 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -251,6 +251,53 @@ def _load_session_module(monkeypatch): common_constants_mod.MAXIMUM_TASK_PAGE_NUMBER = _MTPN monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod) + common_metadata_utils_mod = ModuleType("common.metadata_utils") + common_metadata_utils_mod.apply_meta_data_filter = lambda *_args, **_kwargs: [] + common_metadata_utils_mod.convert_conditions = lambda conditions: conditions + common_metadata_utils_mod.meta_filter = lambda *_args, **_kwargs: True + monkeypatch.setitem(sys.modules, "common.metadata_utils", common_metadata_utils_mod) + + common_settings_mod = ModuleType("common.settings") + common_settings_mod.retriever = SimpleNamespace() + common_settings_mod.kg_retriever = SimpleNamespace() + monkeypatch.setitem(sys.modules, "common.settings", common_settings_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.add_tenant_id_to_kwargs = lambda func: func + api_utils_mod.check_duplicate_ids = lambda ids, _kind="item": (ids, []) + api_utils_mod.get_data_error_result = lambda message="Sorry! Data missing!", code=_StubRetCode.DATA_ERROR: {"code": code, "message": message} + api_utils_mod.get_error_data_result = lambda message="Sorry! Data missing!", code=_StubRetCode.DATA_ERROR: {"code": code, "message": message} + api_utils_mod.get_json_result = lambda code=_StubRetCode.SUCCESS, message="success", data=None: {"code": code, "message": message, "data": data} + api_utils_mod.get_result = lambda code=_StubRetCode.SUCCESS, message="", data=None, total=None: { + key: value + for key, value in {"code": code, "message": message, "data": data, "total": total}.items() + if value is not None + } + api_utils_mod.get_request_json = lambda: _AwaitableValue({}) + api_utils_mod.server_error_response = lambda e: {"code": _StubRetCode.SERVER_ERROR, "message": str(e)} + api_utils_mod.token_required = lambda func: func + api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + rag_app_tag_mod = ModuleType("rag.app.tag") + rag_app_tag_mod.label_question = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "rag.app.tag", rag_app_tag_mod) + + rag_prompts_generator_mod = ModuleType("rag.prompts.generator") + rag_prompts_generator_mod.cross_languages = lambda *_args, **_kwargs: "" + rag_prompts_generator_mod.keyword_extraction = lambda *_args, **_kwargs: "" + rag_prompts_generator_mod.chunks_format = lambda chunks: chunks + monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_prompts_generator_mod) + + rag_prompts_template_mod = ModuleType("rag.prompts.template") + rag_prompts_template_mod.load_prompt = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "rag.prompts.template", rag_prompts_template_mod) + + rag_nlp_mod = ModuleType("rag.nlp") + rag_nlp_mod.search = SimpleNamespace(index_name=lambda tenant_id: f"idx_{tenant_id}") + monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp_mod) + monkeypatch.setitem(sys.modules, "rag.nlp.search", rag_nlp_mod.search) + deepdoc_pkg = ModuleType("deepdoc") deepdoc_parser_pkg = ModuleType("deepdoc.parser") deepdoc_parser_pkg.__path__ = [] @@ -508,8 +555,128 @@ def _load_session_module(monkeypatch): quart_mod.jsonify = lambda payload: payload quart_mod.current_app = SimpleNamespace() quart_mod.has_app_context = lambda: False + quart_mod.has_request_context = lambda: False + quart_mod.has_websocket_context = lambda: False + quart_mod.websocket = SimpleNamespace() monkeypatch.setitem(sys.modules, "quart", quart_mod) + quart_auth_mod = ModuleType("quart_auth") + + class _StubAuthUser: + pass + + quart_auth_mod.AuthUser = _StubAuthUser + monkeypatch.setitem(sys.modules, "quart_auth", quart_auth_mod) + + class _FakeExpr: + def __or__(self, other): + return self + + def __and__(self, other): + return self + + class _FakeField: + def __eq__(self, other): + return _FakeExpr() + + def __ne__(self, other): + return _FakeExpr() + + def is_null(self, value=True): + return _FakeExpr() + + class _StubTaskModel: + id = _FakeField() + doc_id = _FakeField() + + db_models_mod = ModuleType("api.db.db_models") + db_models_mod.APIToken = SimpleNamespace(query=lambda **_kwargs: []) + db_models_mod.Task = _StubTaskModel + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [str(repo_root / "api" / "db" / "services")] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + api_service_mod = ModuleType("api.db.services.api_service") + api_service_mod.API4ConversationService = SimpleNamespace( + get_names=lambda *_args, **_kwargs: [], + get_list=lambda *_args, **_kwargs: (0, []), + save=lambda **_kwargs: True, + get_by_id=lambda _session_id: (True, SimpleNamespace(to_dict=lambda: {"id": _session_id})), + delete_by_id=lambda *_args, **_kwargs: True, + query=lambda **_kwargs: [], + ) + monkeypatch.setitem(sys.modules, "api.db.services.api_service", api_service_mod) + + canvas_service_mod = ModuleType("api.db.services.canvas_service") + canvas_service_mod.CanvasTemplateService = SimpleNamespace(get_all=lambda *_args, **_kwargs: []) + canvas_service_mod.UserCanvasService = SimpleNamespace( + query=lambda **_kwargs: [], + get_by_id=lambda *_args, **_kwargs: (False, None), + accessible=lambda *_args, **_kwargs: False, + get_agent_dsl_with_release=lambda *_args, **_kwargs: (SimpleNamespace(id="agent-1"), "{}"), + ) + + async def _empty_agent_completion(*_args, **_kwargs): + if False: + yield None + + canvas_service_mod.completion = _empty_agent_completion + canvas_service_mod.completion_openai = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "api.db.services.canvas_service", canvas_service_mod) + + conversation_service_mod = ModuleType("api.db.services.conversation_service") + conversation_service_mod.ConversationService = SimpleNamespace(query=lambda **_kwargs: []) + conversation_service_mod.async_iframe_completion = lambda *_args, **_kwargs: None + conversation_service_mod.async_completion = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.conversation_service", conversation_service_mod) + + dialog_service_mod = ModuleType("api.db.services.dialog_service") + dialog_service_mod.DialogService = SimpleNamespace( + query=lambda **_kwargs: [], + get_by_id=lambda *_args, **_kwargs: (False, None), + ) + dialog_service_mod.async_ask = lambda *_args, **_kwargs: None + dialog_service_mod.async_chat = lambda *_args, **_kwargs: None + dialog_service_mod.gen_mindmap = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.dialog_service", dialog_service_mod) + + doc_metadata_service_mod = ModuleType("api.db.services.doc_metadata_service") + doc_metadata_service_mod.DocMetadataService = SimpleNamespace( + get_flatted_meta_by_kbs=lambda *_args, **_kwargs: [], + get_metadata_for_documents=lambda *_args, **_kwargs: {}, + ) + monkeypatch.setitem(sys.modules, "api.db.services.doc_metadata_service", doc_metadata_service_mod) + + knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service") + knowledgebase_service_mod.KnowledgebaseService = SimpleNamespace( + query=lambda **_kwargs: [], + get_by_id=lambda *_args, **_kwargs: (False, None), + ) + monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod) + + search_service_mod = ModuleType("api.db.services.search_service") + search_service_mod.SearchService = SimpleNamespace( + query=lambda **_kwargs: [], + get_detail=lambda *_args, **_kwargs: None, + ) + monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + user_service_mod.UserTenantService = SimpleNamespace(query=lambda **_kwargs: []) + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + user_canvas_version_mod = ModuleType("api.db.services.user_canvas_version") + user_canvas_version_mod.UserCanvasVersionService = SimpleNamespace( + list_by_canvas_id=lambda *_args, **_kwargs: [], + get_by_id=lambda *_args, **_kwargs: (False, None), + get_latest_version_title=lambda *_args, **_kwargs: "", + save_or_replace_latest=lambda **_kwargs: True, + build_version_title=lambda *_args, **_kwargs: "v1", + ) + monkeypatch.setitem(sys.modules, "api.db.services.user_canvas_version", user_canvas_version_mod) + module_path = repo_root / "api" / "apps" / "sdk" / "session.py" spec = importlib.util.spec_from_file_location("test_session_sdk_routes_unit_module", module_path) module = importlib.util.module_from_spec(spec) @@ -612,7 +779,10 @@ def _load_agent_api_module(monkeypatch): monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod) knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service") - knowledgebase_service_mod.KnowledgebaseService = SimpleNamespace(query=lambda **_kwargs: []) + knowledgebase_service_mod.KnowledgebaseService = SimpleNamespace( + query=lambda **_kwargs: [], + get_by_id=lambda *_args, **_kwargs: (False, None), + ) monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod) task_service_mod = ModuleType("api.db.services.task_service") @@ -1352,7 +1522,7 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch): "rank_feature": rank_feature, } ) - return {"chunks": [{"id": "chunk-1", "vector": [0.1]}]} + return {"chunks": [{"id": "chunk-1", "doc_id": "doc-1", "kb_id": "kb-1", "vector": [0.1]}]} async def _translate(_tenant_id, _chat_id, question, _langs): return question + "-translated" @@ -1384,10 +1554,16 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch): "vector_similarity_weight": 0.8, "top_k": 7, "rerank_id": "reranker-model", + "reference_metadata": {"include": True, "fields": ["author"]}, } }, ) monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: [{"id": "doc-2"}]) + monkeypatch.setattr( + module.DocMetadataService, + "get_metadata_for_documents", + lambda _doc_ids, _kb_id: {"doc-1": {"author": "alice", "year": "2025"}}, + ) monkeypatch.setattr(module, "apply_meta_data_filter", _apply_filter) monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-a")]) monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) @@ -1409,6 +1585,8 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch): assert retrieval_capture["local_doc_ids"] == ["doc-filtered"] assert retrieval_capture["rank_feature"] == ["label-1"] assert retrieval_capture["rerank_mdl"] is not None + assert res["data"]["chunks"][0]["document_metadata"]["author"] == "alice" + assert "year" not in res["data"]["chunks"][0]["document_metadata"] assert any(call[1] == module.LLMType.EMBEDDING.value and call[2] == "embd-model" for call in llm_calls) llm_calls.clear() @@ -1621,9 +1799,18 @@ def test_build_reference_chunks_metadata_matrix_unit(monkeypatch): monkeypatch.setattr(module, "chunks_format", lambda _reference: [{"dataset_id": "kb-1", "document_id": "doc-1"}]) monkeypatch.setattr(module.DocMetadataService, "get_metadata_for_documents", lambda _doc_ids, _kb_id: {"doc-1": {"author": "alice"}}) + res = module._build_reference_chunks([], include_metadata=True, metadata_fields=None) + assert res[0]["document_metadata"] == {"author": "alice"} + + res = module._build_reference_chunks([], include_metadata=True, metadata_fields=[]) + assert "document_metadata" not in res[0] + res = module._build_reference_chunks([], include_metadata=True, metadata_fields=[1, None]) assert "document_metadata" not in res[0] + res = module._build_reference_chunks([], include_metadata=True, metadata_fields="author") + assert "document_metadata" not in res[0] + source_chunks = [ {"dataset_id": "kb-1", "document_id": "doc-1"}, {"dataset_id": "kb-2", "document_id": "doc-2"}, diff --git a/web/src/components/fallback-component/index.tsx b/web/src/components/fallback-component/index.tsx index 131051827..7f4f1b795 100644 --- a/web/src/components/fallback-component/index.tsx +++ b/web/src/components/fallback-component/index.tsx @@ -1,5 +1,6 @@ import React from 'react'; import { useTranslation } from 'react-i18next'; +import { isRouteErrorResponse, useRouteError } from 'react-router'; interface FallbackComponentProps { error?: Error; @@ -7,10 +8,32 @@ interface FallbackComponentProps { } const FallbackComponent: React.FC = ({ - error, + error: errorProp, reset, }) => { const { t } = useTranslation(); + const routeError = useRouteError(); + const error = + errorProp ?? (routeError instanceof Error ? routeError : undefined); + + let routeErrorDataStr = ''; + if (isRouteErrorResponse(routeError)) { + if (typeof routeError.data === 'string') { + routeErrorDataStr = routeError.data; + } else if (routeError.data == null) { + routeErrorDataStr = 'no body'; + } else { + try { + routeErrorDataStr = JSON.stringify(routeError.data); + } catch { + routeErrorDataStr = String(routeError.data); + } + } + } + + const errorMessage = isRouteErrorResponse(routeError) + ? `${routeError.status} ${routeError.statusText}${routeErrorDataStr ? `: ${routeErrorDataStr}` : ''}` + : (error?.toString() ?? (routeError ? String(routeError) : undefined)); return (
@@ -21,10 +44,10 @@ const FallbackComponent: React.FC = ({ 'Sorry, an error occurred while loading the page.', )}

- {error && ( -
+ {errorMessage && ( +
{t('error_boundary.details', 'Error details')} - {error.toString()} + {errorMessage}
)}
diff --git a/web/src/components/markdown-content/index.tsx b/web/src/components/markdown-content/index.tsx index 72247674f..846d25d5a 100644 --- a/web/src/components/markdown-content/index.tsx +++ b/web/src/components/markdown-content/index.tsx @@ -40,6 +40,13 @@ import styles from './index.module.less'; const getChunkIndex = (match: string) => parseCitationIndex(match); +const formatMetadataValue = (value: unknown) => { + if (Array.isArray(value)) return value.join(', '); + if (value === null || value === undefined) return ''; + if (typeof value === 'object') return JSON.stringify(value); + return String(value); +}; + // TODO: The display of the table is inconsistent with the display previously placed in the MessageItem. const MarkdownContent = ({ reference, @@ -174,6 +181,21 @@ const MarkdownContent = ({ className={classNames(styles.chunkContentText)} dir="auto" >
+ {chunkItem?.document_metadata && + Object.keys(chunkItem.document_metadata).length > 0 && ( +
+ {Object.entries(chunkItem.document_metadata).map( + ([key, value]) => ( +
+ {key}:{' '} + + {formatMetadataValue(value)} + +
+ ), + )} +
+ )} {documentId && (
{fileThumbnail ? ( diff --git a/web/src/hooks/use-knowledge-request.ts b/web/src/hooks/use-knowledge-request.ts index 782b1282f..5bd5a796b 100644 --- a/web/src/hooks/use-knowledge-request.ts +++ b/web/src/hooks/use-knowledge-request.ts @@ -48,6 +48,7 @@ export const enum KnowledgeApiAction { FetchKnowledgeDetail = 'fetchKnowledgeDetail', FetchKnowledgeGraph = 'fetchKnowledgeGraph', FetchMetadata = 'fetchMetadata', + FetchMetadataKeys = 'fetchMetadataKeys', FetchKnowledgeList = 'fetchKnowledgeList', RemoveKnowledgeGraph = 'removeKnowledgeGraph', } @@ -378,6 +379,24 @@ export function useFetchKnowledgeMetadata(kbIds: string[] = []) { return { data, loading }; } +export function useFetchKnowledgeMetadataKeys(kbIds: string[] = []) { + const sortedKbIds = useMemo(() => [...kbIds].sort(), [kbIds]); + const { data, isFetching: loading } = useQuery({ + queryKey: [KnowledgeApiAction.FetchMetadataKeys, sortedKbIds], + initialData: [], + enabled: sortedKbIds.length > 0, + gcTime: 0, + queryFn: async () => { + const { data } = await kbService.getMetaKeys({ + kb_ids: sortedKbIds.join(','), + }); + return data?.data ?? []; + }, + }); + + return { data, loading }; +} + export const useRemoveKnowledgeGraph = () => { const knowledgeBaseId = useKnowledgeBaseId(); diff --git a/web/src/interfaces/database/chat.ts b/web/src/interfaces/database/chat.ts index 5cce383f5..447409bcf 100644 --- a/web/src/interfaces/database/chat.ts +++ b/web/src/interfaces/database/chat.ts @@ -22,6 +22,10 @@ export interface PromptConfig { cross_languages?: Array; tavily_api_key?: string; toc_enhance?: boolean; + reference_metadata?: { + include?: boolean; + fields?: string[]; + }; } export interface Parameter { @@ -126,6 +130,7 @@ export interface IReferenceChunk { term_similarity: number; positions: number[]; doc_type?: string; + document_metadata?: Record; } export interface IReference { diff --git a/web/src/pages/agent/form/doc-generator-form/use-values.ts b/web/src/pages/agent/form/doc-generator-form/use-values.ts index e4426ae8a..b4df1809a 100644 --- a/web/src/pages/agent/form/doc-generator-form/use-values.ts +++ b/web/src/pages/agent/form/doc-generator-form/use-values.ts @@ -1,5 +1,5 @@ +import { type Node } from '@xyflow/react'; import { useMemo } from 'react'; -import { Node } from 'reactflow'; import { initialDocGeneratorValues } from '../../constant'; export const useValues = (node?: Node) => { diff --git a/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx b/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx index 367748cef..5794787d9 100644 --- a/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx +++ b/web/src/pages/next-chats/chat/app-settings/chat-basic-settings.tsx @@ -13,16 +13,45 @@ import { FormLabel, FormMessage, } from '@/components/ui/form'; +import { MultiSelect } from '@/components/ui/multi-select'; +import { Switch } from '@/components/ui/switch'; import { Textarea } from '@/components/ui/textarea'; import { useTranslate } from '@/hooks/common-hooks'; +import { useFetchKnowledgeMetadataKeys } from '@/hooks/use-knowledge-request'; import { getDirAttribute } from '@/utils/text-direction'; -import { useFormContext } from 'react-hook-form'; +import { useEffect, useMemo } from 'react'; +import { useFormContext, useWatch } from 'react-hook-form'; export default function ChatBasicSetting() { const { t } = useTranslate('chat'); const form = useFormContext(); const emptyResponseValue = form.watch('prompt_config.empty_response'); const prologueValue = form.watch('prompt_config.prologue'); + const kbIds = (useWatch({ control: form.control, name: 'dataset_ids' }) || + []) as string[]; + const metadataInclude = useWatch({ + control: form.control, + name: 'prompt_config.reference_metadata.include', + }); + const { data: metadataKeys } = useFetchKnowledgeMetadataKeys(kbIds); + const metadataFieldOptions = useMemo(() => { + return (metadataKeys || []).map((key) => ({ + label: key, + value: key, + })); + }, [metadataKeys]); + + useEffect(() => { + const currentFields = form.getValues('prompt_config.reference_metadata.fields'); + if (metadataInclude && Array.isArray(currentFields) && currentFields.length > 0 && metadataKeys) { + const validFields = currentFields.filter((field) => metadataKeys.includes(field)); + if (validFields.length !== currentFields.length) { + form.setValue('prompt_config.reference_metadata.fields', validFields); + } + } else if (!metadataInclude) { + form.setValue('prompt_config.reference_metadata.fields', undefined); + } + }, [kbIds, metadataKeys, metadataInclude, form]); return (
@@ -83,6 +112,59 @@ export default function ChatBasicSetting() { + ( + + + { + field.onChange(value); + if (!value) { + form.setValue( + 'prompt_config.reference_metadata.fields', + undefined, + ); + } + }} + /> + + + Show chunk metadata + + + )} + /> + {metadataInclude && ( + ( + + + {t('metadataKeys')} + + + + + + + )} + /> + )}
); } diff --git a/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx b/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx index 028cc0147..f3079a314 100644 --- a/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx +++ b/web/src/pages/next-chats/chat/app-settings/chat-settings.tsx @@ -57,6 +57,10 @@ export function ChatSettings({ hasSingleChatBox }: ChatSettingsProps) { reasoning: false, cross_languages: [], toc_enhance: false, + reference_metadata: { + include: false, + fields: undefined, + }, }, top_n: 8, similarity_threshold: 0.2, @@ -74,6 +78,14 @@ export function ChatSettings({ hasSingleChatBox }: ChatSettingsProps) { values, 'llm_setting.', ); + const referenceMetadata = nextValues?.prompt_config?.reference_metadata; + if ( + referenceMetadata && + Array.isArray(referenceMetadata.fields) && + referenceMetadata.fields.length === 0 + ) { + referenceMetadata.fields = undefined; + } updateChat({ chatId: id!, @@ -101,8 +113,20 @@ export function ChatSettings({ hasSingleChatBox }: ChatSettingsProps) { const llmSettingEnabledValues = setLLMSettingEnabledValues( data.llm_setting, ); + const referenceMetadata = data?.prompt_config?.reference_metadata; + const normalizedReferenceMetadata = + referenceMetadata && + Array.isArray(referenceMetadata.fields) && + referenceMetadata.fields.length === 0 + ? { ...referenceMetadata, fields: undefined } + : referenceMetadata; + const nextData = { ...data, + prompt_config: { + ...data.prompt_config, + reference_metadata: normalizedReferenceMetadata, + }, ...llmSettingEnabledValues, }; diff --git a/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx b/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx index ba29383f9..f80ab79b7 100644 --- a/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx +++ b/web/src/pages/next-chats/chat/app-settings/use-chat-setting-schema.tsx @@ -36,6 +36,12 @@ export function useChatSettingSchema() { reasoning: z.boolean().optional(), cross_languages: z.array(z.string()).optional(), toc_enhance: z.boolean().optional(), + reference_metadata: z + .object({ + include: z.boolean().optional(), + fields: z.array(z.string()).optional(), + }) + .optional(), }); const formSchema = z.object({ diff --git a/web/src/pages/next-search/search-setting.tsx b/web/src/pages/next-search/search-setting.tsx index 7b8203bf0..d9a381782 100644 --- a/web/src/pages/next-search/search-setting.tsx +++ b/web/src/pages/next-search/search-setting.tsx @@ -22,9 +22,15 @@ import { FormMessage, } from '@/components/ui/form'; import { Input } from '@/components/ui/input'; +import { MultiSelect } from '@/components/ui/multi-select'; import { RAGFlowSelect } from '@/components/ui/select'; import { Spin } from '@/components/ui/spin'; import { Switch } from '@/components/ui/switch'; +import { Textarea } from '@/components/ui/textarea'; +import { + useFetchKnowledgeList, + useFetchKnowledgeMetadataKeys, +} from '@/hooks/use-knowledge-request'; import { useComposeLlmOptionsByModelTypes, useSelectLlmOptionsByModelType, @@ -79,6 +85,12 @@ const SearchSettingFormSchema = z highlight: z.boolean(), keyword: z.boolean(), chat_settingcross_languages: z.array(z.string()), + reference_metadata: z + .object({ + include: z.boolean().optional(), + fields: z.array(z.string()).optional(), + }) + .optional(), ...MetadataFilterSchema, }), }) @@ -156,6 +168,14 @@ const SearchSetting: React.FC = ({ related_search: search_config?.related_search || false, query_mindmap: search_config?.query_mindmap || false, meta_data_filter: search_config?.meta_data_filter, + reference_metadata: { + include: search_config?.reference_metadata?.include || false, + fields: + search_config?.reference_metadata?.fields && + search_config.reference_metadata.fields.length > 0 + ? search_config.reference_metadata.fields + : undefined, + }, }, }); }, [data, search_config, llm_setting, formMethods, descriptionDefaultValue]); @@ -193,6 +213,35 @@ const SearchSetting: React.FC = ({ control: formMethods.control, name: 'search_config.summary', }); + const selectedKbIds = useWatch({ + control: formMethods.control, + name: 'search_config.kb_ids', + }); + const referenceMetadataEnabled = useWatch({ + control: formMethods.control, + name: 'search_config.reference_metadata.include', + }); + const { data: metadataKeys } = useFetchKnowledgeMetadataKeys( + selectedKbIds || [], + ); + const metadataFieldOptions = useMemo(() => { + return (metadataKeys || []).map((key) => ({ + label: key, + value: key, + })); + }, [metadataKeys]); + + useEffect(() => { + const currentFields = formMethods.getValues('search_config.reference_metadata.fields'); + if (referenceMetadataEnabled && Array.isArray(currentFields) && currentFields.length > 0 && metadataKeys) { + const validFields = currentFields.filter((field) => metadataKeys.includes(field)); + if (validFields.length !== currentFields.length) { + formMethods.setValue('search_config.reference_metadata.fields', validFields); + } + } else if (!referenceMetadataEnabled) { + formMethods.setValue('search_config.reference_metadata.fields', undefined); + } + }, [selectedKbIds, metadataKeys, referenceMetadataEnabled, formMethods]); // Reset top_k to 1024 only when user actively disables rerank (from true to false) const prevRerankEnabled = useRef(undefined); @@ -227,11 +276,22 @@ const SearchSetting: React.FC = ({ frequency_penalty: llm_setting.frequency_penalty, presence_penalty: llm_setting.presence_penalty, } as IllmSettingProps; + const referenceMetadata = other_config.reference_metadata; + const normalizedReferenceMetadata = referenceMetadata + ? { + ...referenceMetadata, + ...(Array.isArray(referenceMetadata.fields) && + referenceMetadata.fields.length === 0 + ? { fields: undefined } + : {}), + } + : referenceMetadata; await updateSearch({ ...other_formdata, search_config: { ...other_config, + reference_metadata: normalizedReferenceMetadata, chat_id: llm_setting.llm_id, vector_similarity_weight: 1 - vector_similarity_weight, rerank_id: use_rerank ? rerank_id : '', @@ -288,6 +348,61 @@ const SearchSetting: React.FC = ({ required > + ( + + + { + field.onChange(value); + if (!value) { + formMethods.setValue( + 'search_config.reference_metadata.fields', + undefined, + ); + } + }} + /> + + + Show chunk metadata + + + )} + /> + {referenceMetadataEnabled && ( + ( + + + Metadata fields + + + + + + + )} + /> + )} { + if (Array.isArray(value)) return value.join(', '); + if (value === null || value === undefined) return ''; + if (typeof value === 'object') return JSON.stringify(value); + return String(value); +}; export default function SearchingView({ setIsSearching, searchData, @@ -208,6 +214,26 @@ export default function SearchingView({ {chunk.content_with_weight}
+ {chunk.document_metadata && + Object.keys(chunk.document_metadata).length > 0 && ( +
+ {Object.entries(chunk.document_metadata).map( + ([key, value]) => ( +
+ + {key}: + {' '} + + {formatMetadataValue(value)} + +
+ ), + )} +
+ )}
diff --git a/web/src/pages/next-searches/hooks.ts b/web/src/pages/next-searches/hooks.ts index 89bdd88c5..e8358e20e 100644 --- a/web/src/pages/next-searches/hooks.ts +++ b/web/src/pages/next-searches/hooks.ts @@ -185,6 +185,10 @@ export interface ISearchAppDetailProps { method: string; manual: { key: string; op: string; value: string }[]; }; + reference_metadata?: { + include?: boolean; + fields?: string[]; + }; }; tenant_id: string; update_time: number; diff --git a/web/src/services/knowledge-service.ts b/web/src/services/knowledge-service.ts index 47e674e45..dfd31c341 100644 --- a/web/src/services/knowledge-service.ts +++ b/web/src/services/knowledge-service.ts @@ -24,6 +24,7 @@ const { listTagByKnowledgeIds, setMeta, getMeta, + getMetaKeys, retrievalTestShare, } = api; @@ -81,6 +82,10 @@ const methods = { url: getMeta, method: 'get', }, + getMetaKeys: { + url: getMetaKeys, + method: 'get', + }, retrievalTestShare: { url: retrievalTestShare, method: 'post',