feat(raptor): add Psi tree builder with original-space ranking and safe migration (#14679)

### What problem does this PR solve?

Closes #14674.

This PR improves RAPTOR configuration and tree construction while
preserving the existing RAPTOR behavior as the default.

RAPTOR currently builds summary layers with the original UMAP + GMM
clustering path. This PR keeps that default path, and adds:

- A hidden backend tree-builder option:
  - `tree_builder="raptor"`: default, existing RAPTOR behavior.
- `tree_builder="psi"`: rank-aware Psi-style tree builder using original
embedding-space cosine ranking.
- A user-facing clustering method option for the default RAPTOR builder:
  - `clustering_method="gmm"`: existing default.
- `clustering_method="ahc"`: agglomerative hierarchical clustering path.
- A RAPTOR UI setting for `Clustering method` and `Max cluster`.

### What changed

#### Backend

- Added `tree_builder` support for RAPTOR/Psi.
- Added `clustering_method` support for GMM/AHC.
- Kept existing RAPTOR + GMM as the default.
- Added Psi tree building from original-space cosine similarity.
- Added bucketed Psi building controls for large inputs:
  - `raptor.ext.psi_exact_max_leaves`
  - `raptor.ext.psi_bucket_size`
- Added method-aware RAPTOR summary metadata using existing
`extra.raptor_method`.
- Avoided adding a dedicated DB schema field for experimental method
tracking.
- Added cleanup/migration logic to avoid mixing stale RAPTOR summary
trees.
- Added defensive checks for Psi tree construction and summary failures.

#### Frontend/UI

- Added `Clustering method` in RAPTOR settings with `GMM` and `AHC`.
- Added/kept `Max cluster` in RAPTOR settings.
- Enlarged max cluster UI limit to `1024`, matching backend validation.
- Kept AHC editable even when a RAPTOR task has already finished.
- Fixed the UI save payload so `clustering_method` and `tree_builder`
are serialized through `parser_config.raptor.ext`, avoiding backend
validation errors for extra top-level RAPTOR fields.

Example saved RAPTOR config:

```json
{
  "raptor": {
    "max_cluster": 317,
    "ext": {
      "clustering_method": "ahc",
      "tree_builder": "raptor"
    }
  }
}

Co-authored-by: CaptainTimon <CaptainTimon@users.noreply.github.com>
This commit is contained in:
CaptainTimon
2026-05-11 15:42:31 -10:00
committed by GitHub
parent 415169d497
commit 2717ee283f
21 changed files with 1722 additions and 140 deletions

View File

@ -327,10 +327,14 @@ def validate_uuid1_hex(v: Any) -> str:
class Base(BaseModel):
"""Strict base model that rejects unknown request fields."""
model_config = ConfigDict(extra="forbid", strict=True)
class RaptorConfig(Base):
"""Dataset parser configuration for RAPTOR summary generation."""
use_raptor: Annotated[bool, Field(default=False)]
prompt: Annotated[
str,
@ -344,11 +348,15 @@ class RaptorConfig(Base):
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
random_seed: Annotated[int, Field(default=0, ge=0)]
scope: Annotated[Literal["file", "dataset"], Field(default="file")]
clustering_method: Annotated[Literal["gmm", "ahc"], Field(default="gmm")]
tree_builder: Annotated[Literal["raptor", "psi"], Field(default="raptor")]
auto_disable_for_structured_data: Annotated[bool, Field(default=True)]
ext: Annotated[dict, Field(default={})]
class GraphragConfig(Base):
"""Dataset parser configuration for GraphRAG generation."""
use_graphrag: Annotated[bool, Field(default=False)]
entity_types: Annotated[list[str], Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])]
method: Annotated[Literal["light", "general", "ner"], Field(default="light")]
@ -357,6 +365,8 @@ class GraphragConfig(Base):
class ParentChildConfig(Base):
"""Dataset parser configuration for parent-child chunking."""
use_parent_child: Annotated[bool, Field(default=False)]
children_delimiter: Annotated[str, Field(default=r"\n", min_length=1)]
@ -381,6 +391,8 @@ TableColumnRole = Literal["indexing", "metadata", "both"]
class ParserConfig(Base):
"""Complete parser configuration accepted by dataset APIs."""
auto_keywords: Annotated[int, Field(default=0, ge=0, le=32)]
auto_questions: Annotated[int, Field(default=0, ge=0, le=10)]
chunk_token_num: Annotated[int, Field(default=512, ge=1, le=2048)]
@ -439,6 +451,7 @@ class UpdateDocumentReq(Base):
@field_validator("chunk_method", mode="after")
@classmethod
def validate_document_chunk_method(cls, chunk_method: str | None):
"""Validate an optional document parser method."""
if chunk_method:
# Validate chunk method if present
valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "knowledge_graph", "email", "tag"}
@ -450,6 +463,7 @@ class UpdateDocumentReq(Base):
@field_validator("enabled", mode="after")
@classmethod
def validate_document_enabled(cls, enabled: str | None):
"""Validate the optional enabled flag."""
if enabled:
converted = int(enabled)
if converted < 0 or converted > 1:
@ -460,6 +474,7 @@ class UpdateDocumentReq(Base):
@field_validator("meta_fields", mode="after")
@classmethod
def validate_document_meta_fields(cls, meta_fields: dict | None):
"""Validate user-provided document metadata values."""
if meta_fields is None:
return None
@ -475,6 +490,8 @@ class UpdateDocumentReq(Base):
class CreateDatasetReq(Base):
"""Request model for creating a dataset."""
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
avatar: Annotated[str | None, Field(default=None, max_length=65535)]
description: Annotated[str | None, Field(default=None, max_length=65535)]
@ -490,6 +507,7 @@ class CreateDatasetReq(Base):
@field_validator("pipeline_id", mode="before")
@classmethod
def handle_pipeline_id(cls, v: str | None, info: ValidationInfo):
"""Drop pipeline_id when parse_type selects direct parser mode."""
if v is None:
return v
if info.data.get("parse_type", 0) == 1:
@ -743,6 +761,8 @@ class CreateDatasetReq(Base):
class UpdateDatasetReq(CreateDatasetReq):
"""Request model for updating a dataset."""
dataset_id: Annotated[str, Field(...)]
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
pagerank: Annotated[int, Field(default=0, ge=0, le=100)]
@ -752,10 +772,13 @@ class UpdateDatasetReq(CreateDatasetReq):
@field_validator("dataset_id", mode="before")
@classmethod
def validate_dataset_id(cls, v: Any) -> str:
"""Validate and normalize the dataset id."""
return validate_uuid1_hex(v)
class DeleteReq(Base):
"""Base request model for batch delete APIs."""
ids: Annotated[list[str] | None, Field(default=None)]
delete_all: Annotated[bool, Field(default=False)]
@ -833,10 +856,15 @@ class DeleteReq(Base):
return ids_list
class DeleteDatasetReq(DeleteReq): ...
class DeleteDatasetReq(DeleteReq):
"""Request model for deleting datasets."""
...
class DeleteDocumentReq(DeleteReq):
"""Request model for deleting documents."""
@field_validator("ids", mode="after")
@classmethod
def validate_ids(cls, v_list: list[str] | None) -> list[str] | None:
@ -862,6 +890,8 @@ class DeleteDocumentReq(DeleteReq):
class SearchDatasetReq(BaseModel):
"""Request model for searching one dataset."""
model_config = ConfigDict(extra="ignore")
question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)]
@ -881,6 +911,8 @@ class SearchDatasetReq(BaseModel):
class SearchDatasetsReq(BaseModel):
"""Request model for searching multiple datasets."""
model_config = ConfigDict(extra="ignore")
dataset_ids: Annotated[list[str], Field(..., min_length=1)]
@ -901,6 +933,8 @@ class SearchDatasetsReq(BaseModel):
class BaseListReq(BaseModel):
"""Shared pagination and sorting fields for list APIs."""
model_config = ConfigDict(extra="forbid")
id: Annotated[str | None, Field(default=None)]
@ -913,10 +947,13 @@ class BaseListReq(BaseModel):
@field_validator("id", mode="before")
@classmethod
def validate_id(cls, v: Any) -> str:
"""Validate and normalize an optional list filter id."""
return validate_uuid1_hex(v)
class ListDatasetReq(BaseListReq):
"""Request model for listing datasets."""
include_parsing_status: Annotated[bool, Field(default=False)]
ext: Annotated[dict, Field(default={})]
@ -925,22 +962,29 @@ class ListDatasetReq(BaseListReq):
class CreateFolderReq(Base):
"""Request model for creating a folder."""
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(...)]
parent_id: Annotated[str | None, Field(default=None)]
type: Annotated[str | None, Field(default=None)]
class DeleteFileReq(Base):
"""Request model for deleting files."""
ids: Annotated[list[str], Field(min_length=1)]
class MoveFileReq(Base):
"""Request model for moving or renaming files."""
src_file_ids: Annotated[list[str], Field(min_length=1)]
dest_file_id: Annotated[str | None, Field(default=None)]
new_name: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(default=None)]
@model_validator(mode="after")
def check_operation(self):
"""Require either a destination folder or a new file name."""
if not self.dest_file_id and not self.new_name:
raise ValueError("At least one of dest_file_id or new_name must be provided")
if self.new_name and len(self.src_file_ids) > 1:
@ -949,6 +993,8 @@ class MoveFileReq(Base):
class ListFileReq(BaseModel):
"""Request model for listing files."""
model_config = ConfigDict(extra="forbid")
parent_id: Annotated[str | None, Field(default=None)]

View File

@ -14,11 +14,13 @@
# limitations under the License.
#
import asyncio
from dataclasses import dataclass, field
import logging
import re
import numpy as np
import umap
from sklearn.cluster import AgglomerativeClustering
from sklearn.mixture import GaussianMixture
from api.db.services.task_service import has_canceled
@ -33,9 +35,127 @@ from rag.graphrag.utils import (
set_llm_cache,
)
from common.misc_utils import thread_pool_exec
from rag.utils.raptor_utils import (
AHC_CLUSTERING_METHOD,
GMM_CLUSTERING_METHOD,
PSI_TREE_BUILDER,
RAPTOR_TREE_BUILDER,
SUPPORTED_CLUSTERING_METHODS,
SUPPORTED_TREE_BUILDERS,
)
@dataclass
class _PsiTreeNode:
"""Node used to represent the in-memory Psi merge tree."""
index: int
text: str = ""
embedding: np.ndarray | None = None
children: list["_PsiTreeNode"] = field(default_factory=list)
parent: "_PsiTreeNode | None" = None
class _PsiUnionFind:
"""Build parent links for the Psi merge tree from ranked leaf pairs."""
def __init__(self, n: int):
"""Initialize the union-find state for n leaf nodes."""
self._rank = [0 for _ in range(n)]
self._parent_chains = [[] for _ in range(n)]
self._node_ids = [[i] for i in range(n)]
self._tree = [-1 for _ in range(max(1, 2 * n - 1))]
self._next_id = n
@staticmethod
def _ordered_extend(target: list[int], values: list[int]):
"""Append unseen values while preserving their original order."""
for value in values:
if value not in target:
target.append(value)
def _find(self, i: int) -> list[int]:
"""Return the parent chain for a leaf, extending it lazily."""
chain = self._parent_chains[i]
if not chain or (len(chain) == 1 and chain[0] == i):
return [i]
if chain[0] == i:
self._ordered_extend(chain, self._find(chain[1]))
else:
self._ordered_extend(chain, self._find(chain[0]))
return chain
def _rank_bisect_right(self, chain: list[int], rank: int) -> int:
"""Return the first chain index whose rank is greater than rank."""
idx = 0
while idx < len(chain) and self._rank[chain[idx]] <= rank:
idx += 1
return idx
def _build(self, i: int, j: int, insert_point: int | None = None):
"""Record a merge edge in the compact parent array."""
if insert_point is not None:
parent_ids = self._node_ids[insert_point]
parent_rank_idx = self._rank[i] + 1
if parent_rank_idx >= len(parent_ids):
logging.warning(
"RAPTOR Psi union fallback: rank index %d is out of bounds for node %d with %d parent ids",
parent_rank_idx,
insert_point,
len(parent_ids),
)
parent_rank_idx = len(parent_ids) - 1
self._tree[self._node_ids[i][-1]] = parent_ids[parent_rank_idx]
return
self._tree[self._node_ids[i][-1]] = self._next_id
self._tree[self._node_ids[j][-1]] = self._next_id
self._node_ids[i].append(self._next_id)
self._next_id += 1
def union(self, i: int, j: int) -> bool:
"""Merge two ranked leaves and return whether a new edge was added."""
root_i = self._find(i)[-1]
root_j = self._find(j)[-1]
if root_i == root_j:
return False
if self._rank[root_i] < self._rank[root_j]:
if not self._parent_chains[root_j]:
self._parent_chains[root_j].append(root_j)
chain = self._parent_chains[j]
higher_rank_idx = self._rank_bisect_right(chain, self._rank[root_i])
if higher_rank_idx >= len(chain):
higher_rank_idx = len(chain) - 1
insert_point = chain[higher_rank_idx]
self._ordered_extend(self._parent_chains[root_i], chain[higher_rank_idx:])
self._build(root_i, root_j, insert_point=insert_point)
elif self._rank[root_i] > self._rank[root_j]:
if not self._parent_chains[root_i]:
self._parent_chains[root_i].append(root_i)
chain = self._parent_chains[i]
higher_rank_idx = self._rank_bisect_right(chain, self._rank[root_j])
if higher_rank_idx >= len(chain):
higher_rank_idx = len(chain) - 1
insert_point = chain[higher_rank_idx]
self._ordered_extend(self._parent_chains[root_j], chain[higher_rank_idx:])
self._build(root_j, root_i, insert_point=insert_point)
else:
if not self._parent_chains[root_i]:
self._parent_chains[root_i].append(root_i)
self._ordered_extend(self._parent_chains[root_j], self._parent_chains[i][-1:])
self._rank[root_i] += 1
self._build(root_i, root_j)
return True
@property
def tree(self) -> list[int]:
"""Return the compact child-to-parent array for constructed nodes."""
return self._tree[:self._next_id]
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
"""Build RAPTOR summary layers with the classic or Psi tree strategy."""
def __init__(
self,
max_cluster,
@ -45,7 +165,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
max_token=512,
threshold=0.1,
max_errors=3,
tree_builder=RAPTOR_TREE_BUILDER,
clustering_method=GMM_CLUSTERING_METHOD,
psi_exact_max_leaves=4096,
psi_bucket_size=1024,
):
"""Configure RAPTOR summarization, clustering, and Psi limits."""
self._max_cluster = max_cluster
self._llm_model = llm_model
self._embd_model = embd_model
@ -54,8 +179,17 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
self._tree_builder = tree_builder or RAPTOR_TREE_BUILDER
if self._tree_builder not in SUPPORTED_TREE_BUILDERS:
raise ValueError(f"Unsupported RAPTOR tree builder: {self._tree_builder}")
self._clustering_method = clustering_method or GMM_CLUSTERING_METHOD
if self._clustering_method not in SUPPORTED_CLUSTERING_METHODS:
raise ValueError(f"Unsupported RAPTOR clustering method: {self._clustering_method}")
self._psi_exact_max_leaves = max(2, int(psi_exact_max_leaves or 4096))
self._psi_bucket_size = min(max(2, int(psi_bucket_size or 1024)), self._psi_exact_max_leaves)
def _check_task_canceled(self, task_id: str, message: str = ""):
"""Raise if the current document task was canceled."""
if task_id and has_canceled(task_id):
log_msg = f"Task {task_id} cancelled during RAPTOR {message}."
logging.info(log_msg)
@ -63,6 +197,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
"""Call the configured LLM with caching and short retries."""
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
if cached:
return cached
@ -86,6 +221,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(20)
async def _embedding_encode(self, txt):
"""Encode text with the configured embedding model and cache result."""
response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt)
if response is not None:
return response
@ -97,6 +233,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
"""Choose the GMM cluster count with the lowest BIC score."""
max_clusters = min(self._max_cluster, len(embeddings))
n_clusters = np.arange(1, max_clusters)
bics = []
@ -109,57 +246,422 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters
def _get_clusters_ahc(self, embeddings: np.ndarray, task_id: str = "") -> np.ndarray:
"""Cluster embeddings with Ward-linkage AHC and a dendrogram gap heuristic."""
n = len(embeddings)
if n <= 1:
return np.zeros(n, dtype=int)
if n == 2:
return np.arange(n)
self._check_task_canceled(task_id, "_get_clusters_ahc dendrogram")
full_clust = AgglomerativeClustering(
n_clusters=None,
distance_threshold=0,
compute_distances=True,
linkage="ward",
)
full_clust.fit(embeddings)
distances = full_clust.distances_
if len(distances) > 1:
gaps = np.diff(distances)
max_gap_idx = int(np.argmax(gaps))
n_clusters = max(1, min(n - max_gap_idx - 1, self._max_cluster))
else:
n_clusters = max(1, min(n, self._max_cluster))
if n_clusters <= 1:
logging.info("RAPTOR AHC: _get_clusters_ahc selected one cluster for %d embeddings", n)
return np.zeros(n, dtype=int)
logging.info("RAPTOR AHC: _get_clusters_ahc selected n_clusters=%d for %d embeddings", n_clusters, n)
self._check_task_canceled(task_id, "_get_clusters_ahc fit")
clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage="ward")
return clustering.fit_predict(embeddings)
def _adjust_tree_nodes(self, embeddings: np.ndarray, labels: np.ndarray, max_iter: int = 5) -> np.ndarray:
"""Refine AHC assignments by reassigning nodes to nearest centroids."""
labels = labels.copy()
for _ in range(max_iter):
unique_labels = np.unique(labels)
if len(unique_labels) <= 1:
return labels
centroids = np.stack([embeddings[labels == lbl].mean(axis=0) for lbl in unique_labels])
diffs = embeddings[:, np.newaxis, :] - centroids[np.newaxis, :, :]
sq_dists = (diffs**2).sum(axis=2)
new_label_indices = np.argmin(sq_dists, axis=1)
new_labels = unique_labels[new_label_indices]
if np.array_equal(new_labels, labels):
break
unique_new = np.unique(new_labels)
remap = {old: new for new, old in enumerate(unique_new)}
labels = np.array([remap[int(lbl)] for lbl in new_labels])
return labels
@timeout(60 * 20)
async def _summarize_texts(self, texts: list[str], callback=None, task_id: str = ""):
"""Summarize a cluster and return text plus embedding when successful."""
self._check_task_canceled(task_id, "summarization")
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
try:
async with chat_limiter:
self._check_task_canceled(task_id, "before LLM call")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(cluster_content=cluster_content),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
self._check_task_canceled(task_id, "before embedding")
embds = await self._embedding_encode(cnt)
return cnt, embds
except TaskCanceledException:
raise
except Exception as exc:
self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(texts)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
return None
@staticmethod
def _root(node: _PsiTreeNode) -> _PsiTreeNode:
"""Return the current root for a Psi tree node."""
while node.parent is not None:
node = node.parent
return node
def _rank_leaf_pairs(self, leaves: list[_PsiTreeNode]) -> np.ndarray:
"""Rank all leaf pairs by original embedding-space cosine similarity."""
node_embeddings = np.asarray([leaf.embedding for leaf in leaves], dtype=np.float64)
node_embeddings = self._normalize_embeddings(node_embeddings)
similarities = node_embeddings @ node_embeddings.T
lower = np.tril_indices(len(leaves), -1)
ordered = np.argsort(similarities[lower], axis=0)[::-1]
return np.stack([lower[0][ordered], lower[1][ordered]], axis=-1)
@staticmethod
def _normalize_embeddings(node_embeddings: np.ndarray) -> np.ndarray:
"""Normalize embeddings for cosine operations while tolerating zero vectors."""
node_embeddings = np.asarray(node_embeddings, dtype=np.float64)
norms = np.linalg.norm(node_embeddings, axis=1, keepdims=True)
return node_embeddings / np.maximum(norms, 1e-12)
def _split_psi_buckets(self, nodes: list[_PsiTreeNode]) -> list[list[_PsiTreeNode]]:
"""Split large Psi inputs so exact pair ranking is bounded per bucket."""
if len(nodes) <= self._psi_bucket_size:
return [nodes]
node_embeddings = self._normalize_embeddings(np.asarray([node.embedding for node in nodes], dtype=np.float64))
groups = [np.arange(len(nodes), dtype=int)]
buckets = []
while groups:
group = np.asarray(groups.pop(), dtype=int)
if len(group) <= self._psi_bucket_size:
buckets.append(group.tolist())
continue
fanout = min(max(2, int(np.ceil(len(group) / self._psi_bucket_size))), len(group), 32)
group_embeddings = node_embeddings[group]
center_idx = np.linspace(0, len(group_embeddings) - 1, num=fanout, dtype=int)
centers = group_embeddings[center_idx].copy()
for _ in range(5):
labels = np.argmax(group_embeddings @ centers.T, axis=1)
for center_id in range(fanout):
mask = labels == center_id
if not np.any(mask):
continue
center = group_embeddings[mask].mean(axis=0)
norm = np.linalg.norm(center)
centers[center_id] = center / norm if norm > 0 else center
labels = np.argmax(group_embeddings @ centers.T, axis=1)
split_groups = [group[labels == center_id].tolist() for center_id in range(fanout)]
split_groups = [bucket for bucket in split_groups if bucket]
if len(split_groups) <= 1:
split_groups = [
group[start:start + self._psi_bucket_size].tolist()
for start in range(0, len(group), self._psi_bucket_size)
]
groups.extend(split_groups)
buckets = [bucket for bucket in buckets if bucket]
buckets.sort(key=lambda bucket: (len(bucket), bucket[0]))
return [[nodes[idx] for idx in bucket] for bucket in buckets]
def _assign_prototype_embeddings(self, node: _PsiTreeNode) -> np.ndarray:
"""Assign mean child embeddings to internal Psi nodes for bucket-level ranking."""
if not node.children:
return np.asarray(node.embedding, dtype=np.float64)
embeddings = np.asarray([self._assign_prototype_embeddings(child) for child in node.children], dtype=np.float64)
node.embedding = embeddings.mean(axis=0)
return node.embedding
@staticmethod
def _iter_nodes(root: _PsiTreeNode):
"""Yield nodes in a Psi tree using a stack traversal."""
stack = [root]
while stack:
node = stack.pop()
yield node
stack.extend(node.children)
def _create_psi_parent(self, index: int, children: list[_PsiTreeNode]) -> _PsiTreeNode:
"""Create a parent node and attach the provided children to it."""
parent = _PsiTreeNode(index=index, children=children)
for child in children:
child.parent = parent
return parent
def _rebalance_psi_tree(self, root: _PsiTreeNode, next_index: int) -> tuple[_PsiTreeNode, int]:
"""Group oversized Psi tree nodes so fanout stays within max_cluster."""
max_children = max(2, int(self._max_cluster or 2))
def rebalance(node: _PsiTreeNode):
"""Recursively group children when a Psi node exceeds fanout."""
nonlocal next_index
for child in list(node.children):
rebalance(child)
while len(node.children) > max_children:
original_children = len(node.children)
grouped_children = []
for start in range(0, len(node.children), max_children):
batch = node.children[start:start + max_children]
if len(batch) == 1:
grouped_children.append(batch[0])
batch[0].parent = node
else:
grouped_children.append(self._create_psi_parent(next_index, batch))
grouped_children[-1].parent = node
next_index += 1
node.children = grouped_children
logging.info(
"RAPTOR Psi rebalance: node=%s children=%d grouped_to=%d max_cluster=%d",
node.index,
original_children,
len(grouped_children),
max_children,
)
rebalance(root)
return self._root(root), next_index
def _build_exact_psi_structure(
self,
nodes: list[_PsiTreeNode],
next_index: int,
task_id: str = "",
) -> tuple[_PsiTreeNode, int, int]:
"""Build an exact Psi subtree for a bounded node set."""
if len(nodes) == 1:
return nodes[0], next_index, 0
ranked_pairs = self._rank_leaf_pairs(nodes)
union_find = _PsiUnionFind(len(nodes))
merges = 0
for left_idx, right_idx in ranked_pairs:
self._check_task_canceled(task_id, "Psi tree construction")
if union_find.union(int(left_idx), int(right_idx)):
merges += 1
if merges == len(nodes) - 1:
break
local_nodes = {idx: node for idx, node in enumerate(nodes)}
tree = union_find.tree
children_by_parent = {}
for child_idx, parent_idx in enumerate(tree):
if child_idx not in local_nodes:
local_nodes[child_idx] = _PsiTreeNode(index=next_index)
next_index += 1
if parent_idx == -1:
continue
children_by_parent.setdefault(parent_idx, []).append(child_idx)
if parent_idx not in local_nodes:
local_nodes[parent_idx] = _PsiTreeNode(index=next_index)
next_index += 1
for parent_idx, child_indices in children_by_parent.items():
parent = local_nodes[parent_idx]
parent.children = [local_nodes[child_idx] for child_idx in child_indices]
for child in parent.children:
child.parent = parent
roots = [local_nodes[idx] for idx, parent_idx in enumerate(tree) if parent_idx == -1 and idx in local_nodes]
root = max(roots, key=lambda node: node.index)
return root, next_index, merges
def _build_bucketed_psi_structure(
self,
nodes: list[_PsiTreeNode],
next_index: int,
task_id: str = "",
) -> tuple[_PsiTreeNode, int, int]:
"""Build large Psi trees by exact-ranking bounded buckets, then bucket roots."""
buckets = self._split_psi_buckets(nodes)
logging.info(
"RAPTOR Psi bucketed build: nodes=%d buckets=%d bucket_size=%d exact_max_leaves=%d",
len(nodes),
len(buckets),
self._psi_bucket_size,
self._psi_exact_max_leaves,
)
bucket_roots = []
merges = 0
for bucket in buckets:
bucket_root, next_index, bucket_merges = self._build_psi_structure_from_nodes(bucket, next_index, task_id)
self._assign_prototype_embeddings(bucket_root)
bucket_roots.append(bucket_root)
merges += bucket_merges
if len(bucket_roots) == 1:
return bucket_roots[0], next_index, merges
root, next_index, root_merges = self._build_psi_structure_from_nodes(bucket_roots, next_index, task_id)
return root, next_index, merges + root_merges
def _build_psi_structure_from_nodes(
self,
nodes: list[_PsiTreeNode],
next_index: int,
task_id: str = "",
) -> tuple[_PsiTreeNode, int, int]:
"""Build Psi structure exactly for small sets and bucket large sets."""
if len(nodes) <= self._psi_exact_max_leaves:
return self._build_exact_psi_structure(nodes, next_index, task_id)
return self._build_bucketed_psi_structure(nodes, next_index, task_id)
def _build_psi_structure(self, chunks, task_id: str = "") -> tuple[_PsiTreeNode, list[_PsiTreeNode]]:
"""Build the Psi merge tree from original chunk embeddings."""
leaves = [
_PsiTreeNode(index=i, text=text, embedding=np.asarray(embd))
for i, (text, embd) in enumerate(chunks)
]
if len(leaves) == 1:
return leaves[0], leaves
root, next_index, merges = self._build_psi_structure_from_nodes(leaves, len(leaves), task_id)
root, _ = self._rebalance_psi_tree(root, next_index)
logging.info(
"RAPTOR Psi tree built: leaves=%d merges=%d root_fanout=%d",
len(leaves),
merges,
len(root.children),
)
return root, leaves
@staticmethod
def _psi_layers(root: _PsiTreeNode) -> dict[int, list[_PsiTreeNode]]:
"""Collect non-leaf Psi nodes by height for bottom-up summarization."""
layers = {}
def height(node: _PsiTreeNode) -> int:
"""Return node height while collecting internal nodes by layer."""
if not node.children:
return 0
node_height = max(height(child) for child in node.children) + 1
layers.setdefault(node_height, []).append(node)
return node_height
height(root)
return layers
async def _build_psi_layers(self, chunks, callback=None, task_id: str = ""):
"""Materialize Psi tree layers as summary chunks."""
layers = [(0, len(chunks))]
root, _ = self._build_psi_structure(chunks, task_id=task_id)
for layer_idx, (_, nodes) in enumerate(sorted(self._psi_layers(root).items()), start=1):
layer_start = len(chunks)
async def summarize_node(node: _PsiTreeNode):
"""Summarize one Psi internal node if its children have text."""
texts = [child.text for child in node.children if child.text]
if not texts:
logging.warning("RAPTOR Psi node %s skipped because it has no child text to summarize", node.index)
return None
result = await self._summarize_texts(texts, callback, task_id)
if result is None:
logging.warning("RAPTOR Psi node %s skipped because summarization failed", node.index)
return None
node.text, node.embedding = result
return node
tasks = [asyncio.create_task(summarize_node(node)) for node in nodes]
try:
summarized_nodes = await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in RAPTOR Psi tree processing: {e}")
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
summarized_nodes = [node for node in summarized_nodes if node is not None]
for node in summarized_nodes:
chunks.append((node.text, node.embedding))
if len(chunks) > layer_start:
layers.append((layer_start, len(chunks)))
logging.info(
"RAPTOR Psi layer materialized: layer=%d nodes=%d summaries=%d",
layer_idx,
len(nodes),
len(chunks) - layer_start,
)
if callback:
callback(msg="Build one Psi-RAG layer: {} -> {}".format(len(nodes), len(chunks) - layer_start))
else:
logging.warning("RAPTOR Psi layer %d produced no summaries; stopping materialization", layer_idx)
break
return chunks, layers
async def __call__(self, chunks, random_state, callback=None, task_id: str = ""):
"""Build summary chunks and layer boundaries for RAPTOR retrieval."""
if len(chunks) <= 1:
return [], []
chunks = [(s, a) for s, a in chunks if s and a is not None and len(a) > 0]
if len(chunks) <= 1:
return chunks, [(0, len(chunks))]
if self._tree_builder == PSI_TREE_BUILDER:
logging.info("RAPTOR: using %s tree builder for %d chunks", self._tree_builder, len(chunks))
return await self._build_psi_layers(chunks, callback, task_id)
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
@timeout(60 * 20)
async def summarize(ck_idx: list[int]):
"""Summarize one classic RAPTOR cluster into the chunk list."""
nonlocal chunks
self._check_task_canceled(task_id, "summarization")
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
try:
async with chat_limiter:
self._check_task_canceled(task_id, "before LLM call")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(cluster_content=cluster_content),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
self._check_task_canceled(task_id, "before embedding")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
except TaskCanceledException:
raise
except Exception as exc:
self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
result = await self._summarize_texts(texts, callback, task_id)
if result is not None:
chunks.append(result)
while end - start > 1:
self._check_task_canceled(task_id, "layer processing")
@ -167,8 +669,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
embeddings = [embd for _, embd in chunks[start:end]]
if len(embeddings) == 2:
await summarize([start, start + 1])
produced = len(chunks) - end
if produced == 0:
logging.warning("RAPTOR layer produced no summaries; stopping materialization")
break
if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
callback(msg="Cluster one layer: {} -> {}".format(end - start, produced))
layers.append((end, len(chunks)))
start = end
end = len(chunks)
@ -180,15 +686,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_components=min(12, len(embeddings) - 2),
metric="cosine",
).fit_transform(embeddings)
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id)
if self._clustering_method == AHC_CLUSTERING_METHOD:
logging.info("RAPTOR: using clustering_method=%s before _get_clusters_ahc", self._clustering_method)
raw_labels = self._get_clusters_ahc(reduced_embeddings, task_id=task_id)
raw_cluster_count = np.unique(raw_labels).size
logging.info("RAPTOR AHC: _get_clusters_ahc produced n_clusters=%d", raw_cluster_count)
if raw_cluster_count > 1:
adjusted = self._adjust_tree_nodes(reduced_embeddings, raw_labels)
adjusted_cluster_count = np.unique(adjusted).size
logging.info("RAPTOR AHC: _adjust_tree_nodes adjusted n_clusters=%d", adjusted_cluster_count)
else:
adjusted = raw_labels
logging.warning("RAPTOR AHC: _adjust_tree_nodes skipped because _get_clusters_ahc returned one cluster")
unique_labels = np.unique(adjusted)
label_map = {old: idx for idx, old in enumerate(unique_labels)}
lbls = [label_map[int(lbl)] for lbl in adjusted]
n_clusters = len(unique_labels)
else:
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id)
if n_clusters == 1:
lbls = [0 for _ in range(len(reduced_embeddings))]
else:
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
gm.fit(reduced_embeddings)
probs = gm.predict_proba(reduced_embeddings)
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
if n_clusters == 1:
lbls = [0 for _ in range(len(reduced_embeddings))]
else:
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
gm.fit(reduced_embeddings)
probs = gm.predict_proba(reduced_embeddings)
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
lbls = [int(lbl[0]) if isinstance(lbl, np.ndarray) else int(lbl) for lbl in lbls]
tasks = []
for c in range(n_clusters):
@ -205,10 +733,21 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
await asyncio.gather(*tasks, return_exceptions=True)
raise
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
produced = len(chunks) - end
assert produced <= n_clusters, "{} vs. {}".format(produced, n_clusters)
if produced < n_clusters:
logging.warning(
"RAPTOR layer produced %d/%d cluster summaries; skipped %d cluster(s) due to errors",
produced,
n_clusters,
n_clusters - produced,
)
if produced == 0:
logging.warning("RAPTOR layer produced no summaries; stopping materialization")
break
layers.append((end, len(chunks)))
if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
callback(msg="Cluster one layer: {} -> {}".format(end - start, produced))
start = end
end = len(chunks)

View File

@ -36,7 +36,15 @@ from api.db.joint_services.memory_message_service import handle_save_to_memory_t
from common.connection_utils import timeout
from common.metadata_utils import turn2jsonschema, update_metadata_to
from rag.utils.base64_image import image2id
from rag.utils.raptor_utils import should_skip_raptor, get_skip_reason
from rag.utils.raptor_utils import (
collect_raptor_chunk_ids,
collect_raptor_methods,
get_raptor_clustering_method,
get_raptor_tree_builder,
get_skip_reason,
make_raptor_summary_chunk_id,
should_skip_raptor,
)
from common.log_utils import init_root_logger
from common.config_utils import show_configs
from rag.graphrag.general.index import run_graphrag_for_kb
@ -70,7 +78,10 @@ from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
email, tag
from rag.nlp import search, rag_tokenizer, add_positions
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.raptor import (
RAPTOR_TREE_BUILDER,
RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor,
)
from common.token_utils import num_tokens_from_string, truncate
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from rag.graphrag.utils import chat_limiter
@ -817,61 +828,160 @@ async def run_dataflow(task: dict):
dsl=str(pipeline))
async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str) -> bool:
"""Return True if RAPTOR chunks already exist for doc_id in the doc store.
RAPTOR_METHOD_SEARCH_LIMIT = 10000
Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading
chunk cannot produce a false-negative result. Uses thread_pool_exec so
the blocking doc-store call does not stall the event loop.
"""
async def get_raptor_chunk_field_map(doc_id: str, tenant_id: str, kb_id: str) -> dict:
"""Return stored RAPTOR marker fields for a document."""
from common.doc_store.doc_store_base import OrderByExpr
from rag.nlp import search as nlp_search
try:
condition = {"doc_id": doc_id, "raptor_kwd": ["raptor"]}
async def search_fields(fields: list[str], condition: dict, order_by=None):
"""Search chunk fields in the current knowledge base."""
res = await thread_pool_exec(
settings.docStoreConn.search,
["raptor_kwd"], [], condition, [], OrderByExpr(),
0, 1, nlp_search.index_name(tenant_id), [kb_id]
fields, [], condition, [], order_by or OrderByExpr(),
0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id]
)
field_map = settings.docStoreConn.get_fields(res, ["raptor_kwd"])
found = bool(field_map)
if found:
return settings.docStoreConn.get_fields(res, fields)
primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]})
if collect_raptor_chunk_ids(primary):
return primary
try:
return await search_fields(
["raptor_kwd", "extra"],
{"doc_id": doc_id},
OrderByExpr().desc("create_timestamp_flt"),
)
except Exception:
logging.debug("RAPTOR fallback method lookup with extra field failed for doc %s", doc_id, exc_info=True)
return primary
async def get_raptor_chunk_methods(doc_id: str, tenant_id: str, kb_id: str) -> set[str]:
"""Return the RAPTOR tree builders already stored for doc_id.
Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading
chunk cannot produce a false-negative result. Legacy summary chunks that
do not have method metadata are treated as the original RAPTOR builder.
"""
try:
field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id)
methods = collect_raptor_methods(field_map)
if methods:
logging.info(
"Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s) already exist",
doc_id, tenant_id, kb_id,
"Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist",
doc_id, tenant_id, kb_id, sorted(methods),
)
else:
logging.info(
"Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)",
doc_id, tenant_id, kb_id,
)
return found
return methods
except Exception:
logging.exception("Failed to check RAPTOR chunks for doc %s", doc_id)
return False
raise
async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, tree_builder: str = RAPTOR_TREE_BUILDER) -> bool:
"""Return whether doc_id already has summaries for tree_builder."""
methods = await get_raptor_chunk_methods(doc_id, tenant_id, kb_id)
return tree_builder in methods
async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_method: str | None = None):
"""Delete RAPTOR summaries for doc_id, optionally preserving one method."""
from rag.nlp import search as nlp_search
if keep_method is None:
logging.info(
"delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)",
doc_id, tenant_id, kb_id,
)
await thread_pool_exec(
settings.docStoreConn.delete,
{"doc_id": doc_id, "raptor_kwd": ["raptor"]},
nlp_search.index_name(tenant_id),
kb_id,
)
return 0
field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id)
chunk_ids = collect_raptor_chunk_ids(field_map, exclude_methods={keep_method})
if not chunk_ids:
logging.debug(
"delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)",
doc_id, tenant_id, kb_id, keep_method,
)
return 0
logging.info(
"delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)",
len(chunk_ids), doc_id, tenant_id, kb_id, keep_method,
)
await thread_pool_exec(
settings.docStoreConn.delete,
{"id": list(chunk_ids)},
nlp_search.index_name(tenant_id),
kb_id,
)
return len(chunk_ids)
@timeout(3600)
async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
"""Generate RAPTOR summaries for selected documents in a knowledge base."""
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
raptor_config = kb_parser_config.get("raptor", {})
raptor_ext_config = raptor_config.get("ext") or {}
tree_builder = get_raptor_tree_builder(raptor_config)
clustering_method = get_raptor_clustering_method(raptor_config)
vctr_nm = "q_%d_vec" % vector_size
res = []
tk_count = 0
cleanup_raptor_chunks = []
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
doc_name_by_id = {}
doc_info_by_id = {}
for doc_id in set(doc_ids):
ok, source_doc = DocumentService.get_by_id(doc_id)
if not ok or not source_doc:
continue
source_name = getattr(source_doc, "name", "")
if source_name:
doc_name_by_id[doc_id] = source_name
doc_info_by_id[doc_id] = {
"name": getattr(source_doc, "name", ""),
"type": getattr(source_doc, "type", ""),
"parser_id": getattr(source_doc, "parser_id", ""),
"parser_config": getattr(source_doc, "parser_config", {}) or {},
}
def schedule_raptor_cleanup(doc_id: str, keep_method: str | None = None):
"""Queue stale RAPTOR summaries for deletion after successful insert."""
cleanup_plan = (doc_id, keep_method)
if cleanup_plan not in cleanup_raptor_chunks:
cleanup_raptor_chunks.append(cleanup_plan)
def skip_raptor_doc(doc_id: str) -> bool:
"""Return whether RAPTOR should be skipped for this source document."""
doc_info = doc_info_by_id.get(doc_id, {})
file_type = doc_info.get("type") or row.get("type", "")
parser_id = doc_info.get("parser_id") or row.get("parser_id", "")
parser_config = doc_info.get("parser_config") or row.get("parser_config", {})
if should_skip_raptor(file_type, parser_id, parser_config, raptor_config):
skip_reason = get_skip_reason(file_type, parser_id, parser_config)
doc_name = doc_info.get("name") or doc_id
logging.info("Skipping Raptor for document %s: %s", doc_name, skip_reason)
callback(msg=f"[RAPTOR] doc:{doc_id} skipped: {skip_reason}")
return True
return False
async def generate(chunks, did):
"""Run RAPTOR and append generated summary chunks for one doc id."""
nonlocal tk_count, res
logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did)
raptor = Raptor(
raptor_config.get("max_cluster", 64),
chat_mdl,
@ -880,16 +990,21 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["max_token"],
raptor_config["threshold"],
max_errors=max_errors,
tree_builder=tree_builder,
clustering_method=clustering_method,
psi_exact_max_leaves=raptor_ext_config.get("psi_exact_max_leaves", 4096),
psi_bucket_size=raptor_ext_config.get("psi_bucket_size", 1024),
)
original_length = len(chunks)
chunks, layers = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
effective_doc_name = row["name"] if did == fake_doc_id else doc_name_by_id.get(did, row["name"])
effective_doc_name = row["name"] if did == fake_doc_id else doc_info_by_id.get(did, {}).get("name") or row["name"]
doc = {
"doc_id": did,
"kb_id": [str(row["kb_id"])],
"docnm_kwd": effective_doc_name,
"title_tks": rag_tokenizer.tokenize(effective_doc_name),
"raptor_kwd": "raptor"
"raptor_kwd": "raptor",
"extra": {"raptor_method": tree_builder},
}
if row["pagerank"]:
doc[PAGERANK_FLD] = int(row["pagerank"])
@ -906,7 +1021,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
for idx, (content, vctr) in enumerate(chunks[original_length:], start=original_length):
d = copy.deepcopy(doc)
d["id"] = xxhash.xxh64((content + str(fake_doc_id)).encode("utf-8")).hexdigest()
d["id"] = make_raptor_summary_chunk_id(content, did)
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
d[vctr_nm] = vctr.tolist()
@ -918,12 +1033,28 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
tk_count += num_tokens_from_string(content)
if raptor_config.get("scope", "file") == "file":
dataset_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"])
remove_dataset_summaries = bool(dataset_methods)
has_file_level_target = False
if dataset_methods:
callback(msg="[RAPTOR] will remove dataset-level summaries after file-level summaries are available.")
for x, doc_id in enumerate(doc_ids):
# CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store
if await has_raptor_chunks(doc_id, row["tenant_id"], row["kb_id"]):
callback(msg=f"[RAPTOR] doc:{doc_id} already has RAPTOR chunks, skipping.")
if skip_raptor_doc(doc_id):
callback(prog=(x + 1.) / len(doc_ids))
continue
# CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
if tree_builder in existing_methods:
has_file_level_target = True
if existing_methods != {tree_builder}:
schedule_raptor_cleanup(doc_id, tree_builder)
callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.")
callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.")
callback(prog=(x + 1.) / len(doc_ids))
continue
if existing_methods:
callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.")
chunks = []
skipped_chunks = 0
@ -945,12 +1076,52 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
callback(msg=f"[WARN] No valid chunks with vectors found for doc {doc_id}, skipping")
continue
before_generate = len(res)
await generate(chunks, doc_id)
if len(res) > before_generate:
has_file_level_target = True
if existing_methods:
schedule_raptor_cleanup(doc_id, tree_builder)
callback(prog=(x + 1.) / len(doc_ids))
if remove_dataset_summaries:
if has_file_level_target:
schedule_raptor_cleanup(fake_doc_id)
else:
callback(msg="[RAPTOR] kept dataset-level summaries because no file-level summaries were built.")
else:
migrated_file_docs = 0
file_cleanup_doc_ids = []
skipped_doc_ids = set()
for doc_id in set(doc_ids):
if skip_raptor_doc(doc_id):
skipped_doc_ids.add(doc_id)
continue
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
if existing_methods:
file_cleanup_doc_ids.append(doc_id)
migrated_file_docs += 1
if migrated_file_docs:
callback(msg=f"[RAPTOR] will remove file-level summaries for {migrated_file_docs} docs after dataset-level build succeeds.")
existing_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"])
if tree_builder in existing_methods:
if existing_methods != {tree_builder}:
schedule_raptor_cleanup(fake_doc_id, tree_builder)
callback(msg="[RAPTOR] will remove old dataset-level RAPTOR summaries after insert.")
for doc_id in file_cleanup_doc_ids:
schedule_raptor_cleanup(doc_id)
callback(msg=f"[RAPTOR] dataset-level {tree_builder} summaries already exist, skipping.")
return res, tk_count, cleanup_raptor_chunks
migrate_dataset_summaries = bool(existing_methods)
if migrate_dataset_summaries:
callback(msg=f"[RAPTOR] will migrate dataset-level RAPTOR summaries to {tree_builder} after insert.")
chunks = []
skipped_chunks = 0
for doc_id in doc_ids:
if doc_id in skipped_doc_ids:
continue
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
@ -965,13 +1136,22 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
callback(msg=f"[WARN] Skipped {skipped_chunks} chunks without vector field '{vctr_nm}'. Consider re-parsing documents with the current embedding model.")
if not chunks:
if skipped_doc_ids and len(skipped_doc_ids) == len(set(doc_ids)):
callback(msg="[RAPTOR] all documents were skipped by RAPTOR auto-disable rules.")
return res, tk_count, cleanup_raptor_chunks
logging.error(f"RAPTOR: No valid chunks with vectors found in any document for kb {row['kb_id']}")
callback(msg=f"[ERROR] No valid chunks with vectors found. Please ensure documents are parsed with the current embedding model (vector size: {vector_size}).")
return res, tk_count
return res, tk_count, cleanup_raptor_chunks
before_generate = len(res)
await generate(chunks, fake_doc_id)
if len(res) > before_generate:
for doc_id in file_cleanup_doc_ids:
schedule_raptor_cleanup(doc_id)
if migrate_dataset_summaries:
schedule_raptor_cleanup(fake_doc_id, tree_builder)
return res, tk_count
return res, tk_count, cleanup_raptor_chunks
async def delete_image(kb_id, chunk_id):
@ -1029,6 +1209,29 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
# Roll back partial RAPTOR summary inserts so the next run is not
# mistaken for a completed checkpoint by get_raptor_chunk_methods.
raptor_ids_to_rollback = [
c["id"] for c in chunks[:b + settings.DOC_BULK_SIZE]
if c.get("raptor_kwd") == "raptor"
]
if raptor_ids_to_rollback:
try:
await thread_pool_exec(
settings.docStoreConn.delete,
{"id": raptor_ids_to_rollback},
search.index_name(task_tenant_id),
task_dataset_id,
)
logging.info(
"insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)",
len(raptor_ids_to_rollback), task_id,
)
except Exception:
logging.exception(
"insert_chunks: failed to roll back partial RAPTOR chunks after cancellation (task=%s)",
task_id,
)
progress_callback(-1, msg="Task has been canceled.")
return False
if b % 128 == 0:
@ -1088,6 +1291,7 @@ async def do_handle_task(task):
task_parser_config = task["parser_config"]
task_start_ts = timer()
toc_thread = None
raptor_cleanup_chunks = []
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@ -1135,7 +1339,9 @@ async def do_handle_task(task):
"threshold": 0.1,
"max_cluster": 64,
"random_seed": 0,
"scope": "file"
"scope": "file",
"clustering_method": "gmm",
"tree_builder": "raptor",
},
}
)
@ -1143,23 +1349,12 @@ async def do_handle_task(task):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
# Check if Raptor should be skipped for structured data
file_type = task.get("type", "")
parser_id = task.get("parser_id", "")
raptor_config = kb_parser_config.get("raptor", {})
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
progress_callback(prog=1.0, msg=f"Raptor skipped: {skip_reason}")
return
# bind LLM for raptor
chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id)
chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language)
# run RAPTOR
async with kg_limiter:
chunks, token_count = await run_raptor_for_kb(
chunks, token_count, raptor_cleanup_chunks = await run_raptor_for_kb(
row=task,
kb_parser_config=kb_parser_config,
chat_mdl=chat_model,
@ -1268,6 +1463,18 @@ async def do_handle_task(task):
progress_callback(-1, msg="Task has been canceled.")
return
if raptor_cleanup_chunks:
cleaned_chunks = 0
for cleanup_doc_id, keep_method in raptor_cleanup_chunks:
cleaned_chunks += await delete_raptor_chunks(
cleanup_doc_id,
task_tenant_id,
task_dataset_id,
keep_method=keep_method,
)
if cleaned_chunks:
progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.")
logging.info(
"Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(
task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts

View File

@ -46,6 +46,8 @@ column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk ord
column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval")
column_mom_id = Column("mom_id", String(256), nullable=True, comment="parent chunk id")
column_chunk_data = Column("chunk_data", JSON, nullable=True, comment="table parser row data")
column_raptor_kwd = Column("raptor_kwd", String(256), nullable=True, comment="RAPTOR summary marker")
column_raptor_layer_int = Column("raptor_layer_int", Integer, nullable=True, comment="RAPTOR summary layer")
column_definitions: list[Column] = [
Column("id", String(256), primary_key=True, comment="chunk id"),
@ -86,6 +88,8 @@ column_definitions: list[Column] = [
Column("rank_flt", Double, nullable=True, comment="rank of this entity"),
Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'",
comment="whether it has been deleted"),
column_raptor_kwd,
column_raptor_layer_int,
column_chunk_data,
Column("metadata", JSON, nullable=True, comment="metadata for this chunk"),
Column("extra", JSON, nullable=True, comment="extra information of non-general chunk"),
@ -127,7 +131,14 @@ FTS_COLUMNS_TKS: list[str] = [
]
# Extra columns to add after table creation (for migration)
EXTRA_COLUMNS: list[Column] = [column_order_id, column_group_id, column_mom_id, column_chunk_data]
EXTRA_COLUMNS: list[Column] = [
column_order_id,
column_group_id,
column_mom_id,
column_chunk_data,
column_raptor_kwd,
column_raptor_layer_int,
]
class SearchResult(BaseModel):

View File

@ -18,15 +18,111 @@
Utility functions for Raptor processing decisions.
"""
import json
import logging
from typing import Optional
import xxhash
RAPTOR_TREE_BUILDER = "raptor"
PSI_TREE_BUILDER = "psi"
SUPPORTED_TREE_BUILDERS = {RAPTOR_TREE_BUILDER, PSI_TREE_BUILDER}
GMM_CLUSTERING_METHOD = "gmm"
AHC_CLUSTERING_METHOD = "ahc"
SUPPORTED_CLUSTERING_METHODS = {GMM_CLUSTERING_METHOD, AHC_CLUSTERING_METHOD}
# File extensions for structured data types
EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"}
CSV_EXTENSIONS = {".csv", ".tsv"}
STRUCTURED_EXTENSIONS = EXCEL_EXTENSIONS | CSV_EXTENSIONS
def get_raptor_tree_builder(raptor_config: dict | None) -> str:
"""Return the configured RAPTOR tree builder with legacy ext fallback."""
raptor_config = raptor_config or {}
ext = raptor_config.get("ext") or {}
tree_builder = ext.get("tree_builder") or raptor_config.get("tree_builder") or RAPTOR_TREE_BUILDER
if tree_builder not in SUPPORTED_TREE_BUILDERS:
raise ValueError(f"Unsupported RAPTOR tree builder: {tree_builder}")
return tree_builder
def get_raptor_clustering_method(raptor_config: dict | None) -> str:
"""Return the configured RAPTOR clustering method with legacy ext fallback."""
raptor_config = raptor_config or {}
ext = raptor_config.get("ext") or {}
clustering_method = ext.get("clustering_method") or raptor_config.get("clustering_method") or GMM_CLUSTERING_METHOD
if clustering_method not in SUPPORTED_CLUSTERING_METHODS:
raise ValueError(f"Unsupported RAPTOR clustering method: {clustering_method}")
return clustering_method
def _as_extra_dict(extra) -> dict:
"""Normalize a chunk extra payload into a dictionary."""
if isinstance(extra, dict):
return extra
if isinstance(extra, str) and extra:
try:
parsed = json.loads(extra)
except json.JSONDecodeError:
logging.warning(
"Ignoring malformed RAPTOR extra payload while collecting chunk metadata: %s",
extra[:200],
exc_info=True,
)
return {}
return parsed if isinstance(parsed, dict) else {}
return {}
def _has_raptor_marker(marker) -> bool:
"""Return whether a chunk marker identifies a RAPTOR summary chunk."""
if isinstance(marker, list):
return any(str(item) == RAPTOR_TREE_BUILDER for item in marker)
return str(marker) == RAPTOR_TREE_BUILDER
def _raptor_methods_from_fields(fields: dict, extra: dict | None = None) -> set[str]:
"""Read RAPTOR builder methods from stored chunk fields."""
extra = extra if extra is not None else _as_extra_dict(fields.get("extra"))
method = extra.get("raptor_method") or RAPTOR_TREE_BUILDER
if isinstance(method, list):
return {str(item) for item in method if item}
return {str(method)} if method else set()
def collect_raptor_methods(field_map: dict) -> set[str]:
"""Collect tree-builder methods from RAPTOR summary chunk fields."""
methods = set()
for fields in field_map.values():
extra = _as_extra_dict(fields.get("extra"))
marker = fields.get("raptor_kwd") or extra.get("raptor_kwd")
if not _has_raptor_marker(marker):
continue
methods.update(_raptor_methods_from_fields(fields, extra))
return methods
def collect_raptor_chunk_ids(field_map: dict, exclude_methods: set[str] | None = None) -> set[str]:
"""Collect RAPTOR summary chunk IDs, optionally excluding some methods."""
chunk_ids = set()
exclude_methods = exclude_methods or set()
for chunk_id, fields in field_map.items():
extra = _as_extra_dict(fields.get("extra"))
marker = fields.get("raptor_kwd") or extra.get("raptor_kwd")
if _has_raptor_marker(marker):
if _raptor_methods_from_fields(fields, extra).issubset(exclude_methods):
continue
chunk_ids.add(chunk_id)
return chunk_ids
def make_raptor_summary_chunk_id(content: str, doc_id: str) -> str:
"""Build the stable ID used for generated RAPTOR summary chunks."""
return xxhash.xxh64((content + str(doc_id)).encode("utf-8")).hexdigest()
def is_structured_file_type(file_type: Optional[str]) -> bool:
"""
Check if a file type is structured data (Excel, CSV, etc.)

View File

@ -583,6 +583,10 @@ class TestDatasetUpdate:
{"raptor": {"max_cluster": 512}},
{"raptor": {"max_cluster": 1024}},
{"raptor": {"random_seed": 0}},
{"raptor": {"clustering_method": "gmm"}},
{"raptor": {"clustering_method": "ahc"}},
{"raptor": {"tree_builder": "raptor"}},
{"raptor": {"tree_builder": "psi"}},
],
ids=[
"auto_keywords_min",
@ -633,6 +637,10 @@ class TestDatasetUpdate:
"raptor_max_cluster_mid",
"raptor_max_cluster_max",
"raptor_random_seed_min",
"raptor_clustering_method_gmm",
"raptor_clustering_method_ahc",
"raptor_tree_builder_raptor",
"raptor_tree_builder_psi",
],
)
def test_parser_config(self, HttpApiAuth, add_dataset_func, parser_config):
@ -707,6 +715,10 @@ class TestDatasetUpdate:
({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer"),
({"raptor": {"random_seed": "string"}}, "Input should be a valid integer"),
({"raptor": {"clustering_method": "unknown"}}, "Input should be 'gmm' or 'ahc'"),
({"raptor": {"clustering_method": None}}, "Input should be 'gmm' or 'ahc'"),
({"raptor": {"tree_builder": "ahc"}}, "Input should be 'raptor' or 'psi'"),
({"raptor": {"tree_builder": None}}, "Input should be 'raptor' or 'psi'"),
({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"),
],
ids=[
@ -763,6 +775,10 @@ class TestDatasetUpdate:
"raptor_random_seed_min_limit",
"raptor_random_seed_float_not_allowed",
"raptor_random_seed_type_invalid",
"raptor_clustering_method_invalid",
"raptor_clustering_method_none_invalid",
"raptor_tree_builder_invalid",
"raptor_tree_builder_none_invalid",
"parser_config_type_invalid",
],
)

View File

@ -0,0 +1,375 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib
import sys
import types
import pytest
np = pytest.importorskip("numpy")
from api.utils.validation_utils import RaptorConfig
from pydantic import ValidationError
@pytest.fixture()
def raptor_module(monkeypatch):
class TaskCanceledException(Exception):
pass
class DummyLimiter:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
class DummyGaussianMixture:
def __init__(self, *args, **kwargs):
pass
def fit(self, embeddings):
return self
def bic(self, embeddings):
return 0
def predict_proba(self, embeddings):
return np.ones((len(embeddings), 1))
class DummyAgglomerativeClustering:
def __init__(self, n_clusters=None, distance_threshold=None, compute_distances=False, linkage="ward"):
self.n_clusters = n_clusters
self.distance_threshold = distance_threshold
self.compute_distances = compute_distances
self.linkage = linkage
self.distances_ = np.array([0.1, 0.2, 1.0])
def fit(self, embeddings):
self.labels_ = self.fit_predict(embeddings)
return self
def fit_predict(self, embeddings):
if self.n_clusters is None:
return np.zeros(len(embeddings), dtype=int)
return np.array([idx % self.n_clusters for idx in range(len(embeddings))])
class DummyUMAP:
def __init__(self, *args, **kwargs):
pass
def fit_transform(self, embeddings):
raise AssertionError("Psi tree builder must use original embeddings, not UMAP")
sklearn_module = types.ModuleType("sklearn")
mixture_module = types.ModuleType("sklearn.mixture")
mixture_module.GaussianMixture = DummyGaussianMixture
cluster_module = types.ModuleType("sklearn.cluster")
cluster_module.AgglomerativeClustering = DummyAgglomerativeClustering
umap_module = types.ModuleType("umap")
umap_module.UMAP = DummyUMAP
task_service_module = types.ModuleType("api.db.services.task_service")
task_service_module.has_canceled = lambda task_id: False
connection_utils_module = types.ModuleType("common.connection_utils")
connection_utils_module.timeout = lambda seconds: lambda fn: fn
exceptions_module = types.ModuleType("common.exceptions")
exceptions_module.TaskCanceledException = TaskCanceledException
token_utils_module = types.ModuleType("common.token_utils")
token_utils_module.truncate = lambda text, max_len: text[:max_len]
graphrag_utils_module = types.ModuleType("rag.graphrag.utils")
graphrag_utils_module.chat_limiter = DummyLimiter()
graphrag_utils_module.get_embed_cache = lambda *args, **kwargs: None
graphrag_utils_module.get_llm_cache = lambda *args, **kwargs: None
graphrag_utils_module.set_embed_cache = lambda *args, **kwargs: None
graphrag_utils_module.set_llm_cache = lambda *args, **kwargs: None
async def thread_pool_exec(fn, *args, **kwargs):
return fn(*args, **kwargs)
misc_utils_module = types.ModuleType("common.misc_utils")
misc_utils_module.thread_pool_exec = thread_pool_exec
monkeypatch.setitem(sys.modules, "sklearn", sklearn_module)
monkeypatch.setitem(sys.modules, "sklearn.mixture", mixture_module)
monkeypatch.setitem(sys.modules, "sklearn.cluster", cluster_module)
monkeypatch.setitem(sys.modules, "umap", umap_module)
monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service_module)
monkeypatch.setitem(sys.modules, "common.connection_utils", connection_utils_module)
monkeypatch.setitem(sys.modules, "common.exceptions", exceptions_module)
monkeypatch.setitem(sys.modules, "common.token_utils", token_utils_module)
monkeypatch.setitem(sys.modules, "rag.graphrag.utils", graphrag_utils_module)
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_module)
monkeypatch.delitem(sys.modules, "rag.raptor", raising=False)
module = importlib.import_module("rag.raptor")
yield module
monkeypatch.delitem(sys.modules, "rag.raptor", raising=False)
class FakeChatModel:
llm_name = "fake-chat"
max_length = 4096
def __init__(self):
self.calls = []
async def async_chat(self, system, history, gen_conf):
self.calls.append(history[0]["content"])
return f"summary-{len(self.calls)}"
class FakeEmbeddingModel:
llm_name = "fake-embedding"
def encode(self, texts):
embeddings = []
for text in texts:
checksum = sum(ord(ch) for ch in text)
embeddings.append(np.array([len(text), checksum % 17 + 1], dtype=float))
return embeddings, len(texts)
_DEFAULT_TREE_BUILDER = object()
def _make_raptor(raptor_module, max_cluster=64, tree_builder=_DEFAULT_TREE_BUILDER, **kwargs):
if tree_builder is _DEFAULT_TREE_BUILDER:
kwargs["tree_builder"] = raptor_module.PSI_TREE_BUILDER
else:
kwargs["tree_builder"] = tree_builder
return raptor_module.RecursiveAbstractiveProcessing4TreeOrganizedRetrieval(
max_cluster,
FakeChatModel(),
FakeEmbeddingModel(),
"{cluster_content}",
max_token=32,
threshold=0.1,
**kwargs,
)
def _chunks():
return [
("alpha first", np.array([1.0, 0.0])),
("alpha second", np.array([0.99, 0.01])),
("alpha third", np.array([0.98, 0.02])),
]
def test_default_tree_builder_remains_original_raptor(raptor_module):
raptor = _make_raptor(raptor_module, tree_builder=None)
assert raptor._tree_builder == raptor_module.RAPTOR_TREE_BUILDER
def test_unknown_tree_builder_is_rejected(raptor_module):
with pytest.raises(ValueError, match="Unsupported RAPTOR tree builder"):
_make_raptor(raptor_module, tree_builder="ahc")
def test_raptor_config_accepts_hidden_psi_tree_builder():
assert RaptorConfig().tree_builder == "raptor"
assert RaptorConfig().clustering_method == "gmm"
assert RaptorConfig(clustering_method="ahc").clustering_method == "ahc"
assert RaptorConfig(tree_builder="psi").tree_builder == "psi"
with pytest.raises(ValidationError):
RaptorConfig(tree_builder="ahc")
with pytest.raises(ValidationError):
RaptorConfig(clustering_method="psi")
def test_ahc_clustering_method_is_supported_in_original_tree_builder(raptor_module):
raptor = _make_raptor(raptor_module, tree_builder=raptor_module.RAPTOR_TREE_BUILDER, clustering_method="ahc")
labels = raptor._get_clusters_ahc(np.array([[0.0, 0.0], [0.1, 0.0], [10.0, 10.0], [10.1, 10.0]]))
assert raptor._tree_builder == raptor_module.RAPTOR_TREE_BUILDER
assert raptor._clustering_method == "ahc"
assert len(labels) == 4
def test_unknown_clustering_method_is_rejected(raptor_module):
with pytest.raises(ValueError, match="Unsupported RAPTOR clustering method"):
_make_raptor(raptor_module, clustering_method="psi")
def test_psi_tree_builder_ranks_all_leaf_pairs_by_original_cosine_similarity(raptor_module):
raptor = _make_raptor(raptor_module)
leaves = [
raptor_module._PsiTreeNode(index=0, embedding=np.array([1.0, 0.0])),
raptor_module._PsiTreeNode(index=1, embedding=np.array([0.0, 1.0])),
raptor_module._PsiTreeNode(index=2, embedding=np.array([0.99, 0.01])),
raptor_module._PsiTreeNode(index=3, embedding=np.array([-1.0, 0.0])),
]
ranked_pairs = raptor._rank_leaf_pairs(leaves)
assert len(ranked_pairs) == 6
assert tuple(ranked_pairs[0]) == (2, 0)
def test_psi_tree_builder_uses_cosine_similarity_not_vector_magnitude(raptor_module):
raptor = _make_raptor(raptor_module)
leaves = [
raptor_module._PsiTreeNode(index=0, embedding=np.array([100.0, 0.0])),
raptor_module._PsiTreeNode(index=1, embedding=np.array([1.0, 1.0])),
raptor_module._PsiTreeNode(index=2, embedding=np.array([0.1, 0.0])),
]
ranked_pairs = raptor._rank_leaf_pairs(leaves)
assert tuple(ranked_pairs[0]) == (2, 0)
def test_psi_tree_builder_handles_zero_vectors_in_cosine_ranking(raptor_module):
raptor = _make_raptor(raptor_module)
leaves = [
raptor_module._PsiTreeNode(index=0, embedding=np.array([0.0, 0.0])),
raptor_module._PsiTreeNode(index=1, embedding=np.array([1.0, 0.0])),
raptor_module._PsiTreeNode(index=2, embedding=np.array([0.9, 0.1])),
]
ranked_pairs = raptor._rank_leaf_pairs(leaves)
assert tuple(ranked_pairs[0]) == (2, 1)
def test_psi_tree_builder_collapses_leaf_into_ranked_pair_parent(raptor_module):
raptor = _make_raptor(raptor_module, max_cluster=64)
root, leaves = raptor._build_psi_structure(_chunks())
assert len(root.children) == 3
assert {child.index for child in root.children} == {0, 1, 2}
assert all(leaf.parent is root for leaf in leaves)
def test_psi_tree_builder_collapses_leaf_at_matching_rank(monkeypatch, raptor_module):
raptor = _make_raptor(raptor_module, max_cluster=64)
chunks = [
("node 0", np.array([1.0, 0.0])),
("node 1", np.array([0.9, 0.1])),
("node 2", np.array([-1.0, 0.0])),
("node 3", np.array([-0.9, -0.1])),
("node 4", np.array([0.8, 0.2])),
]
monkeypatch.setattr(
raptor,
"_rank_leaf_pairs",
lambda _leaves: np.array([[0, 1], [2, 3], [0, 2], [4, 0]]),
)
root, leaves = raptor._build_psi_structure(chunks)
assert leaves[4].parent is leaves[0].parent
assert leaves[4].parent is not root
assert len(root.children) == 2
def test_psi_union_find_clamps_out_of_bounds_parent_rank(caplog, raptor_module):
union_find = raptor_module._PsiUnionFind(2)
union_find._node_ids[1] = [1]
union_find._rank[0] = 2
with caplog.at_level("WARNING"):
union_find._build(0, 1, insert_point=1)
assert union_find.tree[0] == 1
assert "rank index" in caplog.text
def test_psi_tree_builder_rebalances_nodes_over_max_children(raptor_module):
raptor = _make_raptor(raptor_module, max_cluster=2)
root, _ = raptor._build_psi_structure(_chunks())
assert all(len(node.children) <= 2 for node in raptor._iter_nodes(root))
assert len(root.children) == 2
assert any(child.children for child in root.children)
def test_psi_tree_builder_uses_bucketed_structure_for_large_inputs(monkeypatch, raptor_module):
chunks = [(f"node {idx}", np.array([float(idx), float(idx % 3 + 1)])) for idx in range(8)]
raptor = _make_raptor(
raptor_module,
max_cluster=3,
psi_exact_max_leaves=3,
psi_bucket_size=2,
)
ranked_sizes = []
original_rank = raptor._rank_leaf_pairs
def track_rank(nodes):
ranked_sizes.append(len(nodes))
return original_rank(nodes)
monkeypatch.setattr(raptor, "_rank_leaf_pairs", track_rank)
root, leaves = raptor._build_psi_structure(chunks)
assert len(leaves) == len(chunks)
assert all(leaf.parent is not None for leaf in leaves)
assert all(len(node.children) <= 3 for node in raptor._iter_nodes(root))
assert max(ranked_sizes) <= 3
@pytest.mark.asyncio
async def test_psi_tree_builder_materializes_rebalanced_summary_layers_without_umap(monkeypatch, raptor_module):
def fail_umap(*args, **kwargs):
raise AssertionError("Psi tree builder must use original embeddings, not UMAP")
monkeypatch.setattr(raptor_module.umap, "UMAP", fail_umap)
raptor = _make_raptor(raptor_module, max_cluster=2)
chunks, layers = await raptor(_chunks(), random_state=0)
assert len(chunks) == 5
assert layers == [(0, 3), (3, 4), (4, 5)]
assert [chunk[0] for chunk in chunks[3:]] == ["summary-1", "summary-2"]
@pytest.mark.asyncio
async def test_psi_tree_builder_skips_failed_node_summary(monkeypatch, raptor_module):
raptor = _make_raptor(raptor_module, max_cluster=2)
async def fail_summary(*args, **kwargs):
return None
monkeypatch.setattr(raptor, "_summarize_texts", fail_summary)
chunks, layers = await raptor(_chunks(), random_state=0)
assert len(chunks) == 3
assert [chunk[0] for chunk in chunks] == [chunk[0] for chunk in _chunks()]
assert layers == [(0, 3)]
@pytest.mark.asyncio
async def test_original_raptor_stops_when_transient_summary_fails(monkeypatch, raptor_module):
raptor = _make_raptor(raptor_module, tree_builder=raptor_module.RAPTOR_TREE_BUILDER)
async def fail_summary(*args, **kwargs):
return None
monkeypatch.setattr(raptor, "_summarize_texts", fail_summary)
input_chunks = _chunks()[:2]
chunks, layers = await raptor(input_chunks, random_state=0)
assert len(chunks) == 2
assert [chunk[0] for chunk in chunks] == [chunk[0] for chunk in input_chunks]
assert layers == [(0, 2)]

View File

@ -18,15 +18,22 @@
Unit tests for Raptor utility functions.
"""
import logging
import pytest
from rag.utils.raptor_utils import (
CSV_EXTENSIONS,
EXCEL_EXTENSIONS,
STRUCTURED_EXTENSIONS,
collect_raptor_chunk_ids,
collect_raptor_methods,
get_raptor_clustering_method,
get_raptor_tree_builder,
get_skip_reason,
is_structured_file_type,
is_tabular_pdf,
make_raptor_summary_chunk_id,
should_skip_raptor,
get_skip_reason,
EXCEL_EXTENSIONS,
CSV_EXTENSIONS,
STRUCTURED_EXTENSIONS
)
@ -283,5 +290,117 @@ class TestIntegrationScenarios:
assert should_skip_raptor(file_type, raptor_config=raptor_config) is False
class TestRaptorTreeBuilderConfig:
"""Test RAPTOR tree builder config resolution"""
def test_defaults_to_original_raptor_builder(self):
assert get_raptor_tree_builder({}) == "raptor"
assert get_raptor_tree_builder(None) == "raptor"
def test_reads_top_level_tree_builder(self):
assert get_raptor_tree_builder({"tree_builder": "psi"}) == "psi"
def test_reads_legacy_ext_tree_builder(self):
assert get_raptor_tree_builder({"ext": {"tree_builder": "psi"}}) == "psi"
def test_ext_tree_builder_overrides_stale_top_level_value(self):
assert get_raptor_tree_builder({"tree_builder": "psi", "ext": {"tree_builder": "raptor"}}) == "raptor"
def test_rejects_unknown_tree_builder(self):
with pytest.raises(ValueError, match="Unsupported RAPTOR tree builder"):
get_raptor_tree_builder({"tree_builder": "ahc"})
class TestRaptorClusteringMethodConfig:
"""Test RAPTOR clustering method config resolution"""
def test_defaults_to_gmm(self):
assert get_raptor_clustering_method({}) == "gmm"
assert get_raptor_clustering_method(None) == "gmm"
def test_reads_top_level_clustering_method(self):
assert get_raptor_clustering_method({"clustering_method": "gmm"}) == "gmm"
assert get_raptor_clustering_method({"clustering_method": "ahc"}) == "ahc"
def test_reads_legacy_ext_clustering_method(self):
assert get_raptor_clustering_method({"ext": {"clustering_method": "ahc"}}) == "ahc"
def test_ext_clustering_method_overrides_stale_top_level_value(self):
assert get_raptor_clustering_method({"clustering_method": "gmm", "ext": {"clustering_method": "ahc"}}) == "ahc"
def test_rejects_unknown_clustering_method(self):
with pytest.raises(ValueError, match="Unsupported RAPTOR clustering method"):
get_raptor_clustering_method({"clustering_method": "unknown"})
class TestRaptorMethodCollection:
"""Test RAPTOR summary method extraction from doc-store fields"""
def test_legacy_summary_without_method_is_original_raptor(self):
field_map = {"chunk_1": {"raptor_kwd": "raptor"}}
assert collect_raptor_methods(field_map) == {"raptor"}
assert collect_raptor_chunk_ids(field_map) == {"chunk_1"}
def test_extra_method_is_preserved(self):
field_map = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}}
assert collect_raptor_methods(field_map) == {"psi"}
assert collect_raptor_chunk_ids(field_map) == {"chunk_1"}
def test_extra_field_supports_oceanbase_legacy_rows(self):
field_map = {
"chunk_1": {
"extra": {
"raptor_kwd": "raptor",
"raptor_method": "psi",
}
},
"chunk_2": {
"extra": "{\"raptor_kwd\": \"raptor\"}",
},
"chunk_3": {
"extra": {"raptor_kwd": ""},
},
}
assert collect_raptor_methods(field_map) == {"psi", "raptor"}
assert collect_raptor_chunk_ids(field_map) == {"chunk_1", "chunk_2"}
def test_non_raptor_rows_are_ignored(self):
field_map = {
"chunk_1": {"raptor_kwd": ""},
"chunk_2": {"extra": {"raptor_kwd": "graph"}},
"chunk_3": {},
}
assert collect_raptor_methods(field_map) == set()
assert collect_raptor_chunk_ids(field_map) == set()
def test_malformed_extra_payload_is_logged_and_ignored(self, caplog):
field_map = {"chunk_1": {"extra": "{bad json"}}
with caplog.at_level(logging.WARNING):
assert collect_raptor_methods(field_map) == set()
assert collect_raptor_chunk_ids(field_map) == set()
assert "Ignoring malformed RAPTOR extra payload" in caplog.text
def test_chunk_id_collection_can_preserve_current_method(self):
field_map = {
"legacy": {"raptor_kwd": "raptor"},
"old": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}},
"current": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}},
}
assert collect_raptor_chunk_ids(field_map, exclude_methods={"psi"}) == {"legacy", "old"}
assert collect_raptor_chunk_ids(field_map, exclude_methods={"raptor"}) == {"current"}
def test_summary_chunk_ids_include_real_document_id(self):
content = "same generated summary"
assert make_raptor_summary_chunk_id(content, "doc-a") != make_raptor_summary_chunk_id(content, "doc-b")
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -17,7 +17,7 @@ import { DocumentParserType, ParseType } from '@/constants/knowledge';
import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-request';
import { IModalProps } from '@/interfaces/common';
import { IParserConfig } from '@/interfaces/database/document';
import { IChangeParserConfigRequestBody } from '@/interfaces/request/document';
import { IChangeParserRequestBody } from '@/interfaces/request/document';
import { MetadataType } from '@/pages/dataset/components/metedata/constant';
import {
AutoMetadata,
@ -28,7 +28,6 @@ import {
} from '@/pages/dataset/dataset-setting/configuration/common-item';
import { zodResolver } from '@hookform/resolvers/zod';
import omit from 'lodash/omit';
import {} from 'module';
import { useEffect, useMemo } from 'react';
import { useForm, useWatch } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
@ -56,10 +55,7 @@ import {
const FormId = 'ChunkMethodDialogForm';
interface IProps extends IModalProps<{
parserId: string;
parserConfig: IChangeParserConfigRequestBody;
}> {
interface IProps extends IModalProps<IChangeParserRequestBody> {
loading: boolean;
parserId: string;
pipelineId?: string;
@ -126,16 +122,19 @@ export function ChunkMethodDialog({
mineru_formula_enable: z.boolean().optional(),
mineru_table_enable: z.boolean().optional(),
mineru_lang: z.string().optional(),
// raptor: z
// .object({
// use_raptor: z.boolean().optional(),
// prompt: z.string().optional().optional(),
// max_token: z.coerce.number().optional(),
// threshold: z.coerce.number().optional(),
// max_cluster: z.coerce.number().optional(),
// random_seed: z.coerce.number().optional(),
// })
// .optional(),
raptor: z
.object({
use_raptor: z.boolean().optional(),
prompt: z.string().optional(),
max_token: z.coerce.number().optional(),
threshold: z.coerce.number().optional(),
max_cluster: z.coerce.number().optional(),
random_seed: z.coerce.number().optional(),
scope: z.string().optional(),
clustering_method: z.enum(['gmm', 'ahc']).optional(),
tree_builder: z.enum(['raptor', 'psi']).optional(),
})
.optional(),
// graphrag: z.object({
// use_graphrag: z.boolean().optional(),
// }),

View File

@ -23,14 +23,17 @@ export function useDefaultParserValues() {
mineru_formula_enable: true,
mineru_table_enable: true,
mineru_lang: 'English',
// raptor: {
// use_raptor: false,
// prompt: t('knowledgeConfiguration.promptText'),
// max_token: 256,
// threshold: 0.1,
// max_cluster: 64,
// random_seed: 0,
// },
raptor: {
use_raptor: false,
prompt: t('knowledgeConfiguration.promptText'),
max_token: 256,
threshold: 0.1,
max_cluster: 64,
random_seed: 0,
scope: 'file',
clustering_method: 'gmm',
tree_builder: 'raptor',
},
// graphrag: {
// use_graphrag: false,
// },

View File

@ -8,7 +8,7 @@ import {
} from '@/pages/dataset/dataset/generate-button/generate';
import random from 'lodash/random';
import { Shuffle } from 'lucide-react';
import { useCallback } from 'react';
import { useCallback, useEffect, useMemo } from 'react';
import { useFormContext, useWatch } from 'react-hook-form';
import { SliderInputFormField } from '../slider-input-form-field';
import {
@ -50,10 +50,10 @@ export const showTagItems = (parserId: DocumentParserType) => {
const UseRaptorField = 'parser_config.raptor.use_raptor';
const RandomSeedField = 'parser_config.raptor.random_seed';
const MaxTokenField = 'parser_config.raptor.max_token';
const ThresholdField = 'parser_config.raptor.threshold';
const MaxCluster = 'parser_config.raptor.max_cluster';
const Prompt = 'parser_config.raptor.prompt';
const ClusteringMethodField = 'parser_config.raptor.clustering_method';
const ClusteringMethodExtField = 'parser_config.raptor.ext.clustering_method';
const TreeBuilderField = 'parser_config.raptor.tree_builder';
const MaxClusterMax = 1024;
// The three types "table", "resume" and "one" do not display this configuration.
@ -67,17 +67,48 @@ const RaptorFormFields = ({
const form = useFormContext();
const { t } = useTranslate('knowledgeConfiguration');
const useRaptor = useWatch({ name: UseRaptorField });
const clusteringMethod = useWatch({ name: ClusteringMethodField });
const extClusteringMethod = useWatch({ name: ClusteringMethodExtField });
const selectedClusteringMethod = useMemo(
() =>
(clusteringMethod ??
extClusteringMethod ??
form.getValues(ClusteringMethodField) ??
form.getValues(ClusteringMethodExtField) ??
'gmm') as 'gmm' | 'ahc',
[clusteringMethod, extClusteringMethod, form],
);
const handleGenerate = useCallback(() => {
form.setValue(RandomSeedField, random(10000));
}, [form]);
const handleClusteringMethodChange = useCallback(
(method: 'gmm' | 'ahc') => {
form.setValue(ClusteringMethodField, method, {
shouldDirty: true,
shouldValidate: true,
});
form.setValue(TreeBuilderField, 'raptor', {
shouldDirty: true,
shouldValidate: true,
});
},
[form],
);
useEffect(() => {
if (!clusteringMethod && !extClusteringMethod) {
handleClusteringMethodChange('gmm');
}
}, [clusteringMethod, extClusteringMethod, handleClusteringMethodChange]);
return (
<>
<FormField
control={form.control}
name={UseRaptorField}
render={({ field }) => {
render={() => {
return (
<FormItem
defaultChecked={false}
@ -209,11 +240,61 @@ const RaptorFormFields = ({
sliderTestId="ds-settings-raptor-threshold-slider"
numberInputTestId="ds-settings-raptor-threshold-input"
></SliderInputFormField>
<FormField
control={form.control}
name={ClusteringMethodField}
render={({ field }) => {
return (
<FormItem className=" items-center space-y-0 ">
<div className="flex items-start">
<FormLabel
tooltip={t('clusteringMethodTip')}
className="text-sm whitespace-nowrap w-1/4"
>
{t('clusteringMethod')}
</FormLabel>
<div className="w-3/4">
<FormControl>
<Radio.Group
{...field}
value={selectedClusteringMethod}
onChange={(value) =>
handleClusteringMethodChange(value as 'gmm' | 'ahc')
}
>
<div
className={'flex gap-4 w-full text-text-secondary '}
>
<Radio
value="gmm"
testId="ds-settings-raptor-clustering-method-option-gmm"
>
{t('clusteringMethodGmm')}
</Radio>
<Radio
value="ahc"
testId="ds-settings-raptor-clustering-method-option-ahc"
>
{t('clusteringMethodAhc')}
</Radio>
</div>
</Radio.Group>
</FormControl>
</div>
</div>
<div className="flex pt-1">
<div className="w-1/4"></div>
<FormMessage />
</div>
</FormItem>
);
}}
/>
<SliderInputFormField
name={'parser_config.raptor.max_cluster'}
label={t('maxCluster')}
tooltip={t('maxClusterTip')}
max={1024}
max={MaxClusterMax}
min={1}
layout={FormLayout.Horizontal}
sliderTestId="ds-settings-raptor-max-cluster-slider"

View File

@ -13,6 +13,7 @@ type RadioProps = {
checked?: boolean;
disabled?: boolean;
onChange?: (checked: boolean) => void;
testId?: string;
children?: React.ReactNode;
} & Omit<
React.InputHTMLAttributes<HTMLInputElement>,
@ -25,6 +26,7 @@ function Radio({
checked,
disabled,
onChange,
testId,
children,
...props
}: RadioProps) {
@ -65,6 +67,7 @@ function Radio({
onChange={handleChange}
disabled={mergedDisabled}
className={cn('peer absolute size-[1px] opacity-0', className)}
data-testid={testId}
{...props}
name={groupContext?.name}
/>
@ -151,9 +154,11 @@ const Group = React.forwardRef<HTMLDivElement, RadioGroupProps>(
)}
>
{React.Children.map(children, (child) => {
if (!React.isValidElement<RadioProps>(child)) return child;
if (!React.isValidElement<RadioProps>(child)) {
return child;
}
return React.cloneElement(child, {
disabled: disabled || child.props?.disabled,
disabled: disabled || child.props.disabled,
});
})}
</div>

View File

@ -21,10 +21,17 @@ export const extractRaptorConfigExt = (
max_cluster,
random_seed,
scope,
clustering_method,
tree_builder,
auto_disable_for_structured_data,
ext,
...raptorExt
} = raptorConfig;
const extClusteringMethod = ext?.clustering_method;
const normalizedClusteringMethod =
clustering_method ?? extClusteringMethod ?? 'gmm';
const normalizedTreeBuilder = tree_builder ?? ext?.tree_builder ?? 'raptor';
return {
use_raptor,
prompt,
@ -34,7 +41,12 @@ export const extractRaptorConfigExt = (
random_seed,
scope,
auto_disable_for_structured_data,
ext: { ...ext, ...raptorExt },
ext: {
...ext,
...raptorExt,
clustering_method: normalizedClusteringMethod,
tree_builder: normalizedTreeBuilder,
},
};
};

View File

@ -0,0 +1,45 @@
import { extractParserConfigExt } from '../parser-config-utils';
describe('extractParserConfigExt', () => {
it('serializes RAPTOR clustering fields through ext for API compatibility', () => {
const result = extractParserConfigExt({
raptor: {
use_raptor: true,
prompt: 'Summarize {cluster_content}',
max_token: 256,
threshold: 0.1,
max_cluster: 317,
random_seed: 0,
scope: 'file',
clustering_method: 'ahc',
tree_builder: 'raptor',
},
});
expect(result?.raptor).not.toHaveProperty('clustering_method');
expect(result?.raptor).not.toHaveProperty('tree_builder');
expect(result?.raptor?.ext).toMatchObject({
clustering_method: 'ahc',
tree_builder: 'raptor',
});
});
it('preserves existing RAPTOR ext clustering values when the top-level field is absent', () => {
const result = extractParserConfigExt({
raptor: {
max_cluster: 512,
ext: {
clustering_method: 'ahc',
tree_builder: 'raptor',
psi_bucket_size: 1024,
},
},
});
expect(result?.raptor?.ext).toMatchObject({
clustering_method: 'ahc',
tree_builder: 'raptor',
psi_bucket_size: 1024,
});
});
});

View File

@ -73,11 +73,13 @@ interface Parserconfig {
}
interface Raptor {
clustering_method?: 'gmm' | 'ahc';
max_cluster: number;
max_token: number;
prompt: string;
random_seed: number;
threshold: number;
tree_builder?: 'raptor' | 'psi';
use_raptor: boolean;
}

View File

@ -11,6 +11,17 @@ export interface IChangeParserConfigRequestBody {
image_table_context_window?: number;
image_context_size?: number;
table_context_size?: number;
raptor?: {
use_raptor?: boolean;
prompt?: string;
max_token?: number;
threshold?: number;
max_cluster?: number;
random_seed?: number;
scope?: string;
clustering_method?: 'gmm' | 'ahc';
tree_builder?: 'raptor' | 'psi';
};
// Metadata fields
metadata?: Array<{
key?: string;
@ -27,8 +38,8 @@ export interface IChangeParserConfigRequestBody {
export interface IChangeParserRequestBody {
parser_id: string;
pipeline_id: string;
doc_id: string;
pipeline_id?: string;
doc_id?: string;
parser_config: IChangeParserConfigRequestBody;
}

View File

@ -861,6 +861,11 @@ The above is the content you need to summarize.`,
thresholdTip:
'In RAPTOR, chunks are clustered by their semantic similarity. The Threshold parameter sets the minimum similarity required for chunks to be grouped together. A higher Threshold means fewer chunks in each cluster, while a lower one means more.',
thresholdMessage: 'Threshold is required',
clusteringMethod: 'Clustering method',
clusteringMethodTip:
'Select the RAPTOR clustering method. AHC can use a larger max cluster value, but may require more memory on large inputs.',
clusteringMethodGmm: 'GMM',
clusteringMethodAhc: 'AHC',
maxCluster: 'Max cluster',
maxClusterTip: 'The maximum number of clusters to create.',
maxClusterMessage: 'Max cluster is required',

View File

@ -772,6 +772,11 @@ export default {
maxTokenMessage: '最大token数是必填项',
threshold: '阈值',
thresholdMessage: '阈值是必填项',
clusteringMethod: '聚类方法',
clusteringMethodTip:
'选择 RAPTOR 聚类方法。AHC 可以使用更大的最大聚类数,但在大规模输入时可能占用更多内存。',
clusteringMethodGmm: 'GMM',
clusteringMethodAhc: 'AHC',
maxCluster: '最大聚类数',
maxClusterMessage: '最大聚类数是必填项',
randomSeed: '随机种子',

View File

@ -42,11 +42,14 @@ export const formSchema = z
.object({
use_raptor: z.boolean().optional(),
prompt: z.string().optional(),
max_token: z.number().optional(),
threshold: z.number().optional(),
max_cluster: z.number().optional(),
random_seed: z.number().optional(),
max_token: z.coerce.number().optional(),
threshold: z.coerce.number().optional(),
max_cluster: z.coerce.number().optional(),
random_seed: z.coerce.number().optional(),
scope: z.string().optional(),
clustering_method: z.enum(['gmm', 'ahc']).optional(),
tree_builder: z.enum(['raptor', 'psi']).optional(),
ext: z.record(z.string(), z.any()).optional(),
})
.refine(
(data) => {

View File

@ -95,6 +95,8 @@ export default function DatasetSettings() {
max_cluster: 64,
random_seed: 0,
scope: 'file',
clustering_method: 'gmm',
tree_builder: 'raptor',
prompt: t('knowledgeConfiguration.promptText'),
},
graphrag: {

View File

@ -19,7 +19,7 @@ export const useChangeDocumentParser = () => {
if (record?.id && record?.dataset_id) {
const ret = await setDocumentParser({
parserId: parserConfigInfo.parser_id,
pipelineId: parserConfigInfo.pipeline_id,
pipelineId: parserConfigInfo.pipeline_id || '',
documentId: record?.id,
datasetId: record?.dataset_id,
parserConfig: parserConfigInfo.parser_config,