Files
ragflow/rag/svr/task_executor.py
CaptainTimon 2717ee283f feat(raptor): add Psi tree builder with original-space ranking and safe migration (#14679)
### What problem does this PR solve?

Closes #14674.

This PR improves RAPTOR configuration and tree construction while
preserving the existing RAPTOR behavior as the default.

RAPTOR currently builds summary layers with the original UMAP + GMM
clustering path. This PR keeps that default path, and adds:

- A hidden backend tree-builder option:
  - `tree_builder="raptor"`: default, existing RAPTOR behavior.
- `tree_builder="psi"`: rank-aware Psi-style tree builder using original
embedding-space cosine ranking.
- A user-facing clustering method option for the default RAPTOR builder:
  - `clustering_method="gmm"`: existing default.
- `clustering_method="ahc"`: agglomerative hierarchical clustering path.
- A RAPTOR UI setting for `Clustering method` and `Max cluster`.

### What changed

#### Backend

- Added `tree_builder` support for RAPTOR/Psi.
- Added `clustering_method` support for GMM/AHC.
- Kept existing RAPTOR + GMM as the default.
- Added Psi tree building from original-space cosine similarity.
- Added bucketed Psi building controls for large inputs:
  - `raptor.ext.psi_exact_max_leaves`
  - `raptor.ext.psi_bucket_size`
- Added method-aware RAPTOR summary metadata using existing
`extra.raptor_method`.
- Avoided adding a dedicated DB schema field for experimental method
tracking.
- Added cleanup/migration logic to avoid mixing stale RAPTOR summary
trees.
- Added defensive checks for Psi tree construction and summary failures.

#### Frontend/UI

- Added `Clustering method` in RAPTOR settings with `GMM` and `AHC`.
- Added/kept `Max cluster` in RAPTOR settings.
- Enlarged max cluster UI limit to `1024`, matching backend validation.
- Kept AHC editable even when a RAPTOR task has already finished.
- Fixed the UI save payload so `clustering_method` and `tree_builder`
are serialized through `parser_config.raptor.ext`, avoiding backend
validation errors for extra top-level RAPTOR fields.

Example saved RAPTOR config:

```json
{
  "raptor": {
    "max_cluster": 317,
    "ext": {
      "clustering_method": "ahc",
      "tree_builder": "raptor"
    }
  }
}

Co-authored-by: CaptainTimon <CaptainTimon@users.noreply.github.com>
2026-05-12 09:42:31 +08:00

1775 lines
76 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from common.misc_utils import thread_pool_exec
start_ts = time.time()
import asyncio
import socket
# from beartype import BeartypeConf
# from beartype.claw import beartype_all # <-- you didn't sign up for this
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import random
import sys
import threading
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.joint_services.memory_message_service import handle_save_to_memory_task
from common.connection_utils import timeout
from common.metadata_utils import turn2jsonschema, update_metadata_to
from rag.utils.base64_image import image2id
from rag.utils.raptor_utils import (
collect_raptor_chunk_ids,
collect_raptor_methods,
get_raptor_clustering_method,
get_raptor_tree_builder,
get_skip_reason,
make_raptor_summary_chunk_id,
should_skip_raptor,
)
from common.log_utils import init_root_logger
from common.config_utils import show_configs
from rag.graphrag.general.index import run_graphrag_for_kb
from rag.graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text, \
gen_metadata
import logging
import os
from datetime import datetime
import json
import xxhash
import copy
import re
from functools import partial
from multiprocessing.context import TimeoutError
from timeit import default_timer as timer
import signal
import exceptiongroup
import faulthandler
import numpy as np
from peewee import DoesNotExist
from common.constants import LLMType, ParserType, PipelineTaskType
from api.db.services.document_service import DocumentService
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.file2document_service import File2DocumentService
from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_tenant_default_model_by_type
from common.versions import get_ragflow_version
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
email, tag
from rag.nlp import search, rag_tokenizer, add_positions
from rag.raptor import (
RAPTOR_TREE_BUILDER,
RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor,
)
from common.token_utils import num_tokens_from_string, truncate
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from rag.graphrag.utils import chat_limiter
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common.exceptions import TaskCanceledException
from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME
from rag.utils.table_es_metadata import (
aggregate_table_manual_doc_metadata,
merge_table_parser_config_from_kb,
table_parser_strip_doc_metadata_keys,
)
BATCH_SIZE = 64
FACTORY = {
"general": naive,
ParserType.NAIVE.value: naive,
ParserType.PAPER.value: paper,
ParserType.BOOK.value: book,
ParserType.PRESENTATION.value: presentation,
ParserType.MANUAL.value: manual,
ParserType.LAWS.value: laws,
ParserType.QA.value: qa,
ParserType.TABLE.value: table,
ParserType.RESUME.value: resume,
ParserType.PICTURE.value: picture,
ParserType.ONE.value: one,
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email,
ParserType.KG.value: naive,
ParserType.TAG.value: tag
}
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
"dataflow": PipelineTaskType.PARSE,
"raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG,
"mindmap": PipelineTaskType.MINDMAP,
"memory": PipelineTaskType.MEMORY,
}
UNACKED_ITERATOR = None
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
BOOT_AT = datetime.now().astimezone().isoformat(timespec="milliseconds")
PENDING_TASKS = 0
LAG_TASKS = 0
DONE_TASKS = 0
FAILED_TASKS = 0
CURRENT_TASKS = {}
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1"))
MAX_CONCURRENT_MINIO = int(os.environ.get('MAX_CONCURRENT_MINIO', '10'))
task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
chunk_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
embed_limiter = asyncio.Semaphore(MAX_CONCURRENT_CHUNK_BUILDERS)
minio_limiter = asyncio.Semaphore(MAX_CONCURRENT_MINIO)
kg_limiter = asyncio.Semaphore(2)
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120'))
stop_event = threading.Event()
def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...")
stop_event.set()
time.sleep(1)
sys.exit(0)
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
try:
if prog is not None and prog < 0:
msg = "[ERROR]" + msg
cancel = has_canceled(task_id)
if cancel:
msg += " [Canceled]"
prog = -1
if to_page > 0:
if msg:
if from_page < to_page:
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
if msg:
msg = datetime.now().strftime("%H:%M:%S") + " " + msg
d = {"progress_msg": msg}
if prog is not None:
d["progress"] = prog
TaskService.update_progress(task_id, d)
close_connection()
if cancel:
raise TaskCanceledException(msg)
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
except TaskCanceledException:
raise
except DoesNotExist:
logging.warning(f"set_progress({task_id}) got exception DoesNotExist")
except Exception as e:
logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception: {e}")
async def collect():
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
global UNACKED_ITERATOR
svr_queue_names = settings.get_svr_queue_names()
redis_msg = None
try:
if not UNACKED_ITERATOR:
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
try:
redis_msg = next(UNACKED_ITERATOR)
except StopIteration:
for svr_queue_name in svr_queue_names:
redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
if redis_msg:
break
except Exception as e:
logging.exception(f"collect got exception: {e}")
return None, None
if not redis_msg:
return None, None
msg = redis_msg.get_message()
if not msg:
logging.error(f"collect got empty message of {redis_msg.get_msg_id()}")
redis_msg.ack()
return None, None
canceled = False
if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]:
task = msg
if task["task_type"] in PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES:
task = TaskService.get_task(msg["id"], msg["doc_ids"])
if task:
task["doc_id"] = msg["doc_id"]
task["doc_ids"] = msg.get("doc_ids", []) or []
elif msg.get("task_type") == PipelineTaskType.MEMORY.lower():
_, task_obj = TaskService.get_by_id(msg["id"])
task = task_obj.to_dict()
else:
task = TaskService.get_task(msg["id"])
if task:
canceled = has_canceled(task["id"])
if not task or canceled:
state = "is unknown" if not task else "has been cancelled"
FAILED_TASKS += 1
logging.warning(f"collect task {msg['id']} {state}")
redis_msg.ack()
return None, None
task_type = msg.get("task_type", "")
task["task_type"] = task_type
if task_type[:8] == "dataflow":
task["tenant_id"] = msg["tenant_id"]
task["dataflow_id"] = msg["dataflow_id"]
task["kb_id"] = msg.get("kb_id", "")
if task_type[:6] == "memory":
task["memory_id"] = msg["memory_id"]
task["source_id"] = msg["source_id"]
task["message_dict"] = msg["message_dict"]
return redis_msg, task
async def get_storage_binary(bucket, name):
return await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, name)
@timeout(60 * 80, 1)
async def build_chunks(task, progress_callback):
if task["size"] > settings.DOC_MAXIMUM_SIZE:
set_progress(task["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(settings.DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
chunker = FACTORY[task["parser_id"].lower()]
try:
st = timer()
bucket, name = File2DocumentService.get_storage_address(doc_id=task["doc_id"])
binary = await get_storage_binary(bucket, name)
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
except TimeoutError:
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception(
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
raise
except Exception as e:
if re.search("(No such file|not found)", str(e)):
progress_callback(-1, "Can not find file <%s> from minio. Could you try it again?" % task["name"])
else:
progress_callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
raise
# Table parser column roles / mode are stored on the dataset (KB) parser_config;
# chunk tasks carry document-level parser_config only — merge KB keys so manual roles apply.
parser_config_for_chunk = merge_table_parser_config_from_kb(task)
if task.get("parser_id", "").lower() == "table" and task.get("kb_parser_config"):
logging.debug(
"[TASK_EXECUTOR_DEBUG] table parser: merged KB keys into parser_config for chunk; "
f"mode={parser_config_for_chunk.get('table_column_mode')}, "
f"roles_keys={list((parser_config_for_chunk.get('table_column_roles') or {}).keys())}"
)
try:
async with chunk_limiter:
cks = await thread_pool_exec(
chunker.chunk,
task["name"],
binary=binary,
from_page=task["from_page"],
to_page=task["to_page"],
lang=task["language"],
callback=progress_callback,
kb_id=task["kb_id"],
parser_config=parser_config_for_chunk,
tenant_id=task["tenant_id"],
)
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException:
raise
except Exception as e:
progress_callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", ""))
logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"]))
raise
# Extract and persist PDF outline if the parser attached it.
if cks and cks[0].get("__outline__"):
outline = cks[0].pop("__outline__")
try:
DocMetadataService.update_document_metadata(
task["doc_id"],
update_metadata_to({"outline": outline},
DocMetadataService.get_document_metadata(task["doc_id"]) or {})
)
logging.info("Persisted PDF outline (%d entries) for doc %s", len(outline), task["doc_id"])
except Exception as e:
logging.warning("Failed to persist PDF outline for doc %s: %s", task["doc_id"], e)
docs = []
doc = {
"doc_id": task["doc_id"],
"kb_id": str(task["kb_id"])
}
if task["pagerank"]:
doc[PAGERANK_FLD] = int(task["pagerank"])
st = timer()
@timeout(60)
async def upload_to_minio(document, chunk):
try:
d = copy.deepcopy(document)
d.update(chunk)
d["id"] = xxhash.xxh64(
(chunk["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if d.get("img_id"):
docs.append(d)
return
if not d.get("image"):
_ = d.pop("image", None)
d["img_id"] = ""
docs.append(d)
return
await image2id(d, partial(settings.STORAGE_IMPL.put, tenant_id=task["tenant_id"]), d["id"], task["kb_id"])
docs.append(d)
except Exception:
logging.exception(
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
raise
tasks = []
for ck in cks:
tasks.append(asyncio.create_task(upload_to_minio(doc, ck)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"MINIO PUT({task['name']}) got exception: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
el = timer() - st
logging.info("MINIO PUT({}) cost {:.3f} s".format(task["name"], el))
if task["parser_config"].get("auto_keywords", 0):
st = timer()
progress_callback(msg="Start to generate keywords for every chunk ...")
chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"])
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
async def doc_keyword_extraction(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached:
if has_canceled(task["id"]):
progress_callback(-1, msg="Task has been canceled.")
return
async with chat_limiter:
cached = await keyword_extraction(chat_mdl, d["content_with_weight"], topn)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached:
d["important_kwd"] = [k for k in re.split(r"[,;;、\r\n]+", cached) if k.strip()]
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
return
tasks = []
for d in docs:
tasks.append(
asyncio.create_task(doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error in doc_keyword_extraction: {}".format(e))
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["parser_config"].get("auto_questions", 0):
st = timer()
progress_callback(msg="Start to generate questions for every chunk ...")
chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"])
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
async def doc_question_proposal(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached:
if has_canceled(task["id"]):
progress_callback(-1, msg="Task has been canceled.")
return
async with chat_limiter:
cached = await question_proposal(chat_mdl, d["content_with_weight"], topn)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached:
d["question_kwd"] = cached.split("\n")
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
tasks = []
for d in docs:
tasks.append(
asyncio.create_task(doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"])))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error in doc_question_proposal", exc_info=e)
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["parser_config"].get("enable_metadata", False) and (task["parser_config"].get("metadata") or task["parser_config"].get("built_in_metadata")):
st = timer()
progress_callback(msg="Start to generate meta-data for every chunk ...")
chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"])
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
async def gen_metadata_task(chat_mdl, d):
metadata_conf = task["parser_config"].get("metadata", [])
built_in_metadata = list(task["parser_config"].get("built_in_metadata") or [])
if isinstance(metadata_conf, dict):
if not isinstance(metadata_conf.get("properties"), dict):
metadata_conf = {"type": "object", "properties": {}}
if built_in_metadata:
metadata_conf = {
**metadata_conf,
"properties": {
**metadata_conf.get("properties", {}),
**turn2jsonschema(built_in_metadata).get("properties", {}),
},
}
elif isinstance(metadata_conf, list):
metadata_conf = metadata_conf + built_in_metadata
else:
metadata_conf = built_in_metadata
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "metadata",
metadata_conf)
if not cached:
if has_canceled(task["id"]):
progress_callback(-1, msg="Task has been canceled.")
return
async with chat_limiter:
cached = await gen_metadata(chat_mdl,
turn2jsonschema(metadata_conf),
d["content_with_weight"])
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "metadata",
metadata_conf)
if cached:
d["metadata_obj"] = cached
tasks = []
for d in docs:
tasks.append(asyncio.create_task(gen_metadata_task(chat_mdl, d)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error in doc_question_proposal", exc_info=e)
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
metadata = {}
for doc in docs:
metadata = update_metadata_to(metadata, doc["metadata_obj"])
del doc["metadata_obj"]
if metadata:
existing_meta = DocMetadataService.get_document_metadata(task["doc_id"])
existing_meta = existing_meta if isinstance(existing_meta, dict) else {}
metadata = update_metadata_to(metadata, existing_meta)
DocMetadataService.update_document_metadata(task["doc_id"], metadata)
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
if task["kb_parser_config"].get("tag_kb_ids", []):
progress_callback(msg="Start to tag for every chunk ...")
kb_ids = task["kb_parser_config"]["tag_kb_ids"]
tenant_id = task["tenant_id"]
topn_tags = task["kb_parser_config"].get("topn_tags", 3)
S = 1000
st = timer()
examples = []
all_tags = get_tags_from_cache(kb_ids)
if not all_tags:
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
set_tags_to_cache(kb_ids, all_tags)
else:
all_tags = json.loads(all_tags)
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, task["llm_id"])
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
docs_to_tag = []
for d in docs:
task_canceled = has_canceled(task["id"])
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return None
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(
d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else:
docs_to_tag.append(d)
async def doc_content_tagging(chat_mdl, d, topn_tags):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
if not cached:
if has_canceled(task["id"]):
progress_callback(-1, msg="Task has been canceled.")
return
picked_examples = random.choices(examples, k=2) if len(examples) > 2 else examples
if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
async with chat_limiter:
cached = await content_tagging(
chat_mdl,
d["content_with_weight"],
all_tags,
picked_examples,
topn_tags,
)
if cached:
cached = json.dumps(cached)
if cached:
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
d[TAG_FLD] = json.loads(cached)
tasks = []
for d in docs_to_tag:
tasks.append(asyncio.create_task(doc_content_tagging(chat_mdl, d, topn_tags)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error("Error tagging docs: {}".format(e))
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))
return docs
def build_TOC(task, docs, progress_callback):
progress_callback(msg="Start to generate table of content ...")
chat_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.CHAT, task["llm_id"])
chat_mdl = LLMBundle(task["tenant_id"], chat_model_config, lang=task["language"])
docs = sorted(docs, key=lambda d: (
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
))
toc: list[dict] = asyncio.run(
run_toc_from_text([d["content_with_weight"] for d in docs], chat_mdl, progress_callback))
logging.info("------------ T O C -------------\n" + json.dumps(toc, ensure_ascii=False, indent=' '))
for ii, item in enumerate(toc):
try:
chunk_val = item.pop("chunk_id", None)
if chunk_val is None or str(chunk_val).strip() == "":
logging.warning(f"Index {ii}: chunk_id is missing or empty. Skipping.")
continue
curr_idx = int(chunk_val)
if curr_idx >= len(docs):
logging.error(f"Index {ii}: chunk_id {curr_idx} exceeds docs length {len(docs)}.")
continue
item["ids"] = [docs[curr_idx]["id"]]
if ii + 1 < len(toc):
next_chunk_val = toc[ii + 1].get("chunk_id", "")
if str(next_chunk_val).strip() != "":
next_idx = int(next_chunk_val)
for jj in range(curr_idx + 1, min(next_idx + 1, len(docs))):
item["ids"].append(docs[jj]["id"])
else:
logging.warning(f"Index {ii + 1}: next chunk_id is empty, range fill skipped.")
except (ValueError, TypeError) as e:
logging.error(f"Index {ii}: Data conversion error - {e}")
except Exception as e:
logging.exception(f"Index {ii}: Unexpected error - {e}")
if toc:
d = copy.deepcopy(docs[-1])
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
d["toc_kwd"] = "toc"
d["available_int"] = 0
d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64(
(d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d
return None
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
parser_id = row.get("parser_id", None)
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size, parser_id)
async def embedding(docs, mdl, parser_config=None, callback=None):
if parser_config is None:
parser_config = {}
tts, cnts = [], []
for d in docs:
tts.append(d.get("docnm_kwd", "Title"))
c = "\n".join(d.get("question_kwd", []))
if not c:
c = d["content_with_weight"]
c = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c)
if not c:
c = "None"
cnts.append(c)
tk_count = 0
if len(tts) == len(cnts):
vts, c = await thread_pool_exec(mdl.encode, tts[0:1])
tts = np.tile(vts[0], (len(cnts), 1))
tk_count += c
@timeout(60)
def batch_encode(txts):
nonlocal mdl
return mdl.encode([truncate(c, mdl.max_length - 10) for c in txts])
cnts_batches = []
for i in range(0, len(cnts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await thread_pool_exec(batch_encode, cnts[i: i + settings.EMBEDDING_BATCH_SIZE])
cnts_batches.append(vts)
tk_count += c
callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="")
cnts = np.vstack(cnts_batches) if cnts_batches else np.array([])
filename_embd_weight = parser_config.get("filename_embd_weight", 0.1) # due to the db support none value
if not filename_embd_weight:
filename_embd_weight = 0.1
title_w = float(filename_embd_weight)
if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
vects = title_w * tts + (1 - title_w) * cnts
else:
vects = cnts
assert len(vects) == len(docs)
vector_size = 0
for i, d in enumerate(docs):
v = vects[i].tolist()
vector_size = len(v)
d["q_%d_vec" % len(v)] = v
return tk_count, vector_size
async def run_dataflow(task: dict):
from api.db.services.canvas_service import UserCanvasService
from rag.flow.pipeline import Pipeline
task_start_ts = timer()
dataflow_id = task["dataflow_id"]
doc_id = task["doc_id"]
task_id = task["id"]
task_dataset_id = task["kb_id"]
if task["task_type"] == "dataflow":
e, cvs = UserCanvasService.get_by_id(dataflow_id)
assert e, "User pipeline not found."
dsl = cvs.dsl
else:
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
assert e, "Pipeline log not found."
dsl = pipeline_log.dsl
dataflow_id = pipeline_log.pipeline_id
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
chunks = await pipeline.run(file=task["file"]) if task.get("file") else await pipeline.run()
if doc_id == CANVAS_DEBUG_DOC_ID:
return
if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
# The output key may exist with an empty payload; check presence, not truthiness.
if "chunks" in chunks:
chunks = copy.deepcopy(chunks["chunks"])
elif "json" in chunks:
chunks = copy.deepcopy(chunks["json"])
elif "markdown" in chunks:
chunks = [{"text": [chunks["markdown"]]}] if chunks["markdown"] else []
elif "text" in chunks:
chunks = [{"text": [chunks["text"]]}] if chunks["text"] else []
elif "html" in chunks:
chunks = [{"text": [chunks["html"]]}] if chunks["html"] else []
else:
chunks = []
# An empty normalized payload means "nothing parsed", so stop before embedding/indexing.
if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
keys = [k for o in chunks for k in list(o.keys())]
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
try:
set_progress(task_id, prog=0.82, msg="\n-------------------------------------\nStart to embedding...")
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id
embd_model_config = get_model_config_by_type_and_name(task["tenant_id"], LLMType.EMBEDDING, embedding_id)
embedding_model = LLMBundle(task["tenant_id"], embd_model_config)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects_batches = []
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20 / (len(texts) // settings.EMBEDDING_BATCH_SIZE + 1)
prog = 0.8
for i in range(0, len(texts), settings.EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await thread_pool_exec(batch_encode, texts[i: i + settings.EMBEDDING_BATCH_SIZE])
vects_batches.append(vts)
embedding_token_consumption += c
prog += delta
if i % (len(texts) // settings.EMBEDDING_BATCH_SIZE / 100 + 1) == 1:
set_progress(task_id, prog=prog, msg=f"{i + 1} / {len(texts) // settings.EMBEDDING_BATCH_SIZE}")
vects = np.vstack(vects_batches) if vects_batches else np.array([])
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
v = vects[i].tolist()
ck["q_%d_vec" % len(v)] = v
except TaskCanceledException:
raise
except Exception as e:
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
metadata = {}
for ck in chunks:
ck["doc_id"] = doc_id
ck["kb_id"] = [str(task["kb_id"])]
ck["docnm_kwd"] = task["name"]
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
ck["create_timestamp_flt"] = datetime.now().timestamp()
if not ck.get("id"):
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
if "questions" in ck:
if "question_tks" not in ck:
ck["question_kwd"] = ck["questions"].split("\n")
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
del ck["questions"]
if "keywords" in ck:
if "important_tks" not in ck:
ck["important_kwd"] = [k for k in re.split(r"[,;;、\r\n]+", ck["keywords"]) if k.strip()]
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
del ck["keywords"]
if "summary" in ck:
if "content_ltks" not in ck:
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
del ck["summary"]
if "metadata" in ck:
metadata = update_metadata_to(metadata, ck["metadata"])
del ck["metadata"]
if "content_with_weight" not in ck:
ck["content_with_weight"] = ck["text"]
del ck["text"]
if "positions" in ck:
add_positions(ck, ck["positions"])
del ck["positions"]
if metadata:
existing_meta = DocMetadataService.get_document_metadata(doc_id)
existing_meta = existing_meta if isinstance(existing_meta, dict) else {}
metadata = update_metadata_to(metadata, existing_meta)
DocMetadataService.update_document_metadata(doc_id, metadata)
start_ts = timer()
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_chunks(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id,
task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks),
task_time_cost)
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption,
task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE,
dsl=str(pipeline))
RAPTOR_METHOD_SEARCH_LIMIT = 10000
async def get_raptor_chunk_field_map(doc_id: str, tenant_id: str, kb_id: str) -> dict:
"""Return stored RAPTOR marker fields for a document."""
from common.doc_store.doc_store_base import OrderByExpr
from rag.nlp import search as nlp_search
async def search_fields(fields: list[str], condition: dict, order_by=None):
"""Search chunk fields in the current knowledge base."""
res = await thread_pool_exec(
settings.docStoreConn.search,
fields, [], condition, [], order_by or OrderByExpr(),
0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id]
)
return settings.docStoreConn.get_fields(res, fields)
primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]})
if collect_raptor_chunk_ids(primary):
return primary
try:
return await search_fields(
["raptor_kwd", "extra"],
{"doc_id": doc_id},
OrderByExpr().desc("create_timestamp_flt"),
)
except Exception:
logging.debug("RAPTOR fallback method lookup with extra field failed for doc %s", doc_id, exc_info=True)
return primary
async def get_raptor_chunk_methods(doc_id: str, tenant_id: str, kb_id: str) -> set[str]:
"""Return the RAPTOR tree builders already stored for doc_id.
Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading
chunk cannot produce a false-negative result. Legacy summary chunks that
do not have method metadata are treated as the original RAPTOR builder.
"""
try:
field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id)
methods = collect_raptor_methods(field_map)
if methods:
logging.info(
"Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist",
doc_id, tenant_id, kb_id, sorted(methods),
)
else:
logging.info(
"Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)",
doc_id, tenant_id, kb_id,
)
return methods
except Exception:
logging.exception("Failed to check RAPTOR chunks for doc %s", doc_id)
raise
async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, tree_builder: str = RAPTOR_TREE_BUILDER) -> bool:
"""Return whether doc_id already has summaries for tree_builder."""
methods = await get_raptor_chunk_methods(doc_id, tenant_id, kb_id)
return tree_builder in methods
async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_method: str | None = None):
"""Delete RAPTOR summaries for doc_id, optionally preserving one method."""
from rag.nlp import search as nlp_search
if keep_method is None:
logging.info(
"delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)",
doc_id, tenant_id, kb_id,
)
await thread_pool_exec(
settings.docStoreConn.delete,
{"doc_id": doc_id, "raptor_kwd": ["raptor"]},
nlp_search.index_name(tenant_id),
kb_id,
)
return 0
field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id)
chunk_ids = collect_raptor_chunk_ids(field_map, exclude_methods={keep_method})
if not chunk_ids:
logging.debug(
"delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)",
doc_id, tenant_id, kb_id, keep_method,
)
return 0
logging.info(
"delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)",
len(chunk_ids), doc_id, tenant_id, kb_id, keep_method,
)
await thread_pool_exec(
settings.docStoreConn.delete,
{"id": list(chunk_ids)},
nlp_search.index_name(tenant_id),
kb_id,
)
return len(chunk_ids)
@timeout(3600)
async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
"""Generate RAPTOR summaries for selected documents in a knowledge base."""
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
raptor_config = kb_parser_config.get("raptor", {})
raptor_ext_config = raptor_config.get("ext") or {}
tree_builder = get_raptor_tree_builder(raptor_config)
clustering_method = get_raptor_clustering_method(raptor_config)
vctr_nm = "q_%d_vec" % vector_size
res = []
tk_count = 0
cleanup_raptor_chunks = []
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
doc_info_by_id = {}
for doc_id in set(doc_ids):
ok, source_doc = DocumentService.get_by_id(doc_id)
if not ok or not source_doc:
continue
doc_info_by_id[doc_id] = {
"name": getattr(source_doc, "name", ""),
"type": getattr(source_doc, "type", ""),
"parser_id": getattr(source_doc, "parser_id", ""),
"parser_config": getattr(source_doc, "parser_config", {}) or {},
}
def schedule_raptor_cleanup(doc_id: str, keep_method: str | None = None):
"""Queue stale RAPTOR summaries for deletion after successful insert."""
cleanup_plan = (doc_id, keep_method)
if cleanup_plan not in cleanup_raptor_chunks:
cleanup_raptor_chunks.append(cleanup_plan)
def skip_raptor_doc(doc_id: str) -> bool:
"""Return whether RAPTOR should be skipped for this source document."""
doc_info = doc_info_by_id.get(doc_id, {})
file_type = doc_info.get("type") or row.get("type", "")
parser_id = doc_info.get("parser_id") or row.get("parser_id", "")
parser_config = doc_info.get("parser_config") or row.get("parser_config", {})
if should_skip_raptor(file_type, parser_id, parser_config, raptor_config):
skip_reason = get_skip_reason(file_type, parser_id, parser_config)
doc_name = doc_info.get("name") or doc_id
logging.info("Skipping Raptor for document %s: %s", doc_name, skip_reason)
callback(msg=f"[RAPTOR] doc:{doc_id} skipped: {skip_reason}")
return True
return False
async def generate(chunks, did):
"""Run RAPTOR and append generated summary chunks for one doc id."""
nonlocal tk_count, res
logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did)
raptor = Raptor(
raptor_config.get("max_cluster", 64),
chat_mdl,
embd_mdl,
raptor_config["prompt"],
raptor_config["max_token"],
raptor_config["threshold"],
max_errors=max_errors,
tree_builder=tree_builder,
clustering_method=clustering_method,
psi_exact_max_leaves=raptor_ext_config.get("psi_exact_max_leaves", 4096),
psi_bucket_size=raptor_ext_config.get("psi_bucket_size", 1024),
)
original_length = len(chunks)
chunks, layers = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
effective_doc_name = row["name"] if did == fake_doc_id else doc_info_by_id.get(did, {}).get("name") or row["name"]
doc = {
"doc_id": did,
"kb_id": [str(row["kb_id"])],
"docnm_kwd": effective_doc_name,
"title_tks": rag_tokenizer.tokenize(effective_doc_name),
"raptor_kwd": "raptor",
"extra": {"raptor_method": tree_builder},
}
if row["pagerank"]:
doc[PAGERANK_FLD] = int(row["pagerank"])
# Build index→layer mapping from RAPTOR layer boundaries.
# layers is [(start, end), ...] where layer 0 is the original chunks
# and layer 1+ are summary layers. We skip layer 0 (original chunks).
chunk_layer = {}
for layer_idx, (layer_start, layer_end) in enumerate(layers):
if layer_idx == 0:
continue # layer 0 = original input chunks, not summaries
for ci in range(layer_start, layer_end):
chunk_layer[ci] = layer_idx
for idx, (content, vctr) in enumerate(chunks[original_length:], start=original_length):
d = copy.deepcopy(doc)
d["id"] = make_raptor_summary_chunk_id(content, did)
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
d[vctr_nm] = vctr.tolist()
d["content_with_weight"] = content
d["content_ltks"] = rag_tokenizer.tokenize(content)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["raptor_layer_int"] = chunk_layer.get(idx, 1)
res.append(d)
tk_count += num_tokens_from_string(content)
if raptor_config.get("scope", "file") == "file":
dataset_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"])
remove_dataset_summaries = bool(dataset_methods)
has_file_level_target = False
if dataset_methods:
callback(msg="[RAPTOR] will remove dataset-level summaries after file-level summaries are available.")
for x, doc_id in enumerate(doc_ids):
if skip_raptor_doc(doc_id):
callback(prog=(x + 1.) / len(doc_ids))
continue
# CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
if tree_builder in existing_methods:
has_file_level_target = True
if existing_methods != {tree_builder}:
schedule_raptor_cleanup(doc_id, tree_builder)
callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.")
callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.")
callback(prog=(x + 1.) / len(doc_ids))
continue
if existing_methods:
callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.")
chunks = []
skipped_chunks = 0
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
# Skip chunks that don't have the required vector field (may have been indexed with different embedding model)
if vctr_nm not in d or d[vctr_nm] is None:
skipped_chunks += 1
logging.warning(f"RAPTOR: Chunk missing vector field '{vctr_nm}' in doc {doc_id}, skipping")
continue
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
if skipped_chunks > 0:
callback(msg=f"[WARN] Skipped {skipped_chunks} chunks without vector field '{vctr_nm}' for doc {doc_id}. Consider re-parsing the document with the current embedding model.")
if not chunks:
logging.warning(f"RAPTOR: No valid chunks with vectors found for doc {doc_id}")
callback(msg=f"[WARN] No valid chunks with vectors found for doc {doc_id}, skipping")
continue
before_generate = len(res)
await generate(chunks, doc_id)
if len(res) > before_generate:
has_file_level_target = True
if existing_methods:
schedule_raptor_cleanup(doc_id, tree_builder)
callback(prog=(x + 1.) / len(doc_ids))
if remove_dataset_summaries:
if has_file_level_target:
schedule_raptor_cleanup(fake_doc_id)
else:
callback(msg="[RAPTOR] kept dataset-level summaries because no file-level summaries were built.")
else:
migrated_file_docs = 0
file_cleanup_doc_ids = []
skipped_doc_ids = set()
for doc_id in set(doc_ids):
if skip_raptor_doc(doc_id):
skipped_doc_ids.add(doc_id)
continue
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
if existing_methods:
file_cleanup_doc_ids.append(doc_id)
migrated_file_docs += 1
if migrated_file_docs:
callback(msg=f"[RAPTOR] will remove file-level summaries for {migrated_file_docs} docs after dataset-level build succeeds.")
existing_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"])
if tree_builder in existing_methods:
if existing_methods != {tree_builder}:
schedule_raptor_cleanup(fake_doc_id, tree_builder)
callback(msg="[RAPTOR] will remove old dataset-level RAPTOR summaries after insert.")
for doc_id in file_cleanup_doc_ids:
schedule_raptor_cleanup(doc_id)
callback(msg=f"[RAPTOR] dataset-level {tree_builder} summaries already exist, skipping.")
return res, tk_count, cleanup_raptor_chunks
migrate_dataset_summaries = bool(existing_methods)
if migrate_dataset_summaries:
callback(msg=f"[RAPTOR] will migrate dataset-level RAPTOR summaries to {tree_builder} after insert.")
chunks = []
skipped_chunks = 0
for doc_id in doc_ids:
if doc_id in skipped_doc_ids:
continue
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
# Skip chunks that don't have the required vector field
if vctr_nm not in d or d[vctr_nm] is None:
skipped_chunks += 1
logging.warning(f"RAPTOR: Chunk missing vector field '{vctr_nm}' in doc {doc_id}, skipping")
continue
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
if skipped_chunks > 0:
callback(msg=f"[WARN] Skipped {skipped_chunks} chunks without vector field '{vctr_nm}'. Consider re-parsing documents with the current embedding model.")
if not chunks:
if skipped_doc_ids and len(skipped_doc_ids) == len(set(doc_ids)):
callback(msg="[RAPTOR] all documents were skipped by RAPTOR auto-disable rules.")
return res, tk_count, cleanup_raptor_chunks
logging.error(f"RAPTOR: No valid chunks with vectors found in any document for kb {row['kb_id']}")
callback(msg=f"[ERROR] No valid chunks with vectors found. Please ensure documents are parsed with the current embedding model (vector size: {vector_size}).")
return res, tk_count, cleanup_raptor_chunks
before_generate = len(res)
await generate(chunks, fake_doc_id)
if len(res) > before_generate:
for doc_id in file_cleanup_doc_ids:
schedule_raptor_cleanup(doc_id)
if migrate_dataset_summaries:
schedule_raptor_cleanup(fake_doc_id, tree_builder)
return res, tk_count, cleanup_raptor_chunks
async def delete_image(kb_id, chunk_id):
try:
async with minio_limiter:
settings.STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
raise
async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
"""
Insert chunks into document store (Elasticsearch OR Infinity).
Args:
task_id: Task identifier
task_tenant_id: Tenant ID
task_dataset_id: Dataset/knowledge base ID
chunks: List of chunk dictionaries to insert
progress_callback: Callback function for progress updates
"""
mothers = []
mother_ids = set([])
for ck in chunks:
mom = ck.get("mom") or ck.get("mom_with_weight") or ""
if not mom:
continue
id = xxhash.xxh64(mom.encode("utf-8")).hexdigest()
ck["mom_id"] = id
if id in mother_ids:
continue
mother_ids.add(id)
mom_ck = copy.deepcopy(ck)
mom_ck["id"] = id
mom_ck["content_with_weight"] = mom
mom_ck["available_int"] = 0
flds = list(mom_ck.keys())
for fld in flds:
if fld not in ["id", "content_with_weight", "doc_id", "docnm_kwd", "kb_id", "available_int",
"position_int", "create_timestamp_flt", "page_num_int", "top_int"]:
del mom_ck[fld]
mothers.append(mom_ck)
for b in range(0, len(mothers), settings.DOC_BULK_SIZE):
await thread_pool_exec(settings.docStoreConn.insert, mothers[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return False
for b in range(0, len(chunks), settings.DOC_BULK_SIZE):
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert, chunks[b:b + settings.DOC_BULK_SIZE],
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
# Roll back partial RAPTOR summary inserts so the next run is not
# mistaken for a completed checkpoint by get_raptor_chunk_methods.
raptor_ids_to_rollback = [
c["id"] for c in chunks[:b + settings.DOC_BULK_SIZE]
if c.get("raptor_kwd") == "raptor"
]
if raptor_ids_to_rollback:
try:
await thread_pool_exec(
settings.docStoreConn.delete,
{"id": raptor_ids_to_rollback},
search.index_name(task_tenant_id),
task_dataset_id,
)
logging.info(
"insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)",
len(raptor_ids_to_rollback), task_id,
)
except Exception:
logging.exception(
"insert_chunks: failed to roll back partial RAPTOR chunks after cancellation (task=%s)",
task_id,
)
progress_callback(-1, msg="Task has been canceled.")
return False
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + settings.DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await thread_pool_exec(settings.docStoreConn.delete, {"id": chunk_ids},
search.index_name(task_tenant_id), task_dataset_id, )
tasks = []
for chunk_id in chunk_ids:
tasks.append(asyncio.create_task(delete_image(task_dataset_id, chunk_id)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"delete_image failed: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return False
return True
@timeout(60 * 60 * 3, 1)
async def do_handle_task(task):
task_type = task.get("task_type", "")
if task_type == "memory":
await handle_save_to_memory_task(task)
return
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
await run_dataflow(task)
return
task_id = task["id"]
task_from_page = task["from_page"]
task_to_page = task["to_page"]
task_tenant_id = task["tenant_id"]
task_embedding_id = task["embd_id"]
task_language = task["language"]
doc_task_llm_id = task["parser_config"].get("llm_id") or task["llm_id"]
kb_task_llm_id = task['kb_parser_config'].get("llm_id") or task["llm_id"]
task['llm_id'] = kb_task_llm_id
task_dataset_id = task["kb_id"]
task_doc_id = task["doc_id"]
task_document_name = task["name"]
task_parser_config = task["parser_config"]
task_start_ts = timer()
toc_thread = None
raptor_cleanup_chunks = []
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
try:
# bind embedding model
if task_embedding_id:
embd_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.EMBEDDING, task_embedding_id)
else:
embd_model_config = get_tenant_default_model_by_type(task_tenant_id, LLMType.EMBEDDING)
embedding_model = LLMBundle(task_tenant_id, embd_model_config, lang=task_language)
vts, _ = embedding_model.encode(["ok"])
vector_size = len(vts[0])
except Exception as e:
error_message = f'Fail to bind embedding model: {str(e)}'
progress_callback(-1, msg=error_message)
logging.exception(error_message)
raise
init_kb(task, vector_size)
if task_type[:len("dataflow")] == "dataflow":
await run_dataflow(task)
return
if task_type == "raptor":
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
if not ok:
progress_callback(prog=-1.0, msg="Cannot found valid dataset for RAPTOR task")
return
kb_parser_config = kb.parser_config
if not kb_parser_config.get("raptor", {}).get("use_raptor", False):
kb_parser_config.update(
{
"raptor": {
"use_raptor": True,
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
"max_token": 256,
"threshold": 0.1,
"max_cluster": 64,
"random_seed": 0,
"scope": "file",
"clustering_method": "gmm",
"tree_builder": "raptor",
},
}
)
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
# bind LLM for raptor
chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id)
chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language)
# run RAPTOR
async with kg_limiter:
chunks, token_count, raptor_cleanup_chunks = await run_raptor_for_kb(
row=task,
kb_parser_config=kb_parser_config,
chat_mdl=chat_model,
embd_mdl=embedding_model,
vector_size=vector_size,
callback=progress_callback,
doc_ids=task.get("doc_ids", []),
)
if fake_doc_ids := task.get("doc_ids", []):
task_doc_id = fake_doc_ids[0] # use the first document ID to represent this task for logging purposes
# Either using graphrag or Standard chunking methods
elif task_type == "graphrag":
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
if not ok:
progress_callback(prog=-1.0, msg="Cannot found valid dataset for GraphRAG task")
return
kb_parser_config = kb.parser_config
if not kb_parser_config.get("graphrag", {}).get("use_graphrag", False):
kb_parser_config.update(
{
"graphrag": {
"use_graphrag": True,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category",
],
"method": "light",
}
}
)
if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": kb_parser_config}):
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
return
graphrag_conf = kb_parser_config.get("graphrag", {})
start_ts = timer()
chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id)
chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language)
with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False)
async with kg_limiter:
# await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
result = await run_graphrag_for_kb(
row=task,
doc_ids=task.get("doc_ids", []),
language=task_language,
kb_parser_config=kb_parser_config,
chat_model=chat_model,
embedding_model=embedding_model,
callback=progress_callback,
with_resolution=with_resolution,
with_community=with_community,
)
logging.info(f"GraphRAG task result for task {task}:\n{result}")
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
return
elif task_type == "mindmap":
progress_callback(1, "place holder")
pass
return
else:
# Standard chunking methods
task['llm_id'] = doc_task_llm_id
start_ts = timer()
chunks = await build_chunks(task, progress_callback)
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
if not chunks:
progress_callback(1., msg=f"No chunk built from {task_document_name}")
return
progress_callback(msg="Generate {} chunks".format(len(chunks)))
start_ts = timer()
try:
token_count, vector_size = await embedding(chunks, embedding_model, task_parser_config, progress_callback)
except TaskCanceledException:
raise
except Exception as e:
error_message = "Generate embedding error:{}".format(str(e))
progress_callback(-1, error_message)
logging.exception(error_message)
token_count = 0
raise
progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
logging.info(progress_message)
progress_callback(msg=progress_message)
if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False):
toc_thread = asyncio.create_task(asyncio.to_thread(build_TOC, task, chunks, progress_callback))
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
async def _maybe_insert_chunks(_chunks):
if has_canceled(task_id):
progress_callback(-1, msg="Task has been canceled.")
return False
insert_result = await insert_chunks(task_id, task_tenant_id, task_dataset_id, _chunks, progress_callback)
return bool(insert_result)
try:
if not await _maybe_insert_chunks(chunks):
return
if has_canceled(task_id):
progress_callback(-1, msg="Task has been canceled.")
return
if raptor_cleanup_chunks:
cleaned_chunks = 0
for cleanup_doc_id, keep_method in raptor_cleanup_chunks:
cleaned_chunks += await delete_raptor_chunks(
cleanup_doc_id,
task_tenant_id,
task_dataset_id,
keep_method=keep_method,
)
if cleaned_chunks:
progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.")
logging.info(
"Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(
task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts
)
)
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
# Table parser (manual): push metadata/both column values to document-level metadata for UI / chat filters
if task.get("parser_id", "").lower() == "table":
eff_pc = merge_table_parser_config_from_kb(task)
logging.debug(
f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}"
)
if eff_pc.get("table_column_mode") == "manual":
try:
agg = aggregate_table_manual_doc_metadata(chunks, task)
logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}")
strip_keys = table_parser_strip_doc_metadata_keys(eff_pc)
existing = DocMetadataService.get_document_metadata(task_doc_id)
existing = existing if isinstance(existing, dict) else {}
preserved = {k: v for k, v in existing.items() if k not in strip_keys}
merged = update_metadata_to(dict(preserved), agg)
logging.debug(
f"[TABLE_META_DEBUG] calling update_document_metadata for doc_id={task_doc_id}, "
f"meta_fields keys={list(merged.keys())}, "
f"table_strip_key_count={len(strip_keys)}, agg_keys={list(agg.keys())}"
)
try:
DocMetadataService.update_document_metadata(task_doc_id, merged)
logging.debug("[TABLE_META_DEBUG] update_document_metadata succeeded")
except Exception as ue:
logging.error(
"update_document_metadata failed (table parser, doc_id=%s): %s",
task_doc_id,
ue,
exc_info=True,
)
except Exception as e:
logging.exception(
"Table parser document metadata aggregation failed (doc_id=%s): %s",
task_doc_id,
e,
)
progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts))
if toc_thread:
d = await toc_thread
if d:
if not await _maybe_insert_chunks([d]):
return
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, 0, 1, 0)
if has_canceled(task_id):
progress_callback(-1, msg="Task has been canceled.")
return
task_time_cost = timer() - task_start_ts
progress_callback(prog=1.0, msg="Task done ({:.2f}s)".format(task_time_cost))
logging.info(
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(
task_document_name, task_from_page, task_to_page, len(chunks), token_count, task_time_cost
)
)
finally:
if toc_thread is not None and not toc_thread.done():
toc_thread.cancel()
if has_canceled(task_id):
try:
exists = await thread_pool_exec(
settings.docStoreConn.index_exist,
search.index_name(task_tenant_id),
task_dataset_id,
)
if exists:
await thread_pool_exec(
settings.docStoreConn.delete,
{"doc_id": task_doc_id},
search.index_name(task_tenant_id),
task_dataset_id,
)
except Exception as e:
logging.exception(
f"Remove doc({task_doc_id}) from docStore failed when task({task_id}) canceled, exception: {e}")
async def handle_task():
global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect()
if not task:
await asyncio.sleep(5)
return
task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type,
PipelineTaskType.PARSE) or PipelineTaskType.PARSE
task_id = task["id"]
try:
logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
await do_handle_task(task)
DONE_TASKS += 1
CURRENT_TASKS.pop(task_id, None)
logging.info(f"handle_task done for task {json.dumps(task)}")
except TaskCanceledException as e:
DONE_TASKS += 1
CURRENT_TASKS.pop(task_id, None)
logging.info(
f"handle_task canceled for task {task_id}: {getattr(e, 'msg', str(e))}"
)
except Exception as e:
FAILED_TASKS += 1
CURRENT_TASKS.pop(task_id, None)
try:
err_msg = str(e)
while isinstance(e, exceptiongroup.ExceptionGroup):
e = e.exceptions[0]
err_msg += ' -- ' + str(e)
set_progress(task_id, prog=-1, msg=f"[Exception]: {err_msg}")
except Exception as e:
logging.exception(f"[Exception]: {str(e)}")
pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
finally:
if not task.get("dataflow_id", ""):
referred_document_id = None
if task_type in ["graphrag", "raptor", "mindmap"]:
referred_document_id = task["doc_ids"][0]
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="",
task_type=pipeline_task_type,
task_id=task_id, referred_document_id=referred_document_id)
redis_msg.ack()
async def get_server_ip() -> str:
# get ip by udp
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
except Exception as e:
logging.error(str(e))
return 'Unknown'
async def report_status():
"""
Periodically reports the executor's heartbeat
"""
global PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS
ip_address = await get_server_ip()
pid = os.getpid()
# Register the executor in Redis
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
redis_lock = RedisDistributedLock("clean_task_executor", lock_value=CONSUMER_NAME, timeout=60)
while True:
now = datetime.now()
now_ts = now.timestamp()
group_info = REDIS_CONN.queue_info(settings.get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME) or {}
PENDING_TASKS = int(group_info.get("pending", 0))
LAG_TASKS = int(group_info.get("lag", 0))
current = copy.deepcopy(CURRENT_TASKS)
heartbeat = json.dumps({
"ip_address": ip_address,
"pid": pid,
"name": CONSUMER_NAME,
"now": now.astimezone().isoformat(timespec="milliseconds"),
"boot_at": BOOT_AT,
"pending": PENDING_TASKS,
"lag": LAG_TASKS,
"done": DONE_TASKS,
"failed": FAILED_TASKS,
"current": current,
})
# Report heartbeat to Redis
try:
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now_ts)
except Exception as e:
logging.warning(f"Failed to report heartbeat: {e}")
else:
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
# Clean up own expired heartbeat
try:
REDIS_CONN.zremrangebyscore(CONSUMER_NAME, 0, now_ts - 60 * 30)
except Exception as e:
logging.warning(f"Failed to clean heartbeat: {e}")
# Clean other executors
lock_acquired = False
try:
lock_acquired = redis_lock.acquire()
except Exception as e:
logging.warning(f"Failed to acquire Redis lock: {e}")
if lock_acquired:
try:
task_executors = REDIS_CONN.smembers("TASKEXE") or set()
for worker_name in task_executors:
if worker_name == CONSUMER_NAME:
continue
try:
last_heartbeat = REDIS_CONN.REDIS.zrevrange(worker_name, 0, 0, withscores=True)
except Exception as e:
logging.warning(f"Failed to read zset for {worker_name}: {e}")
continue
if not last_heartbeat or now_ts - last_heartbeat[0][1] > WORKER_HEARTBEAT_TIMEOUT:
logging.info(f"{worker_name} expired, removed")
REDIS_CONN.srem("TASKEXE", worker_name)
REDIS_CONN.delete(worker_name)
except Exception as e:
logging.warning(f"Failed to clean other executors: {e}")
finally:
redis_lock.release()
await asyncio.sleep(30)
async def task_manager():
try:
await handle_task()
finally:
task_limiter.release()
async def main():
# Stagger executor startup to prevent connection storm to Infinity
# Extract worker number from CONSUMER_NAME (e.g., "task_executor_abc123_5" -> 5)
try:
worker_num = int(CONSUMER_NAME.rsplit("_", 1)[-1])
# Add random delay: base delay + worker_num * 2.0s + random jitter
# This spreads out connection attempts over several seconds
startup_delay = worker_num * 2.0 + random.uniform(0, 0.5)
if startup_delay > 0:
logging.info(f"Staggering startup by {startup_delay:.2f}s to prevent connection storm")
await asyncio.sleep(startup_delay)
except (ValueError, IndexError):
pass # Non-standard consumer name, skip delay
logging.info(r"""
____ __ _
/ _/___ ____ ____ _____/ /_(_)___ ____ ________ ______ _____ _____
/ // __ \/ __ `/ _ \/ ___/ __/ / __ \/ __ \ / ___/ _ \/ ___/ | / / _ \/ ___/
_/ // / / / /_/ / __(__ ) /_/ / /_/ / / / / (__ ) __/ / | |/ / __/ /
/___/_/ /_/\__, /\___/____/\__/_/\____/_/ /_/ /____/\___/_/ |___/\___/_/
/____/
""")
logging.info(f'RAGFlow ingestion version: {get_ragflow_version()}')
show_configs()
settings.init_settings()
settings.check_and_install_torch()
logging.info(f'default embedding config: {settings.EMBEDDING_CFG}')
settings.print_rag_settings()
if sys.platform != "win32":
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
signal.signal(signal.SIGUSR2, stop_tracemalloc)
TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0"))
if TRACE_MALLOC_ENABLED:
start_tracemalloc_and_snapshot(None, None)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
report_task = asyncio.create_task(report_status())
tasks = []
logging.info(f"RAGFlow ingestion is ready after {time.time() - start_ts}s initialization.")
try:
while not stop_event.is_set():
await task_limiter.acquire()
t = asyncio.create_task(task_manager())
tasks.append(t)
finally:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
report_task.cancel()
await asyncio.gather(report_task, return_exceptions=True)
logging.error("BUG!!! You should not reach here!!!")
if __name__ == "__main__":
faulthandler.enable()
init_root_logger(CONSUMER_NAME)
try:
asyncio.run(main())
except Exception as e:
logging.exception(f"Unhandled exception: {e}")
sys.exit(1)