From 2717ee283f3485878f8a56ecfc4410f62fa29246 Mon Sep 17 00:00:00 2001 From: CaptainTimon <279704422+CaptainTimon@users.noreply.github.com> Date: Mon, 11 May 2026 15:42:31 -1000 Subject: [PATCH] feat(raptor): add Psi tree builder with original-space ranking and safe migration (#14679) ### What problem does this PR solve? Closes #14674. This PR improves RAPTOR configuration and tree construction while preserving the existing RAPTOR behavior as the default. RAPTOR currently builds summary layers with the original UMAP + GMM clustering path. This PR keeps that default path, and adds: - A hidden backend tree-builder option: - `tree_builder="raptor"`: default, existing RAPTOR behavior. - `tree_builder="psi"`: rank-aware Psi-style tree builder using original embedding-space cosine ranking. - A user-facing clustering method option for the default RAPTOR builder: - `clustering_method="gmm"`: existing default. - `clustering_method="ahc"`: agglomerative hierarchical clustering path. - A RAPTOR UI setting for `Clustering method` and `Max cluster`. ### What changed #### Backend - Added `tree_builder` support for RAPTOR/Psi. - Added `clustering_method` support for GMM/AHC. - Kept existing RAPTOR + GMM as the default. - Added Psi tree building from original-space cosine similarity. - Added bucketed Psi building controls for large inputs: - `raptor.ext.psi_exact_max_leaves` - `raptor.ext.psi_bucket_size` - Added method-aware RAPTOR summary metadata using existing `extra.raptor_method`. - Avoided adding a dedicated DB schema field for experimental method tracking. - Added cleanup/migration logic to avoid mixing stale RAPTOR summary trees. - Added defensive checks for Psi tree construction and summary failures. #### Frontend/UI - Added `Clustering method` in RAPTOR settings with `GMM` and `AHC`. - Added/kept `Max cluster` in RAPTOR settings. - Enlarged max cluster UI limit to `1024`, matching backend validation. - Kept AHC editable even when a RAPTOR task has already finished. - Fixed the UI save payload so `clustering_method` and `tree_builder` are serialized through `parser_config.raptor.ext`, avoiding backend validation errors for extra top-level RAPTOR fields. Example saved RAPTOR config: ```json { "raptor": { "max_cluster": 317, "ext": { "clustering_method": "ahc", "tree_builder": "raptor" } } } Co-authored-by: CaptainTimon --- api/utils/validation_utils.py | 48 +- rag/raptor.py | 637 ++++++++++++++++-- rag/svr/task_executor.py | 295 ++++++-- rag/utils/ob_conn.py | 13 +- rag/utils/raptor_utils.py | 96 +++ .../test_update_dataset.py | 16 + .../rag/test_raptor_psi_tree_builder.py | 375 +++++++++++ test/unit_test/rag/utils/test_raptor_utils.py | 127 +++- .../components/chunk-method-dialog/index.tsx | 31 +- .../use-default-parser-values.ts | 19 +- .../raptor-form-fields.tsx | 95 ++- web/src/components/ui/radio.tsx | 9 +- web/src/hooks/parser-config-utils.ts | 14 +- .../hooks/tests/parser-config-utils.test.ts | 45 ++ web/src/interfaces/database/dataset.ts | 2 + web/src/interfaces/request/document.ts | 15 +- web/src/locales/en.ts | 5 + web/src/locales/zh.ts | 5 + .../dataset/dataset-setting/form-schema.ts | 11 +- .../pages/dataset/dataset-setting/index.tsx | 2 + .../dataset/use-change-document-parser.ts | 2 +- 21 files changed, 1722 insertions(+), 140 deletions(-) create mode 100644 test/unit_test/rag/test_raptor_psi_tree_builder.py create mode 100644 web/src/hooks/tests/parser-config-utils.test.ts diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 7a8a63939..1e6c0056b 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -327,10 +327,14 @@ def validate_uuid1_hex(v: Any) -> str: class Base(BaseModel): + """Strict base model that rejects unknown request fields.""" + model_config = ConfigDict(extra="forbid", strict=True) class RaptorConfig(Base): + """Dataset parser configuration for RAPTOR summary generation.""" + use_raptor: Annotated[bool, Field(default=False)] prompt: Annotated[ str, @@ -344,11 +348,15 @@ class RaptorConfig(Base): max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)] random_seed: Annotated[int, Field(default=0, ge=0)] scope: Annotated[Literal["file", "dataset"], Field(default="file")] + clustering_method: Annotated[Literal["gmm", "ahc"], Field(default="gmm")] + tree_builder: Annotated[Literal["raptor", "psi"], Field(default="raptor")] auto_disable_for_structured_data: Annotated[bool, Field(default=True)] ext: Annotated[dict, Field(default={})] class GraphragConfig(Base): + """Dataset parser configuration for GraphRAG generation.""" + use_graphrag: Annotated[bool, Field(default=False)] entity_types: Annotated[list[str], Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])] method: Annotated[Literal["light", "general", "ner"], Field(default="light")] @@ -357,6 +365,8 @@ class GraphragConfig(Base): class ParentChildConfig(Base): + """Dataset parser configuration for parent-child chunking.""" + use_parent_child: Annotated[bool, Field(default=False)] children_delimiter: Annotated[str, Field(default=r"\n", min_length=1)] @@ -381,6 +391,8 @@ TableColumnRole = Literal["indexing", "metadata", "both"] class ParserConfig(Base): + """Complete parser configuration accepted by dataset APIs.""" + auto_keywords: Annotated[int, Field(default=0, ge=0, le=32)] auto_questions: Annotated[int, Field(default=0, ge=0, le=10)] chunk_token_num: Annotated[int, Field(default=512, ge=1, le=2048)] @@ -439,6 +451,7 @@ class UpdateDocumentReq(Base): @field_validator("chunk_method", mode="after") @classmethod def validate_document_chunk_method(cls, chunk_method: str | None): + """Validate an optional document parser method.""" if chunk_method: # Validate chunk method if present valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "knowledge_graph", "email", "tag"} @@ -450,6 +463,7 @@ class UpdateDocumentReq(Base): @field_validator("enabled", mode="after") @classmethod def validate_document_enabled(cls, enabled: str | None): + """Validate the optional enabled flag.""" if enabled: converted = int(enabled) if converted < 0 or converted > 1: @@ -460,6 +474,7 @@ class UpdateDocumentReq(Base): @field_validator("meta_fields", mode="after") @classmethod def validate_document_meta_fields(cls, meta_fields: dict | None): + """Validate user-provided document metadata values.""" if meta_fields is None: return None @@ -475,6 +490,8 @@ class UpdateDocumentReq(Base): class CreateDatasetReq(Base): + """Request model for creating a dataset.""" + name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] avatar: Annotated[str | None, Field(default=None, max_length=65535)] description: Annotated[str | None, Field(default=None, max_length=65535)] @@ -490,6 +507,7 @@ class CreateDatasetReq(Base): @field_validator("pipeline_id", mode="before") @classmethod def handle_pipeline_id(cls, v: str | None, info: ValidationInfo): + """Drop pipeline_id when parse_type selects direct parser mode.""" if v is None: return v if info.data.get("parse_type", 0) == 1: @@ -743,6 +761,8 @@ class CreateDatasetReq(Base): class UpdateDatasetReq(CreateDatasetReq): + """Request model for updating a dataset.""" + dataset_id: Annotated[str, Field(...)] name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] pagerank: Annotated[int, Field(default=0, ge=0, le=100)] @@ -752,10 +772,13 @@ class UpdateDatasetReq(CreateDatasetReq): @field_validator("dataset_id", mode="before") @classmethod def validate_dataset_id(cls, v: Any) -> str: + """Validate and normalize the dataset id.""" return validate_uuid1_hex(v) class DeleteReq(Base): + """Base request model for batch delete APIs.""" + ids: Annotated[list[str] | None, Field(default=None)] delete_all: Annotated[bool, Field(default=False)] @@ -833,10 +856,15 @@ class DeleteReq(Base): return ids_list -class DeleteDatasetReq(DeleteReq): ... +class DeleteDatasetReq(DeleteReq): + """Request model for deleting datasets.""" + + ... class DeleteDocumentReq(DeleteReq): + """Request model for deleting documents.""" + @field_validator("ids", mode="after") @classmethod def validate_ids(cls, v_list: list[str] | None) -> list[str] | None: @@ -862,6 +890,8 @@ class DeleteDocumentReq(DeleteReq): class SearchDatasetReq(BaseModel): + """Request model for searching one dataset.""" + model_config = ConfigDict(extra="ignore") question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)] @@ -881,6 +911,8 @@ class SearchDatasetReq(BaseModel): class SearchDatasetsReq(BaseModel): + """Request model for searching multiple datasets.""" + model_config = ConfigDict(extra="ignore") dataset_ids: Annotated[list[str], Field(..., min_length=1)] @@ -901,6 +933,8 @@ class SearchDatasetsReq(BaseModel): class BaseListReq(BaseModel): + """Shared pagination and sorting fields for list APIs.""" + model_config = ConfigDict(extra="forbid") id: Annotated[str | None, Field(default=None)] @@ -913,10 +947,13 @@ class BaseListReq(BaseModel): @field_validator("id", mode="before") @classmethod def validate_id(cls, v: Any) -> str: + """Validate and normalize an optional list filter id.""" return validate_uuid1_hex(v) class ListDatasetReq(BaseListReq): + """Request model for listing datasets.""" + include_parsing_status: Annotated[bool, Field(default=False)] ext: Annotated[dict, Field(default={})] @@ -925,22 +962,29 @@ class ListDatasetReq(BaseListReq): class CreateFolderReq(Base): + """Request model for creating a folder.""" + name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(...)] parent_id: Annotated[str | None, Field(default=None)] type: Annotated[str | None, Field(default=None)] class DeleteFileReq(Base): + """Request model for deleting files.""" + ids: Annotated[list[str], Field(min_length=1)] class MoveFileReq(Base): + """Request model for moving or renaming files.""" + src_file_ids: Annotated[list[str], Field(min_length=1)] dest_file_id: Annotated[str | None, Field(default=None)] new_name: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(default=None)] @model_validator(mode="after") def check_operation(self): + """Require either a destination folder or a new file name.""" if not self.dest_file_id and not self.new_name: raise ValueError("At least one of dest_file_id or new_name must be provided") if self.new_name and len(self.src_file_ids) > 1: @@ -949,6 +993,8 @@ class MoveFileReq(Base): class ListFileReq(BaseModel): + """Request model for listing files.""" + model_config = ConfigDict(extra="forbid") parent_id: Annotated[str | None, Field(default=None)] diff --git a/rag/raptor.py b/rag/raptor.py index e4017319b..a7f2c782d 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -14,11 +14,13 @@ # limitations under the License. # import asyncio +from dataclasses import dataclass, field import logging import re import numpy as np import umap +from sklearn.cluster import AgglomerativeClustering from sklearn.mixture import GaussianMixture from api.db.services.task_service import has_canceled @@ -33,9 +35,127 @@ from rag.graphrag.utils import ( set_llm_cache, ) from common.misc_utils import thread_pool_exec +from rag.utils.raptor_utils import ( + AHC_CLUSTERING_METHOD, + GMM_CLUSTERING_METHOD, + PSI_TREE_BUILDER, + RAPTOR_TREE_BUILDER, + SUPPORTED_CLUSTERING_METHODS, + SUPPORTED_TREE_BUILDERS, +) + + +@dataclass +class _PsiTreeNode: + """Node used to represent the in-memory Psi merge tree.""" + + index: int + text: str = "" + embedding: np.ndarray | None = None + children: list["_PsiTreeNode"] = field(default_factory=list) + parent: "_PsiTreeNode | None" = None + + +class _PsiUnionFind: + """Build parent links for the Psi merge tree from ranked leaf pairs.""" + + def __init__(self, n: int): + """Initialize the union-find state for n leaf nodes.""" + self._rank = [0 for _ in range(n)] + self._parent_chains = [[] for _ in range(n)] + self._node_ids = [[i] for i in range(n)] + self._tree = [-1 for _ in range(max(1, 2 * n - 1))] + self._next_id = n + + @staticmethod + def _ordered_extend(target: list[int], values: list[int]): + """Append unseen values while preserving their original order.""" + for value in values: + if value not in target: + target.append(value) + + def _find(self, i: int) -> list[int]: + """Return the parent chain for a leaf, extending it lazily.""" + chain = self._parent_chains[i] + if not chain or (len(chain) == 1 and chain[0] == i): + return [i] + if chain[0] == i: + self._ordered_extend(chain, self._find(chain[1])) + else: + self._ordered_extend(chain, self._find(chain[0])) + return chain + + def _rank_bisect_right(self, chain: list[int], rank: int) -> int: + """Return the first chain index whose rank is greater than rank.""" + idx = 0 + while idx < len(chain) and self._rank[chain[idx]] <= rank: + idx += 1 + return idx + + def _build(self, i: int, j: int, insert_point: int | None = None): + """Record a merge edge in the compact parent array.""" + if insert_point is not None: + parent_ids = self._node_ids[insert_point] + parent_rank_idx = self._rank[i] + 1 + if parent_rank_idx >= len(parent_ids): + logging.warning( + "RAPTOR Psi union fallback: rank index %d is out of bounds for node %d with %d parent ids", + parent_rank_idx, + insert_point, + len(parent_ids), + ) + parent_rank_idx = len(parent_ids) - 1 + self._tree[self._node_ids[i][-1]] = parent_ids[parent_rank_idx] + return + self._tree[self._node_ids[i][-1]] = self._next_id + self._tree[self._node_ids[j][-1]] = self._next_id + self._node_ids[i].append(self._next_id) + self._next_id += 1 + + def union(self, i: int, j: int) -> bool: + """Merge two ranked leaves and return whether a new edge was added.""" + root_i = self._find(i)[-1] + root_j = self._find(j)[-1] + if root_i == root_j: + return False + + if self._rank[root_i] < self._rank[root_j]: + if not self._parent_chains[root_j]: + self._parent_chains[root_j].append(root_j) + chain = self._parent_chains[j] + higher_rank_idx = self._rank_bisect_right(chain, self._rank[root_i]) + if higher_rank_idx >= len(chain): + higher_rank_idx = len(chain) - 1 + insert_point = chain[higher_rank_idx] + self._ordered_extend(self._parent_chains[root_i], chain[higher_rank_idx:]) + self._build(root_i, root_j, insert_point=insert_point) + elif self._rank[root_i] > self._rank[root_j]: + if not self._parent_chains[root_i]: + self._parent_chains[root_i].append(root_i) + chain = self._parent_chains[i] + higher_rank_idx = self._rank_bisect_right(chain, self._rank[root_j]) + if higher_rank_idx >= len(chain): + higher_rank_idx = len(chain) - 1 + insert_point = chain[higher_rank_idx] + self._ordered_extend(self._parent_chains[root_j], chain[higher_rank_idx:]) + self._build(root_j, root_i, insert_point=insert_point) + else: + if not self._parent_chains[root_i]: + self._parent_chains[root_i].append(root_i) + self._ordered_extend(self._parent_chains[root_j], self._parent_chains[i][-1:]) + self._rank[root_i] += 1 + self._build(root_i, root_j) + return True + + @property + def tree(self) -> list[int]: + """Return the compact child-to-parent array for constructed nodes.""" + return self._tree[:self._next_id] class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: + """Build RAPTOR summary layers with the classic or Psi tree strategy.""" + def __init__( self, max_cluster, @@ -45,7 +165,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: max_token=512, threshold=0.1, max_errors=3, + tree_builder=RAPTOR_TREE_BUILDER, + clustering_method=GMM_CLUSTERING_METHOD, + psi_exact_max_leaves=4096, + psi_bucket_size=1024, ): + """Configure RAPTOR summarization, clustering, and Psi limits.""" self._max_cluster = max_cluster self._llm_model = llm_model self._embd_model = embd_model @@ -54,8 +179,17 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: self._max_token = max_token self._max_errors = max(1, max_errors) self._error_count = 0 - + self._tree_builder = tree_builder or RAPTOR_TREE_BUILDER + if self._tree_builder not in SUPPORTED_TREE_BUILDERS: + raise ValueError(f"Unsupported RAPTOR tree builder: {self._tree_builder}") + self._clustering_method = clustering_method or GMM_CLUSTERING_METHOD + if self._clustering_method not in SUPPORTED_CLUSTERING_METHODS: + raise ValueError(f"Unsupported RAPTOR clustering method: {self._clustering_method}") + self._psi_exact_max_leaves = max(2, int(psi_exact_max_leaves or 4096)) + self._psi_bucket_size = min(max(2, int(psi_bucket_size or 1024)), self._psi_exact_max_leaves) + def _check_task_canceled(self, task_id: str, message: str = ""): + """Raise if the current document task was canceled.""" if task_id and has_canceled(task_id): log_msg = f"Task {task_id} cancelled during RAPTOR {message}." logging.info(log_msg) @@ -63,6 +197,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: @timeout(60 * 20) async def _chat(self, system, history, gen_conf): + """Call the configured LLM with caching and short retries.""" cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf) if cached: return cached @@ -86,6 +221,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: @timeout(20) async def _embedding_encode(self, txt): + """Encode text with the configured embedding model and cache result.""" response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt) if response is not None: return response @@ -97,6 +233,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: return embds def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""): + """Choose the GMM cluster count with the lowest BIC score.""" max_clusters = min(self._max_cluster, len(embeddings)) n_clusters = np.arange(1, max_clusters) bics = [] @@ -109,57 +246,422 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: optimal_clusters = n_clusters[np.argmin(bics)] return optimal_clusters + def _get_clusters_ahc(self, embeddings: np.ndarray, task_id: str = "") -> np.ndarray: + """Cluster embeddings with Ward-linkage AHC and a dendrogram gap heuristic.""" + n = len(embeddings) + if n <= 1: + return np.zeros(n, dtype=int) + if n == 2: + return np.arange(n) + + self._check_task_canceled(task_id, "_get_clusters_ahc dendrogram") + full_clust = AgglomerativeClustering( + n_clusters=None, + distance_threshold=0, + compute_distances=True, + linkage="ward", + ) + full_clust.fit(embeddings) + + distances = full_clust.distances_ + if len(distances) > 1: + gaps = np.diff(distances) + max_gap_idx = int(np.argmax(gaps)) + n_clusters = max(1, min(n - max_gap_idx - 1, self._max_cluster)) + else: + n_clusters = max(1, min(n, self._max_cluster)) + if n_clusters <= 1: + logging.info("RAPTOR AHC: _get_clusters_ahc selected one cluster for %d embeddings", n) + return np.zeros(n, dtype=int) + + logging.info("RAPTOR AHC: _get_clusters_ahc selected n_clusters=%d for %d embeddings", n_clusters, n) + self._check_task_canceled(task_id, "_get_clusters_ahc fit") + clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage="ward") + return clustering.fit_predict(embeddings) + + def _adjust_tree_nodes(self, embeddings: np.ndarray, labels: np.ndarray, max_iter: int = 5) -> np.ndarray: + """Refine AHC assignments by reassigning nodes to nearest centroids.""" + labels = labels.copy() + for _ in range(max_iter): + unique_labels = np.unique(labels) + if len(unique_labels) <= 1: + return labels + centroids = np.stack([embeddings[labels == lbl].mean(axis=0) for lbl in unique_labels]) + diffs = embeddings[:, np.newaxis, :] - centroids[np.newaxis, :, :] + sq_dists = (diffs**2).sum(axis=2) + new_label_indices = np.argmin(sq_dists, axis=1) + new_labels = unique_labels[new_label_indices] + if np.array_equal(new_labels, labels): + break + unique_new = np.unique(new_labels) + remap = {old: new for new, old in enumerate(unique_new)} + labels = np.array([remap[int(lbl)] for lbl in new_labels]) + return labels + + @timeout(60 * 20) + async def _summarize_texts(self, texts: list[str], callback=None, task_id: str = ""): + """Summarize a cluster and return text plus embedding when successful.""" + self._check_task_canceled(task_id, "summarization") + + len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) + cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) + try: + async with chat_limiter: + self._check_task_canceled(task_id, "before LLM call") + + cnt = await self._chat( + "You're a helpful assistant.", + [ + { + "role": "user", + "content": self._prompt.format(cluster_content=cluster_content), + } + ], + {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 + ) + cnt = re.sub( + "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", + "", + cnt, + ) + logging.debug(f"SUM: {cnt}") + + self._check_task_canceled(task_id, "before embedding") + + embds = await self._embedding_encode(cnt) + return cnt, embds + except TaskCanceledException: + raise + except Exception as exc: + self._error_count += 1 + warn_msg = f"[RAPTOR] Skip cluster ({len(texts)} chunks) due to error: {exc}" + logging.warning(warn_msg) + if callback: + callback(msg=warn_msg) + if self._error_count >= self._max_errors: + raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc + return None + + @staticmethod + def _root(node: _PsiTreeNode) -> _PsiTreeNode: + """Return the current root for a Psi tree node.""" + while node.parent is not None: + node = node.parent + return node + + def _rank_leaf_pairs(self, leaves: list[_PsiTreeNode]) -> np.ndarray: + """Rank all leaf pairs by original embedding-space cosine similarity.""" + node_embeddings = np.asarray([leaf.embedding for leaf in leaves], dtype=np.float64) + node_embeddings = self._normalize_embeddings(node_embeddings) + similarities = node_embeddings @ node_embeddings.T + lower = np.tril_indices(len(leaves), -1) + ordered = np.argsort(similarities[lower], axis=0)[::-1] + return np.stack([lower[0][ordered], lower[1][ordered]], axis=-1) + + @staticmethod + def _normalize_embeddings(node_embeddings: np.ndarray) -> np.ndarray: + """Normalize embeddings for cosine operations while tolerating zero vectors.""" + node_embeddings = np.asarray(node_embeddings, dtype=np.float64) + norms = np.linalg.norm(node_embeddings, axis=1, keepdims=True) + return node_embeddings / np.maximum(norms, 1e-12) + + def _split_psi_buckets(self, nodes: list[_PsiTreeNode]) -> list[list[_PsiTreeNode]]: + """Split large Psi inputs so exact pair ranking is bounded per bucket.""" + if len(nodes) <= self._psi_bucket_size: + return [nodes] + + node_embeddings = self._normalize_embeddings(np.asarray([node.embedding for node in nodes], dtype=np.float64)) + groups = [np.arange(len(nodes), dtype=int)] + buckets = [] + + while groups: + group = np.asarray(groups.pop(), dtype=int) + if len(group) <= self._psi_bucket_size: + buckets.append(group.tolist()) + continue + + fanout = min(max(2, int(np.ceil(len(group) / self._psi_bucket_size))), len(group), 32) + group_embeddings = node_embeddings[group] + center_idx = np.linspace(0, len(group_embeddings) - 1, num=fanout, dtype=int) + centers = group_embeddings[center_idx].copy() + + for _ in range(5): + labels = np.argmax(group_embeddings @ centers.T, axis=1) + for center_id in range(fanout): + mask = labels == center_id + if not np.any(mask): + continue + center = group_embeddings[mask].mean(axis=0) + norm = np.linalg.norm(center) + centers[center_id] = center / norm if norm > 0 else center + + labels = np.argmax(group_embeddings @ centers.T, axis=1) + split_groups = [group[labels == center_id].tolist() for center_id in range(fanout)] + split_groups = [bucket for bucket in split_groups if bucket] + if len(split_groups) <= 1: + split_groups = [ + group[start:start + self._psi_bucket_size].tolist() + for start in range(0, len(group), self._psi_bucket_size) + ] + groups.extend(split_groups) + + buckets = [bucket for bucket in buckets if bucket] + buckets.sort(key=lambda bucket: (len(bucket), bucket[0])) + return [[nodes[idx] for idx in bucket] for bucket in buckets] + + def _assign_prototype_embeddings(self, node: _PsiTreeNode) -> np.ndarray: + """Assign mean child embeddings to internal Psi nodes for bucket-level ranking.""" + if not node.children: + return np.asarray(node.embedding, dtype=np.float64) + embeddings = np.asarray([self._assign_prototype_embeddings(child) for child in node.children], dtype=np.float64) + node.embedding = embeddings.mean(axis=0) + return node.embedding + + @staticmethod + def _iter_nodes(root: _PsiTreeNode): + """Yield nodes in a Psi tree using a stack traversal.""" + stack = [root] + while stack: + node = stack.pop() + yield node + stack.extend(node.children) + + def _create_psi_parent(self, index: int, children: list[_PsiTreeNode]) -> _PsiTreeNode: + """Create a parent node and attach the provided children to it.""" + parent = _PsiTreeNode(index=index, children=children) + for child in children: + child.parent = parent + return parent + + def _rebalance_psi_tree(self, root: _PsiTreeNode, next_index: int) -> tuple[_PsiTreeNode, int]: + """Group oversized Psi tree nodes so fanout stays within max_cluster.""" + max_children = max(2, int(self._max_cluster or 2)) + + def rebalance(node: _PsiTreeNode): + """Recursively group children when a Psi node exceeds fanout.""" + nonlocal next_index + + for child in list(node.children): + rebalance(child) + + while len(node.children) > max_children: + original_children = len(node.children) + grouped_children = [] + for start in range(0, len(node.children), max_children): + batch = node.children[start:start + max_children] + if len(batch) == 1: + grouped_children.append(batch[0]) + batch[0].parent = node + else: + grouped_children.append(self._create_psi_parent(next_index, batch)) + grouped_children[-1].parent = node + next_index += 1 + node.children = grouped_children + logging.info( + "RAPTOR Psi rebalance: node=%s children=%d grouped_to=%d max_cluster=%d", + node.index, + original_children, + len(grouped_children), + max_children, + ) + + rebalance(root) + return self._root(root), next_index + + def _build_exact_psi_structure( + self, + nodes: list[_PsiTreeNode], + next_index: int, + task_id: str = "", + ) -> tuple[_PsiTreeNode, int, int]: + """Build an exact Psi subtree for a bounded node set.""" + if len(nodes) == 1: + return nodes[0], next_index, 0 + + ranked_pairs = self._rank_leaf_pairs(nodes) + union_find = _PsiUnionFind(len(nodes)) + merges = 0 + for left_idx, right_idx in ranked_pairs: + self._check_task_canceled(task_id, "Psi tree construction") + if union_find.union(int(left_idx), int(right_idx)): + merges += 1 + if merges == len(nodes) - 1: + break + + local_nodes = {idx: node for idx, node in enumerate(nodes)} + tree = union_find.tree + children_by_parent = {} + for child_idx, parent_idx in enumerate(tree): + if child_idx not in local_nodes: + local_nodes[child_idx] = _PsiTreeNode(index=next_index) + next_index += 1 + if parent_idx == -1: + continue + children_by_parent.setdefault(parent_idx, []).append(child_idx) + if parent_idx not in local_nodes: + local_nodes[parent_idx] = _PsiTreeNode(index=next_index) + next_index += 1 + + for parent_idx, child_indices in children_by_parent.items(): + parent = local_nodes[parent_idx] + parent.children = [local_nodes[child_idx] for child_idx in child_indices] + for child in parent.children: + child.parent = parent + + roots = [local_nodes[idx] for idx, parent_idx in enumerate(tree) if parent_idx == -1 and idx in local_nodes] + root = max(roots, key=lambda node: node.index) + return root, next_index, merges + + def _build_bucketed_psi_structure( + self, + nodes: list[_PsiTreeNode], + next_index: int, + task_id: str = "", + ) -> tuple[_PsiTreeNode, int, int]: + """Build large Psi trees by exact-ranking bounded buckets, then bucket roots.""" + buckets = self._split_psi_buckets(nodes) + logging.info( + "RAPTOR Psi bucketed build: nodes=%d buckets=%d bucket_size=%d exact_max_leaves=%d", + len(nodes), + len(buckets), + self._psi_bucket_size, + self._psi_exact_max_leaves, + ) + + bucket_roots = [] + merges = 0 + for bucket in buckets: + bucket_root, next_index, bucket_merges = self._build_psi_structure_from_nodes(bucket, next_index, task_id) + self._assign_prototype_embeddings(bucket_root) + bucket_roots.append(bucket_root) + merges += bucket_merges + + if len(bucket_roots) == 1: + return bucket_roots[0], next_index, merges + + root, next_index, root_merges = self._build_psi_structure_from_nodes(bucket_roots, next_index, task_id) + return root, next_index, merges + root_merges + + def _build_psi_structure_from_nodes( + self, + nodes: list[_PsiTreeNode], + next_index: int, + task_id: str = "", + ) -> tuple[_PsiTreeNode, int, int]: + """Build Psi structure exactly for small sets and bucket large sets.""" + if len(nodes) <= self._psi_exact_max_leaves: + return self._build_exact_psi_structure(nodes, next_index, task_id) + return self._build_bucketed_psi_structure(nodes, next_index, task_id) + + def _build_psi_structure(self, chunks, task_id: str = "") -> tuple[_PsiTreeNode, list[_PsiTreeNode]]: + """Build the Psi merge tree from original chunk embeddings.""" + leaves = [ + _PsiTreeNode(index=i, text=text, embedding=np.asarray(embd)) + for i, (text, embd) in enumerate(chunks) + ] + if len(leaves) == 1: + return leaves[0], leaves + + root, next_index, merges = self._build_psi_structure_from_nodes(leaves, len(leaves), task_id) + root, _ = self._rebalance_psi_tree(root, next_index) + logging.info( + "RAPTOR Psi tree built: leaves=%d merges=%d root_fanout=%d", + len(leaves), + merges, + len(root.children), + ) + return root, leaves + + @staticmethod + def _psi_layers(root: _PsiTreeNode) -> dict[int, list[_PsiTreeNode]]: + """Collect non-leaf Psi nodes by height for bottom-up summarization.""" + layers = {} + + def height(node: _PsiTreeNode) -> int: + """Return node height while collecting internal nodes by layer.""" + if not node.children: + return 0 + node_height = max(height(child) for child in node.children) + 1 + layers.setdefault(node_height, []).append(node) + return node_height + + height(root) + return layers + + async def _build_psi_layers(self, chunks, callback=None, task_id: str = ""): + """Materialize Psi tree layers as summary chunks.""" + layers = [(0, len(chunks))] + root, _ = self._build_psi_structure(chunks, task_id=task_id) + + for layer_idx, (_, nodes) in enumerate(sorted(self._psi_layers(root).items()), start=1): + layer_start = len(chunks) + + async def summarize_node(node: _PsiTreeNode): + """Summarize one Psi internal node if its children have text.""" + texts = [child.text for child in node.children if child.text] + if not texts: + logging.warning("RAPTOR Psi node %s skipped because it has no child text to summarize", node.index) + return None + result = await self._summarize_texts(texts, callback, task_id) + if result is None: + logging.warning("RAPTOR Psi node %s skipped because summarization failed", node.index) + return None + node.text, node.embedding = result + return node + + tasks = [asyncio.create_task(summarize_node(node)) for node in nodes] + try: + summarized_nodes = await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in RAPTOR Psi tree processing: {e}") + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + summarized_nodes = [node for node in summarized_nodes if node is not None] + for node in summarized_nodes: + chunks.append((node.text, node.embedding)) + + if len(chunks) > layer_start: + layers.append((layer_start, len(chunks))) + logging.info( + "RAPTOR Psi layer materialized: layer=%d nodes=%d summaries=%d", + layer_idx, + len(nodes), + len(chunks) - layer_start, + ) + if callback: + callback(msg="Build one Psi-RAG layer: {} -> {}".format(len(nodes), len(chunks) - layer_start)) + else: + logging.warning("RAPTOR Psi layer %d produced no summaries; stopping materialization", layer_idx) + break + + return chunks, layers + async def __call__(self, chunks, random_state, callback=None, task_id: str = ""): + """Build summary chunks and layer boundaries for RAPTOR retrieval.""" if len(chunks) <= 1: return [], [] chunks = [(s, a) for s, a in chunks if s and a is not None and len(a) > 0] + if len(chunks) <= 1: + return chunks, [(0, len(chunks))] + if self._tree_builder == PSI_TREE_BUILDER: + logging.info("RAPTOR: using %s tree builder for %d chunks", self._tree_builder, len(chunks)) + return await self._build_psi_layers(chunks, callback, task_id) + layers = [(0, len(chunks))] start, end = 0, len(chunks) @timeout(60 * 20) async def summarize(ck_idx: list[int]): + """Summarize one classic RAPTOR cluster into the chunk list.""" nonlocal chunks - self._check_task_canceled(task_id, "summarization") - texts = [chunks[i][0] for i in ck_idx] - len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) - cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) - try: - async with chat_limiter: - self._check_task_canceled(task_id, "before LLM call") - - cnt = await self._chat( - "You're a helpful assistant.", - [ - { - "role": "user", - "content": self._prompt.format(cluster_content=cluster_content), - } - ], - {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 - ) - cnt = re.sub( - "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", - "", - cnt, - ) - logging.debug(f"SUM: {cnt}") - - self._check_task_canceled(task_id, "before embedding") - - embds = await self._embedding_encode(cnt) - chunks.append((cnt, embds)) - except TaskCanceledException: - raise - except Exception as exc: - self._error_count += 1 - warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}" - logging.warning(warn_msg) - if callback: - callback(msg=warn_msg) - if self._error_count >= self._max_errors: - raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc + result = await self._summarize_texts(texts, callback, task_id) + if result is not None: + chunks.append(result) while end - start > 1: self._check_task_canceled(task_id, "layer processing") @@ -167,8 +669,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: embeddings = [embd for _, embd in chunks[start:end]] if len(embeddings) == 2: await summarize([start, start + 1]) + produced = len(chunks) - end + if produced == 0: + logging.warning("RAPTOR layer produced no summaries; stopping materialization") + break if callback: - callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) + callback(msg="Cluster one layer: {} -> {}".format(end - start, produced)) layers.append((end, len(chunks))) start = end end = len(chunks) @@ -180,15 +686,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: n_components=min(12, len(embeddings) - 2), metric="cosine", ).fit_transform(embeddings) - n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id) + if self._clustering_method == AHC_CLUSTERING_METHOD: + logging.info("RAPTOR: using clustering_method=%s before _get_clusters_ahc", self._clustering_method) + raw_labels = self._get_clusters_ahc(reduced_embeddings, task_id=task_id) + raw_cluster_count = np.unique(raw_labels).size + logging.info("RAPTOR AHC: _get_clusters_ahc produced n_clusters=%d", raw_cluster_count) + if raw_cluster_count > 1: + adjusted = self._adjust_tree_nodes(reduced_embeddings, raw_labels) + adjusted_cluster_count = np.unique(adjusted).size + logging.info("RAPTOR AHC: _adjust_tree_nodes adjusted n_clusters=%d", adjusted_cluster_count) + else: + adjusted = raw_labels + logging.warning("RAPTOR AHC: _adjust_tree_nodes skipped because _get_clusters_ahc returned one cluster") + unique_labels = np.unique(adjusted) + label_map = {old: idx for idx, old in enumerate(unique_labels)} + lbls = [label_map[int(lbl)] for lbl in adjusted] + n_clusters = len(unique_labels) + else: + n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id) + if n_clusters == 1: + lbls = [0 for _ in range(len(reduced_embeddings))] + else: + gm = GaussianMixture(n_components=n_clusters, random_state=random_state) + gm.fit(reduced_embeddings) + probs = gm.predict_proba(reduced_embeddings) + lbls = [np.where(prob > self._threshold)[0] for prob in probs] + lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] + if n_clusters == 1: lbls = [0 for _ in range(len(reduced_embeddings))] else: - gm = GaussianMixture(n_components=n_clusters, random_state=random_state) - gm.fit(reduced_embeddings) - probs = gm.predict_proba(reduced_embeddings) - lbls = [np.where(prob > self._threshold)[0] for prob in probs] - lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] + lbls = [int(lbl[0]) if isinstance(lbl, np.ndarray) else int(lbl) for lbl in lbls] tasks = [] for c in range(n_clusters): @@ -205,10 +733,21 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: await asyncio.gather(*tasks, return_exceptions=True) raise - assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) + produced = len(chunks) - end + assert produced <= n_clusters, "{} vs. {}".format(produced, n_clusters) + if produced < n_clusters: + logging.warning( + "RAPTOR layer produced %d/%d cluster summaries; skipped %d cluster(s) due to errors", + produced, + n_clusters, + n_clusters - produced, + ) + if produced == 0: + logging.warning("RAPTOR layer produced no summaries; stopping materialization") + break layers.append((end, len(chunks))) if callback: - callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) + callback(msg="Cluster one layer: {} -> {}".format(end - start, produced)) start = end end = len(chunks) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index cb4136617..492ae69e2 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -36,7 +36,15 @@ from api.db.joint_services.memory_message_service import handle_save_to_memory_t from common.connection_utils import timeout from common.metadata_utils import turn2jsonschema, update_metadata_to from rag.utils.base64_image import image2id -from rag.utils.raptor_utils import should_skip_raptor, get_skip_reason +from rag.utils.raptor_utils import ( + collect_raptor_chunk_ids, + collect_raptor_methods, + get_raptor_clustering_method, + get_raptor_tree_builder, + get_skip_reason, + make_raptor_summary_chunk_id, + should_skip_raptor, +) from common.log_utils import init_root_logger from common.config_utils import show_configs from rag.graphrag.general.index import run_graphrag_for_kb @@ -70,7 +78,10 @@ from api.db.db_models import close_connection from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ email, tag from rag.nlp import search, rag_tokenizer, add_positions -from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor +from rag.raptor import ( + RAPTOR_TREE_BUILDER, + RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor, +) from common.token_utils import num_tokens_from_string, truncate from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock from rag.graphrag.utils import chat_limiter @@ -817,61 +828,160 @@ async def run_dataflow(task: dict): dsl=str(pipeline)) -async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str) -> bool: - """Return True if RAPTOR chunks already exist for doc_id in the doc store. +RAPTOR_METHOD_SEARCH_LIMIT = 10000 - Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading - chunk cannot produce a false-negative result. Uses thread_pool_exec so - the blocking doc-store call does not stall the event loop. - """ + +async def get_raptor_chunk_field_map(doc_id: str, tenant_id: str, kb_id: str) -> dict: + """Return stored RAPTOR marker fields for a document.""" from common.doc_store.doc_store_base import OrderByExpr from rag.nlp import search as nlp_search - try: - condition = {"doc_id": doc_id, "raptor_kwd": ["raptor"]} + + async def search_fields(fields: list[str], condition: dict, order_by=None): + """Search chunk fields in the current knowledge base.""" res = await thread_pool_exec( settings.docStoreConn.search, - ["raptor_kwd"], [], condition, [], OrderByExpr(), - 0, 1, nlp_search.index_name(tenant_id), [kb_id] + fields, [], condition, [], order_by or OrderByExpr(), + 0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id] ) - field_map = settings.docStoreConn.get_fields(res, ["raptor_kwd"]) - found = bool(field_map) - if found: + return settings.docStoreConn.get_fields(res, fields) + + primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]}) + if collect_raptor_chunk_ids(primary): + return primary + + try: + return await search_fields( + ["raptor_kwd", "extra"], + {"doc_id": doc_id}, + OrderByExpr().desc("create_timestamp_flt"), + ) + except Exception: + logging.debug("RAPTOR fallback method lookup with extra field failed for doc %s", doc_id, exc_info=True) + return primary + + +async def get_raptor_chunk_methods(doc_id: str, tenant_id: str, kb_id: str) -> set[str]: + """Return the RAPTOR tree builders already stored for doc_id. + + Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading + chunk cannot produce a false-negative result. Legacy summary chunks that + do not have method metadata are treated as the original RAPTOR builder. + """ + try: + field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id) + methods = collect_raptor_methods(field_map) + if methods: logging.info( - "Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s) already exist", - doc_id, tenant_id, kb_id, + "Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist", + doc_id, tenant_id, kb_id, sorted(methods), ) else: logging.info( "Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)", doc_id, tenant_id, kb_id, ) - return found + return methods except Exception: logging.exception("Failed to check RAPTOR chunks for doc %s", doc_id) - return False + raise + + +async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, tree_builder: str = RAPTOR_TREE_BUILDER) -> bool: + """Return whether doc_id already has summaries for tree_builder.""" + methods = await get_raptor_chunk_methods(doc_id, tenant_id, kb_id) + return tree_builder in methods + + +async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_method: str | None = None): + """Delete RAPTOR summaries for doc_id, optionally preserving one method.""" + from rag.nlp import search as nlp_search + + if keep_method is None: + logging.info( + "delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)", + doc_id, tenant_id, kb_id, + ) + await thread_pool_exec( + settings.docStoreConn.delete, + {"doc_id": doc_id, "raptor_kwd": ["raptor"]}, + nlp_search.index_name(tenant_id), + kb_id, + ) + return 0 + + field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id) + chunk_ids = collect_raptor_chunk_ids(field_map, exclude_methods={keep_method}) + if not chunk_ids: + logging.debug( + "delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)", + doc_id, tenant_id, kb_id, keep_method, + ) + return 0 + + logging.info( + "delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)", + len(chunk_ids), doc_id, tenant_id, kb_id, keep_method, + ) + await thread_pool_exec( + settings.docStoreConn.delete, + {"id": list(chunk_ids)}, + nlp_search.index_name(tenant_id), + kb_id, + ) + return len(chunk_ids) @timeout(3600) async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]): + """Generate RAPTOR summaries for selected documents in a knowledge base.""" fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID raptor_config = kb_parser_config.get("raptor", {}) + raptor_ext_config = raptor_config.get("ext") or {} + tree_builder = get_raptor_tree_builder(raptor_config) + clustering_method = get_raptor_clustering_method(raptor_config) vctr_nm = "q_%d_vec" % vector_size res = [] tk_count = 0 + cleanup_raptor_chunks = [] max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3)) - doc_name_by_id = {} + doc_info_by_id = {} for doc_id in set(doc_ids): ok, source_doc = DocumentService.get_by_id(doc_id) if not ok or not source_doc: continue - source_name = getattr(source_doc, "name", "") - if source_name: - doc_name_by_id[doc_id] = source_name + doc_info_by_id[doc_id] = { + "name": getattr(source_doc, "name", ""), + "type": getattr(source_doc, "type", ""), + "parser_id": getattr(source_doc, "parser_id", ""), + "parser_config": getattr(source_doc, "parser_config", {}) or {}, + } + + def schedule_raptor_cleanup(doc_id: str, keep_method: str | None = None): + """Queue stale RAPTOR summaries for deletion after successful insert.""" + cleanup_plan = (doc_id, keep_method) + if cleanup_plan not in cleanup_raptor_chunks: + cleanup_raptor_chunks.append(cleanup_plan) + + def skip_raptor_doc(doc_id: str) -> bool: + """Return whether RAPTOR should be skipped for this source document.""" + doc_info = doc_info_by_id.get(doc_id, {}) + file_type = doc_info.get("type") or row.get("type", "") + parser_id = doc_info.get("parser_id") or row.get("parser_id", "") + parser_config = doc_info.get("parser_config") or row.get("parser_config", {}) + if should_skip_raptor(file_type, parser_id, parser_config, raptor_config): + skip_reason = get_skip_reason(file_type, parser_id, parser_config) + doc_name = doc_info.get("name") or doc_id + logging.info("Skipping Raptor for document %s: %s", doc_name, skip_reason) + callback(msg=f"[RAPTOR] doc:{doc_id} skipped: {skip_reason}") + return True + return False async def generate(chunks, did): + """Run RAPTOR and append generated summary chunks for one doc id.""" nonlocal tk_count, res + logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did) raptor = Raptor( raptor_config.get("max_cluster", 64), chat_mdl, @@ -880,16 +990,21 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si raptor_config["max_token"], raptor_config["threshold"], max_errors=max_errors, + tree_builder=tree_builder, + clustering_method=clustering_method, + psi_exact_max_leaves=raptor_ext_config.get("psi_exact_max_leaves", 4096), + psi_bucket_size=raptor_ext_config.get("psi_bucket_size", 1024), ) original_length = len(chunks) chunks, layers = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"]) - effective_doc_name = row["name"] if did == fake_doc_id else doc_name_by_id.get(did, row["name"]) + effective_doc_name = row["name"] if did == fake_doc_id else doc_info_by_id.get(did, {}).get("name") or row["name"] doc = { "doc_id": did, "kb_id": [str(row["kb_id"])], "docnm_kwd": effective_doc_name, "title_tks": rag_tokenizer.tokenize(effective_doc_name), - "raptor_kwd": "raptor" + "raptor_kwd": "raptor", + "extra": {"raptor_method": tree_builder}, } if row["pagerank"]: doc[PAGERANK_FLD] = int(row["pagerank"]) @@ -906,7 +1021,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si for idx, (content, vctr) in enumerate(chunks[original_length:], start=original_length): d = copy.deepcopy(doc) - d["id"] = xxhash.xxh64((content + str(fake_doc_id)).encode("utf-8")).hexdigest() + d["id"] = make_raptor_summary_chunk_id(content, did) d["create_time"] = str(datetime.now()).replace("T", " ")[:19] d["create_timestamp_flt"] = datetime.now().timestamp() d[vctr_nm] = vctr.tolist() @@ -918,12 +1033,28 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si tk_count += num_tokens_from_string(content) if raptor_config.get("scope", "file") == "file": + dataset_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"]) + remove_dataset_summaries = bool(dataset_methods) + has_file_level_target = False + if dataset_methods: + callback(msg="[RAPTOR] will remove dataset-level summaries after file-level summaries are available.") + for x, doc_id in enumerate(doc_ids): - # CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store - if await has_raptor_chunks(doc_id, row["tenant_id"], row["kb_id"]): - callback(msg=f"[RAPTOR] doc:{doc_id} already has RAPTOR chunks, skipping.") + if skip_raptor_doc(doc_id): callback(prog=(x + 1.) / len(doc_ids)) continue + # CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store + existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"]) + if tree_builder in existing_methods: + has_file_level_target = True + if existing_methods != {tree_builder}: + schedule_raptor_cleanup(doc_id, tree_builder) + callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.") + callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.") + callback(prog=(x + 1.) / len(doc_ids)) + continue + if existing_methods: + callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.") chunks = [] skipped_chunks = 0 @@ -945,12 +1076,52 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si callback(msg=f"[WARN] No valid chunks with vectors found for doc {doc_id}, skipping") continue + before_generate = len(res) await generate(chunks, doc_id) + if len(res) > before_generate: + has_file_level_target = True + if existing_methods: + schedule_raptor_cleanup(doc_id, tree_builder) callback(prog=(x + 1.) / len(doc_ids)) + + if remove_dataset_summaries: + if has_file_level_target: + schedule_raptor_cleanup(fake_doc_id) + else: + callback(msg="[RAPTOR] kept dataset-level summaries because no file-level summaries were built.") else: + migrated_file_docs = 0 + file_cleanup_doc_ids = [] + skipped_doc_ids = set() + for doc_id in set(doc_ids): + if skip_raptor_doc(doc_id): + skipped_doc_ids.add(doc_id) + continue + existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"]) + if existing_methods: + file_cleanup_doc_ids.append(doc_id) + migrated_file_docs += 1 + if migrated_file_docs: + callback(msg=f"[RAPTOR] will remove file-level summaries for {migrated_file_docs} docs after dataset-level build succeeds.") + + existing_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"]) + if tree_builder in existing_methods: + if existing_methods != {tree_builder}: + schedule_raptor_cleanup(fake_doc_id, tree_builder) + callback(msg="[RAPTOR] will remove old dataset-level RAPTOR summaries after insert.") + for doc_id in file_cleanup_doc_ids: + schedule_raptor_cleanup(doc_id) + callback(msg=f"[RAPTOR] dataset-level {tree_builder} summaries already exist, skipping.") + return res, tk_count, cleanup_raptor_chunks + migrate_dataset_summaries = bool(existing_methods) + if migrate_dataset_summaries: + callback(msg=f"[RAPTOR] will migrate dataset-level RAPTOR summaries to {tree_builder} after insert.") + chunks = [] skipped_chunks = 0 for doc_id in doc_ids: + if doc_id in skipped_doc_ids: + continue for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True): @@ -965,13 +1136,22 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si callback(msg=f"[WARN] Skipped {skipped_chunks} chunks without vector field '{vctr_nm}'. Consider re-parsing documents with the current embedding model.") if not chunks: + if skipped_doc_ids and len(skipped_doc_ids) == len(set(doc_ids)): + callback(msg="[RAPTOR] all documents were skipped by RAPTOR auto-disable rules.") + return res, tk_count, cleanup_raptor_chunks logging.error(f"RAPTOR: No valid chunks with vectors found in any document for kb {row['kb_id']}") callback(msg=f"[ERROR] No valid chunks with vectors found. Please ensure documents are parsed with the current embedding model (vector size: {vector_size}).") - return res, tk_count + return res, tk_count, cleanup_raptor_chunks + before_generate = len(res) await generate(chunks, fake_doc_id) + if len(res) > before_generate: + for doc_id in file_cleanup_doc_ids: + schedule_raptor_cleanup(doc_id) + if migrate_dataset_summaries: + schedule_raptor_cleanup(fake_doc_id, tree_builder) - return res, tk_count + return res, tk_count, cleanup_raptor_chunks async def delete_image(kb_id, chunk_id): @@ -1029,6 +1209,29 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre search.index_name(task_tenant_id), task_dataset_id, ) task_canceled = has_canceled(task_id) if task_canceled: + # Roll back partial RAPTOR summary inserts so the next run is not + # mistaken for a completed checkpoint by get_raptor_chunk_methods. + raptor_ids_to_rollback = [ + c["id"] for c in chunks[:b + settings.DOC_BULK_SIZE] + if c.get("raptor_kwd") == "raptor" + ] + if raptor_ids_to_rollback: + try: + await thread_pool_exec( + settings.docStoreConn.delete, + {"id": raptor_ids_to_rollback}, + search.index_name(task_tenant_id), + task_dataset_id, + ) + logging.info( + "insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)", + len(raptor_ids_to_rollback), task_id, + ) + except Exception: + logging.exception( + "insert_chunks: failed to roll back partial RAPTOR chunks after cancellation (task=%s)", + task_id, + ) progress_callback(-1, msg="Task has been canceled.") return False if b % 128 == 0: @@ -1088,6 +1291,7 @@ async def do_handle_task(task): task_parser_config = task["parser_config"] task_start_ts = timer() toc_thread = None + raptor_cleanup_chunks = [] # prepare the progress callback function progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) @@ -1135,7 +1339,9 @@ async def do_handle_task(task): "threshold": 0.1, "max_cluster": 64, "random_seed": 0, - "scope": "file" + "scope": "file", + "clustering_method": "gmm", + "tree_builder": "raptor", }, } ) @@ -1143,23 +1349,12 @@ async def do_handle_task(task): progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration") return - # Check if Raptor should be skipped for structured data - file_type = task.get("type", "") - parser_id = task.get("parser_id", "") - raptor_config = kb_parser_config.get("raptor", {}) - - if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config): - skip_reason = get_skip_reason(file_type, parser_id, task_parser_config) - logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}") - progress_callback(prog=1.0, msg=f"Raptor skipped: {skip_reason}") - return - # bind LLM for raptor chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id) chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language) # run RAPTOR async with kg_limiter: - chunks, token_count = await run_raptor_for_kb( + chunks, token_count, raptor_cleanup_chunks = await run_raptor_for_kb( row=task, kb_parser_config=kb_parser_config, chat_mdl=chat_model, @@ -1268,6 +1463,18 @@ async def do_handle_task(task): progress_callback(-1, msg="Task has been canceled.") return + if raptor_cleanup_chunks: + cleaned_chunks = 0 + for cleanup_doc_id, keep_method in raptor_cleanup_chunks: + cleaned_chunks += await delete_raptor_chunks( + cleanup_doc_id, + task_tenant_id, + task_dataset_id, + keep_method=keep_method, + ) + if cleaned_chunks: + progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.") + logging.info( "Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format( task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts diff --git a/rag/utils/ob_conn.py b/rag/utils/ob_conn.py index 22fbc9c7b..fde2138f0 100644 --- a/rag/utils/ob_conn.py +++ b/rag/utils/ob_conn.py @@ -46,6 +46,8 @@ column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk ord column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval") column_mom_id = Column("mom_id", String(256), nullable=True, comment="parent chunk id") column_chunk_data = Column("chunk_data", JSON, nullable=True, comment="table parser row data") +column_raptor_kwd = Column("raptor_kwd", String(256), nullable=True, comment="RAPTOR summary marker") +column_raptor_layer_int = Column("raptor_layer_int", Integer, nullable=True, comment="RAPTOR summary layer") column_definitions: list[Column] = [ Column("id", String(256), primary_key=True, comment="chunk id"), @@ -86,6 +88,8 @@ column_definitions: list[Column] = [ Column("rank_flt", Double, nullable=True, comment="rank of this entity"), Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'", comment="whether it has been deleted"), + column_raptor_kwd, + column_raptor_layer_int, column_chunk_data, Column("metadata", JSON, nullable=True, comment="metadata for this chunk"), Column("extra", JSON, nullable=True, comment="extra information of non-general chunk"), @@ -127,7 +131,14 @@ FTS_COLUMNS_TKS: list[str] = [ ] # Extra columns to add after table creation (for migration) -EXTRA_COLUMNS: list[Column] = [column_order_id, column_group_id, column_mom_id, column_chunk_data] +EXTRA_COLUMNS: list[Column] = [ + column_order_id, + column_group_id, + column_mom_id, + column_chunk_data, + column_raptor_kwd, + column_raptor_layer_int, +] class SearchResult(BaseModel): diff --git a/rag/utils/raptor_utils.py b/rag/utils/raptor_utils.py index dd6f75dd9..91d43cd93 100644 --- a/rag/utils/raptor_utils.py +++ b/rag/utils/raptor_utils.py @@ -18,15 +18,111 @@ Utility functions for Raptor processing decisions. """ +import json import logging from typing import Optional +import xxhash + +RAPTOR_TREE_BUILDER = "raptor" +PSI_TREE_BUILDER = "psi" +SUPPORTED_TREE_BUILDERS = {RAPTOR_TREE_BUILDER, PSI_TREE_BUILDER} +GMM_CLUSTERING_METHOD = "gmm" +AHC_CLUSTERING_METHOD = "ahc" +SUPPORTED_CLUSTERING_METHODS = {GMM_CLUSTERING_METHOD, AHC_CLUSTERING_METHOD} + # File extensions for structured data types EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"} CSV_EXTENSIONS = {".csv", ".tsv"} STRUCTURED_EXTENSIONS = EXCEL_EXTENSIONS | CSV_EXTENSIONS +def get_raptor_tree_builder(raptor_config: dict | None) -> str: + """Return the configured RAPTOR tree builder with legacy ext fallback.""" + raptor_config = raptor_config or {} + ext = raptor_config.get("ext") or {} + tree_builder = ext.get("tree_builder") or raptor_config.get("tree_builder") or RAPTOR_TREE_BUILDER + if tree_builder not in SUPPORTED_TREE_BUILDERS: + raise ValueError(f"Unsupported RAPTOR tree builder: {tree_builder}") + return tree_builder + + +def get_raptor_clustering_method(raptor_config: dict | None) -> str: + """Return the configured RAPTOR clustering method with legacy ext fallback.""" + raptor_config = raptor_config or {} + ext = raptor_config.get("ext") or {} + clustering_method = ext.get("clustering_method") or raptor_config.get("clustering_method") or GMM_CLUSTERING_METHOD + if clustering_method not in SUPPORTED_CLUSTERING_METHODS: + raise ValueError(f"Unsupported RAPTOR clustering method: {clustering_method}") + return clustering_method + + +def _as_extra_dict(extra) -> dict: + """Normalize a chunk extra payload into a dictionary.""" + if isinstance(extra, dict): + return extra + if isinstance(extra, str) and extra: + try: + parsed = json.loads(extra) + except json.JSONDecodeError: + logging.warning( + "Ignoring malformed RAPTOR extra payload while collecting chunk metadata: %s", + extra[:200], + exc_info=True, + ) + return {} + return parsed if isinstance(parsed, dict) else {} + return {} + + +def _has_raptor_marker(marker) -> bool: + """Return whether a chunk marker identifies a RAPTOR summary chunk.""" + if isinstance(marker, list): + return any(str(item) == RAPTOR_TREE_BUILDER for item in marker) + return str(marker) == RAPTOR_TREE_BUILDER + + +def _raptor_methods_from_fields(fields: dict, extra: dict | None = None) -> set[str]: + """Read RAPTOR builder methods from stored chunk fields.""" + extra = extra if extra is not None else _as_extra_dict(fields.get("extra")) + method = extra.get("raptor_method") or RAPTOR_TREE_BUILDER + if isinstance(method, list): + return {str(item) for item in method if item} + return {str(method)} if method else set() + + +def collect_raptor_methods(field_map: dict) -> set[str]: + """Collect tree-builder methods from RAPTOR summary chunk fields.""" + methods = set() + for fields in field_map.values(): + extra = _as_extra_dict(fields.get("extra")) + marker = fields.get("raptor_kwd") or extra.get("raptor_kwd") + if not _has_raptor_marker(marker): + continue + + methods.update(_raptor_methods_from_fields(fields, extra)) + return methods + + +def collect_raptor_chunk_ids(field_map: dict, exclude_methods: set[str] | None = None) -> set[str]: + """Collect RAPTOR summary chunk IDs, optionally excluding some methods.""" + chunk_ids = set() + exclude_methods = exclude_methods or set() + for chunk_id, fields in field_map.items(): + extra = _as_extra_dict(fields.get("extra")) + marker = fields.get("raptor_kwd") or extra.get("raptor_kwd") + if _has_raptor_marker(marker): + if _raptor_methods_from_fields(fields, extra).issubset(exclude_methods): + continue + chunk_ids.add(chunk_id) + return chunk_ids + + +def make_raptor_summary_chunk_id(content: str, doc_id: str) -> str: + """Build the stable ID used for generated RAPTOR summary chunks.""" + return xxhash.xxh64((content + str(doc_id)).encode("utf-8")).hexdigest() + + def is_structured_file_type(file_type: Optional[str]) -> bool: """ Check if a file type is structured data (Excel, CSV, etc.) diff --git a/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py b/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py index 30d19d4ac..c3cd9ac3d 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py +++ b/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py @@ -583,6 +583,10 @@ class TestDatasetUpdate: {"raptor": {"max_cluster": 512}}, {"raptor": {"max_cluster": 1024}}, {"raptor": {"random_seed": 0}}, + {"raptor": {"clustering_method": "gmm"}}, + {"raptor": {"clustering_method": "ahc"}}, + {"raptor": {"tree_builder": "raptor"}}, + {"raptor": {"tree_builder": "psi"}}, ], ids=[ "auto_keywords_min", @@ -633,6 +637,10 @@ class TestDatasetUpdate: "raptor_max_cluster_mid", "raptor_max_cluster_max", "raptor_random_seed_min", + "raptor_clustering_method_gmm", + "raptor_clustering_method_ahc", + "raptor_tree_builder_raptor", + "raptor_tree_builder_psi", ], ) def test_parser_config(self, HttpApiAuth, add_dataset_func, parser_config): @@ -707,6 +715,10 @@ class TestDatasetUpdate: ({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), ({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer"), ({"raptor": {"random_seed": "string"}}, "Input should be a valid integer"), + ({"raptor": {"clustering_method": "unknown"}}, "Input should be 'gmm' or 'ahc'"), + ({"raptor": {"clustering_method": None}}, "Input should be 'gmm' or 'ahc'"), + ({"raptor": {"tree_builder": "ahc"}}, "Input should be 'raptor' or 'psi'"), + ({"raptor": {"tree_builder": None}}, "Input should be 'raptor' or 'psi'"), ({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"), ], ids=[ @@ -763,6 +775,10 @@ class TestDatasetUpdate: "raptor_random_seed_min_limit", "raptor_random_seed_float_not_allowed", "raptor_random_seed_type_invalid", + "raptor_clustering_method_invalid", + "raptor_clustering_method_none_invalid", + "raptor_tree_builder_invalid", + "raptor_tree_builder_none_invalid", "parser_config_type_invalid", ], ) diff --git a/test/unit_test/rag/test_raptor_psi_tree_builder.py b/test/unit_test/rag/test_raptor_psi_tree_builder.py new file mode 100644 index 000000000..1d0af20d9 --- /dev/null +++ b/test/unit_test/rag/test_raptor_psi_tree_builder.py @@ -0,0 +1,375 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +import sys +import types + +import pytest + +np = pytest.importorskip("numpy") + +from api.utils.validation_utils import RaptorConfig +from pydantic import ValidationError + + +@pytest.fixture() +def raptor_module(monkeypatch): + class TaskCanceledException(Exception): + pass + + class DummyLimiter: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + class DummyGaussianMixture: + def __init__(self, *args, **kwargs): + pass + + def fit(self, embeddings): + return self + + def bic(self, embeddings): + return 0 + + def predict_proba(self, embeddings): + return np.ones((len(embeddings), 1)) + + class DummyAgglomerativeClustering: + def __init__(self, n_clusters=None, distance_threshold=None, compute_distances=False, linkage="ward"): + self.n_clusters = n_clusters + self.distance_threshold = distance_threshold + self.compute_distances = compute_distances + self.linkage = linkage + self.distances_ = np.array([0.1, 0.2, 1.0]) + + def fit(self, embeddings): + self.labels_ = self.fit_predict(embeddings) + return self + + def fit_predict(self, embeddings): + if self.n_clusters is None: + return np.zeros(len(embeddings), dtype=int) + return np.array([idx % self.n_clusters for idx in range(len(embeddings))]) + + class DummyUMAP: + def __init__(self, *args, **kwargs): + pass + + def fit_transform(self, embeddings): + raise AssertionError("Psi tree builder must use original embeddings, not UMAP") + + sklearn_module = types.ModuleType("sklearn") + mixture_module = types.ModuleType("sklearn.mixture") + mixture_module.GaussianMixture = DummyGaussianMixture + cluster_module = types.ModuleType("sklearn.cluster") + cluster_module.AgglomerativeClustering = DummyAgglomerativeClustering + umap_module = types.ModuleType("umap") + umap_module.UMAP = DummyUMAP + task_service_module = types.ModuleType("api.db.services.task_service") + task_service_module.has_canceled = lambda task_id: False + connection_utils_module = types.ModuleType("common.connection_utils") + connection_utils_module.timeout = lambda seconds: lambda fn: fn + exceptions_module = types.ModuleType("common.exceptions") + exceptions_module.TaskCanceledException = TaskCanceledException + token_utils_module = types.ModuleType("common.token_utils") + token_utils_module.truncate = lambda text, max_len: text[:max_len] + graphrag_utils_module = types.ModuleType("rag.graphrag.utils") + graphrag_utils_module.chat_limiter = DummyLimiter() + graphrag_utils_module.get_embed_cache = lambda *args, **kwargs: None + graphrag_utils_module.get_llm_cache = lambda *args, **kwargs: None + graphrag_utils_module.set_embed_cache = lambda *args, **kwargs: None + graphrag_utils_module.set_llm_cache = lambda *args, **kwargs: None + + async def thread_pool_exec(fn, *args, **kwargs): + return fn(*args, **kwargs) + + misc_utils_module = types.ModuleType("common.misc_utils") + misc_utils_module.thread_pool_exec = thread_pool_exec + + monkeypatch.setitem(sys.modules, "sklearn", sklearn_module) + monkeypatch.setitem(sys.modules, "sklearn.mixture", mixture_module) + monkeypatch.setitem(sys.modules, "sklearn.cluster", cluster_module) + monkeypatch.setitem(sys.modules, "umap", umap_module) + monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service_module) + monkeypatch.setitem(sys.modules, "common.connection_utils", connection_utils_module) + monkeypatch.setitem(sys.modules, "common.exceptions", exceptions_module) + monkeypatch.setitem(sys.modules, "common.token_utils", token_utils_module) + monkeypatch.setitem(sys.modules, "rag.graphrag.utils", graphrag_utils_module) + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_module) + monkeypatch.delitem(sys.modules, "rag.raptor", raising=False) + module = importlib.import_module("rag.raptor") + yield module + monkeypatch.delitem(sys.modules, "rag.raptor", raising=False) + + +class FakeChatModel: + llm_name = "fake-chat" + max_length = 4096 + + def __init__(self): + self.calls = [] + + async def async_chat(self, system, history, gen_conf): + self.calls.append(history[0]["content"]) + return f"summary-{len(self.calls)}" + + +class FakeEmbeddingModel: + llm_name = "fake-embedding" + + def encode(self, texts): + embeddings = [] + for text in texts: + checksum = sum(ord(ch) for ch in text) + embeddings.append(np.array([len(text), checksum % 17 + 1], dtype=float)) + return embeddings, len(texts) + + +_DEFAULT_TREE_BUILDER = object() + + +def _make_raptor(raptor_module, max_cluster=64, tree_builder=_DEFAULT_TREE_BUILDER, **kwargs): + if tree_builder is _DEFAULT_TREE_BUILDER: + kwargs["tree_builder"] = raptor_module.PSI_TREE_BUILDER + else: + kwargs["tree_builder"] = tree_builder + return raptor_module.RecursiveAbstractiveProcessing4TreeOrganizedRetrieval( + max_cluster, + FakeChatModel(), + FakeEmbeddingModel(), + "{cluster_content}", + max_token=32, + threshold=0.1, + **kwargs, + ) + + +def _chunks(): + return [ + ("alpha first", np.array([1.0, 0.0])), + ("alpha second", np.array([0.99, 0.01])), + ("alpha third", np.array([0.98, 0.02])), + ] + + +def test_default_tree_builder_remains_original_raptor(raptor_module): + raptor = _make_raptor(raptor_module, tree_builder=None) + + assert raptor._tree_builder == raptor_module.RAPTOR_TREE_BUILDER + + +def test_unknown_tree_builder_is_rejected(raptor_module): + with pytest.raises(ValueError, match="Unsupported RAPTOR tree builder"): + _make_raptor(raptor_module, tree_builder="ahc") + + +def test_raptor_config_accepts_hidden_psi_tree_builder(): + assert RaptorConfig().tree_builder == "raptor" + assert RaptorConfig().clustering_method == "gmm" + assert RaptorConfig(clustering_method="ahc").clustering_method == "ahc" + assert RaptorConfig(tree_builder="psi").tree_builder == "psi" + + with pytest.raises(ValidationError): + RaptorConfig(tree_builder="ahc") + with pytest.raises(ValidationError): + RaptorConfig(clustering_method="psi") + + +def test_ahc_clustering_method_is_supported_in_original_tree_builder(raptor_module): + raptor = _make_raptor(raptor_module, tree_builder=raptor_module.RAPTOR_TREE_BUILDER, clustering_method="ahc") + + labels = raptor._get_clusters_ahc(np.array([[0.0, 0.0], [0.1, 0.0], [10.0, 10.0], [10.1, 10.0]])) + + assert raptor._tree_builder == raptor_module.RAPTOR_TREE_BUILDER + assert raptor._clustering_method == "ahc" + assert len(labels) == 4 + + +def test_unknown_clustering_method_is_rejected(raptor_module): + with pytest.raises(ValueError, match="Unsupported RAPTOR clustering method"): + _make_raptor(raptor_module, clustering_method="psi") + + +def test_psi_tree_builder_ranks_all_leaf_pairs_by_original_cosine_similarity(raptor_module): + raptor = _make_raptor(raptor_module) + leaves = [ + raptor_module._PsiTreeNode(index=0, embedding=np.array([1.0, 0.0])), + raptor_module._PsiTreeNode(index=1, embedding=np.array([0.0, 1.0])), + raptor_module._PsiTreeNode(index=2, embedding=np.array([0.99, 0.01])), + raptor_module._PsiTreeNode(index=3, embedding=np.array([-1.0, 0.0])), + ] + + ranked_pairs = raptor._rank_leaf_pairs(leaves) + + assert len(ranked_pairs) == 6 + assert tuple(ranked_pairs[0]) == (2, 0) + + +def test_psi_tree_builder_uses_cosine_similarity_not_vector_magnitude(raptor_module): + raptor = _make_raptor(raptor_module) + leaves = [ + raptor_module._PsiTreeNode(index=0, embedding=np.array([100.0, 0.0])), + raptor_module._PsiTreeNode(index=1, embedding=np.array([1.0, 1.0])), + raptor_module._PsiTreeNode(index=2, embedding=np.array([0.1, 0.0])), + ] + + ranked_pairs = raptor._rank_leaf_pairs(leaves) + + assert tuple(ranked_pairs[0]) == (2, 0) + + +def test_psi_tree_builder_handles_zero_vectors_in_cosine_ranking(raptor_module): + raptor = _make_raptor(raptor_module) + leaves = [ + raptor_module._PsiTreeNode(index=0, embedding=np.array([0.0, 0.0])), + raptor_module._PsiTreeNode(index=1, embedding=np.array([1.0, 0.0])), + raptor_module._PsiTreeNode(index=2, embedding=np.array([0.9, 0.1])), + ] + + ranked_pairs = raptor._rank_leaf_pairs(leaves) + + assert tuple(ranked_pairs[0]) == (2, 1) + + +def test_psi_tree_builder_collapses_leaf_into_ranked_pair_parent(raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=64) + + root, leaves = raptor._build_psi_structure(_chunks()) + + assert len(root.children) == 3 + assert {child.index for child in root.children} == {0, 1, 2} + assert all(leaf.parent is root for leaf in leaves) + + +def test_psi_tree_builder_collapses_leaf_at_matching_rank(monkeypatch, raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=64) + chunks = [ + ("node 0", np.array([1.0, 0.0])), + ("node 1", np.array([0.9, 0.1])), + ("node 2", np.array([-1.0, 0.0])), + ("node 3", np.array([-0.9, -0.1])), + ("node 4", np.array([0.8, 0.2])), + ] + monkeypatch.setattr( + raptor, + "_rank_leaf_pairs", + lambda _leaves: np.array([[0, 1], [2, 3], [0, 2], [4, 0]]), + ) + + root, leaves = raptor._build_psi_structure(chunks) + + assert leaves[4].parent is leaves[0].parent + assert leaves[4].parent is not root + assert len(root.children) == 2 + + +def test_psi_union_find_clamps_out_of_bounds_parent_rank(caplog, raptor_module): + union_find = raptor_module._PsiUnionFind(2) + union_find._node_ids[1] = [1] + union_find._rank[0] = 2 + + with caplog.at_level("WARNING"): + union_find._build(0, 1, insert_point=1) + + assert union_find.tree[0] == 1 + assert "rank index" in caplog.text + + +def test_psi_tree_builder_rebalances_nodes_over_max_children(raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=2) + + root, _ = raptor._build_psi_structure(_chunks()) + + assert all(len(node.children) <= 2 for node in raptor._iter_nodes(root)) + assert len(root.children) == 2 + assert any(child.children for child in root.children) + + +def test_psi_tree_builder_uses_bucketed_structure_for_large_inputs(monkeypatch, raptor_module): + chunks = [(f"node {idx}", np.array([float(idx), float(idx % 3 + 1)])) for idx in range(8)] + raptor = _make_raptor( + raptor_module, + max_cluster=3, + psi_exact_max_leaves=3, + psi_bucket_size=2, + ) + ranked_sizes = [] + original_rank = raptor._rank_leaf_pairs + + def track_rank(nodes): + ranked_sizes.append(len(nodes)) + return original_rank(nodes) + + monkeypatch.setattr(raptor, "_rank_leaf_pairs", track_rank) + + root, leaves = raptor._build_psi_structure(chunks) + + assert len(leaves) == len(chunks) + assert all(leaf.parent is not None for leaf in leaves) + assert all(len(node.children) <= 3 for node in raptor._iter_nodes(root)) + assert max(ranked_sizes) <= 3 + + +@pytest.mark.asyncio +async def test_psi_tree_builder_materializes_rebalanced_summary_layers_without_umap(monkeypatch, raptor_module): + def fail_umap(*args, **kwargs): + raise AssertionError("Psi tree builder must use original embeddings, not UMAP") + + monkeypatch.setattr(raptor_module.umap, "UMAP", fail_umap) + raptor = _make_raptor(raptor_module, max_cluster=2) + + chunks, layers = await raptor(_chunks(), random_state=0) + + assert len(chunks) == 5 + assert layers == [(0, 3), (3, 4), (4, 5)] + assert [chunk[0] for chunk in chunks[3:]] == ["summary-1", "summary-2"] + + +@pytest.mark.asyncio +async def test_psi_tree_builder_skips_failed_node_summary(monkeypatch, raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=2) + + async def fail_summary(*args, **kwargs): + return None + + monkeypatch.setattr(raptor, "_summarize_texts", fail_summary) + + chunks, layers = await raptor(_chunks(), random_state=0) + + assert len(chunks) == 3 + assert [chunk[0] for chunk in chunks] == [chunk[0] for chunk in _chunks()] + assert layers == [(0, 3)] + + +@pytest.mark.asyncio +async def test_original_raptor_stops_when_transient_summary_fails(monkeypatch, raptor_module): + raptor = _make_raptor(raptor_module, tree_builder=raptor_module.RAPTOR_TREE_BUILDER) + + async def fail_summary(*args, **kwargs): + return None + + monkeypatch.setattr(raptor, "_summarize_texts", fail_summary) + + input_chunks = _chunks()[:2] + chunks, layers = await raptor(input_chunks, random_state=0) + + assert len(chunks) == 2 + assert [chunk[0] for chunk in chunks] == [chunk[0] for chunk in input_chunks] + assert layers == [(0, 2)] diff --git a/test/unit_test/rag/utils/test_raptor_utils.py b/test/unit_test/rag/utils/test_raptor_utils.py index 5138ccda7..95abe2109 100644 --- a/test/unit_test/rag/utils/test_raptor_utils.py +++ b/test/unit_test/rag/utils/test_raptor_utils.py @@ -18,15 +18,22 @@ Unit tests for Raptor utility functions. """ +import logging + import pytest from rag.utils.raptor_utils import ( + CSV_EXTENSIONS, + EXCEL_EXTENSIONS, + STRUCTURED_EXTENSIONS, + collect_raptor_chunk_ids, + collect_raptor_methods, + get_raptor_clustering_method, + get_raptor_tree_builder, + get_skip_reason, is_structured_file_type, is_tabular_pdf, + make_raptor_summary_chunk_id, should_skip_raptor, - get_skip_reason, - EXCEL_EXTENSIONS, - CSV_EXTENSIONS, - STRUCTURED_EXTENSIONS ) @@ -283,5 +290,117 @@ class TestIntegrationScenarios: assert should_skip_raptor(file_type, raptor_config=raptor_config) is False +class TestRaptorTreeBuilderConfig: + """Test RAPTOR tree builder config resolution""" + + def test_defaults_to_original_raptor_builder(self): + assert get_raptor_tree_builder({}) == "raptor" + assert get_raptor_tree_builder(None) == "raptor" + + def test_reads_top_level_tree_builder(self): + assert get_raptor_tree_builder({"tree_builder": "psi"}) == "psi" + + def test_reads_legacy_ext_tree_builder(self): + assert get_raptor_tree_builder({"ext": {"tree_builder": "psi"}}) == "psi" + + def test_ext_tree_builder_overrides_stale_top_level_value(self): + assert get_raptor_tree_builder({"tree_builder": "psi", "ext": {"tree_builder": "raptor"}}) == "raptor" + + def test_rejects_unknown_tree_builder(self): + with pytest.raises(ValueError, match="Unsupported RAPTOR tree builder"): + get_raptor_tree_builder({"tree_builder": "ahc"}) + + +class TestRaptorClusteringMethodConfig: + """Test RAPTOR clustering method config resolution""" + + def test_defaults_to_gmm(self): + assert get_raptor_clustering_method({}) == "gmm" + assert get_raptor_clustering_method(None) == "gmm" + + def test_reads_top_level_clustering_method(self): + assert get_raptor_clustering_method({"clustering_method": "gmm"}) == "gmm" + assert get_raptor_clustering_method({"clustering_method": "ahc"}) == "ahc" + + def test_reads_legacy_ext_clustering_method(self): + assert get_raptor_clustering_method({"ext": {"clustering_method": "ahc"}}) == "ahc" + + def test_ext_clustering_method_overrides_stale_top_level_value(self): + assert get_raptor_clustering_method({"clustering_method": "gmm", "ext": {"clustering_method": "ahc"}}) == "ahc" + + def test_rejects_unknown_clustering_method(self): + with pytest.raises(ValueError, match="Unsupported RAPTOR clustering method"): + get_raptor_clustering_method({"clustering_method": "unknown"}) + + +class TestRaptorMethodCollection: + """Test RAPTOR summary method extraction from doc-store fields""" + + def test_legacy_summary_without_method_is_original_raptor(self): + field_map = {"chunk_1": {"raptor_kwd": "raptor"}} + + assert collect_raptor_methods(field_map) == {"raptor"} + assert collect_raptor_chunk_ids(field_map) == {"chunk_1"} + + def test_extra_method_is_preserved(self): + field_map = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}} + + assert collect_raptor_methods(field_map) == {"psi"} + assert collect_raptor_chunk_ids(field_map) == {"chunk_1"} + + def test_extra_field_supports_oceanbase_legacy_rows(self): + field_map = { + "chunk_1": { + "extra": { + "raptor_kwd": "raptor", + "raptor_method": "psi", + } + }, + "chunk_2": { + "extra": "{\"raptor_kwd\": \"raptor\"}", + }, + "chunk_3": { + "extra": {"raptor_kwd": ""}, + }, + } + + assert collect_raptor_methods(field_map) == {"psi", "raptor"} + assert collect_raptor_chunk_ids(field_map) == {"chunk_1", "chunk_2"} + + def test_non_raptor_rows_are_ignored(self): + field_map = { + "chunk_1": {"raptor_kwd": ""}, + "chunk_2": {"extra": {"raptor_kwd": "graph"}}, + "chunk_3": {}, + } + + assert collect_raptor_methods(field_map) == set() + assert collect_raptor_chunk_ids(field_map) == set() + + def test_malformed_extra_payload_is_logged_and_ignored(self, caplog): + field_map = {"chunk_1": {"extra": "{bad json"}} + + with caplog.at_level(logging.WARNING): + assert collect_raptor_methods(field_map) == set() + assert collect_raptor_chunk_ids(field_map) == set() + + assert "Ignoring malformed RAPTOR extra payload" in caplog.text + + def test_chunk_id_collection_can_preserve_current_method(self): + field_map = { + "legacy": {"raptor_kwd": "raptor"}, + "old": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}}, + "current": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}, + } + + assert collect_raptor_chunk_ids(field_map, exclude_methods={"psi"}) == {"legacy", "old"} + assert collect_raptor_chunk_ids(field_map, exclude_methods={"raptor"}) == {"current"} + + def test_summary_chunk_ids_include_real_document_id(self): + content = "same generated summary" + + assert make_raptor_summary_chunk_id(content, "doc-a") != make_raptor_summary_chunk_id(content, "doc-b") + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/web/src/components/chunk-method-dialog/index.tsx b/web/src/components/chunk-method-dialog/index.tsx index aa6c23983..21650d7e6 100644 --- a/web/src/components/chunk-method-dialog/index.tsx +++ b/web/src/components/chunk-method-dialog/index.tsx @@ -17,7 +17,7 @@ import { DocumentParserType, ParseType } from '@/constants/knowledge'; import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-request'; import { IModalProps } from '@/interfaces/common'; import { IParserConfig } from '@/interfaces/database/document'; -import { IChangeParserConfigRequestBody } from '@/interfaces/request/document'; +import { IChangeParserRequestBody } from '@/interfaces/request/document'; import { MetadataType } from '@/pages/dataset/components/metedata/constant'; import { AutoMetadata, @@ -28,7 +28,6 @@ import { } from '@/pages/dataset/dataset-setting/configuration/common-item'; import { zodResolver } from '@hookform/resolvers/zod'; import omit from 'lodash/omit'; -import {} from 'module'; import { useEffect, useMemo } from 'react'; import { useForm, useWatch } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; @@ -56,10 +55,7 @@ import { const FormId = 'ChunkMethodDialogForm'; -interface IProps extends IModalProps<{ - parserId: string; - parserConfig: IChangeParserConfigRequestBody; -}> { +interface IProps extends IModalProps { loading: boolean; parserId: string; pipelineId?: string; @@ -126,16 +122,19 @@ export function ChunkMethodDialog({ mineru_formula_enable: z.boolean().optional(), mineru_table_enable: z.boolean().optional(), mineru_lang: z.string().optional(), - // raptor: z - // .object({ - // use_raptor: z.boolean().optional(), - // prompt: z.string().optional().optional(), - // max_token: z.coerce.number().optional(), - // threshold: z.coerce.number().optional(), - // max_cluster: z.coerce.number().optional(), - // random_seed: z.coerce.number().optional(), - // }) - // .optional(), + raptor: z + .object({ + use_raptor: z.boolean().optional(), + prompt: z.string().optional(), + max_token: z.coerce.number().optional(), + threshold: z.coerce.number().optional(), + max_cluster: z.coerce.number().optional(), + random_seed: z.coerce.number().optional(), + scope: z.string().optional(), + clustering_method: z.enum(['gmm', 'ahc']).optional(), + tree_builder: z.enum(['raptor', 'psi']).optional(), + }) + .optional(), // graphrag: z.object({ // use_graphrag: z.boolean().optional(), // }), diff --git a/web/src/components/chunk-method-dialog/use-default-parser-values.ts b/web/src/components/chunk-method-dialog/use-default-parser-values.ts index 47af38771..84f7c9e3c 100644 --- a/web/src/components/chunk-method-dialog/use-default-parser-values.ts +++ b/web/src/components/chunk-method-dialog/use-default-parser-values.ts @@ -23,14 +23,17 @@ export function useDefaultParserValues() { mineru_formula_enable: true, mineru_table_enable: true, mineru_lang: 'English', - // raptor: { - // use_raptor: false, - // prompt: t('knowledgeConfiguration.promptText'), - // max_token: 256, - // threshold: 0.1, - // max_cluster: 64, - // random_seed: 0, - // }, + raptor: { + use_raptor: false, + prompt: t('knowledgeConfiguration.promptText'), + max_token: 256, + threshold: 0.1, + max_cluster: 64, + random_seed: 0, + scope: 'file', + clustering_method: 'gmm', + tree_builder: 'raptor', + }, // graphrag: { // use_graphrag: false, // }, diff --git a/web/src/components/parse-configuration/raptor-form-fields.tsx b/web/src/components/parse-configuration/raptor-form-fields.tsx index 531e6165d..e66ef5453 100644 --- a/web/src/components/parse-configuration/raptor-form-fields.tsx +++ b/web/src/components/parse-configuration/raptor-form-fields.tsx @@ -8,7 +8,7 @@ import { } from '@/pages/dataset/dataset/generate-button/generate'; import random from 'lodash/random'; import { Shuffle } from 'lucide-react'; -import { useCallback } from 'react'; +import { useCallback, useEffect, useMemo } from 'react'; import { useFormContext, useWatch } from 'react-hook-form'; import { SliderInputFormField } from '../slider-input-form-field'; import { @@ -50,10 +50,10 @@ export const showTagItems = (parserId: DocumentParserType) => { const UseRaptorField = 'parser_config.raptor.use_raptor'; const RandomSeedField = 'parser_config.raptor.random_seed'; -const MaxTokenField = 'parser_config.raptor.max_token'; -const ThresholdField = 'parser_config.raptor.threshold'; -const MaxCluster = 'parser_config.raptor.max_cluster'; -const Prompt = 'parser_config.raptor.prompt'; +const ClusteringMethodField = 'parser_config.raptor.clustering_method'; +const ClusteringMethodExtField = 'parser_config.raptor.ext.clustering_method'; +const TreeBuilderField = 'parser_config.raptor.tree_builder'; +const MaxClusterMax = 1024; // The three types "table", "resume" and "one" do not display this configuration. @@ -67,17 +67,48 @@ const RaptorFormFields = ({ const form = useFormContext(); const { t } = useTranslate('knowledgeConfiguration'); const useRaptor = useWatch({ name: UseRaptorField }); + const clusteringMethod = useWatch({ name: ClusteringMethodField }); + const extClusteringMethod = useWatch({ name: ClusteringMethodExtField }); + const selectedClusteringMethod = useMemo( + () => + (clusteringMethod ?? + extClusteringMethod ?? + form.getValues(ClusteringMethodField) ?? + form.getValues(ClusteringMethodExtField) ?? + 'gmm') as 'gmm' | 'ahc', + [clusteringMethod, extClusteringMethod, form], + ); const handleGenerate = useCallback(() => { form.setValue(RandomSeedField, random(10000)); }, [form]); + const handleClusteringMethodChange = useCallback( + (method: 'gmm' | 'ahc') => { + form.setValue(ClusteringMethodField, method, { + shouldDirty: true, + shouldValidate: true, + }); + form.setValue(TreeBuilderField, 'raptor', { + shouldDirty: true, + shouldValidate: true, + }); + }, + [form], + ); + + useEffect(() => { + if (!clusteringMethod && !extClusteringMethod) { + handleClusteringMethodChange('gmm'); + } + }, [clusteringMethod, extClusteringMethod, handleClusteringMethodChange]); + return ( <> { + render={() => { return ( + { + return ( + +
+ + {t('clusteringMethod')} + +
+ + + handleClusteringMethodChange(value as 'gmm' | 'ahc') + } + > +
+ + {t('clusteringMethodGmm')} + + + {t('clusteringMethodAhc')} + +
+
+
+
+
+
+
+ +
+
+ ); + }} + /> void; + testId?: string; children?: React.ReactNode; } & Omit< React.InputHTMLAttributes, @@ -25,6 +26,7 @@ function Radio({ checked, disabled, onChange, + testId, children, ...props }: RadioProps) { @@ -65,6 +67,7 @@ function Radio({ onChange={handleChange} disabled={mergedDisabled} className={cn('peer absolute size-[1px] opacity-0', className)} + data-testid={testId} {...props} name={groupContext?.name} /> @@ -151,9 +154,11 @@ const Group = React.forwardRef( )} > {React.Children.map(children, (child) => { - if (!React.isValidElement(child)) return child; + if (!React.isValidElement(child)) { + return child; + } return React.cloneElement(child, { - disabled: disabled || child.props?.disabled, + disabled: disabled || child.props.disabled, }); })} diff --git a/web/src/hooks/parser-config-utils.ts b/web/src/hooks/parser-config-utils.ts index c02a42a01..e6e7cccb4 100644 --- a/web/src/hooks/parser-config-utils.ts +++ b/web/src/hooks/parser-config-utils.ts @@ -21,10 +21,17 @@ export const extractRaptorConfigExt = ( max_cluster, random_seed, scope, + clustering_method, + tree_builder, auto_disable_for_structured_data, ext, ...raptorExt } = raptorConfig; + const extClusteringMethod = ext?.clustering_method; + const normalizedClusteringMethod = + clustering_method ?? extClusteringMethod ?? 'gmm'; + const normalizedTreeBuilder = tree_builder ?? ext?.tree_builder ?? 'raptor'; + return { use_raptor, prompt, @@ -34,7 +41,12 @@ export const extractRaptorConfigExt = ( random_seed, scope, auto_disable_for_structured_data, - ext: { ...ext, ...raptorExt }, + ext: { + ...ext, + ...raptorExt, + clustering_method: normalizedClusteringMethod, + tree_builder: normalizedTreeBuilder, + }, }; }; diff --git a/web/src/hooks/tests/parser-config-utils.test.ts b/web/src/hooks/tests/parser-config-utils.test.ts new file mode 100644 index 000000000..6bbfcf0cb --- /dev/null +++ b/web/src/hooks/tests/parser-config-utils.test.ts @@ -0,0 +1,45 @@ +import { extractParserConfigExt } from '../parser-config-utils'; + +describe('extractParserConfigExt', () => { + it('serializes RAPTOR clustering fields through ext for API compatibility', () => { + const result = extractParserConfigExt({ + raptor: { + use_raptor: true, + prompt: 'Summarize {cluster_content}', + max_token: 256, + threshold: 0.1, + max_cluster: 317, + random_seed: 0, + scope: 'file', + clustering_method: 'ahc', + tree_builder: 'raptor', + }, + }); + + expect(result?.raptor).not.toHaveProperty('clustering_method'); + expect(result?.raptor).not.toHaveProperty('tree_builder'); + expect(result?.raptor?.ext).toMatchObject({ + clustering_method: 'ahc', + tree_builder: 'raptor', + }); + }); + + it('preserves existing RAPTOR ext clustering values when the top-level field is absent', () => { + const result = extractParserConfigExt({ + raptor: { + max_cluster: 512, + ext: { + clustering_method: 'ahc', + tree_builder: 'raptor', + psi_bucket_size: 1024, + }, + }, + }); + + expect(result?.raptor?.ext).toMatchObject({ + clustering_method: 'ahc', + tree_builder: 'raptor', + psi_bucket_size: 1024, + }); + }); +}); diff --git a/web/src/interfaces/database/dataset.ts b/web/src/interfaces/database/dataset.ts index ebded8b08..b0978e0a5 100644 --- a/web/src/interfaces/database/dataset.ts +++ b/web/src/interfaces/database/dataset.ts @@ -73,11 +73,13 @@ interface Parserconfig { } interface Raptor { + clustering_method?: 'gmm' | 'ahc'; max_cluster: number; max_token: number; prompt: string; random_seed: number; threshold: number; + tree_builder?: 'raptor' | 'psi'; use_raptor: boolean; } diff --git a/web/src/interfaces/request/document.ts b/web/src/interfaces/request/document.ts index 4f16b155d..05693ca35 100644 --- a/web/src/interfaces/request/document.ts +++ b/web/src/interfaces/request/document.ts @@ -11,6 +11,17 @@ export interface IChangeParserConfigRequestBody { image_table_context_window?: number; image_context_size?: number; table_context_size?: number; + raptor?: { + use_raptor?: boolean; + prompt?: string; + max_token?: number; + threshold?: number; + max_cluster?: number; + random_seed?: number; + scope?: string; + clustering_method?: 'gmm' | 'ahc'; + tree_builder?: 'raptor' | 'psi'; + }; // Metadata fields metadata?: Array<{ key?: string; @@ -27,8 +38,8 @@ export interface IChangeParserConfigRequestBody { export interface IChangeParserRequestBody { parser_id: string; - pipeline_id: string; - doc_id: string; + pipeline_id?: string; + doc_id?: string; parser_config: IChangeParserConfigRequestBody; } diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 5c729d773..af24b9d72 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -861,6 +861,11 @@ The above is the content you need to summarize.`, thresholdTip: 'In RAPTOR, chunks are clustered by their semantic similarity. The Threshold parameter sets the minimum similarity required for chunks to be grouped together. A higher Threshold means fewer chunks in each cluster, while a lower one means more.', thresholdMessage: 'Threshold is required', + clusteringMethod: 'Clustering method', + clusteringMethodTip: + 'Select the RAPTOR clustering method. AHC can use a larger max cluster value, but may require more memory on large inputs.', + clusteringMethodGmm: 'GMM', + clusteringMethodAhc: 'AHC', maxCluster: 'Max cluster', maxClusterTip: 'The maximum number of clusters to create.', maxClusterMessage: 'Max cluster is required', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 9de73326f..4e9b8f9ae 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -772,6 +772,11 @@ export default { maxTokenMessage: '最大token数是必填项', threshold: '阈值', thresholdMessage: '阈值是必填项', + clusteringMethod: '聚类方法', + clusteringMethodTip: + '选择 RAPTOR 聚类方法。AHC 可以使用更大的最大聚类数,但在大规模输入时可能占用更多内存。', + clusteringMethodGmm: 'GMM', + clusteringMethodAhc: 'AHC', maxCluster: '最大聚类数', maxClusterMessage: '最大聚类数是必填项', randomSeed: '随机种子', diff --git a/web/src/pages/dataset/dataset-setting/form-schema.ts b/web/src/pages/dataset/dataset-setting/form-schema.ts index 7aef591f0..03424921c 100644 --- a/web/src/pages/dataset/dataset-setting/form-schema.ts +++ b/web/src/pages/dataset/dataset-setting/form-schema.ts @@ -42,11 +42,14 @@ export const formSchema = z .object({ use_raptor: z.boolean().optional(), prompt: z.string().optional(), - max_token: z.number().optional(), - threshold: z.number().optional(), - max_cluster: z.number().optional(), - random_seed: z.number().optional(), + max_token: z.coerce.number().optional(), + threshold: z.coerce.number().optional(), + max_cluster: z.coerce.number().optional(), + random_seed: z.coerce.number().optional(), scope: z.string().optional(), + clustering_method: z.enum(['gmm', 'ahc']).optional(), + tree_builder: z.enum(['raptor', 'psi']).optional(), + ext: z.record(z.string(), z.any()).optional(), }) .refine( (data) => { diff --git a/web/src/pages/dataset/dataset-setting/index.tsx b/web/src/pages/dataset/dataset-setting/index.tsx index 36a0c3f89..930ec8f51 100644 --- a/web/src/pages/dataset/dataset-setting/index.tsx +++ b/web/src/pages/dataset/dataset-setting/index.tsx @@ -95,6 +95,8 @@ export default function DatasetSettings() { max_cluster: 64, random_seed: 0, scope: 'file', + clustering_method: 'gmm', + tree_builder: 'raptor', prompt: t('knowledgeConfiguration.promptText'), }, graphrag: { diff --git a/web/src/pages/dataset/dataset/use-change-document-parser.ts b/web/src/pages/dataset/dataset/use-change-document-parser.ts index cfa358cc1..9806e1708 100644 --- a/web/src/pages/dataset/dataset/use-change-document-parser.ts +++ b/web/src/pages/dataset/dataset/use-change-document-parser.ts @@ -19,7 +19,7 @@ export const useChangeDocumentParser = () => { if (record?.id && record?.dataset_id) { const ret = await setDocumentParser({ parserId: parserConfigInfo.parser_id, - pipelineId: parserConfigInfo.pipeline_id, + pipelineId: parserConfigInfo.pipeline_id || '', documentId: record?.id, datasetId: record?.dataset_id, parserConfig: parserConfigInfo.parser_config,