Files
ragflow/test/unit_test/rag/graphrag/test_graphrag_utils.py
tunsuy e1f1184b01 test: add unit tests for graphrag/utils.py (87 test cases) (#13328)
Add comprehensive unit tests for `graphrag/utils.py`, covering 15
functions/classes with 87 test cases.

Tested functions:
- clean_str, dict_has_keys_with_types, perform_variable_replacements
- get_from_to, compute_args_hash, is_float_regex
- GraphChange dataclass
- handle_single_entity_extraction, handle_single_relationship_extraction
- graph_merge, tidy_graph
- split_string_by_multi_markers, pack_user_ass_to_openai_messages
- is_continuous_subsequence, merge_tuples, flat_uniq_list

All 327 existing + new tests pass with no regressions.
2026-03-05 15:30:43 +08:00

533 lines
17 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import networkx as nx
import pytest
from rag.graphrag.utils import (
GRAPH_FIELD_SEP,
GraphChange,
clean_str,
compute_args_hash,
dict_has_keys_with_types,
flat_uniq_list,
get_from_to,
graph_merge,
handle_single_entity_extraction,
handle_single_relationship_extraction,
is_continuous_subsequence,
is_float_regex,
merge_tuples,
pack_user_ass_to_openai_messages,
perform_variable_replacements,
split_string_by_multi_markers,
tidy_graph,
)
class TestCleanStr:
"""Tests for clean_str function."""
def test_basic_string(self):
assert clean_str("hello world") == "hello world"
def test_strips_whitespace(self):
assert clean_str(" hello ") == "hello"
def test_removes_html_escapes(self):
assert clean_str("&amp; &lt; &gt;") == "& < >"
def test_removes_control_characters(self):
assert clean_str("hello\x00world") == "helloworld"
assert clean_str("test\x1f") == "test"
assert clean_str("\x7fdata") == "data"
def test_removes_double_quotes(self):
assert clean_str('"quoted"') == "quoted"
def test_non_string_passthrough(self):
assert clean_str(123) == 123
assert clean_str(None) is None
assert clean_str([1, 2]) == [1, 2]
def test_empty_string(self):
assert clean_str("") == ""
def test_combined_html_and_control(self):
assert clean_str(" &amp;\x00test\x1f ") == "&test"
class TestDictHasKeysWithTypes:
"""Tests for dict_has_keys_with_types function."""
def test_matching_keys_and_types(self):
data = {"name": "Alice", "age": 30}
assert dict_has_keys_with_types(data, [("name", str), ("age", int)]) is True
def test_missing_key(self):
data = {"name": "Alice"}
assert dict_has_keys_with_types(data, [("name", str), ("age", int)]) is False
def test_wrong_type(self):
data = {"name": "Alice", "age": "thirty"}
assert dict_has_keys_with_types(data, [("name", str), ("age", int)]) is False
def test_empty_expected_fields(self):
assert dict_has_keys_with_types({"a": 1}, []) is True
def test_empty_data(self):
assert dict_has_keys_with_types({}, [("key", str)]) is False
def test_subclass_type_match(self):
assert dict_has_keys_with_types({"val": True}, [("val", int)]) is True
class TestPerformVariableReplacements:
"""Tests for perform_variable_replacements function."""
def test_simple_replacement(self):
result = perform_variable_replacements("Hello {name}!", variables={"name": "World"})
assert result == "Hello World!"
def test_multiple_replacements(self):
result = perform_variable_replacements(
"{greeting} {name}!",
variables={"greeting": "Hi", "name": "Alice"},
)
assert result == "Hi Alice!"
def test_no_variables(self):
result = perform_variable_replacements("No vars here")
assert result == "No vars here"
def test_empty_variables_dict(self):
result = perform_variable_replacements("{keep}", variables={})
assert result == "{keep}"
def test_history_system_message_replacement(self):
history = [
{"role": "system", "content": "You are {role}"},
{"role": "user", "content": "Hello {role}"},
]
perform_variable_replacements("input", history=history, variables={"role": "assistant"})
assert history[0]["content"] == "You are assistant"
assert history[1]["content"] == "Hello {role}"
def test_none_defaults(self):
result = perform_variable_replacements("text")
assert result == "text"
def test_non_string_variable_value(self):
result = perform_variable_replacements("count: {n}", variables={"n": 42})
assert result == "count: 42"
class TestGetFromTo:
"""Tests for get_from_to function."""
def test_ordered_pair(self):
assert get_from_to("A", "B") == ("A", "B")
def test_reversed_pair(self):
assert get_from_to("B", "A") == ("A", "B")
def test_equal_values(self):
assert get_from_to("X", "X") == ("X", "X")
def test_numeric_strings(self):
assert get_from_to("2", "1") == ("1", "2")
class TestComputeArgsHash:
"""Tests for compute_args_hash function."""
def test_deterministic(self):
h1 = compute_args_hash("a", "b", "c")
h2 = compute_args_hash("a", "b", "c")
assert h1 == h2
def test_different_args_different_hash(self):
h1 = compute_args_hash("a", "b")
h2 = compute_args_hash("a", "c")
assert h1 != h2
def test_returns_hex_string(self):
result = compute_args_hash("test")
assert isinstance(result, str)
assert len(result) == 32
int(result, 16)
def test_empty_args(self):
result = compute_args_hash()
assert isinstance(result, str)
class TestIsFloatRegex:
"""Tests for is_float_regex function."""
@pytest.mark.parametrize(
"value",
["1.0", "0.5", "100", "-3.14", "+2.7", ".5", "0"],
)
def test_valid_floats(self, value):
assert is_float_regex(value)
@pytest.mark.parametrize(
"value",
["abc", "", "1.2.3", "1e10", "inf", "NaN", " 1.0", "1.0 "],
)
def test_invalid_floats(self, value):
assert not is_float_regex(value)
class TestGraphChange:
"""Tests for GraphChange dataclass."""
def test_default_empty_sets(self):
change = GraphChange()
assert change.removed_nodes == set()
assert change.added_updated_nodes == set()
assert change.removed_edges == set()
assert change.added_updated_edges == set()
def test_mutable_default_independence(self):
c1 = GraphChange()
c2 = GraphChange()
c1.removed_nodes.add("A")
assert "A" not in c2.removed_nodes
class TestHandleSingleEntityExtraction:
"""Tests for handle_single_entity_extraction function."""
def test_valid_entity(self):
attrs = ['"entity"', "Alice", "Person", "A character"]
result = handle_single_entity_extraction(attrs, "chunk1")
assert result is not None
assert result["entity_name"] == "ALICE"
assert result["entity_type"] == "PERSON"
assert result["description"] == "A character"
assert result["source_id"] == "chunk1"
def test_not_entity_type(self):
attrs = ['"relationship"', "A", "B", "desc"]
assert handle_single_entity_extraction(attrs, "c1") is None
def test_too_few_attributes(self):
attrs = ['"entity"', "name", "type"]
assert handle_single_entity_extraction(attrs, "c1") is None
def test_empty_entity_name(self):
attrs = ['"entity"', '""', "Type", "Desc"]
assert handle_single_entity_extraction(attrs, "c1") is None
def test_entity_name_uppercased(self):
attrs = ['"entity"', "alice", "person", "desc"]
result = handle_single_entity_extraction(attrs, "c1")
assert result["entity_name"] == "ALICE"
assert result["entity_type"] == "PERSON"
class TestHandleSingleRelationshipExtraction:
"""Tests for handle_single_relationship_extraction function."""
def test_valid_relationship(self):
attrs = ['"relationship"', "Alice", "Bob", "friends with", "friendship", "2.0"]
result = handle_single_relationship_extraction(attrs, "chunk1")
assert result is not None
assert result["src_id"] == "ALICE"
assert result["tgt_id"] == "BOB"
assert result["weight"] == 2.0
assert result["description"] == "friends with"
assert result["keywords"] == "friendship"
assert result["source_id"] == "chunk1"
assert "created_at" in result["metadata"]
def test_not_relationship_type(self):
attrs = ['"entity"', "A", "B", "desc", "kw"]
assert handle_single_relationship_extraction(attrs, "c1") is None
def test_too_few_attributes(self):
attrs = ['"relationship"', "A", "B", "desc"]
assert handle_single_relationship_extraction(attrs, "c1") is None
def test_non_float_weight_defaults_to_one(self):
attrs = ['"relationship"', "A", "B", "desc", "kw", "not_a_number"]
result = handle_single_relationship_extraction(attrs, "c1")
assert result["weight"] == 1.0
def test_source_target_sorted(self):
attrs = ['"relationship"', "Zebra", "Apple", "desc", "kw", "1.0"]
result = handle_single_relationship_extraction(attrs, "c1")
assert result["src_id"] == "APPLE"
assert result["tgt_id"] == "ZEBRA"
class TestPackUserAssToOpenaiMessages:
"""Tests for pack_user_ass_to_openai_messages function."""
def test_single_message(self):
result = pack_user_ass_to_openai_messages("hello")
assert result == [{"role": "user", "content": "hello"}]
def test_alternating_roles(self):
result = pack_user_ass_to_openai_messages("q1", "a1", "q2")
assert result == [
{"role": "user", "content": "q1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "q2"},
]
def test_empty(self):
result = pack_user_ass_to_openai_messages()
assert result == []
class TestSplitStringByMultiMarkers:
"""Tests for split_string_by_multi_markers function."""
def test_single_marker(self):
result = split_string_by_multi_markers("a|b|c", ["|"])
assert result == ["a", "b", "c"]
def test_multiple_markers(self):
result = split_string_by_multi_markers("a|b;c", ["|", ";"])
assert result == ["a", "b", "c"]
def test_no_markers(self):
result = split_string_by_multi_markers("abc", [])
assert result == ["abc"]
def test_strips_whitespace(self):
result = split_string_by_multi_markers("a | b | c", ["|"])
assert result == ["a", "b", "c"]
def test_empty_segments_removed(self):
result = split_string_by_multi_markers("a||b", ["|"])
assert result == ["a", "b"]
def test_regex_special_chars_escaped(self):
result = split_string_by_multi_markers("a.b.c", ["."])
assert result == ["a", "b", "c"]
class TestGraphMerge:
"""Tests for graph_merge function."""
def _make_node(self, description="desc", source_id=None):
return {"description": description, "source_id": source_id or ["s1"]}
def _make_edge(self, weight=1.0, description="edge", keywords=None, source_id=None):
return {
"weight": weight,
"description": description,
"keywords": keywords or [],
"source_id": source_id or ["s1"],
}
def test_merge_disjoint_graphs(self):
g1 = nx.Graph()
g1.add_node("A", **self._make_node("A desc"))
g1.graph["source_id"] = ["doc1"]
g2 = nx.Graph()
g2.add_node("B", **self._make_node("B desc"))
g2.graph["source_id"] = ["doc2"]
change = GraphChange()
result = graph_merge(g1, g2, change)
assert result.has_node("A")
assert result.has_node("B")
assert "B" in change.added_updated_nodes
assert result.graph["source_id"] == ["doc1", "doc2"]
def test_merge_overlapping_nodes(self):
g1 = nx.Graph()
g1.add_node("A", description="first", source_id=["s1"])
g1.graph["source_id"] = ["doc1"]
g2 = nx.Graph()
g2.add_node("A", description="second", source_id=["s2"])
g2.graph["source_id"] = ["doc2"]
change = GraphChange()
graph_merge(g1, g2, change)
assert f"first{GRAPH_FIELD_SEP}second" == g1.nodes["A"]["description"]
assert g1.nodes["A"]["source_id"] == ["s1", "s2"]
def test_merge_overlapping_edges(self):
g1 = nx.Graph()
g1.add_node("A", **self._make_node())
g1.add_node("B", **self._make_node())
g1.add_edge("A", "B", **self._make_edge(weight=1.0, description="e1", keywords=["k1"], source_id=["s1"]))
g1.graph["source_id"] = ["doc1"]
g2 = nx.Graph()
g2.add_node("A", **self._make_node())
g2.add_node("B", **self._make_node())
g2.add_edge("A", "B", **self._make_edge(weight=2.0, description="e2", keywords=["k2"], source_id=["s2"]))
g2.graph["source_id"] = ["doc2"]
change = GraphChange()
graph_merge(g1, g2, change)
edge = g1.get_edge_data("A", "B")
assert edge["weight"] == 3.0
assert f"e1{GRAPH_FIELD_SEP}e2" == edge["description"]
assert edge["keywords"] == ["k1", "k2"]
assert edge["source_id"] == ["s1", "s2"]
def test_merge_tracks_changes(self):
g1 = nx.Graph()
g1.graph["source_id"] = []
g2 = nx.Graph()
g2.add_node("X", **self._make_node())
g2.add_node("Y", **self._make_node())
g2.add_edge("X", "Y", **self._make_edge())
g2.graph["source_id"] = ["doc1"]
change = GraphChange()
graph_merge(g1, g2, change)
assert {"X", "Y"} == change.added_updated_nodes
assert {("X", "Y")} == change.added_updated_edges
def test_merge_sets_rank(self):
g1 = nx.Graph()
g1.graph["source_id"] = []
g2 = nx.Graph()
g2.add_node("A", **self._make_node())
g2.add_node("B", **self._make_node())
g2.add_edge("A", "B", **self._make_edge())
g2.graph["source_id"] = ["doc1"]
change = GraphChange()
graph_merge(g1, g2, change)
assert g1.nodes["A"]["rank"] == 1
assert g1.nodes["B"]["rank"] == 1
class TestTidyGraph:
"""Tests for tidy_graph function."""
def test_removes_nodes_missing_attributes(self):
g = nx.Graph()
g.add_node("good", description="d", source_id="s")
g.add_node("bad")
messages = []
tidy_graph(g, lambda msg: messages.append(msg))
assert g.has_node("good")
assert not g.has_node("bad")
assert len(messages) == 1
def test_removes_edges_missing_attributes(self):
g = nx.Graph()
g.add_node("A", description="d", source_id="s")
g.add_node("B", description="d", source_id="s")
g.add_edge("A", "B")
messages = []
tidy_graph(g, lambda msg: messages.append(msg))
assert not g.has_edge("A", "B")
def test_adds_keywords_to_edges_without_it(self):
g = nx.Graph()
g.add_node("A", description="d", source_id="s")
g.add_node("B", description="d", source_id="s")
g.add_edge("A", "B", description="d", source_id="s")
tidy_graph(g, None)
assert g.edges["A", "B"]["keywords"] == []
def test_skip_attribute_check(self):
g = nx.Graph()
g.add_node("no_attrs")
g.add_edge("no_attrs", "no_attrs")
tidy_graph(g, None, check_attribute=False)
assert g.has_node("no_attrs")
def test_none_callback_no_error(self):
g = nx.Graph()
g.add_node("bad")
tidy_graph(g, None)
assert not g.has_node("bad")
class TestIsContinuousSubsequence:
"""Tests for is_continuous_subsequence function."""
def test_basic_match(self):
assert is_continuous_subsequence(("A", "B"), ("A", "B", "C")) is True
def test_no_match(self):
assert is_continuous_subsequence(("A", "C"), ("A", "B", "C")) is False
def test_at_end(self):
assert is_continuous_subsequence(("B", "C"), ("A", "B", "C")) is True
def test_single_element_sequence(self):
assert is_continuous_subsequence(("A", "B"), ("A",)) is False
class TestMergeTuples:
"""Tests for merge_tuples function."""
def test_basic_merge(self):
list1 = [("A", "B")]
list2 = [("B", "C")]
result = merge_tuples(list1, list2)
assert ("A", "B", "C") in result
def test_no_merge_possible(self):
list1 = [("A", "B")]
list2 = [("C", "D")]
result = merge_tuples(list1, list2)
assert ("A", "B") in result
def test_self_loop_kept(self):
list1 = [("A", "B", "A")]
list2 = []
result = merge_tuples(list1, list2)
assert ("A", "B", "A") in result
def test_empty_lists(self):
assert merge_tuples([], []) == []
class TestFlatUniqList:
"""Tests for flat_uniq_list function."""
def test_flat_lists(self):
arr = [{"k": [1, 2]}, {"k": [2, 3]}]
result = flat_uniq_list(arr, "k")
assert set(result) == {1, 2, 3}
def test_scalar_values(self):
arr = [{"k": "a"}, {"k": "b"}, {"k": "a"}]
result = flat_uniq_list(arr, "k")
assert set(result) == {"a", "b"}
def test_empty_list(self):
assert flat_uniq_list([], "k") == []
def test_mixed_list_and_scalar(self):
arr = [{"k": [1, 2]}, {"k": 3}]
result = flat_uniq_list(arr, "k")
assert set(result) == {1, 2, 3}