mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-03-07 16:46:03 +08:00
Add comprehensive unit tests for `graphrag/utils.py`, covering 15 functions/classes with 87 test cases. Tested functions: - clean_str, dict_has_keys_with_types, perform_variable_replacements - get_from_to, compute_args_hash, is_float_regex - GraphChange dataclass - handle_single_entity_extraction, handle_single_relationship_extraction - graph_merge, tidy_graph - split_string_by_multi_markers, pack_user_ass_to_openai_messages - is_continuous_subsequence, merge_tuples, flat_uniq_list All 327 existing + new tests pass with no regressions.
533 lines
17 KiB
Python
533 lines
17 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import networkx as nx
|
|
import pytest
|
|
|
|
from rag.graphrag.utils import (
|
|
GRAPH_FIELD_SEP,
|
|
GraphChange,
|
|
clean_str,
|
|
compute_args_hash,
|
|
dict_has_keys_with_types,
|
|
flat_uniq_list,
|
|
get_from_to,
|
|
graph_merge,
|
|
handle_single_entity_extraction,
|
|
handle_single_relationship_extraction,
|
|
is_continuous_subsequence,
|
|
is_float_regex,
|
|
merge_tuples,
|
|
pack_user_ass_to_openai_messages,
|
|
perform_variable_replacements,
|
|
split_string_by_multi_markers,
|
|
tidy_graph,
|
|
)
|
|
|
|
|
|
class TestCleanStr:
|
|
"""Tests for clean_str function."""
|
|
|
|
def test_basic_string(self):
|
|
assert clean_str("hello world") == "hello world"
|
|
|
|
def test_strips_whitespace(self):
|
|
assert clean_str(" hello ") == "hello"
|
|
|
|
def test_removes_html_escapes(self):
|
|
assert clean_str("& < >") == "& < >"
|
|
|
|
def test_removes_control_characters(self):
|
|
assert clean_str("hello\x00world") == "helloworld"
|
|
assert clean_str("test\x1f") == "test"
|
|
assert clean_str("\x7fdata") == "data"
|
|
|
|
def test_removes_double_quotes(self):
|
|
assert clean_str('"quoted"') == "quoted"
|
|
|
|
def test_non_string_passthrough(self):
|
|
assert clean_str(123) == 123
|
|
assert clean_str(None) is None
|
|
assert clean_str([1, 2]) == [1, 2]
|
|
|
|
def test_empty_string(self):
|
|
assert clean_str("") == ""
|
|
|
|
def test_combined_html_and_control(self):
|
|
assert clean_str(" &\x00test\x1f ") == "&test"
|
|
|
|
|
|
class TestDictHasKeysWithTypes:
|
|
"""Tests for dict_has_keys_with_types function."""
|
|
|
|
def test_matching_keys_and_types(self):
|
|
data = {"name": "Alice", "age": 30}
|
|
assert dict_has_keys_with_types(data, [("name", str), ("age", int)]) is True
|
|
|
|
def test_missing_key(self):
|
|
data = {"name": "Alice"}
|
|
assert dict_has_keys_with_types(data, [("name", str), ("age", int)]) is False
|
|
|
|
def test_wrong_type(self):
|
|
data = {"name": "Alice", "age": "thirty"}
|
|
assert dict_has_keys_with_types(data, [("name", str), ("age", int)]) is False
|
|
|
|
def test_empty_expected_fields(self):
|
|
assert dict_has_keys_with_types({"a": 1}, []) is True
|
|
|
|
def test_empty_data(self):
|
|
assert dict_has_keys_with_types({}, [("key", str)]) is False
|
|
|
|
def test_subclass_type_match(self):
|
|
assert dict_has_keys_with_types({"val": True}, [("val", int)]) is True
|
|
|
|
|
|
class TestPerformVariableReplacements:
|
|
"""Tests for perform_variable_replacements function."""
|
|
|
|
def test_simple_replacement(self):
|
|
result = perform_variable_replacements("Hello {name}!", variables={"name": "World"})
|
|
assert result == "Hello World!"
|
|
|
|
def test_multiple_replacements(self):
|
|
result = perform_variable_replacements(
|
|
"{greeting} {name}!",
|
|
variables={"greeting": "Hi", "name": "Alice"},
|
|
)
|
|
assert result == "Hi Alice!"
|
|
|
|
def test_no_variables(self):
|
|
result = perform_variable_replacements("No vars here")
|
|
assert result == "No vars here"
|
|
|
|
def test_empty_variables_dict(self):
|
|
result = perform_variable_replacements("{keep}", variables={})
|
|
assert result == "{keep}"
|
|
|
|
def test_history_system_message_replacement(self):
|
|
history = [
|
|
{"role": "system", "content": "You are {role}"},
|
|
{"role": "user", "content": "Hello {role}"},
|
|
]
|
|
perform_variable_replacements("input", history=history, variables={"role": "assistant"})
|
|
assert history[0]["content"] == "You are assistant"
|
|
assert history[1]["content"] == "Hello {role}"
|
|
|
|
def test_none_defaults(self):
|
|
result = perform_variable_replacements("text")
|
|
assert result == "text"
|
|
|
|
def test_non_string_variable_value(self):
|
|
result = perform_variable_replacements("count: {n}", variables={"n": 42})
|
|
assert result == "count: 42"
|
|
|
|
|
|
class TestGetFromTo:
|
|
"""Tests for get_from_to function."""
|
|
|
|
def test_ordered_pair(self):
|
|
assert get_from_to("A", "B") == ("A", "B")
|
|
|
|
def test_reversed_pair(self):
|
|
assert get_from_to("B", "A") == ("A", "B")
|
|
|
|
def test_equal_values(self):
|
|
assert get_from_to("X", "X") == ("X", "X")
|
|
|
|
def test_numeric_strings(self):
|
|
assert get_from_to("2", "1") == ("1", "2")
|
|
|
|
|
|
class TestComputeArgsHash:
|
|
"""Tests for compute_args_hash function."""
|
|
|
|
def test_deterministic(self):
|
|
h1 = compute_args_hash("a", "b", "c")
|
|
h2 = compute_args_hash("a", "b", "c")
|
|
assert h1 == h2
|
|
|
|
def test_different_args_different_hash(self):
|
|
h1 = compute_args_hash("a", "b")
|
|
h2 = compute_args_hash("a", "c")
|
|
assert h1 != h2
|
|
|
|
def test_returns_hex_string(self):
|
|
result = compute_args_hash("test")
|
|
assert isinstance(result, str)
|
|
assert len(result) == 32
|
|
int(result, 16)
|
|
|
|
def test_empty_args(self):
|
|
result = compute_args_hash()
|
|
assert isinstance(result, str)
|
|
|
|
|
|
class TestIsFloatRegex:
|
|
"""Tests for is_float_regex function."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"value",
|
|
["1.0", "0.5", "100", "-3.14", "+2.7", ".5", "0"],
|
|
)
|
|
def test_valid_floats(self, value):
|
|
assert is_float_regex(value)
|
|
|
|
@pytest.mark.parametrize(
|
|
"value",
|
|
["abc", "", "1.2.3", "1e10", "inf", "NaN", " 1.0", "1.0 "],
|
|
)
|
|
def test_invalid_floats(self, value):
|
|
assert not is_float_regex(value)
|
|
|
|
|
|
class TestGraphChange:
|
|
"""Tests for GraphChange dataclass."""
|
|
|
|
def test_default_empty_sets(self):
|
|
change = GraphChange()
|
|
assert change.removed_nodes == set()
|
|
assert change.added_updated_nodes == set()
|
|
assert change.removed_edges == set()
|
|
assert change.added_updated_edges == set()
|
|
|
|
def test_mutable_default_independence(self):
|
|
c1 = GraphChange()
|
|
c2 = GraphChange()
|
|
c1.removed_nodes.add("A")
|
|
assert "A" not in c2.removed_nodes
|
|
|
|
|
|
class TestHandleSingleEntityExtraction:
|
|
"""Tests for handle_single_entity_extraction function."""
|
|
|
|
def test_valid_entity(self):
|
|
attrs = ['"entity"', "Alice", "Person", "A character"]
|
|
result = handle_single_entity_extraction(attrs, "chunk1")
|
|
assert result is not None
|
|
assert result["entity_name"] == "ALICE"
|
|
assert result["entity_type"] == "PERSON"
|
|
assert result["description"] == "A character"
|
|
assert result["source_id"] == "chunk1"
|
|
|
|
def test_not_entity_type(self):
|
|
attrs = ['"relationship"', "A", "B", "desc"]
|
|
assert handle_single_entity_extraction(attrs, "c1") is None
|
|
|
|
def test_too_few_attributes(self):
|
|
attrs = ['"entity"', "name", "type"]
|
|
assert handle_single_entity_extraction(attrs, "c1") is None
|
|
|
|
def test_empty_entity_name(self):
|
|
attrs = ['"entity"', '""', "Type", "Desc"]
|
|
assert handle_single_entity_extraction(attrs, "c1") is None
|
|
|
|
def test_entity_name_uppercased(self):
|
|
attrs = ['"entity"', "alice", "person", "desc"]
|
|
result = handle_single_entity_extraction(attrs, "c1")
|
|
assert result["entity_name"] == "ALICE"
|
|
assert result["entity_type"] == "PERSON"
|
|
|
|
|
|
class TestHandleSingleRelationshipExtraction:
|
|
"""Tests for handle_single_relationship_extraction function."""
|
|
|
|
def test_valid_relationship(self):
|
|
attrs = ['"relationship"', "Alice", "Bob", "friends with", "friendship", "2.0"]
|
|
result = handle_single_relationship_extraction(attrs, "chunk1")
|
|
assert result is not None
|
|
assert result["src_id"] == "ALICE"
|
|
assert result["tgt_id"] == "BOB"
|
|
assert result["weight"] == 2.0
|
|
assert result["description"] == "friends with"
|
|
assert result["keywords"] == "friendship"
|
|
assert result["source_id"] == "chunk1"
|
|
assert "created_at" in result["metadata"]
|
|
|
|
def test_not_relationship_type(self):
|
|
attrs = ['"entity"', "A", "B", "desc", "kw"]
|
|
assert handle_single_relationship_extraction(attrs, "c1") is None
|
|
|
|
def test_too_few_attributes(self):
|
|
attrs = ['"relationship"', "A", "B", "desc"]
|
|
assert handle_single_relationship_extraction(attrs, "c1") is None
|
|
|
|
def test_non_float_weight_defaults_to_one(self):
|
|
attrs = ['"relationship"', "A", "B", "desc", "kw", "not_a_number"]
|
|
result = handle_single_relationship_extraction(attrs, "c1")
|
|
assert result["weight"] == 1.0
|
|
|
|
def test_source_target_sorted(self):
|
|
attrs = ['"relationship"', "Zebra", "Apple", "desc", "kw", "1.0"]
|
|
result = handle_single_relationship_extraction(attrs, "c1")
|
|
assert result["src_id"] == "APPLE"
|
|
assert result["tgt_id"] == "ZEBRA"
|
|
|
|
|
|
class TestPackUserAssToOpenaiMessages:
|
|
"""Tests for pack_user_ass_to_openai_messages function."""
|
|
|
|
def test_single_message(self):
|
|
result = pack_user_ass_to_openai_messages("hello")
|
|
assert result == [{"role": "user", "content": "hello"}]
|
|
|
|
def test_alternating_roles(self):
|
|
result = pack_user_ass_to_openai_messages("q1", "a1", "q2")
|
|
assert result == [
|
|
{"role": "user", "content": "q1"},
|
|
{"role": "assistant", "content": "a1"},
|
|
{"role": "user", "content": "q2"},
|
|
]
|
|
|
|
def test_empty(self):
|
|
result = pack_user_ass_to_openai_messages()
|
|
assert result == []
|
|
|
|
|
|
class TestSplitStringByMultiMarkers:
|
|
"""Tests for split_string_by_multi_markers function."""
|
|
|
|
def test_single_marker(self):
|
|
result = split_string_by_multi_markers("a|b|c", ["|"])
|
|
assert result == ["a", "b", "c"]
|
|
|
|
def test_multiple_markers(self):
|
|
result = split_string_by_multi_markers("a|b;c", ["|", ";"])
|
|
assert result == ["a", "b", "c"]
|
|
|
|
def test_no_markers(self):
|
|
result = split_string_by_multi_markers("abc", [])
|
|
assert result == ["abc"]
|
|
|
|
def test_strips_whitespace(self):
|
|
result = split_string_by_multi_markers("a | b | c", ["|"])
|
|
assert result == ["a", "b", "c"]
|
|
|
|
def test_empty_segments_removed(self):
|
|
result = split_string_by_multi_markers("a||b", ["|"])
|
|
assert result == ["a", "b"]
|
|
|
|
def test_regex_special_chars_escaped(self):
|
|
result = split_string_by_multi_markers("a.b.c", ["."])
|
|
assert result == ["a", "b", "c"]
|
|
|
|
|
|
class TestGraphMerge:
|
|
"""Tests for graph_merge function."""
|
|
|
|
def _make_node(self, description="desc", source_id=None):
|
|
return {"description": description, "source_id": source_id or ["s1"]}
|
|
|
|
def _make_edge(self, weight=1.0, description="edge", keywords=None, source_id=None):
|
|
return {
|
|
"weight": weight,
|
|
"description": description,
|
|
"keywords": keywords or [],
|
|
"source_id": source_id or ["s1"],
|
|
}
|
|
|
|
def test_merge_disjoint_graphs(self):
|
|
g1 = nx.Graph()
|
|
g1.add_node("A", **self._make_node("A desc"))
|
|
g1.graph["source_id"] = ["doc1"]
|
|
|
|
g2 = nx.Graph()
|
|
g2.add_node("B", **self._make_node("B desc"))
|
|
g2.graph["source_id"] = ["doc2"]
|
|
|
|
change = GraphChange()
|
|
result = graph_merge(g1, g2, change)
|
|
|
|
assert result.has_node("A")
|
|
assert result.has_node("B")
|
|
assert "B" in change.added_updated_nodes
|
|
assert result.graph["source_id"] == ["doc1", "doc2"]
|
|
|
|
def test_merge_overlapping_nodes(self):
|
|
g1 = nx.Graph()
|
|
g1.add_node("A", description="first", source_id=["s1"])
|
|
g1.graph["source_id"] = ["doc1"]
|
|
|
|
g2 = nx.Graph()
|
|
g2.add_node("A", description="second", source_id=["s2"])
|
|
g2.graph["source_id"] = ["doc2"]
|
|
|
|
change = GraphChange()
|
|
graph_merge(g1, g2, change)
|
|
|
|
assert f"first{GRAPH_FIELD_SEP}second" == g1.nodes["A"]["description"]
|
|
assert g1.nodes["A"]["source_id"] == ["s1", "s2"]
|
|
|
|
def test_merge_overlapping_edges(self):
|
|
g1 = nx.Graph()
|
|
g1.add_node("A", **self._make_node())
|
|
g1.add_node("B", **self._make_node())
|
|
g1.add_edge("A", "B", **self._make_edge(weight=1.0, description="e1", keywords=["k1"], source_id=["s1"]))
|
|
g1.graph["source_id"] = ["doc1"]
|
|
|
|
g2 = nx.Graph()
|
|
g2.add_node("A", **self._make_node())
|
|
g2.add_node("B", **self._make_node())
|
|
g2.add_edge("A", "B", **self._make_edge(weight=2.0, description="e2", keywords=["k2"], source_id=["s2"]))
|
|
g2.graph["source_id"] = ["doc2"]
|
|
|
|
change = GraphChange()
|
|
graph_merge(g1, g2, change)
|
|
|
|
edge = g1.get_edge_data("A", "B")
|
|
assert edge["weight"] == 3.0
|
|
assert f"e1{GRAPH_FIELD_SEP}e2" == edge["description"]
|
|
assert edge["keywords"] == ["k1", "k2"]
|
|
assert edge["source_id"] == ["s1", "s2"]
|
|
|
|
def test_merge_tracks_changes(self):
|
|
g1 = nx.Graph()
|
|
g1.graph["source_id"] = []
|
|
|
|
g2 = nx.Graph()
|
|
g2.add_node("X", **self._make_node())
|
|
g2.add_node("Y", **self._make_node())
|
|
g2.add_edge("X", "Y", **self._make_edge())
|
|
g2.graph["source_id"] = ["doc1"]
|
|
|
|
change = GraphChange()
|
|
graph_merge(g1, g2, change)
|
|
|
|
assert {"X", "Y"} == change.added_updated_nodes
|
|
assert {("X", "Y")} == change.added_updated_edges
|
|
|
|
def test_merge_sets_rank(self):
|
|
g1 = nx.Graph()
|
|
g1.graph["source_id"] = []
|
|
|
|
g2 = nx.Graph()
|
|
g2.add_node("A", **self._make_node())
|
|
g2.add_node("B", **self._make_node())
|
|
g2.add_edge("A", "B", **self._make_edge())
|
|
g2.graph["source_id"] = ["doc1"]
|
|
|
|
change = GraphChange()
|
|
graph_merge(g1, g2, change)
|
|
|
|
assert g1.nodes["A"]["rank"] == 1
|
|
assert g1.nodes["B"]["rank"] == 1
|
|
|
|
|
|
class TestTidyGraph:
|
|
"""Tests for tidy_graph function."""
|
|
|
|
def test_removes_nodes_missing_attributes(self):
|
|
g = nx.Graph()
|
|
g.add_node("good", description="d", source_id="s")
|
|
g.add_node("bad")
|
|
messages = []
|
|
tidy_graph(g, lambda msg: messages.append(msg))
|
|
assert g.has_node("good")
|
|
assert not g.has_node("bad")
|
|
assert len(messages) == 1
|
|
|
|
def test_removes_edges_missing_attributes(self):
|
|
g = nx.Graph()
|
|
g.add_node("A", description="d", source_id="s")
|
|
g.add_node("B", description="d", source_id="s")
|
|
g.add_edge("A", "B")
|
|
messages = []
|
|
tidy_graph(g, lambda msg: messages.append(msg))
|
|
assert not g.has_edge("A", "B")
|
|
|
|
def test_adds_keywords_to_edges_without_it(self):
|
|
g = nx.Graph()
|
|
g.add_node("A", description="d", source_id="s")
|
|
g.add_node("B", description="d", source_id="s")
|
|
g.add_edge("A", "B", description="d", source_id="s")
|
|
tidy_graph(g, None)
|
|
assert g.edges["A", "B"]["keywords"] == []
|
|
|
|
def test_skip_attribute_check(self):
|
|
g = nx.Graph()
|
|
g.add_node("no_attrs")
|
|
g.add_edge("no_attrs", "no_attrs")
|
|
tidy_graph(g, None, check_attribute=False)
|
|
assert g.has_node("no_attrs")
|
|
|
|
def test_none_callback_no_error(self):
|
|
g = nx.Graph()
|
|
g.add_node("bad")
|
|
tidy_graph(g, None)
|
|
assert not g.has_node("bad")
|
|
|
|
|
|
class TestIsContinuousSubsequence:
|
|
"""Tests for is_continuous_subsequence function."""
|
|
|
|
def test_basic_match(self):
|
|
assert is_continuous_subsequence(("A", "B"), ("A", "B", "C")) is True
|
|
|
|
def test_no_match(self):
|
|
assert is_continuous_subsequence(("A", "C"), ("A", "B", "C")) is False
|
|
|
|
def test_at_end(self):
|
|
assert is_continuous_subsequence(("B", "C"), ("A", "B", "C")) is True
|
|
|
|
def test_single_element_sequence(self):
|
|
assert is_continuous_subsequence(("A", "B"), ("A",)) is False
|
|
|
|
|
|
class TestMergeTuples:
|
|
"""Tests for merge_tuples function."""
|
|
|
|
def test_basic_merge(self):
|
|
list1 = [("A", "B")]
|
|
list2 = [("B", "C")]
|
|
result = merge_tuples(list1, list2)
|
|
assert ("A", "B", "C") in result
|
|
|
|
def test_no_merge_possible(self):
|
|
list1 = [("A", "B")]
|
|
list2 = [("C", "D")]
|
|
result = merge_tuples(list1, list2)
|
|
assert ("A", "B") in result
|
|
|
|
def test_self_loop_kept(self):
|
|
list1 = [("A", "B", "A")]
|
|
list2 = []
|
|
result = merge_tuples(list1, list2)
|
|
assert ("A", "B", "A") in result
|
|
|
|
def test_empty_lists(self):
|
|
assert merge_tuples([], []) == []
|
|
|
|
|
|
class TestFlatUniqList:
|
|
"""Tests for flat_uniq_list function."""
|
|
|
|
def test_flat_lists(self):
|
|
arr = [{"k": [1, 2]}, {"k": [2, 3]}]
|
|
result = flat_uniq_list(arr, "k")
|
|
assert set(result) == {1, 2, 3}
|
|
|
|
def test_scalar_values(self):
|
|
arr = [{"k": "a"}, {"k": "b"}, {"k": "a"}]
|
|
result = flat_uniq_list(arr, "k")
|
|
assert set(result) == {"a", "b"}
|
|
|
|
def test_empty_list(self):
|
|
assert flat_uniq_list([], "k") == []
|
|
|
|
def test_mixed_list_and_scalar(self):
|
|
arr = [{"k": [1, 2]}, {"k": 3}]
|
|
result = flat_uniq_list(arr, "k")
|
|
assert set(result) == {1, 2, 3}
|