diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index a041ee081..5e034f0c5 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -439,6 +439,7 @@ def get_parser_config(chunk_method, parser_config): "category", ], "method": "light", + "batch_chunk_token_size": 4096, }, "parent_child": { "use_parent_child": False, diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 1e6c0056b..861f94ee2 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -362,6 +362,7 @@ class GraphragConfig(Base): method: Annotated[Literal["light", "general", "ner"], Field(default="light")] community: Annotated[bool, Field(default=False)] resolution: Annotated[bool, Field(default=False)] + batch_chunk_token_size: Annotated[int, Field(default=4096, ge=512, le=8196)] class ParentChildConfig(Base): diff --git a/rag/graphrag/general/index.py b/rag/graphrag/general/index.py index 9898b19a3..396f3aae0 100644 --- a/rag/graphrag/general/index.py +++ b/rag/graphrag/general/index.py @@ -54,6 +54,22 @@ from common import settings from common.doc_store.doc_store_base import OrderByExpr +DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE = 4096 + + +def _positive_int_config(config: dict, key: str, default: int) -> int: + value = config.get(key, default) + try: + value = int(value) + except (TypeError, ValueError): + logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default) + return default + if value < 512 or value > 8196: + logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default) + return default + return value + + def _select_extractor(graphrag_config: dict): """Return the extractor class matching ``graphrag_config["method"]``. @@ -121,100 +137,6 @@ async def load_subgraph_from_store(tenant_id: str, kb_id: str, doc_id: str): return None -async def run_graphrag( - row: dict, - language, - with_resolution: bool, - with_community: bool, - chat_model, - embedding_model, - callback, -): - enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - start = asyncio.get_running_loop().time() - tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] - chunks = [] - for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True): - chunks.append(d["content_with_weight"]) - - timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000 - - try: - subgraph = await asyncio.wait_for( - generate_subgraph( - _select_extractor(row["kb_parser_config"].get("graphrag", {})), - tenant_id, - kb_id, - doc_id, - chunks, - language, - row["kb_parser_config"]["graphrag"].get("entity_types", []), - chat_model, - embedding_model, - callback, - ), - timeout=timeout_sec, - ) - except asyncio.TimeoutError: - logging.error("generate_subgraph timeout") - raise - - if not subgraph: - return - - graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200) - await graphrag_task_lock.spin_acquire() - callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired") - - try: - subgraph_nodes = set(subgraph.nodes()) - new_graph = await merge_subgraph( - tenant_id, - kb_id, - doc_id, - subgraph, - embedding_model, - callback, - ) - assert new_graph is not None - - if not with_resolution and not with_community: - return - - if with_resolution: - await graphrag_task_lock.spin_acquire() - callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired") - await resolve_entities( - new_graph, - subgraph_nodes, - tenant_id, - kb_id, - doc_id, - chat_model, - embedding_model, - callback, - task_id=row["id"], - ) - if with_community: - await graphrag_task_lock.spin_acquire() - callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired") - await extract_community( - new_graph, - tenant_id, - kb_id, - doc_id, - chat_model, - embedding_model, - callback, - task_id=row["id"], - ) - finally: - graphrag_task_lock.release() - now = asyncio.get_running_loop().time() - callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.") - return - - async def run_graphrag_for_kb( row: dict, doc_ids: list[str], @@ -232,6 +154,8 @@ async def run_graphrag_for_kb( enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") start = asyncio.get_running_loop().time() fields_for_chunks = ["content_with_weight", "doc_id"] + graphrag_config = kb_parser_config.get("graphrag", {}) + batch_chunk_token_size = _positive_int_config(graphrag_config, "batch_chunk_token_size", DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE) if not doc_ids: logging.info(f"Fetching all docs for {kb_id}") @@ -259,21 +183,20 @@ async def run_graphrag_for_kb( chunks = [] current_chunk = "" - # DEBUG: Obtener todos los chunks primero raw_chunks = list(settings.retriever.chunk_list( doc_id, tenant_id, [kb_id], - max_count=10000, # FIX: Aumentar límite para procesar todos los chunks fields=fields_for_chunks, sort_by_position=True, + retrieve_all=True )) - callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}") + callback(msg=f"[GraphRAG] chunk_list returned {len(raw_chunks)} raw chunks for doc:{doc_id}") for d in raw_chunks: content = d["content_with_weight"] - if num_tokens_from_string(current_chunk + content) < 4096: + if num_tokens_from_string(current_chunk + content) < batch_chunk_token_size: current_chunk += content else: if current_chunk: @@ -285,16 +208,7 @@ async def run_graphrag_for_kb( return chunks - all_doc_chunks: dict[str, list[str]] = {} total_chunks = 0 - for doc_id in doc_ids: - chunks = load_doc_chunks(doc_id) - all_doc_chunks[doc_id] = chunks - total_chunks += len(chunks) - - if total_chunks == 0: - callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.") - return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0} semaphore = asyncio.Semaphore(max_parallel_docs) @@ -302,18 +216,13 @@ async def run_graphrag_for_kb( failed_docs: list[tuple[str, str]] = [] # (doc_id, error) async def build_one(doc_id: str): + nonlocal total_chunks + if has_canceled(row["id"]): callback(msg=f"Task {row['id']} cancelled, stopping execution.") raise TaskCanceledException(f"Task {row['id']} was cancelled") - chunks = all_doc_chunks.get(doc_id, []) - if not chunks: - callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.") - return - - kg_extractor = _select_extractor(kb_parser_config.get("graphrag", {})) - - deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000 + kg_extractor = _select_extractor(graphrag_config) async with semaphore: # CHECKPOINT: bounded by semaphore so doc-store lookups respect max_parallel_docs @@ -323,6 +232,13 @@ async def run_graphrag_for_kb( callback(msg=f"[GraphRAG] doc:{doc_id} subgraph found in store, skipping LLM extraction.") return try: + chunks = load_doc_chunks(doc_id) + total_chunks += len(chunks) + if not chunks: + callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.") + return + + deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000 msg = f"[GraphRAG] build_subgraph doc:{doc_id}" callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)") @@ -373,6 +289,10 @@ async def run_graphrag_for_kb( await asyncio.gather(*tasks, return_exceptions=True) raise + if total_chunks == 0 and not subgraphs: + callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.") + return {"ok_docs": [], "failed_docs": [(doc_id, "no available chunks") for doc_id in doc_ids], "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0} + if has_canceled(row["id"]): callback(msg=f"Task {row['id']} cancelled after document processing.") raise TaskCanceledException(f"Task {row['id']} was cancelled") diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 980dba04d..e79671f04 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -753,7 +753,13 @@ class Dealer: kb_ids: list[str], max_count=1024, offset=0, fields=["docnm_kwd", "content_with_weight", "img_id"], - sort_by_position: bool = False): + sort_by_position: bool = False, + retrieve_all: bool = False): + """Return chunks for a document. + + By default, preserve the historical max_count cap. When retrieve_all is + True, keep paging until the doc store returns fewer rows than requested. + """ condition = {"doc_id": doc_id} fields_set = set(fields or []) @@ -771,8 +777,9 @@ class Dealer: res = [] bs = 128 - for p in range(offset, max_count, bs): - limit = min(bs, max_count - p) + p = offset + while retrieve_all or p < max_count: + limit = bs if retrieve_all else min(bs, max_count - p) if limit <= 0: break es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, limit, index_name(tenant_id), @@ -785,6 +792,7 @@ class Dealer: chunk_count = len(dict_chunks) if chunk_count == 0 or chunk_count < limit: break + p += limit return res def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000): diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 548d88ab1..e639ba6e4 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -1390,6 +1390,7 @@ async def do_handle_task(task): "category", ], "method": "light", + "batch_chunk_token_size": 4096, } } ) diff --git a/test/testcases/configs.py b/test/testcases/configs.py index 546cd378c..a4711bf15 100644 --- a/test/testcases/configs.py +++ b/test/testcases/configs.py @@ -65,6 +65,7 @@ DEFAULT_PARSER_CONFIG = { "category", ], "method": "light", + "batch_chunk_token_size": 4096, }, "parent_child": { "use_parent_child": False, diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py index b24d9deea..de0b4189b 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py @@ -387,6 +387,7 @@ DEFAULT_PARSER_CONFIG_FOR_TEST = { "category", ], "method": "light", + "batch_chunk_token_size": 4096, }, } diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py index f174f0e54..2b02c0b19 100644 --- a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py @@ -313,6 +313,7 @@ DEFAULT_PARSER_CONFIG_FOR_TEST = { "category", ], "method": "light", + "batch_chunk_token_size": 4096, }, } diff --git a/web/src/components/parse-configuration/graph-rag-form-fields.tsx b/web/src/components/parse-configuration/graph-rag-form-fields.tsx index d85c88364..f3791c575 100644 --- a/web/src/components/parse-configuration/graph-rag-form-fields.tsx +++ b/web/src/components/parse-configuration/graph-rag-form-fields.tsx @@ -1,3 +1,4 @@ +import { FormLayout } from '@/constants/form'; import { DocumentParserType } from '@/constants/knowledge'; import { useTranslate } from '@/hooks/common-hooks'; import { cn } from '@/lib/utils'; @@ -12,6 +13,7 @@ import { useCallback, useMemo } from 'react'; import { useFormContext, useWatch } from 'react-hook-form'; import { EntityTypesFormField } from '../entity-types-form-field'; import { FormContainer } from '../form-container'; +import { SliderInputFormField } from '../slider-input-form-field'; import { FormControl, FormField, @@ -191,6 +193,19 @@ const GraphRagItems = ({ )} /> + + General: Use prompts provided by github.com/microsoft/graphrag to extract entities and relationships.
NER: Use spaCy NER and rule-based keyword extraction to extract entities and relationships. No LLM is required for extraction itself, making it fast and resource-efficient.`, + graphRagBatchChunkTokenSize: 'Batch chunk token size', + graphRagBatchChunkTokenSizeTip: + 'The token limit for each batch of chunks sent to the LLM for knowledge graph entity and relation extraction. Not applied to NER.', resolution: 'Entity resolution', resolutionTip: `An entity deduplication switch. When enabled, the LLM will combine similar entities - e.g., '2025' and 'the year of 2025', or 'IT' and 'Information Technology' - to construct a more accurate graph`, community: 'Community reports', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index dcd8f5871..81b95ee88 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -818,6 +818,9 @@ export default { graphRagMethodTip: `Light:实体和关系提取提示来自 GitHub - HKUDS/LightRAG:“LightRAG:简单快速的检索增强生成”
General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于图的模块化检索增强生成 (RAG) 系统
NER:使用 spaCy NER 和基于规则的关键词提取来抽取实体和关系,无需 LLM 参与提取过程,速度快且资源消耗低`, + graphRagBatchChunkTokenSize: '批量chunk token 大小', + graphRagBatchChunkTokenSizeTip: + '发送给 LLM 进行知识图谱实体和关系抽取时,每批文本块的 token 上限。NER 不适用。', resolution: '实体归一化', resolutionTip: `解析过程会将具有相同含义的实体合并在一起,从而使知识图谱更简洁、更准确。应合并以下实体:特朗普总统、唐纳德·特朗普、唐纳德·J·特朗普、唐纳德·约翰·特朗普`, community: '社区报告生成', diff --git a/web/src/pages/dataset/dataset-setting/form-schema.ts b/web/src/pages/dataset/dataset-setting/form-schema.ts index 03424921c..acb0eaf10 100644 --- a/web/src/pages/dataset/dataset-setting/form-schema.ts +++ b/web/src/pages/dataset/dataset-setting/form-schema.ts @@ -70,6 +70,12 @@ export const formSchema = z method: z.string().optional(), resolution: z.boolean().optional(), community: z.boolean().optional(), + batch_chunk_token_size: z + .number() + .int() + .min(512) + .max(8196) + .optional(), }) .refine( (data) => { diff --git a/web/src/pages/dataset/dataset-setting/index.tsx b/web/src/pages/dataset/dataset-setting/index.tsx index 930ec8f51..072e84f87 100644 --- a/web/src/pages/dataset/dataset-setting/index.tsx +++ b/web/src/pages/dataset/dataset-setting/index.tsx @@ -103,6 +103,7 @@ export default function DatasetSettings() { use_graphrag: true, entity_types: initialEntityTypes, method: MethodValue.Light, + batch_chunk_token_size: 4096, }, metadata: { type: 'object',