mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-02 16:27:48 +08:00
### What problem does this PR solve? Addresses review feedback on #14074 (Checkpoint mechanism for long-running workflow jobs, issue #12494). **Changes based on @yuzhichang's review:** 1. **Renamed `checkpoint_service.py` → `task_checkpoint.py`** as suggested. 2. **Replaced Redis with direct docEngine queries** as suggested — the subgraph already gets persisted to the doc store by `generate_subgraph()`, so we just query for it instead of maintaining a separate checkpoint in Redis. This is simpler, has no extra dependency, and uses a single source of truth. **Changes based on CodeRabbit review:** 3. **Fixed `source_id` query format mismatch** — subgraphs are stored with `source_id: [doc_id]` (list), but the original query used `source_id: doc_id` (string). Now follows the same pattern as `does_graph_contains()` in `rag/graphrag/utils.py`: filter by `knowledge_graph_kwd` only, then match `source_id` in Python. This avoids ambiguity across Elasticsearch / Infinity / OceanBase backends. ### Changes | File | Change | |---|---| | `api/db/services/task_checkpoint.py` (new) | `load_subgraph_from_store()` and `has_raptor_chunks()` — docEngine-based checkpoint queries | | `rag/graphrag/general/index.py` | `build_one()` calls `load_subgraph_from_store()` before running LLM extraction | | `rag/svr/task_executor.py` | RAPTOR per-doc loop calls `has_raptor_chunks()` before processing | | `test/unit_test/rag/graphrag/test_checkpoint_resume.py` (new) | 10 unit tests covering subgraph loading, source_id filtering, edge cases | ### How it works - **GraphRAG:** Before running expensive LLM entity/relation extraction for a doc, checks the doc store for an existing subgraph (saved by a previous interrupted run). If found, loads it directly and skips LLM calls. - **RAPTOR:** Before processing a doc, checks if RAPTOR chunks (`raptor_kwd="raptor"`) already exist for it. If yes, skips. ### Testing - 10 new unit tests — all passing - Full existing suite: 617 passed ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring
667 lines
24 KiB
Python
667 lines
24 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
|
|
import networkx as nx
|
|
|
|
from api.db.services.document_service import DocumentService
|
|
from api.db.services.task_service import has_canceled
|
|
from common.exceptions import TaskCanceledException
|
|
from common.misc_utils import get_uuid
|
|
from common.connection_utils import timeout
|
|
from rag.graphrag.entity_resolution import EntityResolution
|
|
from rag.graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
|
from rag.graphrag.general.extractor import Extractor
|
|
from rag.graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
|
from rag.graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
|
from rag.graphrag.utils import (
|
|
GraphChange,
|
|
chunk_id,
|
|
does_graph_contains,
|
|
get_graph,
|
|
graph_merge,
|
|
set_graph,
|
|
tidy_graph,
|
|
)
|
|
from common.misc_utils import thread_pool_exec
|
|
from rag.nlp import rag_tokenizer, search
|
|
from rag.utils.redis_conn import RedisDistributedLock
|
|
from common import settings
|
|
from common.doc_store.doc_store_base import OrderByExpr
|
|
|
|
|
|
|
|
async def load_subgraph_from_store(tenant_id: str, kb_id: str, doc_id: str):
|
|
"""Load a previously saved subgraph from the doc store.
|
|
|
|
Filters directly by source_id (== doc_id) and knowledge_graph_kwd in the
|
|
query so the doc store index does the heavy lifting. Expects at most one
|
|
matching chunk per doc_id (as written by generate_subgraph).
|
|
Returns a networkx Graph on hit, or None on miss.
|
|
"""
|
|
fields = ["content_with_weight", "source_id"]
|
|
condition = {
|
|
"knowledge_graph_kwd": ["subgraph"],
|
|
"removed_kwd": "N",
|
|
"source_id": [doc_id],
|
|
}
|
|
try:
|
|
res = await thread_pool_exec(
|
|
settings.docStoreConn.search,
|
|
fields, [], condition, [], OrderByExpr(),
|
|
0, 1, search.index_name(tenant_id), [kb_id]
|
|
)
|
|
field_map = settings.docStoreConn.get_fields(res, fields)
|
|
for cid, row in field_map.items():
|
|
content = row.get("content_with_weight", "")
|
|
if not content:
|
|
continue
|
|
try:
|
|
data = json.loads(content)
|
|
sg = nx.node_link_graph(data, edges="edges")
|
|
sg.graph["source_id"] = [doc_id]
|
|
logging.info(
|
|
"Checkpoint hit: subgraph for doc %s (tenant=%s kb=%s) found at chunk %s",
|
|
doc_id, tenant_id, kb_id, cid,
|
|
)
|
|
return sg
|
|
except Exception:
|
|
logging.exception(
|
|
"Failed to parse subgraph JSON for doc %s chunk %s", doc_id, cid
|
|
)
|
|
except Exception:
|
|
logging.exception("Failed to load subgraph from store for doc %s", doc_id)
|
|
return None
|
|
logging.info(
|
|
"Checkpoint miss: no subgraph for doc %s (tenant=%s kb=%s)",
|
|
doc_id, tenant_id, kb_id,
|
|
)
|
|
return None
|
|
|
|
|
|
async def run_graphrag(
|
|
row: dict,
|
|
language,
|
|
with_resolution: bool,
|
|
with_community: bool,
|
|
chat_model,
|
|
embedding_model,
|
|
callback,
|
|
):
|
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
|
start = asyncio.get_running_loop().time()
|
|
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
|
chunks = []
|
|
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
|
chunks.append(d["content_with_weight"])
|
|
|
|
timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
|
|
|
try:
|
|
subgraph = await asyncio.wait_for(
|
|
generate_subgraph(
|
|
LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {})
|
|
or row["kb_parser_config"]["graphrag"]["method"] != "general"
|
|
else GeneralKGExt,
|
|
tenant_id,
|
|
kb_id,
|
|
doc_id,
|
|
chunks,
|
|
language,
|
|
row["kb_parser_config"]["graphrag"].get("entity_types", []),
|
|
chat_model,
|
|
embedding_model,
|
|
callback,
|
|
),
|
|
timeout=timeout_sec,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logging.error("generate_subgraph timeout")
|
|
raise
|
|
|
|
if not subgraph:
|
|
return
|
|
|
|
graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200)
|
|
await graphrag_task_lock.spin_acquire()
|
|
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
|
|
|
try:
|
|
subgraph_nodes = set(subgraph.nodes())
|
|
new_graph = await merge_subgraph(
|
|
tenant_id,
|
|
kb_id,
|
|
doc_id,
|
|
subgraph,
|
|
embedding_model,
|
|
callback,
|
|
)
|
|
assert new_graph is not None
|
|
|
|
if not with_resolution and not with_community:
|
|
return
|
|
|
|
if with_resolution:
|
|
await graphrag_task_lock.spin_acquire()
|
|
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
|
await resolve_entities(
|
|
new_graph,
|
|
subgraph_nodes,
|
|
tenant_id,
|
|
kb_id,
|
|
doc_id,
|
|
chat_model,
|
|
embedding_model,
|
|
callback,
|
|
task_id=row["id"],
|
|
)
|
|
if with_community:
|
|
await graphrag_task_lock.spin_acquire()
|
|
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
|
await extract_community(
|
|
new_graph,
|
|
tenant_id,
|
|
kb_id,
|
|
doc_id,
|
|
chat_model,
|
|
embedding_model,
|
|
callback,
|
|
task_id=row["id"],
|
|
)
|
|
finally:
|
|
graphrag_task_lock.release()
|
|
now = asyncio.get_running_loop().time()
|
|
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
|
|
return
|
|
|
|
|
|
async def run_graphrag_for_kb(
|
|
row: dict,
|
|
doc_ids: list[str],
|
|
language: str,
|
|
kb_parser_config: dict,
|
|
chat_model,
|
|
embedding_model,
|
|
callback,
|
|
*,
|
|
with_resolution: bool = True,
|
|
with_community: bool = True,
|
|
max_parallel_docs: int = 4,
|
|
) -> dict:
|
|
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
|
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
|
start = asyncio.get_running_loop().time()
|
|
fields_for_chunks = ["content_with_weight", "doc_id"]
|
|
|
|
if not doc_ids:
|
|
logging.info(f"Fetching all docs for {kb_id}")
|
|
docs, _ = DocumentService.get_by_kb_id(
|
|
kb_id=kb_id,
|
|
page_number=0,
|
|
items_per_page=0,
|
|
orderby="create_time",
|
|
desc=False,
|
|
keywords="",
|
|
run_status=[],
|
|
types=[],
|
|
suffix=[],
|
|
)
|
|
doc_ids = [doc["id"] for doc in docs]
|
|
|
|
doc_ids = list(dict.fromkeys(doc_ids))
|
|
if not doc_ids:
|
|
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.")
|
|
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
|
|
|
|
def load_doc_chunks(doc_id: str) -> list[str]:
|
|
from common.token_utils import num_tokens_from_string
|
|
|
|
chunks = []
|
|
current_chunk = ""
|
|
|
|
# DEBUG: Obtener todos los chunks primero
|
|
raw_chunks = list(settings.retriever.chunk_list(
|
|
doc_id,
|
|
tenant_id,
|
|
[kb_id],
|
|
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
|
|
fields=fields_for_chunks,
|
|
sort_by_position=True,
|
|
))
|
|
|
|
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
|
|
|
|
for d in raw_chunks:
|
|
content = d["content_with_weight"]
|
|
if num_tokens_from_string(current_chunk + content) < 4096:
|
|
current_chunk += content
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk)
|
|
current_chunk = content
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk)
|
|
|
|
return chunks
|
|
|
|
all_doc_chunks: dict[str, list[str]] = {}
|
|
total_chunks = 0
|
|
for doc_id in doc_ids:
|
|
chunks = load_doc_chunks(doc_id)
|
|
all_doc_chunks[doc_id] = chunks
|
|
total_chunks += len(chunks)
|
|
|
|
if total_chunks == 0:
|
|
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
|
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
|
|
|
semaphore = asyncio.Semaphore(max_parallel_docs)
|
|
|
|
subgraphs: dict[str, object] = {}
|
|
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
|
|
|
async def build_one(doc_id: str):
|
|
if has_canceled(row["id"]):
|
|
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
|
|
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
|
|
|
chunks = all_doc_chunks.get(doc_id, [])
|
|
if not chunks:
|
|
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
|
return
|
|
|
|
kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt
|
|
|
|
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
|
|
|
async with semaphore:
|
|
# CHECKPOINT: bounded by semaphore so doc-store lookups respect max_parallel_docs
|
|
existing_sg = await load_subgraph_from_store(tenant_id, kb_id, doc_id)
|
|
if existing_sg:
|
|
subgraphs[doc_id] = existing_sg
|
|
callback(msg=f"[GraphRAG] doc:{doc_id} subgraph found in store, skipping LLM extraction.")
|
|
return
|
|
try:
|
|
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
|
|
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
|
|
|
|
try:
|
|
sg = await asyncio.wait_for(
|
|
generate_subgraph(
|
|
kg_extractor,
|
|
tenant_id,
|
|
kb_id,
|
|
doc_id,
|
|
chunks,
|
|
language,
|
|
kb_parser_config.get("graphrag", {}).get("entity_types", []),
|
|
chat_model,
|
|
embedding_model,
|
|
callback,
|
|
task_id=row["id"]
|
|
),
|
|
timeout=deadline,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
failed_docs.append((doc_id, "timeout"))
|
|
callback(msg=f"{msg} FAILED: timeout")
|
|
return
|
|
if sg:
|
|
subgraphs[doc_id] = sg
|
|
callback(msg=f"{msg} done")
|
|
else:
|
|
failed_docs.append((doc_id, "subgraph is empty"))
|
|
callback(msg=f"{msg} empty")
|
|
except TaskCanceledException as canceled:
|
|
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {canceled}")
|
|
except Exception as e:
|
|
failed_docs.append((doc_id, repr(e)))
|
|
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
|
|
|
|
if has_canceled(row["id"]):
|
|
callback(msg=f"Task {row['id']} cancelled before processing documents.")
|
|
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
|
|
|
tasks = [asyncio.create_task(build_one(doc_id)) for doc_id in doc_ids]
|
|
try:
|
|
await asyncio.gather(*tasks, return_exceptions=False)
|
|
except Exception as e:
|
|
logging.error(f"Error in asyncio.gather: {e}")
|
|
for t in tasks:
|
|
t.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
raise
|
|
|
|
if has_canceled(row["id"]):
|
|
callback(msg=f"Task {row['id']} cancelled after document processing.")
|
|
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
|
|
|
ok_docs = [d for d in doc_ids if d in subgraphs]
|
|
if not ok_docs:
|
|
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
|
|
now = asyncio.get_running_loop().time()
|
|
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
|
|
|
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
|
|
await kb_lock.spin_acquire()
|
|
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
|
|
|
|
if has_canceled(row["id"]):
|
|
callback(msg=f"Task {row['id']} cancelled before merging subgraphs.")
|
|
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
|
|
|
try:
|
|
union_nodes: set = set()
|
|
final_graph = None
|
|
|
|
for doc_id in ok_docs:
|
|
sg = subgraphs[doc_id]
|
|
union_nodes.update(set(sg.nodes()))
|
|
|
|
new_graph = await merge_subgraph(
|
|
tenant_id,
|
|
kb_id,
|
|
doc_id,
|
|
sg,
|
|
embedding_model,
|
|
callback,
|
|
)
|
|
if new_graph is not None:
|
|
final_graph = new_graph
|
|
|
|
if final_graph is None:
|
|
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
|
|
else:
|
|
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
|
|
finally:
|
|
kb_lock.release()
|
|
|
|
if not with_resolution and not with_community:
|
|
now = asyncio.get_running_loop().time()
|
|
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
|
|
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
|
|
|
|
if has_canceled(row["id"]):
|
|
callback(msg=f"Task {row['id']} cancelled before resolution/community extraction.")
|
|
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
|
|
|
await kb_lock.spin_acquire()
|
|
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
|
|
|
|
try:
|
|
subgraph_nodes = set()
|
|
for sg in subgraphs.values():
|
|
subgraph_nodes.update(set(sg.nodes()))
|
|
|
|
if with_resolution:
|
|
await resolve_entities(
|
|
final_graph,
|
|
subgraph_nodes,
|
|
tenant_id,
|
|
kb_id,
|
|
None,
|
|
chat_model,
|
|
embedding_model,
|
|
callback,
|
|
task_id=row["id"],
|
|
)
|
|
|
|
if with_community:
|
|
await extract_community(
|
|
final_graph,
|
|
tenant_id,
|
|
kb_id,
|
|
None,
|
|
chat_model,
|
|
embedding_model,
|
|
callback,
|
|
task_id=row["id"],
|
|
)
|
|
finally:
|
|
kb_lock.release()
|
|
|
|
now = asyncio.get_running_loop().time()
|
|
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
|
|
return {
|
|
"ok_docs": ok_docs,
|
|
"failed_docs": failed_docs, # [(doc_id, error), ...]
|
|
"total_docs": len(doc_ids),
|
|
"total_chunks": total_chunks,
|
|
"seconds": now - start,
|
|
}
|
|
|
|
|
|
async def generate_subgraph(
|
|
extractor: Extractor,
|
|
tenant_id: str,
|
|
kb_id: str,
|
|
doc_id: str,
|
|
chunks: list[str],
|
|
language,
|
|
entity_types,
|
|
llm_bdl,
|
|
embed_bdl,
|
|
callback,
|
|
task_id: str = "",
|
|
):
|
|
if task_id and has_canceled(task_id):
|
|
callback(msg=f"Task {task_id} cancelled during subgraph generation for doc {doc_id}.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
|
|
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
|
|
if contains:
|
|
callback(msg=f"Graph already contains {doc_id}")
|
|
return None
|
|
start = asyncio.get_running_loop().time()
|
|
ext = extractor(
|
|
llm_bdl,
|
|
language=language,
|
|
entity_types=entity_types,
|
|
)
|
|
ents, rels = await ext(doc_id, chunks, callback, task_id=task_id)
|
|
subgraph = nx.Graph()
|
|
|
|
for ent in ents:
|
|
if task_id and has_canceled(task_id):
|
|
callback(msg=f"Task {task_id} cancelled during entity processing for doc {doc_id}.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
|
|
assert "description" in ent, f"entity {ent} does not have description"
|
|
ent["source_id"] = [doc_id]
|
|
subgraph.add_node(ent["entity_name"], **ent)
|
|
|
|
ignored_rels = 0
|
|
for rel in rels:
|
|
if task_id and has_canceled(task_id):
|
|
callback(msg=f"Task {task_id} cancelled during relationship processing for doc {doc_id}.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
|
|
assert "description" in rel, f"relation {rel} does not have description"
|
|
if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
|
|
ignored_rels += 1
|
|
continue
|
|
rel["source_id"] = [doc_id]
|
|
subgraph.add_edge(
|
|
rel["src_id"],
|
|
rel["tgt_id"],
|
|
**rel,
|
|
)
|
|
if ignored_rels:
|
|
callback(msg=f"ignored {ignored_rels} relations due to missing entities.")
|
|
tidy_graph(subgraph, callback, check_attribute=False)
|
|
|
|
subgraph.graph["source_id"] = [doc_id]
|
|
chunk = {
|
|
"content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False),
|
|
"knowledge_graph_kwd": "subgraph",
|
|
"kb_id": kb_id,
|
|
"source_id": [doc_id],
|
|
"available_int": 0,
|
|
"removed_kwd": "N",
|
|
}
|
|
cid = chunk_id(chunk)
|
|
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "subgraph", "source_id": doc_id},search.index_name(tenant_id),kb_id,)
|
|
await thread_pool_exec(settings.docStoreConn.insert,[{"id": cid, **chunk}],search.index_name(tenant_id),kb_id,)
|
|
now = asyncio.get_running_loop().time()
|
|
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
|
return subgraph
|
|
|
|
|
|
@timeout(60 * 3)
|
|
async def merge_subgraph(
|
|
tenant_id: str,
|
|
kb_id: str,
|
|
doc_id: str,
|
|
subgraph: nx.Graph,
|
|
embedding_model,
|
|
callback,
|
|
):
|
|
start = asyncio.get_running_loop().time()
|
|
change = GraphChange()
|
|
old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
|
|
if old_graph is not None:
|
|
logging.info("Merge with an exiting graph...................")
|
|
tidy_graph(old_graph, callback)
|
|
new_graph = graph_merge(old_graph, subgraph, change)
|
|
else:
|
|
new_graph = subgraph
|
|
change.added_updated_nodes = set(new_graph.nodes())
|
|
change.added_updated_edges = set(new_graph.edges())
|
|
pr = nx.pagerank(new_graph)
|
|
for node_name, pagerank in pr.items():
|
|
new_graph.nodes[node_name]["pagerank"] = pagerank
|
|
|
|
await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
|
|
now = asyncio.get_running_loop().time()
|
|
callback(msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds.")
|
|
return new_graph
|
|
|
|
|
|
@timeout(60 * 30, 1)
|
|
async def resolve_entities(
|
|
graph,
|
|
subgraph_nodes: set[str],
|
|
tenant_id: str,
|
|
kb_id: str,
|
|
doc_id: str,
|
|
llm_bdl,
|
|
embed_bdl,
|
|
callback,
|
|
task_id: str = "",
|
|
):
|
|
# Check if task has been canceled before resolution
|
|
if task_id and has_canceled(task_id):
|
|
callback(msg=f"Task {task_id} cancelled during entity resolution.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
|
|
start = asyncio.get_running_loop().time()
|
|
er = EntityResolution(
|
|
llm_bdl,
|
|
)
|
|
reso = await er(graph, subgraph_nodes, callback=callback, task_id=task_id)
|
|
graph = reso.graph
|
|
change = reso.change
|
|
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
|
|
callback(msg="Graph resolution updated pagerank.")
|
|
|
|
if task_id and has_canceled(task_id):
|
|
callback(msg=f"Task {task_id} cancelled after entity resolution.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
|
|
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
|
|
now = asyncio.get_running_loop().time()
|
|
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
|
|
|
|
|
@timeout(60 * 30, 1)
|
|
async def extract_community(
|
|
graph,
|
|
tenant_id: str,
|
|
kb_id: str,
|
|
doc_id: str,
|
|
llm_bdl,
|
|
embed_bdl,
|
|
callback,
|
|
task_id: str = "",
|
|
):
|
|
if task_id and has_canceled(task_id):
|
|
callback(msg=f"Task {task_id} cancelled before community extraction.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
|
|
start = asyncio.get_running_loop().time()
|
|
ext = CommunityReportsExtractor(
|
|
llm_bdl,
|
|
)
|
|
cr = await ext(graph, callback=callback, task_id=task_id)
|
|
|
|
if task_id and has_canceled(task_id):
|
|
callback(msg=f"Task {task_id} cancelled during community extraction.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
|
|
community_structure = cr.structured_output
|
|
community_reports = cr.output
|
|
doc_ids = graph.graph["source_id"]
|
|
|
|
now = asyncio.get_running_loop().time()
|
|
callback(msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
|
start = now
|
|
if task_id and has_canceled(task_id):
|
|
callback(msg=f"Task {task_id} cancelled during community indexing.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
|
|
chunks = []
|
|
for stru, rep in zip(community_structure, community_reports):
|
|
obj = {
|
|
"report": rep,
|
|
"evidences": "\n".join([f.get("explanation", "") for f in stru["findings"]]),
|
|
}
|
|
chunk = {
|
|
"id": get_uuid(),
|
|
"docnm_kwd": stru["title"],
|
|
"title_tks": rag_tokenizer.tokenize(stru["title"]),
|
|
"content_with_weight": json.dumps(obj, ensure_ascii=False),
|
|
"content_ltks": rag_tokenizer.tokenize(obj["report"] + " " + obj["evidences"]),
|
|
"knowledge_graph_kwd": "community_report",
|
|
"weight_flt": stru["weight"],
|
|
"entities_kwd": stru["entities"],
|
|
"important_kwd": stru["entities"],
|
|
"kb_id": kb_id,
|
|
"source_id": list(doc_ids),
|
|
"available_int": 0,
|
|
}
|
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
|
chunks.append(chunk)
|
|
|
|
await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,)
|
|
es_bulk_size = 4
|
|
for b in range(0, len(chunks), es_bulk_size):
|
|
doc_store_result = await thread_pool_exec(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,)
|
|
if doc_store_result:
|
|
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
|
raise Exception(error_message)
|
|
|
|
if task_id and has_canceled(task_id):
|
|
callback(msg=f"Task {task_id} cancelled after community indexing.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
|
|
now = asyncio.get_running_loop().time()
|
|
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
|
return community_structure, community_reports
|