# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ spaCy-based entity and relationship extractor for GraphRAG. Combines techniques from **LinearRAG** and **MGranRAG**: * **Entity extraction** uses MGranRAG's multi-pass stacking algorithm (hyphen/apostrophe merging → capitalised-word merging → continuous noun/number merging) combined with spaCy NER, then deduplicated via ``ner_all_keywords``. * **Relationship inference** follows LinearRAG's *relation-free* approach: entities co-occurring in the same sentence (or nearby sentences) are linked by implicit semantic edges whose description is the shared sentence text (semantic bridging). Edge weights are optionally TF- normalised. No LLM calls are needed for the extraction step itself. The LLM is only used downstream (inherited from ``Extractor``) for merging / summarising duplicate entity descriptions when the same entity appears in multiple chunks. """ import logging from collections import defaultdict from rag.graphrag.general.extractor import Extractor from rag.llm.chat_model import Base as CompletionLLM # --------------------------------------------------------------------------- # spaCy model loading (lazy, module-level singleton) # --------------------------------------------------------------------------- _nlp = None _nlp_model_name = "" def _load_spacy_model(model_name: str = "en_core_web_sm"): """Load (or return cached) spaCy language model. Automatically downloads the model if it is not yet installed. """ global _nlp, _nlp_model_name if _nlp is not None and _nlp_model_name == model_name: return _nlp try: import spacy except ImportError: raise ImportError( "spaCy is required for the spacy GraphRAG method. " "Install it with: pip install spacy && python -m spacy download en_core_web_sm" ) try: _nlp = spacy.load(model_name) logging.info("Loaded spaCy model '%s'", model_name) except OSError: logging.warning( "spaCy model '%s' not found; downloading automatically …", model_name ) from spacy.cli import download as spacy_download spacy_download(model_name) _nlp = spacy.load(model_name) logging.info("Downloaded and loaded spaCy model '%s'", model_name) _nlp_model_name = model_name return _nlp # --------------------------------------------------------------------------- # spaCy ↔ application entity-type mapping # --------------------------------------------------------------------------- # spaCy's built-in entity labels → the application-level types used by # ``DEFAULT_ENTITY_TYPES``. Labels not listed here fall through to # ``"category"``. SPACY_TO_APP_ENTITY_TYPE: dict[str, str] = { "PERSON": "person", "ORG": "organization", "GPE": "geo", "LOC": "geo", "FAC": "geo", "EVENT": "event", "PRODUCT": "category", "WORK_OF_ART": "category", "LAW": "category", "LANGUAGE": "category", "NORP": "category", "MONEY": "category", "QUANTITY": "category", "TIME": "event", "DATE": "event", } # Labels to skip entirely (from LinearRAG: ordinals / cardinals are rarely # useful as graph nodes). _SKIP_SPACY_LABELS = {"ORDINAL", "CARDINAL"} # --------------------------------------------------------------------------- # MGranRAG-style multi-pass keyword extraction # --------------------------------------------------------------------------- def _has_uppercase(text: str) -> bool: return any(c.isupper() for c in text) def _replace_word(word: str) -> str: """Normalise spaces around hyphens and apostrophes (from MGranRAG).""" return ( word.replace(" - ", "-") .replace(" -", "-") .replace("- ", "-") .replace(" 's", "'s") .replace(" 'S", "'S") ) def extract_keywords(spacy_doc) -> set[str]: """MGranRAG-style 3-pass stacking keyword extraction. Phase 1 — Hyphen / apostrophe merging: Tokens connected by ``-`` or ``'s`` are merged into a single phrase labelled ``NP`` (e.g. ``New-York``, ``cat's``). Phase 2 — Capitalised-word merging: Consecutive tokens whose ``shape_`` contains ``X`` (i.e. start with an uppercase letter) are merged. Function words (ADP, CCONJ, DET, PART) between them are absorbed as well, producing phrases like ``King of England``. Merged results are labelled ``NX`` unless already ``PROPN``. Phase 3 — Continuous noun / number merging: Consecutive tokens with POS in ``[PROPN, NOUN, NUM, NX, NP]`` are merged and labelled ``NNN`` (unless already ``PROPN``). Finally, results with a trailing lowercase non-noun word are truncated, and coordinating conjunctions (``and``, ``or``) inside a merged phrase cause it to be split so that each proper noun is extracted individually (e.g. ``Bob and Lucy`` → ``Bob``, ``Lucy``). """ # ── Phase 1: hyphen / apostrophe ────────────────────────────────── f1_word: list[str] = [] f1_shape: list[str] = [] f1_pos: list[str] = [] f1_pos_list: list[list[str]] = [] f1_word_list: list[list[str]] = [] is_right = False for token in spacy_doc: if token.shape_ in ("'x", "-") and token.pos_ in ("PUNCT", "PART"): if token.shape_ == "-": is_right = True if f1_word: f1_word[-1] += token.text f1_pos[-1] = "NP" f1_pos_list[-1].append(token.pos_) f1_word_list[-1].append(token.text) elif is_right: is_right = False if f1_word: f1_word[-1] += token.text f1_pos[-1] = "NP" f1_pos_list[-1].append(token.pos_) f1_word_list[-1].append(token.text) else: f1_word.append(token.text) f1_shape.append(token.shape_) f1_pos.append(token.pos_) f1_pos_list.append([token.pos_]) f1_word_list.append([token.text]) # ── Phase 2: capitalised-word merging ─────────────────────────── f2_word: list[str] = [] f2_shape: list[str] = [] f2_pos: list[str] = [] f2_pos_list: list[list[str]] = [] f2_word_list: list[list[str]] = [] for cur in range(len(f1_word)): cw = f1_word[cur] cs = f1_shape[cur] cp = f1_pos[cur] cpl = f1_pos_list[cur] cwl = f1_word_list[cur] if "X" in cs or cp in ("ADP", "CCONJ", "DET", "PART"): if f2_word and "X" in f2_shape[-1]: # Merge with previous capitalised token. f2_word[-1] += " " + cw f2_shape[-1] += "X" if f2_pos[-1] != "PROPN": f2_pos[-1] = "NX" f2_pos_list[-1].extend(cpl) f2_word_list[-1].extend(cwl) else: f2_word.append(cw) f2_shape.append(cs + "Start" if "X" in cs else cs) f2_pos.append(cp) f2_pos_list.append(cpl) f2_word_list.append(cwl) else: f2_word.append(cw) f2_shape.append(cs) f2_pos.append(cp) f2_pos_list.append(cpl) f2_word_list.append(cwl) # ── Phase 3: continuous noun / number merging ─────────────────── f3_word: list[str] = [] f3_shape: list[str] = [] f3_pos: list[str] = [] f3_pos_list: list[list[str]] = [] f3_word_list: list[list[str]] = [] _noun_pos = {"PROPN", "NOUN", "NUM", "NX", "NP"} _noun_pos_ext = _noun_pos | {"NNN"} for cur in range(len(f2_word)): cw = f2_word[cur] cs = f2_shape[cur] cp = f2_pos[cur] cpl = f2_pos_list[cur] cwl = f2_word_list[cur] if cp in _noun_pos: if f3_word and f3_pos[-1] in _noun_pos_ext: f3_word[-1] += " " + cw f3_shape[-1] += "X" if f3_pos[-1] != "PROPN": f3_pos[-1] = "NNN" f3_pos_list[-1].extend(cpl) f3_word_list[-1].extend(cwl) else: f3_word.append(cw) f3_shape.append(cs) f3_pos.append(cp) f3_pos_list.append(cpl) f3_word_list.append(cwl) else: f3_word.append(cw) f3_shape.append(cs) f3_pos.append(cp) f3_pos_list.append(cpl) f3_word_list.append(cwl) # ── Final keyword collection ──────────────────────────────────── keywords: set[str] = set() for cur in range(len(f3_word)): cw = f3_word[cur] cp = f3_pos[cur] cpl = f3_pos_list[cur] cwl = f3_word_list[cur] if cp not in _noun_pos_ext: continue # Truncate trailing lowercase non-noun / non-number words. if cwl and not _has_uppercase(cwl[-1]) and cpl[-1] not in ( "PROPN", "NOUN", "NUM", "PART", ): for i in range(len(cpl) - 1, 0, -1): if cpl[i] in ("PROPN", "NOUN", "NUM", "PART") or _has_uppercase( cwl[i] ): break word = _replace_word(" ".join(cwl[: i + 1])) keywords.add(word) else: word = _replace_word(cw) keywords.add(word) # Split on coordinating conjunctions (and/or) inside merged # phrases so that individual proper nouns are also extracted # (e.g. ``Bob and Lucy`` → ``Bob``, ``Lucy``). if any(p in ("PROPN", "NOUN", "NUM") for p in cpl): cur_kws: list[str] = [] for pidx, pos in enumerate(cpl): if pos == "CCONJ" and cwl[pidx] and cwl[pidx][0].islower(): if cur_kws: keywords.add(_replace_word(" ".join(cur_kws))) cur_kws = [] else: cur_kws.append(cwl[pidx]) if cur_kws: keywords.add(_replace_word(" ".join(cur_kws))) return keywords def get_ner(spacy_doc) -> dict[str, str]: """Return ``{entity_text: spaCy_label}`` for all NER entities.""" entities_dict: dict[str, str] = {} for ent in spacy_doc.ents: if ent.label_ in _SKIP_SPACY_LABELS: continue text = ent.text.strip() for t in text.split("\n"): t = t.strip() if t: entities_dict[t] = ent.label_ return entities_dict def ner_all_keywords(spacy_doc) -> set[str]: """Combine rule-based keyword extraction with spaCy NER (MGranRAG). Returns the union of: - keywords from the 3-pass stacking algorithm (``extract_keywords``) - entity texts from spaCy NER (``get_ner``) """ keywords = extract_keywords(spacy_doc) ner_dict = get_ner(spacy_doc) return keywords.union(ner_dict.keys()) # --------------------------------------------------------------------------- # Main extractor class # --------------------------------------------------------------------------- class GraphExtractor(Extractor): """Extract entities and relationships using spaCy (no LLM calls). Entity extraction MGranRAG's ``ner_all_keywords`` combines a 3-pass stacking keyword algorithm with spaCy NER, yielding broader coverage than NER alone (e.g. it catches compound nouns, hyphenated terms, and multi-word proper nouns that NER might miss). Relationship inference LinearRAG's *relation-free* semantic bridging: entities co-occurring in the same sentence (or within ``max_sentence_distance`` sentences) are linked by an implicit edge. The edge description is the shared sentence text, which provides natural language context without requiring an LLM. Optionally, edge weights are TF-normalised (LinearRAG): ``weight = count(entity_in_chunk) / sum(all_entity_counts_in_chunk)``. The ``llm_invoker`` is only used downstream for merging / summarising duplicate descriptions (inherited from ``Extractor``). Parameters ---------- llm_invoker : CompletionLLM LLM handle (used only for description summarisation, not extraction). language : str Language hint. entity_types : list[str] | None Application-level entity types to keep. Entities whose mapped type is not in this list are discarded. spacy_model : str Name of the spaCy model to load (default ``en_core_web_sm``). max_sentence_distance : int When inferring relationships, pair entities that co-occur within the same sentence. If > 1, also pair entities in sentences whose indices differ by at most this value. relationship_strength : int Default weight assigned to every inferred relationship when ``use_tf_weight`` is ``False``. use_tf_weight : bool If ``True``, use TF-normalised weighting (LinearRAG-style) for edge weights instead of the fixed ``relationship_strength``. """ def __init__( self, llm_invoker: CompletionLLM, language: str | None = "English", entity_types: list[str] | None = None, spacy_model: str = "en_core_web_sm", max_sentence_distance: int = 1, relationship_strength: int = 1, use_tf_weight: bool = False, ): super().__init__(llm_invoker, language, entity_types) self._spacy_model_name = spacy_model self._max_sentence_distance = max_sentence_distance self._relationship_strength = relationship_strength self._use_tf_weight = use_tf_weight # Eagerly load the model so import errors surface early. self._nlp = _load_spacy_model(spacy_model) # ------------------------------------------------------------------ # Public interface – called by ``Extractor.__call__`` # ------------------------------------------------------------------ async def _process_single_content( self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results, task_id="", ): """Process one chunk through spaCy NER + keyword stacking + co-occurrence.""" chunk_key = chunk_key_dp[0] content = chunk_key_dp[1] doc = self._nlp(content) # ── 1. Entity extraction (MGranRAG: ner_all_keywords) ──────── # Build a mapping from keyword text → spaCy label (if available). ner_label_map: dict[str, str] = get_ner(doc) all_keywords = ner_all_keywords(doc) # For each keyword, determine its app-level entity type. # - If the keyword matches a NER entity, use that label. # - Otherwise, infer from POS heuristics. ent_records: dict[str, dict] = {} # entity_name_upper → record ent_by_sent: dict[int, list[dict]] = defaultdict(list) for kw in all_keywords: kw_upper = kw.strip().upper() if not kw_upper: continue # Determine entity type. spacy_label = ner_label_map.get(kw) if spacy_label: app_type = SPACY_TO_APP_ENTITY_TYPE.get(spacy_label, "category") else: app_type = self._infer_type_from_pos(doc, kw) if app_type not in self._entity_types_set: continue # Determine which sentence this keyword belongs to. sent_idx = self._keyword_sent_idx(doc, kw) # Description: use the containing sentence (LinearRAG semantic bridging). #sent_text = self._keyword_sent_text(doc, kw) ent_record = dict( entity_name=kw_upper, entity_type=app_type.upper(), description="", #sent_text or kw, source_id=chunk_key, ) # A keyword may appear multiple times; keep the first. if kw_upper not in ent_records: ent_records[kw_upper] = ent_record ent_by_sent[sent_idx].append(ent_record) maybe_nodes: dict[str, list[dict]] = defaultdict(list) for name, rec in ent_records.items(): maybe_nodes[name].append(rec) # ── 2. Relationship inference (LinearRAG: sentence co-occurrence) ─ maybe_edges: dict[tuple, list[dict]] = defaultdict(list) # Pre-compute TF weights if needed (LinearRAG). entity_tf: dict[str, float] = {} if self._use_tf_weight: total_count = sum( content.upper().count(name) for name in ent_records ) for name in ent_records: count = content.upper().count(name) entity_tf[name] = count / total_count if total_count > 0 else 0.0 seen_pairs: set[tuple[str, str]] = set() for si in sorted(ent_by_sent.keys()): ents_in_range = list(ent_by_sent[si]) # Expand with nearby sentences. for offset in range(1, self._max_sentence_distance + 1): for nb_si in (si + offset, si - offset): if nb_si in ent_by_sent: ents_in_range.extend(ent_by_sent[nb_si]) # Deduplicate by entity name. unique: dict[str, dict] = {} for e in ents_in_range: unique[e["entity_name"]] = e ent_list = list(unique.values()) for a_idx in range(len(ent_list)): for b_idx in range(a_idx + 1, len(ent_list)): ea, eb = ent_list[a_idx], ent_list[b_idx] pair = tuple(sorted([ea["entity_name"], eb["entity_name"]])) if pair in seen_pairs: continue seen_pairs.add(pair) # Relationship description: shared sentence text # (LinearRAG semantic bridging — the sentence is the # semantic bridge between entities). #desc = self._cooccurrence_description(doc, ea["entity_name"], eb["entity_name"]) # Edge weight: TF-normalised (LinearRAG) or fixed. if self._use_tf_weight: w = (entity_tf.get(ea["entity_name"], 0.0) + entity_tf.get(eb["entity_name"], 0.0)) weight = max(w, 0.01) else: weight = self._relationship_strength # Keywords for the edge: the two entity names. edge_record = dict( src_id=pair[0], tgt_id=pair[1], weight=weight, description="", #desc, keywords=[ea["entity_name"], eb["entity_name"]], source_id=chunk_key, ) maybe_edges[pair].append(edge_record) token_count = len(doc) out_results.append((dict(maybe_nodes), dict(maybe_edges), token_count)) if self.callback: self.callback( 0.5 + 0.1 * len(out_results) / num_chunks, msg=f"[spacy] Entities extraction of chunk {chunk_seq+1} " f"{len(out_results)}/{num_chunks} done, " f"{len(maybe_nodes)} nodes, {len(maybe_edges)} edges, " f"{token_count} tokens.", ) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @property def _entity_types_set(self) -> set[str]: return {t.lower() for t in self._entity_types} @staticmethod def _infer_type_from_pos(doc, keyword: str) -> str: """Infer an application-level entity type from POS tags when the keyword was found by the stacking algorithm but not by NER.""" kw_upper = keyword.upper() for token in doc: if token.text.upper() == kw_upper or token.text.upper().startswith(kw_upper.split()[0]): if token.pos_ == "PROPN": return "person" if token.pos_ == "NOUN": return "category" if token.pos_ == "NUM": return "event" break # Fallback: check for uppercase → likely a named entity. if _has_uppercase(keyword): return "person" return "category" @staticmethod def _keyword_sent_idx(doc, keyword: str) -> int: """Return the sentence index that contains *keyword*.""" kw_lower = keyword.lower() for i, sent in enumerate(doc.sents): if kw_lower in sent.text.lower(): return i return 0 @staticmethod def _keyword_sent_text(doc, keyword: str) -> str | None: """Return the sentence text containing *keyword* (LinearRAG semantic bridging).""" kw_lower = keyword.lower() for sent in doc.sents: if kw_lower in sent.text.lower(): return sent.text.strip() return None @staticmethod def _cooccurrence_description(doc, head_name: str, tail_name: str) -> str: """Derive a relationship description using sentence co-occurrence (LinearRAG) with dependency-path enhancement as fallback. If both entities appear in the same sentence, that sentence is used as the description (semantic bridging). Otherwise, try to find a lowest common ancestor in the dependency tree. As a last resort, return a generic statement. """ head_lower = head_name.lower() tail_lower = tail_name.lower() # Primary: shared sentence text (LinearRAG semantic bridging). for sent in doc.sents: sent_lower = sent.text.lower() if head_lower in sent_lower and tail_lower in sent_lower: return sent.text.strip() # Fallback: dependency path via LCA. head_tok = GraphExtractor._find_token_by_text(doc, head_name) tail_tok = GraphExtractor._find_token_by_text(doc, tail_name) if head_tok is not None and tail_tok is not None: path_head = list(GraphExtractor._ancestor_path(head_tok)) path_tail = list(GraphExtractor._ancestor_path(tail_tok)) lca = None for h in path_head: for t in path_tail: if h == t: lca = h break if lca is not None: break if lca is not None and lca is not head_tok and lca is not tail_tok: return f"{head_name} is related to {tail_name} via '{lca.lemma_}'" # Final fallback: nearby sentences. head_sent = GraphExtractor._find_sent_for_text(doc, head_lower) if head_sent is not None: return head_sent.text.strip() return f"{head_name} is related to {tail_name}" @staticmethod def _find_token_by_text(doc, ent_name: str): """Return the head token of the first spaCy entity matching *ent_name*.""" target = ent_name.upper() for ent in doc.ents: if ent.text.strip().upper() == target: return ent.root # Fallback: token-level match for keywords not in doc.ents. for token in doc: if token.text.strip().upper() == target: return token return None @staticmethod def _find_sent_for_text(doc, text_lower: str): """Return the first ``Span`` whose text contains *text_lower*.""" for sent in doc.sents: if text_lower in sent.text.lower(): return sent return None @staticmethod def _ancestor_path(token): """Yield *token* then each ancestor up to the root.""" yield token for anc in token.ancestors: yield anc