Files
ragflow/rag/graphrag/utils.py
Preston Percival e8f19aa338 feat(graphrag): fix merge concurrency and add resume-from-checkpoint (#14238)
This PR addresses three related GraphRAG reliability issues that
together allow long-running GraphRAG tasks (10+ hours of LLM extraction)
to be resumed after a crash or pause without re-doing completed work. It
builds on #14096 (per-doc subgraph cache) and extends the same idea to
the resolution and community-detection phases.

Fixes #14236.

## 1. Fix concurrent merge crash

Long GraphRAG runs would crash near the end of entity resolution with:
```
RuntimeError: dictionary keys changed during iteration
```
in `Extractor._merge_graph_nodes`. Two changes:

- `rag/graphrag/general/extractor.py`: snapshot `graph.neighbors(node1)`
via `list(...)` before iterating, so concurrent `add_edge` /
`remove_node` mutations on the shared `nx.Graph` cannot invalidate the
iterator. Also tracks each redirected neighbour in `node0_neighbors` so
a later merged node sharing the same external neighbour takes the
edge-merge branch instead of overwriting via `add_edge`.
- `rag/graphrag/entity_resolution.py`: serialize the merge step with a
dedicated `asyncio.Semaphore(1)`. `nx.Graph` is not thread-safe and
concurrent merges on overlapping neighbourhoods can produce incorrect
results even with the snapshot fix.

## 2. Don't wipe partial graph on pause

Previously the pause / cancel UI path called
`settings.docStoreConn.delete({"knowledge_graph_kwd": [...]}, ...)`,
destroying every subgraph, entity, relation, and graph row.
Re-triggering then started GraphRAG from scratch even though #14096 had
already added `load_subgraph_from_store`.

After main was merged in (which deleted `api/apps/kb_app.py` per
#14394), the pause path now lives on the new REST surface `DELETE
/v1/datasets/<id>/<index_type>`:

- `api/apps/services/dataset_api_service.py`: `delete_index` accepts a
`wipe: bool = True` parameter. When `False` the doc-store rows and
GraphRAG phase markers are left intact and only the running task is
cancelled. Default preserves historical behaviour.
- `api/apps/restful_apis/dataset_api.py`: parses `?wipe=false|0|no|off`
from the query string and forwards it.
- `web/src/utils/api.ts` + `web/src/services/knowledge-service.ts`:
`unbindPipelineTask` appends `?wipe=false` when explicitly false.
- The GraphRAG pause action in
`web/src/pages/dataset/dataset/generate-button/hook.ts` passes `wipe:
false` for `KnowledgeGraph`; raptor is unchanged.

**UX impact:** the pause icon next to a running GraphRAG task no longer
wipes graph data. The only path that still wipes is the explicit Delete
action in `GenerateLogButton` (trash icon behind a confirmation modal).

## 3. Phase-completion markers (`rag/graphrag/phase_markers.py`)

A small Redis-backed marker layer at
`graphrag:phase:{kb_id}:{resolution_done|community_done}` (7-day TTL).
`run_graphrag_for_kb` consults the markers on entry and skips phases
that already completed in a prior run. Markers are cleared automatically
when:
- new docs are merged into the graph (which invalidates prior resolution
and community results),
- `delete_index` wipes the graph, or
- `delete_knowledge_graph` is called.

Redis failures never block a run -- markers are an optimization, not a
gate.

## 4. Idempotent community detection

`extract_community` previously did `delete-then-insert` on
`community_report` rows; a crash mid-insert left the dataset with no
reports. Now report IDs are derived deterministically from `(kb_id,
community.title)`, the existing report IDs are snapshotted before
insert, new rows are written, then only stale rows are pruned. A failure
at any step leaves either the prior or the new report set intact --
never a partial mix.

## 5. Tunable doc-store insert pipeline

The GraphRAG insert loop in `rag/graphrag/utils.py` and the
`community_report` insert in `rag/graphrag/general/index.py` were both
hardcoded to `es_bulk_size = 4` and ran strictly sequentially. On a real
KB this meant 1077 chunks took ~21 minutes for a 100-chunk slice -- pure
round-trip overhead.

- New `insert_chunks_bounded()` helper in `rag/graphrag/utils.py`
batches inserts via a bounded `asyncio.Semaphore`. Same retry / timeout
semantics as the prior loop.
- Defaults: 64 docs per batch, 4 batches in flight (matches the regular
ingest pipeline in `document_service.py`). Tunable per-deployment via
`GRAPHRAG_INSERT_BULK_SIZE` and `GRAPHRAG_INSERT_CONCURRENCY`.
- Both `set_graph` and `extract_community` now use the helper.

This dropped the same 1077-chunk insert from minutes to seconds in local
testing without measurable extra pressure on Infinity (total in-flight
docs ≤ `BULK_SIZE × CONCURRENCY` = 256 by default).

## Tests

- `test/unit_test/rag/graphrag/test_merge_graph_nodes.py` (3 tests):
dense neighbourhood merge, neighbour-snapshot regression, concurrent
serialized merges.
- `test/unit_test/rag/graphrag/test_phase_markers.py` (4 tests): set/has
round-trip, kb-scoped clear, no-op on empty input, graceful Redis
failure.
-
`test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py`:
new `test_delete_index_wipe_flag_unit` covers `wipe=false` for both
GraphRAG and raptor on the new REST route, and confirms the default
still wipes and clears phase markers.

## Compatibility

- Backward compatible: tasks queued before this change behave
identically (default `wipe=true`, no markers expected).
- No schema/migration changes; all new state lives in Redis.
- New optional REST query param `wipe` on `DELETE
/v1/datasets/<id>/<index_type>`.
- New optional env vars `GRAPHRAG_INSERT_BULK_SIZE` and
`GRAPHRAG_INSERT_CONCURRENCY`; defaults preserve safe behaviour.

## Example of resume

Screenshot below shows a test resuming knowledge graph generation after
applying the concurrency fix and re-deploying.

<img width="521" height="677" alt="image"
src="https://github.com/user-attachments/assets/9ef0d405-cbb3-420d-a1a1-e51f3e7e9b7a"
/>

### Type of change

- [X] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
2026-05-06 15:01:01 +08:00

766 lines
28 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
- [LightRag](https://github.com/HKUDS/LightRAG)
"""
import asyncio
import dataclasses
import html
import json
import logging
import os
import re
import time
from collections import defaultdict
from hashlib import md5
from typing import Any, Callable, Set, Tuple
import networkx as nx
import numpy as np
import xxhash
from networkx.readwrite import json_graph
from common.misc_utils import get_uuid
from common.connection_utils import timeout
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import REDIS_CONN
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
GRAPH_FIELD_SEP = "<SEP>"
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
chat_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
# Doc-store insert batching for GraphRAG subgraph/node/edge/community_report
# chunks. Defaults (64 docs per batch, up to 4 batches in flight) mirror the
# regular ingest pipeline in document_service.py while still keeping the total
# number of simultaneous requests to ES/Infinity bounded. Override with
# GRAPHRAG_INSERT_BULK_SIZE and GRAPHRAG_INSERT_CONCURRENCY.
_INSERT_BULK_SIZE = max(1, int(os.environ.get("GRAPHRAG_INSERT_BULK_SIZE", 64)))
_INSERT_CONCURRENCY = max(1, int(os.environ.get("GRAPHRAG_INSERT_CONCURRENCY", 4)))
async def insert_chunks_bounded(chunks, tenant_id, kb_id, *, callback=None, label="Insert chunks"):
"""Insert ``chunks`` into the doc store in batches with bounded concurrency and retries.
Batch size is controlled by ``GRAPHRAG_INSERT_BULK_SIZE`` (default 64) and
the number of batches in flight by ``GRAPHRAG_INSERT_CONCURRENCY``
(default 4). Each batch has the same retry / timeout behaviour as the
previous hand-rolled loop (3 attempts, exponential backoff).
Raises the first unrecoverable error; other in-flight batches are then
cancelled by ``asyncio.gather``.
"""
if not chunks:
return
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
sem = asyncio.Semaphore(_INSERT_CONCURRENCY)
total = len(chunks)
progress = {"done": 0, "next_report": 100}
progress_lock = asyncio.Lock()
async def _one(offset: int) -> None:
batch = chunks[offset : offset + _INSERT_BULK_SIZE]
timeout_s = 3 if enable_timeout_assertion else 30000000
max_retries = 3
async with sem:
for attempt in range(max_retries):
try:
result = await asyncio.wait_for(
thread_pool_exec(
settings.docStoreConn.insert,
batch,
search.index_name(tenant_id),
kb_id,
),
timeout=timeout_s,
)
if result:
raise Exception(f"Insert chunk error: {result}, please check log file and Elasticsearch/Infinity status!")
break
except asyncio.TimeoutError:
if attempt < max_retries - 1:
wait = 2 ** attempt
logging.warning(f"Insert batch at offset {offset}/{total} attempt {attempt + 1} timed out, retrying in {wait}s")
await asyncio.sleep(wait)
else:
raise
except asyncio.CancelledError:
raise
except Exception as e:
if attempt < max_retries - 1:
wait = 2 ** attempt
logging.warning(f"Insert batch at offset {offset}/{total} attempt {attempt + 1} failed: {e}, retrying in {wait}s")
await asyncio.sleep(wait)
else:
raise
if callback:
async with progress_lock:
progress["done"] += len(batch)
if progress["done"] >= progress["next_report"] or progress["done"] == total:
callback(msg=f"{label}: {progress['done']}/{total}")
progress["next_report"] = progress["done"] + 100
await asyncio.gather(*(asyncio.create_task(_one(o)) for o in range(0, total, _INSERT_BULK_SIZE)))
@dataclasses.dataclass
class GraphChange:
removed_nodes: Set[str] = dataclasses.field(default_factory=set)
added_updated_nodes: Set[str] = dataclasses.field(default_factory=set)
removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
def perform_variable_replacements(input: str, history: list[dict] | None = None, variables: dict | None = None) -> str:
"""Perform variable replacements on the input string and in a chat log."""
if history is None:
history = []
if variables is None:
variables = {}
result = input
def replace_all(input: str) -> str:
result = input
for k, v in variables.items():
result = result.replace(f"{{{k}}}", str(v))
return result
result = replace_all(result)
for i, entry in enumerate(history):
if entry.get("role") == "system":
entry["content"] = replace_all(entry.get("content") or "")
return result
def clean_str(input: Any) -> str:
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
# If we get non-string input, just give it back
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result)
def dict_has_keys_with_types(data: dict, expected_fields: list[tuple[str, type]]) -> bool:
"""Return True if the given dictionary has the given keys with the given types."""
for field, field_type in expected_fields:
if field not in data:
return False
value = data[field]
if not isinstance(value, field_type):
return False
return True
def get_llm_cache(llmnm, txt, history, genconf):
hasher = xxhash.xxh64()
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return None
return bin
def set_llm_cache(llmnm, txt, v, history, genconf):
hasher = xxhash.xxh64()
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
k = hasher.hexdigest()
REDIS_CONN.set(k, v.encode("utf-8"), 24 * 3600)
def get_embed_cache(llmnm, txt):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return np.array(json.loads(bin))
def set_embed_cache(llmnm, txt, arr):
hasher = xxhash.xxh64()
hasher.update(str(llmnm).encode("utf-8"))
hasher.update(str(txt).encode("utf-8"))
k = hasher.hexdigest()
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
REDIS_CONN.set(k, arr.encode("utf-8"), 24 * 3600)
def get_tags_from_cache(kb_ids):
hasher = xxhash.xxh64()
hasher.update(str(kb_ids).encode("utf-8"))
k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return bin
def set_tags_to_cache(kb_ids, tags):
hasher = xxhash.xxh64()
hasher.update(str(kb_ids).encode("utf-8"))
k = hasher.hexdigest()
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
"""
Ensure all nodes and edges in the graph have some essential attribute.
"""
def is_valid_item(node_attrs: dict) -> bool:
valid_node = True
for attr in ["description", "source_id"]:
if attr not in node_attrs:
valid_node = False
break
return valid_node
if check_attribute:
purged_nodes = []
for node, node_attrs in graph.nodes(data=True):
if not is_valid_item(node_attrs):
purged_nodes.append(node)
for node in purged_nodes:
graph.remove_node(node)
if purged_nodes and callback:
callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.")
purged_edges = []
for source, target, attr in graph.edges(data=True):
if check_attribute:
if not is_valid_item(attr):
purged_edges.append((source, target))
if "keywords" not in attr:
attr["keywords"] = []
for source, target in purged_edges:
graph.remove_edge(source, target)
if purged_edges and callback:
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
def get_from_to(node1, node2):
if node1 < node2:
return (node1, node2)
else:
return (node2, node1)
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
"""Merge graph g2 into g1 in place."""
for node_name, attr in g2.nodes(data=True):
change.added_updated_nodes.add(node_name)
if not g1.has_node(node_name):
g1.add_node(node_name, **attr)
continue
node = g1.nodes[node_name]
node["description"] += GRAPH_FIELD_SEP + attr["description"]
# A node's source_id indicates which chunks it came from.
node["source_id"] += attr["source_id"]
for source, target, attr in g2.edges(data=True):
change.added_updated_edges.add(get_from_to(source, target))
edge = g1.get_edge_data(source, target)
if edge is None:
g1.add_edge(source, target, **attr)
continue
edge["weight"] += attr.get("weight", 0)
edge["description"] += GRAPH_FIELD_SEP + attr["description"]
edge["keywords"] += attr["keywords"]
# A edge's source_id indicates which chunks it came from.
edge["source_id"] += attr["source_id"]
for node_degree in g1.degree:
g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
# A graph's source_id indicates which documents it came from.
if "source_id" not in g1.graph:
g1.graph["source_id"] = []
g1.graph["source_id"] += g2.graph.get("source_id", [])
return g1
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
def handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
if not entity_name.strip():
return None
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
entity_source_id = chunk_key
return dict(
entity_name=entity_name.upper(),
entity_type=entity_type.upper(),
description=entity_description,
source_id=entity_source_id,
)
def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_keywords = clean_str(record_attributes[4])
edge_source_id = chunk_key
weight = float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
pair = sorted([source.upper(), target.upper()])
return dict(
src_id=pair[0],
tgt_id=pair[1],
weight=weight,
description=edge_description,
keywords=edge_keywords,
source_id=edge_source_id,
metadata={"created_at": time.time()},
)
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [{"role": roles[i % 2], "content": content} for i, content in enumerate(args)]
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def chunk_id(chunk):
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
global chat_limiter
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
chunk = {
"id": get_uuid(),
"important_kwd": [ent_name],
"title_tks": rag_tokenizer.tokenize(ent_name),
"entity_kwd": ent_name,
"knowledge_graph_kwd": "entity",
"entity_type_kwd": meta["entity_type"],
"content_with_weight": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"source_id": meta["source_id"],
"kb_id": kb_id,
"available_int": 0,
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None:
async with chat_limiter:
timeout = 3 if enable_timeout_assertion else 30000000
ebd, _ = await asyncio.wait_for(
thread_pool_exec(embd_mdl.encode, [ent_name]),
timeout=timeout
)
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
@timeout(3, 3)
async def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
ents = from_ent_name
if isinstance(ents, str):
ents = [from_ent_name]
if isinstance(to_ent_name, str):
to_ent_name = [to_ent_name]
ents.extend(to_ent_name)
ents = list(set(ents))
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
res = []
es_res = await settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
for id in es_res.ids:
try:
if size == 1:
return json.loads(es_res.field[id]["content_with_weight"])
res.append(json.loads(es_res.field[id]["content_with_weight"]))
except Exception:
continue
return res
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
chunk = {
"id": get_uuid(),
"from_entity_kwd": from_ent_name,
"to_entity_kwd": to_ent_name,
"knowledge_graph_kwd": "relation",
"content_with_weight": json.dumps(meta, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
"important_kwd": meta["keywords"],
"source_id": meta["source_id"],
"weight_int": int(meta["weight"]),
"kb_id": kb_id,
"available_int": 0,
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
txt = f"{from_ent_name}->{to_ent_name}"
ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None:
async with chat_limiter:
timeout = 3 if enable_timeout_assertion else 300000000
ebd, _ = await asyncio.wait_for(
thread_pool_exec(
embd_mdl.encode,
[txt + f": {meta['description']}"]
),
timeout=timeout
)
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)
async def does_graph_contains(tenant_id, kb_id, doc_id):
# Get doc_ids of graph
fields = ["source_id"]
condition = {
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await thread_pool_exec(
settings.docStoreConn.search,
fields, [], condition, [], OrderByExpr(),
0, 1, search.index_name(tenant_id), [kb_id]
)
fields2 = settings.docStoreConn.get_fields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"])
return doc_id in graph_doc_ids
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])
doc_ids = []
if res.total == 0:
return doc_ids
for id in res.ids:
doc_ids = res.field[id]["source_id"]
return doc_ids
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])
if not res.total == 0:
for id in res.ids:
try:
if res.field[id]["removed_kwd"] == "N":
g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges")
if "source_id" not in g.graph:
g.graph["source_id"] = res.field[id]["source_id"]
else:
g = await rebuild_graph(tenant_id, kb_id, exclude_rebuild)
return g
except Exception:
continue
result = None
return result
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
global chat_limiter
start = asyncio.get_running_loop().time()
# Build all new chunks first (graph, subgraphs, node/edge embeddings) before
# deleting anything. This ensures that if embedding generation or any other
# step crashes, the old graph and per-doc subgraph checkpoints remain intact
# so the pipeline can resume without re-running earlier phases.
chunks = [
{
"id": get_uuid(),
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False),
"knowledge_graph_kwd": "graph",
"kb_id": kb_id,
"source_id": graph.graph.get("source_id", []),
"available_int": 0,
"removed_kwd": "N",
}
]
# generate updated subgraphs
for source in graph.graph["source_id"]:
subgraph = graph.subgraph([n for n in graph.nodes if source in graph.nodes[n]["source_id"]]).copy()
subgraph.graph["source_id"] = [source]
for n in subgraph.nodes:
subgraph.nodes[n]["source_id"] = [source]
chunks.append(
{
"id": get_uuid(),
"content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False),
"knowledge_graph_kwd": "subgraph",
"kb_id": kb_id,
"source_id": [source],
"available_int": 0,
"removed_kwd": "N",
}
)
tasks = []
for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node]
tasks.append(asyncio.create_task(
graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)
))
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in get_embedding_of_nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
tasks = []
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs:
continue
tasks.append(asyncio.create_task(
graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
))
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in get_embedding_of_edges: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
now = asyncio.get_running_loop().time()
if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now
# All new chunks are ready. Now delete old data and insert the new data.
# Deleting only after chunks are built ensures that a crash during embedding
# generation above does not destroy the old graph/subgraph checkpoints.
await thread_pool_exec(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["graph", "subgraph"]},
search.index_name(tenant_id),
kb_id
)
if change.removed_nodes:
BATCH_SIZE = 100
sorted_nodes = sorted(change.removed_nodes)
for i in range(0, len(sorted_nodes), BATCH_SIZE):
batch = sorted_nodes[i:i + BATCH_SIZE]
await thread_pool_exec(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["entity"], "entity_kwd": batch},
search.index_name(tenant_id),
kb_id
)
if change.removed_edges:
async def del_edges(from_node, to_node):
max_retries = 3
for attempt in range(max_retries):
try:
async with chat_limiter:
await thread_pool_exec(
settings.docStoreConn.delete,
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
search.index_name(tenant_id),
kb_id
)
return
except Exception as e:
if attempt < max_retries - 1:
wait = 2 ** attempt
logging.warning(f"del_edges({from_node}, {to_node}) attempt {attempt + 1} failed: {e}, retrying in {wait}s")
await asyncio.sleep(wait)
else:
raise
tasks = []
for from_node, to_node in change.removed_edges:
tasks.append(asyncio.create_task(del_edges(from_node, to_node)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error while deleting edges: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
del_now = asyncio.get_running_loop().time()
if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {del_now - start:.2f}s.")
start = del_now
await insert_chunks_bounded(chunks, tenant_id, kb_id, callback=callback, label="Insert chunks")
now = asyncio.get_running_loop().time()
if callback:
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
def is_continuous_subsequence(subseq, seq):
def find_all_indexes(tup, value):
indexes = []
start = 0
while True:
try:
index = tup.index(value, start)
indexes.append(index)
start = index + 1
except ValueError:
break
return indexes
index_list = find_all_indexes(seq, subseq[0])
for idx in index_list:
if idx != len(seq) - 1:
if seq[idx + 1] == subseq[-1]:
return True
return False
def merge_tuples(list1, list2):
result = []
for tup in list1:
last_element = tup[-1]
if last_element in tup[:-1]:
result.append(tup)
else:
matching_tuples = [t for t in list2 if t[0] == last_element]
already_match_flag = 0
for match in matching_tuples:
matchh = (match[1], match[0])
if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
continue
already_match_flag = 1
merged_tuple = tup + match[1:]
result.append(merged_tuple)
if not already_match_flag:
result.append(tup)
return result
async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
res = defaultdict(list)
for id in es_res.ids:
smp = es_res.field[id].get("content_with_weight")
if not smp:
continue
try:
smp = json.loads(smp)
except Exception as e:
logging.exception(e)
for ty, ents in smp.items():
res[ty].extend(ents)
return res
def flat_uniq_list(arr, key):
res = []
for a in arr:
a = a[key]
if isinstance(a, list):
res.extend(a)
else:
res.append(a)
return list(set(res))
async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
graph = nx.Graph()
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
bs = 256
for i in range(0, 1024 * bs, bs):
es_res = await thread_pool_exec(
settings.docStoreConn.search,
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]
)
# tot = settings.docStoreConn.get_total(es_res)
es_res = settings.docStoreConn.get_fields(es_res, flds)
if len(es_res) == 0:
break
for id, d in es_res.items():
assert d["knowledge_graph_kwd"] == "subgraph"
if isinstance(exclude_rebuild, list):
if sum([n in d["source_id"] for n in exclude_rebuild]):
continue
elif exclude_rebuild in d["source_id"]:
continue
next_graph = json_graph.node_link_graph(json.loads(d["content_with_weight"]), edges="edges")
merged_graph = nx.compose(graph, next_graph)
merged_source = {n: graph.nodes[n]["source_id"] + next_graph.nodes[n]["source_id"] for n in graph.nodes & next_graph.nodes}
nx.set_node_attributes(merged_graph, merged_source, "source_id")
if "source_id" in graph.graph:
merged_graph.graph["source_id"] = graph.graph["source_id"] + next_graph.graph["source_id"]
else:
merged_graph.graph["source_id"] = next_graph.graph["source_id"]
graph = merged_graph
if len(graph.nodes) == 0:
return None
graph.graph["source_id"] = sorted(graph.graph["source_id"])
return graph