mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-22 08:58:23 +08:00
feat(raptor): add Psi tree builder with original-space ranking and safe migration (#14679)
### What problem does this PR solve? Closes #14674. This PR improves RAPTOR configuration and tree construction while preserving the existing RAPTOR behavior as the default. RAPTOR currently builds summary layers with the original UMAP + GMM clustering path. This PR keeps that default path, and adds: - A hidden backend tree-builder option: - `tree_builder="raptor"`: default, existing RAPTOR behavior. - `tree_builder="psi"`: rank-aware Psi-style tree builder using original embedding-space cosine ranking. - A user-facing clustering method option for the default RAPTOR builder: - `clustering_method="gmm"`: existing default. - `clustering_method="ahc"`: agglomerative hierarchical clustering path. - A RAPTOR UI setting for `Clustering method` and `Max cluster`. ### What changed #### Backend - Added `tree_builder` support for RAPTOR/Psi. - Added `clustering_method` support for GMM/AHC. - Kept existing RAPTOR + GMM as the default. - Added Psi tree building from original-space cosine similarity. - Added bucketed Psi building controls for large inputs: - `raptor.ext.psi_exact_max_leaves` - `raptor.ext.psi_bucket_size` - Added method-aware RAPTOR summary metadata using existing `extra.raptor_method`. - Avoided adding a dedicated DB schema field for experimental method tracking. - Added cleanup/migration logic to avoid mixing stale RAPTOR summary trees. - Added defensive checks for Psi tree construction and summary failures. #### Frontend/UI - Added `Clustering method` in RAPTOR settings with `GMM` and `AHC`. - Added/kept `Max cluster` in RAPTOR settings. - Enlarged max cluster UI limit to `1024`, matching backend validation. - Kept AHC editable even when a RAPTOR task has already finished. - Fixed the UI save payload so `clustering_method` and `tree_builder` are serialized through `parser_config.raptor.ext`, avoiding backend validation errors for extra top-level RAPTOR fields. Example saved RAPTOR config: ```json { "raptor": { "max_cluster": 317, "ext": { "clustering_method": "ahc", "tree_builder": "raptor" } } } Co-authored-by: CaptainTimon <CaptainTimon@users.noreply.github.com>
This commit is contained in:
@ -327,10 +327,14 @@ def validate_uuid1_hex(v: Any) -> str:
|
||||
|
||||
|
||||
class Base(BaseModel):
|
||||
"""Strict base model that rejects unknown request fields."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", strict=True)
|
||||
|
||||
|
||||
class RaptorConfig(Base):
|
||||
"""Dataset parser configuration for RAPTOR summary generation."""
|
||||
|
||||
use_raptor: Annotated[bool, Field(default=False)]
|
||||
prompt: Annotated[
|
||||
str,
|
||||
@ -344,11 +348,15 @@ class RaptorConfig(Base):
|
||||
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
|
||||
random_seed: Annotated[int, Field(default=0, ge=0)]
|
||||
scope: Annotated[Literal["file", "dataset"], Field(default="file")]
|
||||
clustering_method: Annotated[Literal["gmm", "ahc"], Field(default="gmm")]
|
||||
tree_builder: Annotated[Literal["raptor", "psi"], Field(default="raptor")]
|
||||
auto_disable_for_structured_data: Annotated[bool, Field(default=True)]
|
||||
ext: Annotated[dict, Field(default={})]
|
||||
|
||||
|
||||
class GraphragConfig(Base):
|
||||
"""Dataset parser configuration for GraphRAG generation."""
|
||||
|
||||
use_graphrag: Annotated[bool, Field(default=False)]
|
||||
entity_types: Annotated[list[str], Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])]
|
||||
method: Annotated[Literal["light", "general", "ner"], Field(default="light")]
|
||||
@ -357,6 +365,8 @@ class GraphragConfig(Base):
|
||||
|
||||
|
||||
class ParentChildConfig(Base):
|
||||
"""Dataset parser configuration for parent-child chunking."""
|
||||
|
||||
use_parent_child: Annotated[bool, Field(default=False)]
|
||||
children_delimiter: Annotated[str, Field(default=r"\n", min_length=1)]
|
||||
|
||||
@ -381,6 +391,8 @@ TableColumnRole = Literal["indexing", "metadata", "both"]
|
||||
|
||||
|
||||
class ParserConfig(Base):
|
||||
"""Complete parser configuration accepted by dataset APIs."""
|
||||
|
||||
auto_keywords: Annotated[int, Field(default=0, ge=0, le=32)]
|
||||
auto_questions: Annotated[int, Field(default=0, ge=0, le=10)]
|
||||
chunk_token_num: Annotated[int, Field(default=512, ge=1, le=2048)]
|
||||
@ -439,6 +451,7 @@ class UpdateDocumentReq(Base):
|
||||
@field_validator("chunk_method", mode="after")
|
||||
@classmethod
|
||||
def validate_document_chunk_method(cls, chunk_method: str | None):
|
||||
"""Validate an optional document parser method."""
|
||||
if chunk_method:
|
||||
# Validate chunk method if present
|
||||
valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "knowledge_graph", "email", "tag"}
|
||||
@ -450,6 +463,7 @@ class UpdateDocumentReq(Base):
|
||||
@field_validator("enabled", mode="after")
|
||||
@classmethod
|
||||
def validate_document_enabled(cls, enabled: str | None):
|
||||
"""Validate the optional enabled flag."""
|
||||
if enabled:
|
||||
converted = int(enabled)
|
||||
if converted < 0 or converted > 1:
|
||||
@ -460,6 +474,7 @@ class UpdateDocumentReq(Base):
|
||||
@field_validator("meta_fields", mode="after")
|
||||
@classmethod
|
||||
def validate_document_meta_fields(cls, meta_fields: dict | None):
|
||||
"""Validate user-provided document metadata values."""
|
||||
if meta_fields is None:
|
||||
return None
|
||||
|
||||
@ -475,6 +490,8 @@ class UpdateDocumentReq(Base):
|
||||
|
||||
|
||||
class CreateDatasetReq(Base):
|
||||
"""Request model for creating a dataset."""
|
||||
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
|
||||
avatar: Annotated[str | None, Field(default=None, max_length=65535)]
|
||||
description: Annotated[str | None, Field(default=None, max_length=65535)]
|
||||
@ -490,6 +507,7 @@ class CreateDatasetReq(Base):
|
||||
@field_validator("pipeline_id", mode="before")
|
||||
@classmethod
|
||||
def handle_pipeline_id(cls, v: str | None, info: ValidationInfo):
|
||||
"""Drop pipeline_id when parse_type selects direct parser mode."""
|
||||
if v is None:
|
||||
return v
|
||||
if info.data.get("parse_type", 0) == 1:
|
||||
@ -743,6 +761,8 @@ class CreateDatasetReq(Base):
|
||||
|
||||
|
||||
class UpdateDatasetReq(CreateDatasetReq):
|
||||
"""Request model for updating a dataset."""
|
||||
|
||||
dataset_id: Annotated[str, Field(...)]
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
|
||||
pagerank: Annotated[int, Field(default=0, ge=0, le=100)]
|
||||
@ -752,10 +772,13 @@ class UpdateDatasetReq(CreateDatasetReq):
|
||||
@field_validator("dataset_id", mode="before")
|
||||
@classmethod
|
||||
def validate_dataset_id(cls, v: Any) -> str:
|
||||
"""Validate and normalize the dataset id."""
|
||||
return validate_uuid1_hex(v)
|
||||
|
||||
|
||||
class DeleteReq(Base):
|
||||
"""Base request model for batch delete APIs."""
|
||||
|
||||
ids: Annotated[list[str] | None, Field(default=None)]
|
||||
delete_all: Annotated[bool, Field(default=False)]
|
||||
|
||||
@ -833,10 +856,15 @@ class DeleteReq(Base):
|
||||
return ids_list
|
||||
|
||||
|
||||
class DeleteDatasetReq(DeleteReq): ...
|
||||
class DeleteDatasetReq(DeleteReq):
|
||||
"""Request model for deleting datasets."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class DeleteDocumentReq(DeleteReq):
|
||||
"""Request model for deleting documents."""
|
||||
|
||||
@field_validator("ids", mode="after")
|
||||
@classmethod
|
||||
def validate_ids(cls, v_list: list[str] | None) -> list[str] | None:
|
||||
@ -862,6 +890,8 @@ class DeleteDocumentReq(DeleteReq):
|
||||
|
||||
|
||||
class SearchDatasetReq(BaseModel):
|
||||
"""Request model for searching one dataset."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)]
|
||||
@ -881,6 +911,8 @@ class SearchDatasetReq(BaseModel):
|
||||
|
||||
|
||||
class SearchDatasetsReq(BaseModel):
|
||||
"""Request model for searching multiple datasets."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
dataset_ids: Annotated[list[str], Field(..., min_length=1)]
|
||||
@ -901,6 +933,8 @@ class SearchDatasetsReq(BaseModel):
|
||||
|
||||
|
||||
class BaseListReq(BaseModel):
|
||||
"""Shared pagination and sorting fields for list APIs."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: Annotated[str | None, Field(default=None)]
|
||||
@ -913,10 +947,13 @@ class BaseListReq(BaseModel):
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def validate_id(cls, v: Any) -> str:
|
||||
"""Validate and normalize an optional list filter id."""
|
||||
return validate_uuid1_hex(v)
|
||||
|
||||
|
||||
class ListDatasetReq(BaseListReq):
|
||||
"""Request model for listing datasets."""
|
||||
|
||||
include_parsing_status: Annotated[bool, Field(default=False)]
|
||||
ext: Annotated[dict, Field(default={})]
|
||||
|
||||
@ -925,22 +962,29 @@ class ListDatasetReq(BaseListReq):
|
||||
|
||||
|
||||
class CreateFolderReq(Base):
|
||||
"""Request model for creating a folder."""
|
||||
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(...)]
|
||||
parent_id: Annotated[str | None, Field(default=None)]
|
||||
type: Annotated[str | None, Field(default=None)]
|
||||
|
||||
|
||||
class DeleteFileReq(Base):
|
||||
"""Request model for deleting files."""
|
||||
|
||||
ids: Annotated[list[str], Field(min_length=1)]
|
||||
|
||||
|
||||
class MoveFileReq(Base):
|
||||
"""Request model for moving or renaming files."""
|
||||
|
||||
src_file_ids: Annotated[list[str], Field(min_length=1)]
|
||||
dest_file_id: Annotated[str | None, Field(default=None)]
|
||||
new_name: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(default=None)]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_operation(self):
|
||||
"""Require either a destination folder or a new file name."""
|
||||
if not self.dest_file_id and not self.new_name:
|
||||
raise ValueError("At least one of dest_file_id or new_name must be provided")
|
||||
if self.new_name and len(self.src_file_ids) > 1:
|
||||
@ -949,6 +993,8 @@ class MoveFileReq(Base):
|
||||
|
||||
|
||||
class ListFileReq(BaseModel):
|
||||
"""Request model for listing files."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
parent_id: Annotated[str | None, Field(default=None)]
|
||||
|
||||
Reference in New Issue
Block a user