From cfee2bc9dbcaceedc6b09e6e57fffbc240329b09 Mon Sep 17 00:00:00 2001 From: MkDev11 <94194147+MkDev11@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:52:18 -0700 Subject: [PATCH] feat: Auto-adjust chunk recall weights based on user feedback (#12689) ### What problem does this PR solve? Implements automatic adjustment of knowledge base chunk recall weights based on user feedback (upvotes/downvotes). When users upvote or downvote a response, the system locates the corresponding knowledge snippets and adjusts their recall weight to improve future retrieval quality. **Closes #12670** **How it works:** 1. User upvotes/downvotes a response via `POST /thumbup` 2. System extracts chunk IDs from the conversation reference 3. For each referenced chunk: - Reads current `pagerank_fea` value from document store - Increments (+1) for upvote or decrements (-1) for downvote - Clamps weight to [0, 100] range - Updates chunk in ES/Infinity/OceanBase 4. Future retrievals score these chunks higher/lower based on accumulated feedback **Files changed:** - `api/db/services/chunk_feedback_service.py` - New service for updating chunk pagerank weights - `api/apps/conversation_app.py` - Integrated feedback service into thumbup endpoint - `test/testcases/test_web_api/test_chunk_feedback/` - Unit tests ### Type of change - [x] New Feature (non-breaking change which adds functionality) ## Summary by CodeRabbit * **New Features** * Chat message feedback now updates per-chunk relevance weights (feature-flag gated), with configurable weighting and atomic updates across storage backends. * **Bug Fixes** * Stricter validation for message feedback inputs and more robust handling of feedback transitions. * **Tests** * Expanded test coverage for chunk-feedback behavior, weighting strategies, storage backends, and thumb-flip scenarios. * **Chores** * CI workflow extended to run the new chunk-feedback web API tests. --------- Co-authored-by: mkdev11 Co-authored-by: mkdev11 --- .github/workflows/tests.yml | 2 +- api/apps/restful_apis/chat_api.py | 51 +- api/db/services/chunk_feedback_service.py | 321 ++++++++++ rag/prompts/generator.py | 1 + rag/utils/es_conn.py | 115 +++- rag/utils/infinity_conn.py | 78 +++ rag/utils/ob_conn.py | 26 + rag/utils/opensearch_conn.py | 102 ++- .../test_chat_sdk_routes_unit.py | 11 + .../test_chunk_feedback/__init__.py | 15 + .../test_chunk_feedback_service.py | 584 ++++++++++++++++++ 11 files changed, 1293 insertions(+), 13 deletions(-) create mode 100644 api/db/services/chunk_feedback_service.py create mode 100644 test/testcases/test_web_api/test_chunk_feedback/__init__.py create mode 100644 test/testcases/test_web_api/test_chunk_feedback/test_chunk_feedback_service.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0d3aec74a..dae211471 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -247,7 +247,7 @@ jobs: echo "Waiting for service to be available... (last exit code: $?)" sleep 5 done - source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/test_api_app 2>&1 | tee infinity_web_api_test.log + source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/test_api_app test/testcases/test_web_api/test_chunk_feedback 2>&1 | tee infinity_web_api_test.log - name: Run http api tests against Infinity run: | diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py index 6cda09501..7c311ae4b 100644 --- a/api/apps/restful_apis/chat_api.py +++ b/api/apps/restful_apis/chat_api.py @@ -28,6 +28,7 @@ from api.db.joint_services.tenant_model_service import ( get_model_config_by_type_and_name, get_tenant_default_model_by_type, ) +from api.db.services.chunk_feedback_service import ChunkFeedbackService from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap from api.db.services.knowledgebase_service import KnowledgebaseService @@ -769,28 +770,64 @@ async def delete_session_message(chat_id, session_id, msg_id): @manager.route("/chats//sessions//messages//feedback", methods=["PUT"]) # noqa: F821 @login_required async def update_message_feedback(chat_id, session_id, msg_id): - if not _ensure_owned_chat(chat_id): + owned = _ensure_owned_chat(chat_id) + if not owned: return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() ok, conv = ConversationService.get_by_id(session_id) if not ok or conv.dialog_id != chat_id: return get_data_error_result(message="Session not found!") - up_down = req.get("thumbup") + thumb_raw = req.get("thumbup") + if not isinstance(thumb_raw, bool): + return get_data_error_result(message="thumbup must be a boolean") feedback = req.get("feedback", "") - conv = conv.to_dict() - for msg in conv["message"]: + conv_dict = conv.to_dict() + message_index = None + apply_chunk_feedback = False + prior_thumb = None + for i, msg in enumerate(conv_dict["message"]): if msg_id == msg.get("id", "") and msg.get("role", "") == "assistant": - if up_down: + prior_thumb = msg.get("thumbup") + if thumb_raw is True: msg["thumbup"] = True msg.pop("feedback", None) + apply_chunk_feedback = prior_thumb is not True else: msg["thumbup"] = False if feedback: msg["feedback"] = feedback + apply_chunk_feedback = prior_thumb is not False + message_index = i break - ConversationService.update_by_id(conv["id"], conv) - return get_json_result(data=_build_session_response(conv)) + + if message_index is not None and apply_chunk_feedback: + try: + ref_index = (message_index - 1) // 2 + if 0 <= ref_index < len(conv_dict.get("reference", [])): + reference = conv_dict["reference"][ref_index] + if reference: + if isinstance(prior_thumb, bool) and prior_thumb != thumb_raw: + ChunkFeedbackService.apply_feedback( + tenant_id=current_user.id, + reference=reference, + is_positive=not prior_thumb, + ) + feedback_result = ChunkFeedbackService.apply_feedback( + tenant_id=current_user.id, + reference=reference, + is_positive=thumb_raw is True, + ) + logging.debug( + "Chunk feedback applied: %s succeeded, %s failed", + feedback_result["success_count"], + feedback_result["fail_count"], + ) + except Exception as e: + logging.warning("Failed to apply chunk feedback: %s", e) + + ConversationService.update_by_id(conv_dict["id"], conv_dict) + return get_json_result(data=_build_session_response(conv_dict)) except Exception as ex: return server_error_response(ex) diff --git a/api/db/services/chunk_feedback_service.py b/api/db/services/chunk_feedback_service.py new file mode 100644 index 000000000..1d9fe23f4 --- /dev/null +++ b/api/db/services/chunk_feedback_service.py @@ -0,0 +1,321 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Service for adjusting chunk recall weights based on user feedback. + +When users upvote or downvote responses, this service updates the pagerank_fea +field of the referenced chunks to improve future retrieval quality. + +This feature is disabled by default. Enable it by setting the environment +variable CHUNK_FEEDBACK_ENABLED=true. + +Weighting modes (CHUNK_FEEDBACK_WEIGHTING): +- relevance (default): one small budget per feedback event is split across + cited chunks using retrieval scores (similarity / vector_similarity / + term_similarity) from the reference payload, so chunks that drove the answer + move more than weak tail context. +- uniform: legacy behavior — each cited chunk receives the full increment or + decrement (stronger total effect when many chunks are cited). + +Budget per feedback event is a small integer (1) applied to pagerank_fea +(0–100, integer in Infinity/OB/ES mappings). Relevance mode splits that unit +across cited chunks; uniform mode applies one unit per chunk (legacy, stronger +when many chunks are cited). + +Infinity uses row_id (returned by search results since PR #13901) for targeted +single-row updates. If a concurrent update changes the row_id, the Infinity +connector retries with a fresh row_id lookup. +""" +import logging +import math +import os +from typing import List, Tuple + +from common.constants import PAGERANK_FLD +from common import settings +from rag.nlp.search import index_name + + +# Feature flag - disabled by default to prevent unintended side effects +CHUNK_FEEDBACK_ENABLED = os.getenv("CHUNK_FEEDBACK_ENABLED", "false").lower() == "true" + +# relevance: fixed budget split by retrieval signals; uniform: delta per chunk +CHUNK_FEEDBACK_WEIGHTING = os.getenv("CHUNK_FEEDBACK_WEIGHTING", "relevance").strip().lower() + +# Integer units — matches pagerank_fea integer columns in doc stores +UPVOTE_WEIGHT_INCREMENT = 1 +DOWNVOTE_WEIGHT_DECREMENT = 1 +MIN_PAGERANK_WEIGHT = 0 +MAX_PAGERANK_WEIGHT = 100 + +_SCORE_KEYS = ("similarity", "vector_similarity", "term_similarity") + + +def _retrieval_signal(chunk: dict) -> float: + """Best available retrieval score for feedback allocation; 0 if none.""" + best = 0.0 + for key in _SCORE_KEYS: + raw = chunk.get(key) + if raw is None: + continue + try: + val = float(raw) + except (TypeError, ValueError): + continue + if math.isfinite(val) and val > best: + best = val + return best + + +def _split_integer_budget(magnitudes: List[float], budget: int) -> List[int]: + """Split nonnegative integer budget across positive magnitudes (largest remainder).""" + n = len(magnitudes) + if n == 0 or budget == 0: + return [0] * n + total = sum(magnitudes) + if total <= 0: + base = budget // n + rem = budget % n + out = [base] * n + for i in range(rem): + out[i] += 1 + return out + raw = [budget * m / total for m in magnitudes] + floors = [int(math.floor(r)) for r in raw] + remainder = budget - sum(floors) + order = sorted(range(n), key=lambda i: raw[i] - floors[i], reverse=True) + for j in range(remainder): + floors[order[j]] += 1 + return floors + + +def _allocate_deltas_uniform( + chunk_rows: List[Tuple[str, str]], + signed_budget: int, +) -> List[Tuple[str, str, int]]: + """Each row gets the full signed step (legacy: one unit per cited chunk).""" + step = UPVOTE_WEIGHT_INCREMENT if signed_budget > 0 else -DOWNVOTE_WEIGHT_DECREMENT + return [(cid, kb, step) for cid, kb in chunk_rows] + + +def _allocate_deltas_relevance( + chunk_rows: List[Tuple[str, str, dict]], + signed_budget: int, +) -> List[Tuple[str, str, int]]: + """ + Split |signed_budget| integer units across chunks using retrieval_signal weights. + chunk_rows: (chunk_id, kb_id, original_chunk_dict) + """ + if not chunk_rows: + return [] + + magnitudes = [] + for _cid, _kb, ch in chunk_rows: + s = _retrieval_signal(ch) + magnitudes.append(s if s > 0 else 1.0) + + total = sum(magnitudes) + if total <= 0: + magnitudes = [1.0] * len(chunk_rows) + + sign = 1 if signed_budget > 0 else -1 + budget_abs = abs(signed_budget) + parts = _split_integer_budget(magnitudes, budget_abs) + out: List[Tuple[str, str, int]] = [] + for (cid, kb, _ch), p in zip(chunk_rows, parts, strict=True): + out.append((cid, kb, sign * p)) + return out + + +class ChunkFeedbackService: + """Service to update chunk weights based on user feedback.""" + + @staticmethod + def _feedback_rows_from_reference(reference: dict) -> List[Tuple[str, str, dict]]: + """(chunk_id, kb_id, raw_chunk) for chunks that can be updated (single pass). + + raw_chunk is kept for retrieval-signal weighting and optional row_id. + """ + if not reference: + return [] + rows: List[Tuple[str, str, dict]] = [] + for chunk in reference.get("chunks", []): + chunk_id = chunk.get("id") or chunk.get("chunk_id") + kb_id = chunk.get("dataset_id") or chunk.get("kb_id") + if chunk_id and kb_id: + rows.append((chunk_id, kb_id, chunk)) + return rows + + @staticmethod + def update_chunk_weight( + tenant_id: str, + chunk_id: str, + kb_id: str, + delta: int, + row_id: int | None = None, + ) -> bool: + """ + Update the pagerank weight of a single chunk. + + Elasticsearch, OpenSearch, OceanBase/SeekDB, and Infinity use an + atomic adjust on the doc store when supported. Infinity passes + row_id (from retrieval results) for targeted single-row updates. + + Args: + tenant_id: The tenant ID for index naming + chunk_id: The chunk ID to update + kb_id: The knowledgebase ID + delta: Signed integer weight change (pagerank_fea is stored as int) + + Returns: + True if update succeeded, False otherwise + """ + try: + idx_name = index_name(tenant_id) + conn = settings.docStoreConn + adjust = getattr(conn, "adjust_chunk_pagerank_fea", None) + if callable(adjust): + kwargs: dict = {} + if row_id is not None: + kwargs["row_id"] = row_id + success = adjust( + chunk_id, + idx_name, + kb_id, + float(delta), + MIN_PAGERANK_WEIGHT, + MAX_PAGERANK_WEIGHT, + **kwargs, + ) + if success: + logging.info( + "Adjusted chunk %s pagerank by %s (atomic)", + chunk_id, + delta, + ) + else: + logging.warning("Failed atomic pagerank adjust for chunk %s", chunk_id) + return success + + chunk = conn.get(chunk_id, idx_name, [kb_id]) + if not chunk: + logging.warning("Chunk %s not found in index %s", chunk_id, idx_name) + return False + + current_weight = float(chunk.get(PAGERANK_FLD, 0) or 0) + new_weight = current_weight + float(delta) + new_weight = max(float(MIN_PAGERANK_WEIGHT), min(float(MAX_PAGERANK_WEIGHT), new_weight)) + + condition = {"id": chunk_id} + doc_engine = settings.DOC_ENGINE.lower() + if new_weight <= 0.0 and doc_engine in ("elasticsearch", "opensearch"): + new_value = {"remove": PAGERANK_FLD} + else: + new_value = {PAGERANK_FLD: new_weight} + + success = conn.update(condition, new_value, idx_name, kb_id) + + if success: + logging.info( + "Updated chunk %s pagerank: %s -> %s", + chunk_id, + current_weight, + new_weight, + ) + else: + logging.warning("Failed to update chunk %s pagerank", chunk_id) + + return success + + except Exception as e: + logging.exception("Error updating chunk %s weight: %s", chunk_id, e) + return False + + @classmethod + def apply_feedback( + cls, + tenant_id: str, + reference: dict, + is_positive: bool + ) -> dict: + """ + Apply user feedback to all chunks referenced in a response. + + Args: + tenant_id: The tenant ID + reference: The reference dict from the conversation message + is_positive: True for upvote (thumbup), False for downvote + + Returns: + Dict with 'success_count', 'fail_count', and 'chunk_ids' processed + """ + # Check if feature is enabled + if not CHUNK_FEEDBACK_ENABLED: + logging.debug("Chunk feedback feature is disabled") + return {"success_count": 0, "fail_count": 0, "chunk_ids": [], "disabled": True} + + rows = cls._feedback_rows_from_reference(reference) + chunk_ids = [r[0] for r in rows] + + if not chunk_ids: + logging.debug("No chunk IDs found in reference for feedback") + return {"success_count": 0, "fail_count": 0, "chunk_ids": []} + + signed_budget = ( + UPVOTE_WEIGHT_INCREMENT if is_positive else -DOWNVOTE_WEIGHT_DECREMENT + ) + weighting = CHUNK_FEEDBACK_WEIGHTING if CHUNK_FEEDBACK_WEIGHTING in ( + "uniform", + "relevance", + ) else "relevance" + + if weighting == "uniform": + deltas = _allocate_deltas_uniform([(r[0], r[1]) for r in rows], signed_budget) + else: + deltas = _allocate_deltas_relevance(rows, signed_budget) + + success_count = 0 + fail_count = 0 + + row_by_chunk = {r[0]: r[2].get("row_id") for r in rows} + for chunk_id, kb_id, delta in deltas: + if delta == 0: + continue + rid = row_by_chunk.get(chunk_id) + rid_int = None + if rid is not None: + try: + rid_int = int(rid) + except (TypeError, ValueError): + pass + if cls.update_chunk_weight(tenant_id, chunk_id, kb_id, delta, row_id=rid_int): + success_count += 1 + else: + fail_count += 1 + + logging.info( + "Applied %s feedback (%s) to %s/%s chunks", + "positive" if is_positive else "negative", + weighting, + success_count, + len(chunk_ids), + ) + + return { + "success_count": success_count, + "fail_count": fail_count, + "chunk_ids": chunk_ids + } diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 6896ce7a4..e363fe180 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -56,6 +56,7 @@ def chunks_format(reference): "similarity": chunk.get("similarity"), "vector_similarity": chunk.get("vector_similarity"), "term_similarity": chunk.get("term_similarity"), + "row_id": chunk.get("row_id"), "doc_type": get_value(chunk, "doc_type_kwd", "doc_type"), } for chunk in raw_chunks diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 6a3d35eec..fb7ce225f 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -31,6 +31,32 @@ ATTEMPT_TIME = 2 MAX_RESULT_WINDOW = 10000 SEARCH_AFTER_BATCH_SIZE = 1000 +# Single-document atomic pagerank_fea adjust (chunk feedback). Clamps using params.min_w / max_w; +# removes field at zero for rank_feature compatibility. +_PAGERANK_FEA_ADJUST_SCRIPT = """ +double cur = 0.0; +if (ctx._source.containsKey(params.pf)) { + Object v = ctx._source[params.pf]; + if (v != null) { + if (v instanceof Number) { + cur = ((Number)v).doubleValue(); + } else { + try { cur = Double.parseDouble(v.toString()); } catch (Exception e) { cur = 0.0; } + } + } +} +double nw = cur + params.delta; +if (nw < params.min_w) { nw = params.min_w; } +if (nw > params.max_w) { nw = params.max_w; } +if (nw <= 0.0) { + if (ctx._source.containsKey(params.pf)) { + ctx._source.remove(params.pf); + } +} else { + ctx._source[params.pf] = nw; +} +""" + @singleton class ESConnection(ESConnectionBase): @@ -303,7 +329,11 @@ class ESConnection(ESConnectionBase): # update specific single document chunk_id = condition["id"] for i in range(ATTEMPT_TIME): - for k in doc.keys(): + doc_part = copy.deepcopy(doc) + remove_value = doc_part.pop("remove", None) + remove_field = remove_value if isinstance(remove_value, str) else None + remove_dict = remove_value if isinstance(remove_value, dict) else None + for k in doc_part.keys(): if "feas" != k.split("_")[-1]: continue try: @@ -312,8 +342,32 @@ class ESConnection(ESConnectionBase): self.logger.exception( f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") try: - self.es.update(index=index_name, id=chunk_id, doc=doc) - return True + if remove_field is not None: + self.es.update( + index=index_name, + id=chunk_id, + script=f"ctx._source.remove('{remove_field}');", + ) + if remove_dict is not None: + scripts = [] + params = {} + for kk, vv in remove_dict.items(): + scripts.append( + f"if (ctx._source.containsKey('{kk}') && ctx._source.{kk} != null) " + f"{{ int i = ctx._source.{kk}.indexOf(params.p_{kk}); " + f"if (i >= 0) {{ ctx._source.{kk}.remove(i); }} }}" + ) + params[f"p_{kk}"] = vv + if scripts: + self.es.update( + index=index_name, + id=chunk_id, + script={"source": "".join(scripts), "params": params}, + ) + if doc_part: + self.es.update(index=index_name, id=chunk_id, doc=doc_part) + if remove_field is not None or remove_dict is not None or doc_part: + return True except Exception as e: self.logger.exception( f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str( @@ -389,6 +443,61 @@ class ESConnection(ESConnectionBase): break return False + def adjust_chunk_pagerank_fea( + self, + chunk_id: str, + index_name: str, + knowledgebase_id: str, + delta: float, + min_w: float = 0.0, + max_w: float = 100.0, + row_id: int | None = None, + ) -> bool: + """Atomically adjust pagerank_fea on one chunk (painless script).""" + _ = row_id + for _ in range(ATTEMPT_TIME): + try: + self.es.update( + index=index_name, + id=chunk_id, + retry_on_conflict=3, + script={ + "source": _PAGERANK_FEA_ADJUST_SCRIPT.strip(), + "lang": "painless", + "params": { + "pf": PAGERANK_FLD, + "delta": float(delta), + "min_w": float(min_w), + "max_w": float(max_w), + }, + }, + ) + self.logger.debug( + "ESConnection.adjust_chunk_pagerank_fea(index=%s, id=%s, delta=%s) succeeded", + index_name, + chunk_id, + delta, + ) + return True + except ConnectionTimeout: + self.logger.exception("ES request timeout") + time.sleep(3) + self._connect() + continue + except Exception as e: + self.logger.exception( + "ESConnection.adjust_chunk_pagerank_fea(index=%s, id=%s): %s", + index_name, + chunk_id, + e, + ) + if re.search(r"connection", str(e).lower()): + time.sleep(3) + self._connect() + continue + break + return False + def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int: assert "_id" not in condition condition["kb_id"] = knowledgebase_id diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index cdf6978e3..d68cd8800 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -597,6 +597,84 @@ class InfinityConnection(InfinityConnectionBase): self.connPool.release_conn(inf_conn) return True + def adjust_chunk_pagerank_fea( + self, + chunk_id: str, + index_name: str, + knowledgebase_id: str, + delta: int, + min_weight: int, + max_weight: int, + row_id: int | None = None, + max_retries: int = 2, + ) -> bool: + """Adjust pagerank_fea on one chunk row in Infinity. + + Uses row_id for a targeted update when available. If the row_id is + stale (concurrent update changed it), re-reads the current row_id and + retries up to *max_retries* times. + """ + table_name = f"{index_name}_{knowledgebase_id}" + for attempt in range(max_retries + 1): + inf_conn = self.connPool.get_conn() + try: + db_instance = inf_conn.get_database(self.dbName) + table_instance = db_instance.get_table(table_name) + + if row_id is None: + df, _ = table_instance.output( + [PAGERANK_FLD, "row_id()"] + ).filter(f"id = '{chunk_id}'").to_df() + if df.empty: + self.logger.warning( + "adjust_chunk_pagerank_fea: chunk %s not found in %s", + chunk_id, table_name, + ) + return False + current_weight = int(float(df[PAGERANK_FLD].iloc[0] or 0)) + row_id = int(df["row_id"].iloc[0]) + else: + df, _ = table_instance.output( + [PAGERANK_FLD] + ).filter(f"id = '{chunk_id}'").to_df() + if df.empty: + return False + current_weight = int(float(df[PAGERANK_FLD].iloc[0] or 0)) + + new_weight = max(min_weight, min(max_weight, current_weight + delta)) + + table_instance.update( + f"_row_id = {row_id}", + {PAGERANK_FLD: new_weight}, + ) + self.logger.info( + "adjust_chunk_pagerank_fea(chunk=%s, table=%s): %s -> %s via row_id=%s", + chunk_id, table_name, current_weight, new_weight, row_id, + ) + return True + + except InfinityException as e: + if attempt < max_retries: + self.logger.warning( + "adjust_chunk_pagerank_fea stale row_id=%s for chunk %s (attempt %s/%s): %s", + row_id, chunk_id, attempt + 1, max_retries, e, + ) + row_id = None + continue + self.logger.error( + "adjust_chunk_pagerank_fea failed for chunk %s after %s attempts: %s", + chunk_id, max_retries + 1, e, + ) + return False + except Exception as e: + self.logger.error( + "adjust_chunk_pagerank_fea error for chunk %s: %s", chunk_id, e, + ) + return False + finally: + self.connPool.release_conn(inf_conn) + return False + """ Helper functions for search result """ diff --git a/rag/utils/ob_conn.py b/rag/utils/ob_conn.py index 916425e7a..10e033400 100644 --- a/rag/utils/ob_conn.py +++ b/rag/utils/ob_conn.py @@ -1213,6 +1213,32 @@ class OBConnection(OBConnectionBase): logger.error(f"OBConnection.update error: {str(e)}") return False + def adjust_chunk_pagerank_fea( + self, + chunk_id: str, + index_name: str, + knowledgebase_id: str, + delta: int, + min_w: int = 0, + max_w: int = 100, + ) -> bool: + """Atomically adjust pagerank_fea on one chunk row (single UPDATE).""" + if not self._check_table_exists_cached(index_name): + return True + d = int(delta) + sql = ( + f"UPDATE {index_name} SET {PAGERANK_FLD} = " + f"GREATEST({int(min_w)}, LEAST({int(max_w)}, COALESCE({PAGERANK_FLD}, 0) + ({d}))) " + f"WHERE id = {get_value_str(chunk_id)} AND kb_id = {get_value_str(knowledgebase_id)}" + ) + logger.debug("OBConnection.adjust_chunk_pagerank_fea sql: %s", sql) + try: + self.client.perform_raw_text_sql(sql) + return True + except Exception as e: + logger.error("OBConnection.adjust_chunk_pagerank_fea error: %s", e) + return False + def _row_to_entity(self, data: Row, fields: list[str]) -> dict: entity = {} for i, field in enumerate(fields): diff --git a/rag/utils/opensearch_conn.py b/rag/utils/opensearch_conn.py index ad9799400..cb8b70ac2 100644 --- a/rag/utils/opensearch_conn.py +++ b/rag/utils/opensearch_conn.py @@ -34,6 +34,30 @@ from common import settings ATTEMPT_TIME = 2 +_PAGERANK_FEA_ADJUST_SCRIPT = """ +double cur = 0.0; +if (ctx._source.containsKey(params.pf)) { + Object v = ctx._source[params.pf]; + if (v != null) { + if (v instanceof Number) { + cur = ((Number)v).doubleValue(); + } else { + try { cur = Double.parseDouble(v.toString()); } catch (Exception e) { cur = 0.0; } + } + } +} +double nw = cur + params.delta; +if (nw < params.min_w) { nw = params.min_w; } +if (nw > params.max_w) { nw = params.max_w; } +if (nw <= 0.0) { + if (ctx._source.containsKey(params.pf)) { + ctx._source.remove(params.pf); + } +} else { + ctx._source[params.pf] = nw; +} +""" + logger = logging.getLogger('ragflow.opensearch_conn') @@ -329,9 +353,37 @@ class OSConnection(DocStoreConnection): # update specific single document chunkId = condition["id"] for i in range(ATTEMPT_TIME): + doc_part = copy.deepcopy(doc) + remove_value = doc_part.pop("remove", None) + remove_field = remove_value if isinstance(remove_value, str) else None + remove_dict = remove_value if isinstance(remove_value, dict) else None try: - self.os.update(index=indexName, id=chunkId, body={"doc": doc}) - return True + if remove_field is not None: + self.os.update( + index=indexName, + id=chunkId, + body={"script": {"source": f"ctx._source.remove('{remove_field}');"}}, + ) + if remove_dict is not None: + scripts = [] + params = {} + for kk, vv in remove_dict.items(): + scripts.append( + f"if (ctx._source.containsKey('{kk}') && ctx._source.{kk} != null) " + f"{{ int i = ctx._source.{kk}.indexOf(params.p_{kk}); " + f"if (i >= 0) {{ ctx._source.{kk}.remove(i); }} }}" + ) + params[f"p_{kk}"] = vv + if scripts: + self.os.update( + index=indexName, + id=chunkId, + body={"script": {"source": "".join(scripts), "params": params}}, + ) + if doc_part: + self.os.update(index=indexName, id=chunkId, body={"doc": doc_part}) + if remove_field is not None or remove_dict is not None or doc_part: + return True except Exception as e: logger.exception( f"OSConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") @@ -405,6 +457,52 @@ class OSConnection(DocStoreConnection): break return False + def adjust_chunk_pagerank_fea( + self, + chunk_id: str, + indexName: str, + knowledgebaseId: str, + delta: float, + min_w: float = 0.0, + max_w: float = 100.0, + row_id: int | None = None, + ) -> bool: + """Atomically adjust pagerank_fea on one chunk (painless script).""" + _ = row_id + try: + self.os.update( + index=indexName, + id=chunk_id, + retry_on_conflict=3, + body={ + "script": { + "source": _PAGERANK_FEA_ADJUST_SCRIPT.strip(), + "lang": "painless", + "params": { + "pf": PAGERANK_FLD, + "delta": float(delta), + "min_w": float(min_w), + "max_w": float(max_w), + }, + } + }, + ) + logger.debug( + "OSConnection.adjust_chunk_pagerank_fea(index=%s, id=%s, delta=%s) succeeded", + indexName, + chunk_id, + delta, + ) + return True + except Exception as e: + logger.exception( + "OSConnection.adjust_chunk_pagerank_fea(index=%s, id=%s): %s", + indexName, + chunk_id, + e, + ) + return False + def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: assert "_id" not in condition condition["kb_id"] = knowledgebaseId diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py index f0851f4a2..359aa6159 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py @@ -346,10 +346,21 @@ def _load_chat_module(monkeypatch): def query(**_kwargs): return [] + user_service_mod.UserService = type("UserService", (), {}) user_service_mod.TenantService = _StubTenantService user_service_mod.UserTenantService = _StubUserTenantService monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + chunk_feedback_service_mod = ModuleType("api.db.services.chunk_feedback_service") + + class _StubChunkFeedbackService: + @staticmethod + def apply_feedback(**_kwargs): + return {"success_count": 0, "fail_count": 0, "chunk_ids": []} + + chunk_feedback_service_mod.ChunkFeedbackService = _StubChunkFeedbackService + monkeypatch.setitem(sys.modules, "api.db.services.chunk_feedback_service", chunk_feedback_service_mod) + api_utils_mod = ModuleType("api.utils.api_utils") def _check_duplicate_ids(ids, label): diff --git a/test/testcases/test_web_api/test_chunk_feedback/__init__.py b/test/testcases/test_web_api/test_chunk_feedback/__init__.py new file mode 100644 index 000000000..177b91dd0 --- /dev/null +++ b/test/testcases/test_web_api/test_chunk_feedback/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/test/testcases/test_web_api/test_chunk_feedback/test_chunk_feedback_service.py b/test/testcases/test_web_api/test_chunk_feedback/test_chunk_feedback_service.py new file mode 100644 index 000000000..6166f0047 --- /dev/null +++ b/test/testcases/test_web_api/test_chunk_feedback/test_chunk_feedback_service.py @@ -0,0 +1,584 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Tests for ChunkFeedbackService - adjusting chunk weights based on user feedback. + +Uses importlib to load chunk_feedback_service.py in isolation so that +test/testcases/test_web_api/common.py (a test-helper module) does not shadow +the project-level common/ package during collection. +""" +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +pytestmark = pytest.mark.p2 + +_REPO_ROOT = Path(__file__).resolve().parents[4] + + +def _load_feedback_module(monkeypatch): + """Load chunk_feedback_service.py with lightweight stubs for its deps.""" + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(_REPO_ROOT / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + constants_mod = ModuleType("common.constants") + constants_mod.PAGERANK_FLD = "pagerank_fea" + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + settings_mod = ModuleType("common.settings") + settings_mod.docStoreConn = MagicMock() + # Non-ES engines accept pagerank_fea=0; tests below override for elasticsearch/opensearch. + settings_mod.DOC_ENGINE = "infinity" + monkeypatch.setitem(sys.modules, "common.settings", settings_mod) + common_pkg.settings = settings_mod + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + + rag_nlp_pkg = ModuleType("rag.nlp") + rag_nlp_pkg.__path__ = [] + rag_nlp_pkg.search = SimpleNamespace(index_name=lambda tid: f"idx-{tid}") + monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp_pkg) + + rag_nlp_search_mod = ModuleType("rag.nlp.search") + rag_nlp_search_mod.index_name = lambda tid: f"idx-{tid}" + monkeypatch.setitem(sys.modules, "rag.nlp.search", rag_nlp_search_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + spec = importlib.util.spec_from_file_location( + "api.db.services.chunk_feedback_service", + _REPO_ROOT / "api" / "db" / "services" / "chunk_feedback_service.py", + ) + mod = importlib.util.module_from_spec(spec) + monkeypatch.setitem( + sys.modules, "api.db.services.chunk_feedback_service", mod + ) + spec.loader.exec_module(mod) + + return mod, settings_mod + + +@pytest.fixture +def feedback_env(monkeypatch): + """Provide (module, settings_stub) for chunk feedback tests.""" + return _load_feedback_module(monkeypatch) + + +class TestFeedbackRowsFromReference: + """Chunk id + kb resolution via _feedback_rows_from_reference (single pass).""" + + def test_empty_reference(self, feedback_env): + mod, _ = feedback_env + assert mod.ChunkFeedbackService._feedback_rows_from_reference({}) == [] + assert mod.ChunkFeedbackService._feedback_rows_from_reference(None) == [] + + def test_reference_with_id_and_dataset(self, feedback_env): + mod, _ = feedback_env + reference = { + "chunks": [ + {"id": "chunk1", "content": "test", "dataset_id": "kb1"}, + {"id": "chunk2", "content": "test2", "dataset_id": "kb1"}, + ] + } + rows = mod.ChunkFeedbackService._feedback_rows_from_reference(reference) + assert [r[0] for r in rows] == ["chunk1", "chunk2"] + + def test_reference_with_chunk_id_and_kb_id(self, feedback_env): + mod, _ = feedback_env + reference = { + "chunks": [ + {"chunk_id": "chunk1", "content": "test", "kb_id": "kb1"}, + {"chunk_id": "chunk2", "content": "test2", "kb_id": "kb1"}, + ] + } + rows = mod.ChunkFeedbackService._feedback_rows_from_reference(reference) + assert [r[0] for r in rows] == ["chunk1", "chunk2"] + + def test_reference_skips_chunks_without_kb(self, feedback_env): + mod, _ = feedback_env + reference = { + "chunks": [ + {"id": "chunk1", "dataset_id": "kb1"}, + {"id": "chunk2", "content": "no kb"}, + ] + } + rows = mod.ChunkFeedbackService._feedback_rows_from_reference(reference) + assert [r[0] for r in rows] == ["chunk1"] + + def test_reference_with_no_chunks(self, feedback_env): + mod, _ = feedback_env + reference = {"doc_aggs": [{"doc_id": "doc1"}]} + assert mod.ChunkFeedbackService._feedback_rows_from_reference(reference) == [] + + def test_chunk_id_to_kb_map_matches_row_pairs(self, feedback_env): + mod, _ = feedback_env + reference = { + "chunks": [ + {"id": "a", "dataset_id": "kb1"}, + {"chunk_id": "b", "kb_id": "kb2"}, + ] + } + rows = mod.ChunkFeedbackService._feedback_rows_from_reference(reference) + assert {r[0]: r[1] for r in rows} == {"a": "kb1", "b": "kb2"} + + +class TestUpdateChunkWeight: + """Tests for update_chunk_weight method.""" + + def test_update_weight_success(self, feedback_env): + """Should update chunk weight successfully.""" + mod, settings_mod = feedback_env + settings_mod.DOC_ENGINE = "mysql" + mock_doc_store = MagicMock() + mock_doc_store.adjust_chunk_pagerank_fea = None + mock_doc_store.get.return_value = {"pagerank_fea": 10} + mock_doc_store.update.return_value = True + settings_mod.docStoreConn = mock_doc_store + + result = mod.ChunkFeedbackService.update_chunk_weight( + tenant_id="tenant1", + chunk_id="chunk1", + kb_id="kb1", + delta=1 + ) + + assert result is True + mock_doc_store.update.assert_called_once() + + def test_update_weight_chunk_not_found(self, feedback_env): + """Should return False if chunk not found.""" + mod, settings_mod = feedback_env + settings_mod.DOC_ENGINE = "mysql" + mock_doc_store = MagicMock() + mock_doc_store.adjust_chunk_pagerank_fea = None + mock_doc_store.get.return_value = None + settings_mod.docStoreConn = mock_doc_store + + result = mod.ChunkFeedbackService.update_chunk_weight( + tenant_id="tenant1", + chunk_id="chunk1", + kb_id="kb1", + delta=1 + ) + + assert result is False + + def test_update_weight_clamp_max(self, feedback_env): + """Should clamp weight to MAX_PAGERANK_WEIGHT.""" + mod, settings_mod = feedback_env + settings_mod.DOC_ENGINE = "mysql" + mock_doc_store = MagicMock() + mock_doc_store.adjust_chunk_pagerank_fea = None + mock_doc_store.get.return_value = {"pagerank_fea": mod.MAX_PAGERANK_WEIGHT} + mock_doc_store.update.return_value = True + settings_mod.docStoreConn = mock_doc_store + + mod.ChunkFeedbackService.update_chunk_weight( + tenant_id="tenant1", + chunk_id="chunk1", + kb_id="kb1", + delta=10 # Would exceed max + ) + + # Verify the new_value passed to update has clamped weight + call_args = mock_doc_store.update.call_args + new_value = call_args[0][1] + assert new_value["pagerank_fea"] == mod.MAX_PAGERANK_WEIGHT + + def test_update_weight_clamp_min(self, feedback_env): + """Should clamp weight to MIN_PAGERANK_WEIGHT.""" + mod, settings_mod = feedback_env + settings_mod.DOC_ENGINE = "mysql" + mock_doc_store = MagicMock() + mock_doc_store.adjust_chunk_pagerank_fea = None + mock_doc_store.get.return_value = {"pagerank_fea": 0} + mock_doc_store.update.return_value = True + settings_mod.docStoreConn = mock_doc_store + + mod.ChunkFeedbackService.update_chunk_weight( + tenant_id="tenant1", + chunk_id="chunk1", + kb_id="kb1", + delta=-10 # Would go below min + ) + + call_args = mock_doc_store.update.call_args + new_value = call_args[0][1] + assert new_value["pagerank_fea"] == mod.MIN_PAGERANK_WEIGHT + + def test_update_weight_elasticsearch_uses_atomic_adjust(self, feedback_env): + """Elasticsearch uses script-based adjust (rank_feature zero handled in script).""" + mod, settings_mod = feedback_env + settings_mod.DOC_ENGINE = "elasticsearch" + mock_doc_store = MagicMock() + mock_adjust = MagicMock(return_value=True) + mock_doc_store.adjust_chunk_pagerank_fea = mock_adjust + settings_mod.docStoreConn = mock_doc_store + + assert mod.ChunkFeedbackService.update_chunk_weight( + tenant_id="tenant1", + chunk_id="chunk1", + kb_id="kb1", + delta=-1, + ) + mock_adjust.assert_called_once_with( + "chunk1", + "idx-tenant1", + "kb1", + -1, + mod.MIN_PAGERANK_WEIGHT, + mod.MAX_PAGERANK_WEIGHT, + ) + + def test_update_weight_elasticsearch_forwards_row_id(self, feedback_env): + """Elasticsearch adjust accepts and forwards row_id without TypeError.""" + mod, settings_mod = feedback_env + settings_mod.DOC_ENGINE = "elasticsearch" + mock_doc_store = MagicMock() + mock_adjust = MagicMock(return_value=True) + mock_doc_store.adjust_chunk_pagerank_fea = mock_adjust + settings_mod.docStoreConn = mock_doc_store + + assert mod.ChunkFeedbackService.update_chunk_weight( + tenant_id="tenant1", + chunk_id="chunk1", + kb_id="kb1", + delta=-1, + row_id=42, + ) + mock_adjust.assert_called_once_with( + "chunk1", + "idx-tenant1", + "kb1", + -1, + mod.MIN_PAGERANK_WEIGHT, + mod.MAX_PAGERANK_WEIGHT, + row_id=42, + ) + + def test_update_weight_opensearch_uses_atomic_adjust(self, feedback_env): + mod, settings_mod = feedback_env + settings_mod.DOC_ENGINE = "opensearch" + mock_doc_store = MagicMock() + mock_adjust = MagicMock(return_value=True) + mock_doc_store.adjust_chunk_pagerank_fea = mock_adjust + settings_mod.docStoreConn = mock_doc_store + + mod.ChunkFeedbackService.update_chunk_weight( + tenant_id="tenant1", + chunk_id="chunk1", + kb_id="kb1", + delta=-2, + ) + mock_adjust.assert_called_once_with( + "chunk1", + "idx-tenant1", + "kb1", + -2, + mod.MIN_PAGERANK_WEIGHT, + mod.MAX_PAGERANK_WEIGHT, + ) + + def test_update_weight_opensearch_forwards_row_id(self, feedback_env): + """OpenSearch adjust accepts and forwards row_id without TypeError.""" + mod, settings_mod = feedback_env + settings_mod.DOC_ENGINE = "opensearch" + mock_doc_store = MagicMock() + mock_adjust = MagicMock(return_value=True) + mock_doc_store.adjust_chunk_pagerank_fea = mock_adjust + settings_mod.docStoreConn = mock_doc_store + + mod.ChunkFeedbackService.update_chunk_weight( + tenant_id="tenant1", + chunk_id="chunk1", + kb_id="kb1", + delta=-2, + row_id=77, + ) + mock_adjust.assert_called_once_with( + "chunk1", + "idx-tenant1", + "kb1", + -2, + mod.MIN_PAGERANK_WEIGHT, + mod.MAX_PAGERANK_WEIGHT, + row_id=77, + ) + + def test_update_weight_infinity_uses_adjust_with_row_id(self, feedback_env): + """Infinity path passes row_id to adjust_chunk_pagerank_fea.""" + mod, settings_mod = feedback_env + settings_mod.DOC_ENGINE = "infinity" + mock_doc_store = MagicMock() + mock_adjust = MagicMock(return_value=True) + mock_doc_store.adjust_chunk_pagerank_fea = mock_adjust + settings_mod.docStoreConn = mock_doc_store + + ok = mod.ChunkFeedbackService.update_chunk_weight( + tenant_id="tenant1", + chunk_id="chunk1", + kb_id="kb1", + delta=1, + row_id=42, + ) + assert ok is True + mock_adjust.assert_called_once_with( + "chunk1", + "idx-tenant1", + "kb1", + 1, + mod.MIN_PAGERANK_WEIGHT, + mod.MAX_PAGERANK_WEIGHT, + row_id=42, + ) + + +class TestApplyFeedback: + """Tests for apply_feedback method.""" + + def test_apply_feedback_disabled(self, feedback_env, monkeypatch): + """Should return early when feature is disabled.""" + mod, _ = feedback_env + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", False) + + result = mod.ChunkFeedbackService.apply_feedback( + tenant_id="tenant1", + reference={"chunks": [{"id": "chunk1", "dataset_id": "kb1"}]}, + is_positive=True + ) + + assert result["success_count"] == 0 + assert result["fail_count"] == 0 + assert result.get("disabled") is True + + def test_apply_positive_feedback(self, feedback_env, monkeypatch): + """Relevance mode splits the per-event budget across chunks (equal when no scores).""" + mod, _ = feedback_env + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) + mock_update = MagicMock(return_value=True) + monkeypatch.setattr( + mod.ChunkFeedbackService, "update_chunk_weight", mock_update + ) + + reference = { + "chunks": [ + {"id": "chunk1", "dataset_id": "kb1"}, + {"id": "chunk2", "dataset_id": "kb1"}, + ] + } + result = mod.ChunkFeedbackService.apply_feedback( + tenant_id="tenant1", + reference=reference, + is_positive=True + ) + + assert result["success_count"] == 1 + assert result["fail_count"] == 0 + assert mock_update.call_count == 1 + mock_update.assert_called_once_with("tenant1", "chunk1", "kb1", 1, row_id=None) + + def test_apply_negative_feedback(self, feedback_env, monkeypatch): + """Should apply negative feedback with full budget when only one chunk.""" + mod, _ = feedback_env + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) + mock_update = MagicMock(return_value=True) + monkeypatch.setattr( + mod.ChunkFeedbackService, "update_chunk_weight", mock_update + ) + + reference = {"chunks": [{"id": "chunk1", "dataset_id": "kb1"}]} + result = mod.ChunkFeedbackService.apply_feedback( + tenant_id="tenant1", + reference=reference, + is_positive=False + ) + + assert result["success_count"] == 1 + mock_update.assert_called_with("tenant1", "chunk1", "kb1", -1, row_id=None) + + def test_apply_feedback_no_chunks(self, feedback_env, monkeypatch): + """Should handle empty chunk list gracefully.""" + mod, _ = feedback_env + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) + + result = mod.ChunkFeedbackService.apply_feedback( + tenant_id="tenant1", + reference={}, + is_positive=True + ) + + assert result["success_count"] == 0 + assert result["fail_count"] == 0 + assert result["chunk_ids"] == [] + + def test_apply_feedback_partial_failure(self, feedback_env, monkeypatch): + """Should count failures correctly (uniform gives each chunk a unit).""" + mod, _ = feedback_env + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_WEIGHTING", "uniform") + mock_update = MagicMock(side_effect=[True, False]) + monkeypatch.setattr( + mod.ChunkFeedbackService, "update_chunk_weight", mock_update + ) + + reference = { + "chunks": [ + {"id": "chunk1", "dataset_id": "kb1"}, + {"id": "chunk2", "dataset_id": "kb1"}, + ] + } + result = mod.ChunkFeedbackService.apply_feedback( + tenant_id="tenant1", + reference=reference, + is_positive=True + ) + + assert result["success_count"] == 1 + assert result["fail_count"] == 1 + + def test_apply_positive_feedback_uniform_mode(self, feedback_env, monkeypatch): + """uniform: each cited chunk gets the full increment (legacy).""" + mod, _ = feedback_env + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_WEIGHTING", "uniform") + mock_update = MagicMock(return_value=True) + monkeypatch.setattr( + mod.ChunkFeedbackService, "update_chunk_weight", mock_update + ) + reference = { + "chunks": [ + {"id": "chunk1", "dataset_id": "kb1"}, + {"id": "chunk2", "dataset_id": "kb1"}, + ] + } + mod.ChunkFeedbackService.apply_feedback( + tenant_id="tenant1", reference=reference, is_positive=True + ) + mock_update.assert_any_call("tenant1", "chunk1", "kb1", mod.UPVOTE_WEIGHT_INCREMENT, row_id=None) + mock_update.assert_any_call("tenant1", "chunk2", "kb1", mod.UPVOTE_WEIGHT_INCREMENT, row_id=None) + + def test_apply_positive_feedback_relevance_weighted(self, feedback_env, monkeypatch): + """Higher retrieval similarity receives a larger share of the budget.""" + mod, _ = feedback_env + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_WEIGHTING", "relevance") + mock_update = MagicMock(return_value=True) + monkeypatch.setattr( + mod.ChunkFeedbackService, "update_chunk_weight", mock_update + ) + reference = { + "chunks": [ + {"id": "a", "dataset_id": "kb1", "similarity": 0.9}, + {"id": "b", "dataset_id": "kb1", "similarity": 0.1}, + ] + } + mod.ChunkFeedbackService.apply_feedback( + tenant_id="tenant1", reference=reference, is_positive=True + ) + mock_update.assert_called_once_with("tenant1", "a", "kb1", 1, row_id=None) + + def test_apply_feedback_passes_row_id_from_reference(self, feedback_env, monkeypatch): + """row_id from retrieval results flows through to update_chunk_weight.""" + mod, _ = feedback_env + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_WEIGHTING", "relevance") + mock_update = MagicMock(return_value=True) + monkeypatch.setattr( + mod.ChunkFeedbackService, "update_chunk_weight", mock_update + ) + reference = { + "chunks": [ + {"id": "c1", "dataset_id": "kb1", "similarity": 0.8, "row_id": 99}, + ] + } + mod.ChunkFeedbackService.apply_feedback( + tenant_id="tenant1", reference=reference, is_positive=True + ) + mock_update.assert_called_once_with("tenant1", "c1", "kb1", 1, row_id=99) + + +class TestThumbFlipFeedback: + """Verify that toggling thumbup↔thumbdown applies undo + new (two calls).""" + + @staticmethod + def _simulate_feedback(mod, monkeypatch, reference, prior_thumb, new_thumb): + """Reproduce the chat_api thumb-flip logic in isolation.""" + monkeypatch.setattr(mod, "CHUNK_FEEDBACK_ENABLED", True) + mock_update = MagicMock(return_value=True) + monkeypatch.setattr(mod.ChunkFeedbackService, "update_chunk_weight", mock_update) + + calls = [] + + apply_chunk_feedback = False + if new_thumb is True: + apply_chunk_feedback = prior_thumb is not True + else: + apply_chunk_feedback = prior_thumb is not False + + if apply_chunk_feedback and reference: + if isinstance(prior_thumb, bool) and prior_thumb != new_thumb: + r = mod.ChunkFeedbackService.apply_feedback( + tenant_id="t1", reference=reference, is_positive=not prior_thumb, + ) + calls.append(("undo", r)) + r = mod.ChunkFeedbackService.apply_feedback( + tenant_id="t1", reference=reference, is_positive=new_thumb is True, + ) + calls.append(("new", r)) + + return calls, mock_update + + def test_toggle_thumbup_to_thumbdown(self, feedback_env, monkeypatch): + """thumbup→thumbdown: undo (+1→-1) then apply new (-1). Two calls.""" + mod, _ = feedback_env + ref = {"chunks": [{"id": "c1", "dataset_id": "kb1"}]} + calls, mock = self._simulate_feedback(mod, monkeypatch, ref, True, False) + assert len(calls) == 2 + assert calls[0][0] == "undo" + assert calls[1][0] == "new" + + def test_toggle_thumbdown_to_thumbup(self, feedback_env, monkeypatch): + """thumbdown→thumbup: undo (-1→+1) then apply new (+1). Two calls.""" + mod, _ = feedback_env + ref = {"chunks": [{"id": "c1", "dataset_id": "kb1"}]} + calls, mock = self._simulate_feedback(mod, monkeypatch, ref, False, True) + assert len(calls) == 2 + assert calls[0][0] == "undo" + assert calls[1][0] == "new" + + def test_no_prior_to_thumbup(self, feedback_env, monkeypatch): + """None→thumbup: single apply, no undo.""" + mod, _ = feedback_env + ref = {"chunks": [{"id": "c1", "dataset_id": "kb1"}]} + calls, mock = self._simulate_feedback(mod, monkeypatch, ref, None, True) + assert len(calls) == 1 + assert calls[0][0] == "new" + + def test_same_thumb_no_op(self, feedback_env, monkeypatch): + """thumbup→thumbup: no feedback at all (apply_chunk_feedback is False).""" + mod, _ = feedback_env + ref = {"chunks": [{"id": "c1", "dataset_id": "kb1"}]} + calls, mock = self._simulate_feedback(mod, monkeypatch, ref, True, True) + assert len(calls) == 0